Skip to content

Instantly share code, notes, and snippets.

@bjourne
Created September 28, 2022 16:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bjourne/ce1e189b926d3aa3ff8772e9f5252cc2 to your computer and use it in GitHub Desktop.
Save bjourne/ce1e189b926d3aa3ff8772e9f5252cc2 to your computer and use it in GitHub Desktop.
// Notes:
//
// * unsigned int vs. int: makes a small difference for clang but
// probably not for gcc.
// * best tiling appears to be 256x256x256.
//
// 12.31 for two 8192 matrices
//
//
#include <assert.h>
#include <math.h>
#include <pthread.h>
#include <stdbool.h>
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <xmmintrin.h>
#define A_ROWS 101
#define A_COLS 64
#define B_ROWS A_COLS
#define B_COLS 64
#define C_ROWS A_ROWS
#define C_COLS B_COLS
#define A_N_BYTES (A_ROWS * A_COLS * sizeof(float))
#define B_N_BYTES (B_ROWS * B_COLS * sizeof(float))
#define C_N_BYTES (C_ROWS * C_COLS * sizeof(float))
#ifndef TILE_I
#define TILE_I 32
#endif
#ifndef TILE_J
#define TILE_J 32
#endif
#ifndef TILE_K
#define TILE_K 32
#endif
#ifndef N_THREADS
#define N_THREADS 1
#endif
#define SIMD_HEIGHT 2
#define SIMD_WIDTH 16
#define MIN(a, b) ((a > b) ? (b) : (a))
typedef unsigned int uint_t;
void
mul_slow(float * restrict A,
float * restrict B,
float * restrict C,
uint_t a_rows, uint_t a_cols,
uint_t b_rows, uint_t b_cols) {
for (uint_t i = 0; i < a_rows; i++) {
for (uint_t j = 0; j < b_cols; j++) {
float v = 0;
for (uint_t k = 0; k < b_rows; k++) {
v += A[a_cols * i + k] * B[b_cols * k + j];
}
C[b_cols * i + j] = v;
}
}
}
void
print_mat(float *M, int size) {
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
printf("%3.0f ", M[i * size + j]);
}
printf("\n");
}
printf("\n");
}
static void
mul_fast_tile_16x2(
uint_t i0, uint_t i1,
uint_t j0, uint_t j1,
uint_t k0, uint_t k1,
float * restrict Apt,
float * restrict Bpt,
float * restrict C,
uint_t a_rows, uint_t a_cols,
uint_t b_rows, uint_t b_cols
) {
for (uint_t i = i0; i < i1; i += SIMD_HEIGHT) {
float * restrict Bptr = Bpt;
float * restrict Cptr0 = &C[b_cols * (i + 0) + j0];
float * restrict Cptr1 = &C[b_cols * (i + 1) + j0];
for (uint_t j = j0; j < j1; j += SIMD_WIDTH) {
__m128 acc00 = _mm_load_ps(Cptr0 + 0);
__m128 acc01 = _mm_load_ps(Cptr0 + 4);
__m128 acc02 = _mm_load_ps(Cptr0 + 8);
__m128 acc03 = _mm_load_ps(Cptr0 + 12);
__m128 acc10 = _mm_load_ps(Cptr1 + 0);
__m128 acc11 = _mm_load_ps(Cptr1 + 4);
__m128 acc12 = _mm_load_ps(Cptr1 + 8);
__m128 acc13 = _mm_load_ps(Cptr1 + 12);
float * restrict Aptr = &Apt[a_cols * i + SIMD_HEIGHT * k0];
for (uint_t k = k0; k < k1; k++) {
__m128 a0 = _mm_set1_ps(*Aptr++);
__m128 a1 = _mm_set1_ps(*Aptr++);
__m128 b0 = _mm_load_ps(Bptr + 0);
__m128 b1 = _mm_load_ps(Bptr + 4);
__m128 b2 = _mm_load_ps(Bptr + 8);
__m128 b3 = _mm_load_ps(Bptr + 12);
Bptr += SIMD_WIDTH;
acc00 = _mm_add_ps(acc00, _mm_mul_ps(a0, b0));
acc01 = _mm_add_ps(acc01, _mm_mul_ps(a0, b1));
acc02 = _mm_add_ps(acc02, _mm_mul_ps(a0, b2));
acc03 = _mm_add_ps(acc03, _mm_mul_ps(a0, b3));
acc10 = _mm_add_ps(acc10, _mm_mul_ps(a1, b0));
acc11 = _mm_add_ps(acc11, _mm_mul_ps(a1, b1));
acc12 = _mm_add_ps(acc12, _mm_mul_ps(a1, b2));
acc13 = _mm_add_ps(acc13, _mm_mul_ps(a1, b3));
}
_mm_store_ps(Cptr0 + 0, acc00);
_mm_store_ps(Cptr0 + 4, acc01);
_mm_store_ps(Cptr0 + 8, acc02);
_mm_store_ps(Cptr0 + 12, acc03);
_mm_store_ps(Cptr1 + 0, acc10);
_mm_store_ps(Cptr1 + 4, acc11);
_mm_store_ps(Cptr1 + 8, acc12);
_mm_store_ps(Cptr1 + 12, acc13);
Cptr0 += SIMD_WIDTH;
Cptr1 += SIMD_WIDTH;
}
}
}
typedef struct {
float *A, *Bp, *C;
uint_t start_i, end_i;
uint_t a_rows, a_cols;
uint_t b_rows, b_cols;
} mul_job_t;
static void *
mul_thread(void *arg) {
mul_job_t *job = (mul_job_t *)arg;
float *A = job->A;
float *Bp = job->Bp;
float *C = job->C;
uint_t start_i = job->start_i;
uint_t end_i = job->end_i;
uint_t a_rows = job->a_rows;
uint_t a_cols = job->a_cols;
uint_t b_rows = job->b_rows;
uint_t b_cols = job->b_cols;
for (uint_t i = start_i; i < end_i; i += TILE_I) {
uint_t imax = MIN(i + TILE_I, a_rows);
for (uint_t j = 0; j < b_cols; j += TILE_J) {
uint_t jmax = MIN(j + TILE_J, b_cols);
for (uint_t k = 0; k < a_cols; k += TILE_K) {
uint_t kmax = MIN(k + TILE_K, a_cols);
float *Bptr = &Bp[k * b_cols + j * TILE_K];
mul_fast_tile_16x2(i, imax, j, jmax, k, kmax,
A, Bptr, C,
a_rows, a_cols,
b_rows, b_cols);
}
}
}
return 0;
}
/* static float *Abuf = NULL; */
/* static float *Bbuf = NULL; */
static void
mul_fast(float * restrict A,
float * restrict B,
float * restrict C,
uint_t a_rows, uint_t a_cols,
uint_t b_rows, uint_t b_cols) {
assert(TILE_I % SIMD_HEIGHT == 0);
assert(TILE_J % SIMD_WIDTH == 0);
float *Bbuf = malloc(sizeof(float) * b_rows * b_cols);
float *Bptr = Bbuf;
for (uint_t k = 0; k < b_rows; k += TILE_K) {
for (uint_t j = 0; j < b_cols; j += TILE_J) {
for (uint_t y = 0; y < TILE_J; y += SIMD_WIDTH) {
for (uint_t x = 0; x < TILE_K; x++) {
uint_t row = k + x;
uint_t col = j + y;
for (uint_t o = 0; o < SIMD_WIDTH; o++) {
*Bptr++ = B[row * b_cols + col + o];
}
}
}
}
}
assert(Bptr - Bbuf == b_rows * b_cols);
int a_padded_rows = ceil((float)a_rows / (float)SIMD_HEIGHT) * SIMD_HEIGHT;
float *Abuf = malloc(a_padded_rows * a_cols * sizeof(float));
float *Aptr = Abuf;
for (int i = 0; i < a_padded_rows; i += SIMD_HEIGHT) {
for (int j = 0; j < a_cols; j++) {
for (int k = i; k < i + SIMD_HEIGHT; k++) {
if (k < a_rows && j < a_cols) {
*Aptr++ = A[a_cols * k + j];
} else {
*Aptr++ = 0.0;
}
}
}
}
assert(Aptr - Abuf == a_padded_rows * a_cols);
pthread_t threads[N_THREADS];
mul_job_t jobs[N_THREADS];
int n_i_tiles = (int)ceil((float)a_rows / (float)TILE_I);
int tiles_per_thread = (int)ceil((float)n_i_tiles / (float)N_THREADS);
for (int i = 0; i < N_THREADS; i++) {
int start = TILE_I * i * tiles_per_thread;
int end = MIN(TILE_I * (i + 1) * tiles_per_thread, a_rows);
jobs[i] = (mul_job_t){
Abuf, Bbuf, C, start, end,
a_rows, a_cols,
b_rows, b_cols
};
pthread_create(&threads[i], NULL, mul_thread, &jobs[i]);
}
for (int i = 0; i < N_THREADS; i++) {
pthread_join(threads[i], NULL);
}
free(Abuf);
free(Bbuf);
}
int
main(int argc, char *argv[]) {
float *A = malloc(A_N_BYTES);
float *B = malloc(B_N_BYTES);
float *C = calloc(C_N_BYTES, 1);
float *c_ref = calloc(C_N_BYTES, 1);
/* Abuf = malloc(A_N_BYTES); */
/* Bbuf = malloc(B_N_BYTES); */
for (int i = 0; i < A_ROWS * A_COLS; i++) {
A[i] = (rand() / (float)RAND_MAX) * 5;
}
for (int i = 0; i < B_ROWS * B_COLS; i++) {
B[i] = (rand() / (float)RAND_MAX) * 5;
}
mul_slow(A, B, c_ref,
A_ROWS, A_COLS, B_ROWS, B_COLS);
struct timespec begin, end;
clock_gettime(CLOCK_MONOTONIC_RAW, &begin);
mul_fast(A, B, C,
A_ROWS, A_COLS, B_ROWS, B_COLS);
clock_gettime(CLOCK_MONOTONIC_RAW, &end);
double delta = (end.tv_nsec - begin.tv_nsec) / 1000000000.0 +
(end.tv_sec - begin.tv_sec);
float gflops = (long)A_ROWS * (long)A_COLS * (long)B_COLS
/ (delta * 1000.0 * 1000.0 * 1000.0);
printf("[%4d,%4d] * [%4d,%4d] = [%4d, %4d] %8d %6d %6d %6d %6.2f %7.2f\n",
A_ROWS, A_COLS, B_ROWS, B_COLS, C_ROWS, C_COLS,
N_THREADS, TILE_I, TILE_J, TILE_K, delta, gflops);
for (int i = 0; i < C_ROWS; i++) {
for (int j = 0; j < C_COLS; j++) {
float v = C[C_COLS * i + j];
float v2 = c_ref[C_COLS * i + j];
float diff = fabs(v - v2);
if (diff > 0.1) {
printf("%d %d, %6.2f %6.2f\n", i, j, v, v2);
assert(false);
}
}
}
free(A);
free(B);
free(C);
free(c_ref);
/* free(Abuf); */
/* free(Bbuf); */
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment