arm中常用的乘累加操作。

mla

fmla(32)

  • 指令集 >= armv7
  • neon指令 float32x4_t vfmaq_f32(float32x4_t a, float32x4_t b, float32x4_t c)
  • 计算实现
float32_t src1[4];
float32_t src2[4];
float32_t dst[4];

for (int i = 0; i < 4; i++) {
    dst[i] += src1[i] * src2[i];
}

fmla(16)

  • 指令集 >= armv8a
  • neon指令 float16x8_t vfmaq_f16(float16x8_t a, float16x8_t b, float16x8_t c)
  • 计算实现
float16_t src1[8];
float16_t src2[8];
float16_t dst[8];

for (int i = 0; i < 8; i++) {
    dst[i] += src1[i] * src2[i];
}

dot

sdot

  • 指令集 >= armv82
  • neon指令 int32x4_t vdotq_s32(int32x4_t r, int8x16_t a, int8x16_t b)
  • 计算实现
int8_t src1[16];
int8_t src2[16];
int32_t dst[4];

for (int i = 0; i < 4; i++) {
    for (int j = 0; j < 4; j++) {
        dst[i] += src1[i * 4 + j] * src2[i * 4 + j];
    }
}

usdot

  • 指令集 >= armv82
  • neon指令 int32x4_t vusdotq_s32(int32x4_t r, uint8x16_t a, int8x16_t b)
  • 计算实现
uint8_t src1[16];
int8_t src2[16];
int32_t dst[4];

for (int i = 0; i < 4; i++) {
    for (int j = 0; j < 4; j++) {
        dst[i] += src1[i * 4 + j] * src2[i * 4 + j];
    }
}

mmla

smmla

  • 指令集 >= armv86
  • neon指令 int32x4_t vmmlaq_s32(int32x4_t r, int8x16_t a, int8x16_t b)
  • 计算实现
int8_t src1[16];
int8_t src2[16];
int32_t dst[4];

for (int i = 0; i < 2; i++) {
    for (int j = 0; j < 2; j++) {
        for (int k = 0; k < 8; k++) {
            dst[i * 2 + j] += src1[i * 8 + k] * src2[j * 8 + k];
        }
    }
}

usmmla

  • 指令集 >= armv86
  • neon指令 int32x4_t vusmmlaq_s32(int32x4_t r, uint8x16_t a, int8x16_t b)
  • 计算实现
uint8_t src1[16];
int8_t src2[16];
int32_t dst[4];

for (int i = 0; i < 2; i++) {
    for (int j = 0; j < 2; j++) {
        for (int k = 0; k < 8; k++) {
            dst[i * 2 + j] += src1[i * 8 + k] * src2[j * 8 + k];
        }
    }
}

bfmmla

  • 指令集 >= armv86
  • neon指令 float32x4_t vbfmmlaq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b)
  • 计算实现
bfloat16_t src1[8];
bfloat16_t src2[8];
float32_t dst[4];

for (int i = 0; i < 2; i++) {
    for (int j = 0; j < 2; j++) {
        for (int k = 0; k < 4; k++) {
            dst[i * 2 + j] += src1[i * 4 + k] * src2[j * 4 + k];
        }
    }
}