Skip to content

Instantly share code, notes, and snippets.

@zeux
Last active January 10, 2024 02:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeux/5407cd53db4b483d3e25b0906e8d65ab to your computer and use it in GitHub Desktop.
Save zeux/5407cd53db4b483d3e25b0906e8d65ab to your computer and use it in GitHub Desktop.
Matrix-vector multiplication benchmark, targeting Apple M1/M2/M3 -- expecting clang to vectorize the loop in dotprod_fpN functions using half-precision multiply adds. Requires OpenMP (from homebrew)
// brew install libomp
// cc -o matbench matbench.c -O3 -ffast-math -Xclang -fopenmp -I/opt/homebrew/opt/libomp/include -L/opt/homebrew/opt/libomp/lib -lomp
// ./matbench
#include <assert.h>
#include <math.h>
#include <omp.h>
#include <stdio.h>
#include <time.h>
typedef _Float16 half;
static half fp82half(unsigned char v) {
union {
unsigned short u;
half f;
} u;
u.u = v << 8;
return u.f;
}
typedef half (*dotprod_t)(void* w, int n, int i, half* x);
static half dotprod_fp16(void* w, int n, int i, half* x) {
half* r = (half*)w + i * n;
half val = 0;
#pragma omp simd reduction(+ : val) simdlen(32)
for (int j = 0; j < n; j++) {
val += r[j] * x[j];
}
return val;
}
static half dotprod_fp8(void* w, int n, int i, half* x) {
char* r = (char*)w + i * n;
half val = 0;
#pragma omp simd reduction(+ : val) simdlen(32)
for (int j = 0; j < n; j++) {
val += fp82half(r[j]) * x[j];
}
return val;
}
static void matmul(half* xout, half* x, void* w, int n, int d, dotprod_t dotprod) {
// W (d,n) @ x (n,) -> xout (d,)
int i;
#pragma omp parallel for private(i)
for (i = 0; i < d; i++) {
float val = dotprod(w, n, i, x);
xout[i] = val;
}
}
long time_in_ms() {
// return time in milliseconds, for benchmarking the model speed
struct timespec time;
clock_gettime(CLOCK_REALTIME, &time);
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
}
int main() {
int N = 32768;
short* m16 = calloc(N * N, 2);
char* m8 = calloc(N * N, 1);
half* x = calloc(N, 2);
half* xo = calloc(N, 2);
int Tm = omp_get_max_threads();
int Ts[] = {1, 2, 3, 4, 6, 8, 10, 12, 16};
for (size_t Ti = 0; Ti < sizeof(Ts) / sizeof(Ts[0]); ++Ti) {
int T = Ts[Ti];
if (T > Tm) {
break;
}
omp_set_num_threads(T);
printf("OMP threads: %d\n", T);
// warmup
matmul(xo, x, m16, N, N, dotprod_fp16);
matmul(xo, x, m8, N, N, dotprod_fp8);
int reps = 10;
// fp16
long t0 = time_in_ms();
for (int i = 0; i < reps; ++i) {
matmul(xo, x, m16, N, N, dotprod_fp16);
}
// fp8
long t1 = time_in_ms();
for (int i = 0; i < reps; ++i) {
matmul(xo, x, m8, N, N, dotprod_fp8);
}
long t2 = time_in_ms();
printf("fp16: %.1f GB/s\n", ((double)N * N * 2 * reps / 1e9) / ((double)(t1 - t0) / 1000));
printf("fp8 : %.1f GB/s\n", ((double)N * N * reps / 1e9) / ((double)(t2 - t1) / 1000));
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment