Created
November 4, 2024 06:08
-
-
Save jonahwilliams/cc79f92afd2382d8acd48a2632dc8903 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifdef __ARM_NEON | |
// A * B * C * D | |
static void mat_multiply_4x4_neon(const float32_t* A, | |
const float32_t* B, | |
const float32_t* C, | |
float32_t* D) { | |
// these are the columns A | |
float32x4_t A0; | |
float32x4_t A1; | |
float32x4_t A2; | |
float32x4_t A3; | |
// these are the columns B | |
float32x4_t B0; | |
float32x4_t B1; | |
float32x4_t B2; | |
float32x4_t B3; | |
// these are the columns T | |
float32x4_t T0; | |
float32x4_t T1; | |
float32x4_t T2; | |
float32x4_t T3; | |
A0 = vld1q_f32(A); | |
A1 = vld1q_f32(A + 4); | |
A2 = vld1q_f32(A + 8); | |
A3 = vld1q_f32(A + 12); | |
// Zero accumulators for C values | |
T0 = vmovq_n_f32(0); | |
T1 = vmovq_n_f32(0); | |
T2 = vmovq_n_f32(0); | |
T3 = vmovq_n_f32(0); | |
// Multiply accumulate in 4x1 blocks, i.e. each column in C | |
B0 = vld1q_f32(B); | |
B1 = vld1q_f32(B + 4); | |
B2 = vld1q_f32(B + 8); | |
B3 = vld1q_f32(B + 12); | |
T0 = vfmaq_laneq_f32(T0, A0, B0, 0); | |
T0 = vfmaq_laneq_f32(T0, A1, B0, 1); | |
T0 = vfmaq_laneq_f32(T0, A2, B0, 2); | |
T0 = vfmaq_laneq_f32(T0, A3, B0, 3); | |
T1 = vfmaq_laneq_f32(T1, A0, B1, 0); | |
T1 = vfmaq_laneq_f32(T1, A1, B1, 1); | |
T1 = vfmaq_laneq_f32(T1, A2, B1, 2); | |
T1 = vfmaq_laneq_f32(T1, A3, B1, 3); | |
T2 = vfmaq_laneq_f32(T2, A0, B2, 0); | |
T2 = vfmaq_laneq_f32(T2, A1, B2, 1); | |
T2 = vfmaq_laneq_f32(T2, A2, B2, 2); | |
T2 = vfmaq_laneq_f32(T2, A3, B2, 3); | |
T3 = vfmaq_laneq_f32(T3, A0, B3, 0); | |
T3 = vfmaq_laneq_f32(T3, A1, B3, 1); | |
T3 = vfmaq_laneq_f32(T3, A2, B3, 2); | |
T3 = vfmaq_laneq_f32(T3, A3, B3, 3); | |
// Now re-populate A with Matrix C and multiply with T | |
// And into B | |
A0 = vld1q_f32(C); | |
A1 = vld1q_f32(C + 4); | |
A2 = vld1q_f32(C + 8); | |
A3 = vld1q_f32(C + 12); | |
// Zero accumulators | |
B0 = vmovq_n_f32(0); | |
B1 = vmovq_n_f32(0); | |
B2 = vmovq_n_f32(0); | |
B3 = vmovq_n_f32(0); | |
B0 = vfmaq_laneq_f32(B0, T0, A0, 0); | |
B0 = vfmaq_laneq_f32(B0, T1, A0, 1); | |
B0 = vfmaq_laneq_f32(B0, T2, A0, 2); | |
B0 = vfmaq_laneq_f32(B0, T3, A0, 3); | |
B1 = vfmaq_laneq_f32(B1, T0, A1, 0); | |
B1 = vfmaq_laneq_f32(B1, T1, A1, 1); | |
B1 = vfmaq_laneq_f32(B1, T2, A1, 2); | |
B1 = vfmaq_laneq_f32(B1, T3, A1, 3); | |
B2 = vfmaq_laneq_f32(B2, T0, A2, 0); | |
B2 = vfmaq_laneq_f32(B2, T1, A2, 1); | |
B2 = vfmaq_laneq_f32(B2, T2, A2, 2); | |
B2 = vfmaq_laneq_f32(B2, T3, A2, 3); | |
B3 = vfmaq_laneq_f32(B3, T0, A3, 0); | |
B3 = vfmaq_laneq_f32(B3, T1, A3, 1); | |
B3 = vfmaq_laneq_f32(B3, T2, A3, 2); | |
B3 = vfmaq_laneq_f32(B3, T3, A3, 3); | |
vst1q_f32(D, B0); | |
vst1q_f32(D + 4, B1); | |
vst1q_f32(D + 8, B2); | |
vst1q_f32(D + 12, B3); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment