arm中常用的乘累加操作。
mla
fmla(32)
- 指令集
>= armv7
- neon指令
float32x4_t vfmaq_f32(float32x4_t a, float32x4_t b, float32x4_t c)
- 计算实现
float32_t src1[4];
float32_t src2[4];
float32_t dst[4];
for (int i = 0; i < 4; i++) {
dst[i] += src1[i] * src2[i];
}
fmla(16)
- 指令集
>= armv8a
- neon指令
float16x8_t vfmaq_f16(float16x8_t a, float16x8_t b, float16x8_t c)
- 计算实现
float16_t src1[8];
float16_t src2[8];
float16_t dst[8];
for (int i = 0; i < 8; i++) {
dst[i] += src1[i] * src2[i];
}
dot
sdot
- 指令集
>= armv82
- neon指令
int32x4_t vdotq_s32(int32x4_t r, int8x16_t a, int8x16_t b)
- 计算实现
int8_t src1[16];
int8_t src2[16];
int32_t dst[4];
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
dst[i] += src1[i * 4 + j] * src2[i * 4 + j];
}
}
usdot
- 指令集
>= armv82
- neon指令
int32x4_t vusdotq_s32(int32x4_t r, uint8x16_t a, int8x16_t b)
- 计算实现
uint8_t src1[16];
int8_t src2[16];
int32_t dst[4];
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
dst[i] += src1[i * 4 + j] * src2[i * 4 + j];
}
}
mmla
smmla
- 指令集
>= armv86
- neon指令
int32x4_t vmmlaq_s32(int32x4_t r, int8x16_t a, int8x16_t b)
- 计算实现
int8_t src1[16];
int8_t src2[16];
int32_t dst[4];
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
for (int k = 0; k < 8; k++) {
dst[i * 2 + j] += src1[i * 8 + k] * src2[j * 8 + k];
}
}
}
usmmla
- 指令集
>= armv86
- neon指令
int32x4_t vusmmlaq_s32(int32x4_t r, uint8x16_t a, int8x16_t b)
- 计算实现
uint8_t src1[16];
int8_t src2[16];
int32_t dst[4];
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
for (int k = 0; k < 8; k++) {
dst[i * 2 + j] += src1[i * 8 + k] * src2[j * 8 + k];
}
}
}
bfmmla
- 指令集
>= armv86
- neon指令
float32x4_t vbfmmlaq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b)
- 计算实现
bfloat16_t src1[8];
bfloat16_t src2[8];
float32_t dst[4];
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
for (int k = 0; k < 4; k++) {
dst[i * 2 + j] += src1[i * 4 + k] * src2[j * 4 + k];
}
}
}