Skip to content

Instantly share code, notes, and snippets.

@jonahwilliams
Created November 4, 2024 06:08
Show Gist options
  • Save jonahwilliams/cc79f92afd2382d8acd48a2632dc8903 to your computer and use it in GitHub Desktop.
Save jonahwilliams/cc79f92afd2382d8acd48a2632dc8903 to your computer and use it in GitHub Desktop.
#ifdef __ARM_NEON
// A * B * C * D
static void mat_multiply_4x4_neon(const float32_t* A,
const float32_t* B,
const float32_t* C,
float32_t* D) {
// these are the columns A
float32x4_t A0;
float32x4_t A1;
float32x4_t A2;
float32x4_t A3;
// these are the columns B
float32x4_t B0;
float32x4_t B1;
float32x4_t B2;
float32x4_t B3;
// these are the columns T
float32x4_t T0;
float32x4_t T1;
float32x4_t T2;
float32x4_t T3;
A0 = vld1q_f32(A);
A1 = vld1q_f32(A + 4);
A2 = vld1q_f32(A + 8);
A3 = vld1q_f32(A + 12);
// Zero accumulators for C values
T0 = vmovq_n_f32(0);
T1 = vmovq_n_f32(0);
T2 = vmovq_n_f32(0);
T3 = vmovq_n_f32(0);
// Multiply accumulate in 4x1 blocks, i.e. each column in C
B0 = vld1q_f32(B);
B1 = vld1q_f32(B + 4);
B2 = vld1q_f32(B + 8);
B3 = vld1q_f32(B + 12);
T0 = vfmaq_laneq_f32(T0, A0, B0, 0);
T0 = vfmaq_laneq_f32(T0, A1, B0, 1);
T0 = vfmaq_laneq_f32(T0, A2, B0, 2);
T0 = vfmaq_laneq_f32(T0, A3, B0, 3);
T1 = vfmaq_laneq_f32(T1, A0, B1, 0);
T1 = vfmaq_laneq_f32(T1, A1, B1, 1);
T1 = vfmaq_laneq_f32(T1, A2, B1, 2);
T1 = vfmaq_laneq_f32(T1, A3, B1, 3);
T2 = vfmaq_laneq_f32(T2, A0, B2, 0);
T2 = vfmaq_laneq_f32(T2, A1, B2, 1);
T2 = vfmaq_laneq_f32(T2, A2, B2, 2);
T2 = vfmaq_laneq_f32(T2, A3, B2, 3);
T3 = vfmaq_laneq_f32(T3, A0, B3, 0);
T3 = vfmaq_laneq_f32(T3, A1, B3, 1);
T3 = vfmaq_laneq_f32(T3, A2, B3, 2);
T3 = vfmaq_laneq_f32(T3, A3, B3, 3);
// Now re-populate A with Matrix C and multiply with T
// And into B
A0 = vld1q_f32(C);
A1 = vld1q_f32(C + 4);
A2 = vld1q_f32(C + 8);
A3 = vld1q_f32(C + 12);
// Zero accumulators
B0 = vmovq_n_f32(0);
B1 = vmovq_n_f32(0);
B2 = vmovq_n_f32(0);
B3 = vmovq_n_f32(0);
B0 = vfmaq_laneq_f32(B0, T0, A0, 0);
B0 = vfmaq_laneq_f32(B0, T1, A0, 1);
B0 = vfmaq_laneq_f32(B0, T2, A0, 2);
B0 = vfmaq_laneq_f32(B0, T3, A0, 3);
B1 = vfmaq_laneq_f32(B1, T0, A1, 0);
B1 = vfmaq_laneq_f32(B1, T1, A1, 1);
B1 = vfmaq_laneq_f32(B1, T2, A1, 2);
B1 = vfmaq_laneq_f32(B1, T3, A1, 3);
B2 = vfmaq_laneq_f32(B2, T0, A2, 0);
B2 = vfmaq_laneq_f32(B2, T1, A2, 1);
B2 = vfmaq_laneq_f32(B2, T2, A2, 2);
B2 = vfmaq_laneq_f32(B2, T3, A2, 3);
B3 = vfmaq_laneq_f32(B3, T0, A3, 0);
B3 = vfmaq_laneq_f32(B3, T1, A3, 1);
B3 = vfmaq_laneq_f32(B3, T2, A3, 2);
B3 = vfmaq_laneq_f32(B3, T3, A3, 3);
vst1q_f32(D, B0);
vst1q_f32(D + 4, B1);
vst1q_f32(D + 8, B2);
vst1q_f32(D + 12, B3);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment