Skip to content

Instantly share code, notes, and snippets.

@bjourne
Created January 4, 2024 13:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bjourne/c2d0db48b2e50aaadf884e4450c6aa50 to your computer and use it in GitHub Desktop.
Save bjourne/c2d0db48b2e50aaadf884e4450c6aa50 to your computer and use it in GitHub Desktop.
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <xmmintrin.h>
#include <mmintrin.h>
#include "matrix.h"
#include "z_order.h"
using namespace nda;
// Make it easier to read the generated assembly for these functions.
#define NOINLINE __attribute__((noinline))
// A textbook implementation of matrix multiplication. This is very simple,
// but it is slow, primarily because of poor locality of the loads of B. The
// reduction loop is innermost.
template <typename T>
NOINLINE void multiply_reduce_cols(const_matrix_ref<T> A, const_matrix_ref<T> B, matrix_ref<T> C) {
for (index_t i : C.i()) {
for (index_t j : C.j()) {
C(i, j) = 0;
for (index_t k : A.j()) {
C(i, j) += A(i, k) * B(k, j);
}
}
}
}
// This is similar to the above, but:
// - It additionally splits the reduction dimension k,
// - It traverses the io, jo loops in z order, to improve locality,
// - It prefetches in the inner loop.
// This version achieves ~90% of the theoretical peak performance of my AMD Ryzen 5800X.
template <typename T>
NOINLINE void multiply_reduce_tiles_z_order(const_matrix_ref<T> A, const_matrix_ref<T> B, matrix_ref<T> C) {
// Adjust this depending on the target architecture. For AVX2,
// vectors are 256-bit.
constexpr index_t vector_size = 32 / sizeof(T);
constexpr index_t cache_line_size = 64 / sizeof(T);
// We want the tiles to be as big as possible without spilling any
// of the accumulator registers to the stack.
constexpr index_t tile_rows = 4;
constexpr index_t tile_cols = vector_size * 3;
constexpr index_t tile_k = 256;
// TODO: It seems like z-ordering all of io, jo, ko should be best...
// But this seems better, even without the added convenience for initializing
// the output.
for (auto ko : split(A.j(), tile_k)) {
auto split_i = split<tile_rows>(C.i());
auto split_j = split<tile_cols>(C.j());
for_all_in_z_order(std::make_tuple(split_i, split_j), [&](auto io, auto jo) {
// Make a reference to this tile of the output.
auto C_ijo = C(io, jo);
// Define an accumulator buffer.
T buffer[tile_rows * tile_cols] = {0};
auto accumulator = make_array_ref(buffer, make_compact(C_ijo.shape()));
// Perform the matrix multiplication for this tile.
for (index_t k : ko) {
for (index_t i = 0; i < io.extent(); i += cache_line_size) {
_mm_prefetch(&A(io.min() + i, k + 8), _MM_HINT_T0);
}
for (index_t j = 0; j < jo.extent(); j += cache_line_size) {
_mm_prefetch(&B(k + 4, jo.min() + j), _MM_HINT_T0);
}
for (index_t i : io) {
for (index_t j : jo) {
accumulator(i, j) += A(i, k) * B(k, j);
}
}
}
// Add the accumulators for this iteration of ko to the output.
// Because we split the K dimension, we are doing this more than once per
// tile of output. To avoid adding to overlapping regions more than once
// (when `split<>` is applied to a dimension not divided by the split factor),
// we need to only initialize the result for the first iteration of ko.
if (ko.min() == A.j().min()) {
for (index_t i : io) {
for (index_t j : jo) {
C_ijo(i, j) = accumulator(i, j);
}
}
} else {
for (index_t i : io) {
for (index_t j : jo) {
C_ijo(i, j) += accumulator(i, j);
}
}
}
});
}
}
// 0.95s for NumPy
#define M 4096
#define K 4096
#define N 4096
int
main() {
srand(time(NULL));
matrix<double> A({M, K});
matrix<double> B({K, N});
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < K; j++) {
A(i, j) = rand() % 500;
}
}
for (size_t i = 0; i < K; i++) {
for (size_t j = 0; j < N; j++) {
B(i, j) = rand() % 500;
}
}
printf("Multiplying...\n");
matrix<double> C({M, N});
// 37 s for float, 28 s for double
//multiply_reduce_tiles_z_order<double>(A.cref(), B.cref(), C.ref());
multiply_reduce_cols<double>(A.cref(), B.cref(), C.ref());
// for (size_t i = 0; i < 5; i++) {
// for (size_t j = 0; j < 8; j++) {
// A(i, j) = rand() % 500;
// }
// }
//printf("%d\n", m(0, 3));
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment