Skip to content

Instantly share code, notes, and snippets.

@thoughtpolice
Created December 28, 2021 02:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thoughtpolice/c0c602a21f0ed7b5a144252d5123d53f to your computer and use it in GitHub Desktop.
Save thoughtpolice/c0c602a21f0ed7b5a144252d5123d53f to your computer and use it in GitHub Desktop.
matrix multiply with clang extended vector types
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#ifndef __wasm__
#define LLVM_MCA_BEGIN(l) __asm__ __volatile__("# LLVM-MCA-BEGIN " l);
#define LLVM_MCA_END(l) __asm__ __volatile__("# LLVM-MCA-END " l);
#else
#define LLVM_MCA_BEGIN(_)
#define LLVM_MCA_END(_)
#endif
typedef float float32x4 __attribute__((ext_vector_type(4), aligned(16)));
union fmat4x4 {
float m[4][4];
float32x4 row[4];
};
void
matmult_ref(
union fmat4x4 *__restrict__ out,
const union fmat4x4 *__restrict__ A,
const union fmat4x4 *__restrict__ B
) {
union fmat4x4 t __attribute__((aligned(16))); // write to temp
LLVM_MCA_BEGIN("matmult_ref");
for (int i=0; i < 4; i++)
for (int j=0; j < 4; j++)
t.m[i][j] = A->m[i][0] * B->m[0][j]
+ A->m[i][1] * B->m[1][j]
+ A->m[i][2] * B->m[2][j]
+ A->m[i][3] * B->m[3][j];
LLVM_MCA_END("matmult_ref");
*out = t;
}
// this implementation is based on https://gist.github.com/rygorous/4172889
// minus the AVX/SSE intrinsics
static inline float32x4
matmult_lincomb_single(
const float32x4 V,
const union fmat4x4 *M
) {
float32x4 r;
r = __builtin_shufflevector(V, V, 0, 0, 0, 0) * M->row[0];
r += __builtin_shufflevector(V, V, 1, 1, 1, 1) * M->row[1];
r += __builtin_shufflevector(V, V, 2, 2, 2, 2) * M->row[2];
r += __builtin_shufflevector(V, V, 3, 3, 3, 3) * M->row[3];
return r;
}
void
matmult_lincomb(
union fmat4x4 *__restrict__ out,
const union fmat4x4 *__restrict__ A,
const union fmat4x4 *__restrict__ B
) {
LLVM_MCA_BEGIN("matmult_lincomb");
float32x4 out0x = matmult_lincomb_single(A->row[0], B);
float32x4 out1x = matmult_lincomb_single(A->row[1], B);
float32x4 out2x = matmult_lincomb_single(A->row[2], B);
float32x4 out3x = matmult_lincomb_single(A->row[3], B);
LLVM_MCA_END("matmult_lincomb");
out->row[0] = out0x;
out->row[1] = out1x;
out->row[2] = out2x;
out->row[3] = out3x;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment