Skip to content

Instantly share code, notes, and snippets.

@svaniksharma
Created May 7, 2023 05:42
Show Gist options
  • Save svaniksharma/9ad2fa148254ac74b02940326090b18d to your computer and use it in GitHub Desktop.
Save svaniksharma/9ad2fa148254ac74b02940326090b18d to your computer and use it in GitHub Desktop.
Optimizing Matrix Multiplication in Zig
const std = @import("std");
const mem = std.mem;
const meta = std.meta;
const math = std.math;
const Vector = meta.Vector;
const expect = std.testing.expect;
fn generateSquareMatrix(N: usize, allocator: mem.Allocator, gen_rand: bool) ![][]f64 {
var matrix: [][]f64 = undefined;
matrix = try allocator.alloc([]f64, N);
for (matrix) |*row| {
row.* = try allocator.alloc(f64, N);
std.mem.set(f64, row.*, 0);
}
if (gen_rand) {
var prng = std.rand.DefaultPrng.init(blk: {
var seed: u64 = undefined;
try std.os.getrandom(std.mem.asBytes(&seed));
break :blk seed;
});
const rand = prng.random();
for (0..N) |i| {
for (0..N) |j| {
matrix[i][j] = rand.float(f64);
}
}
}
return matrix;
}
fn naiveMatrixMultiply(C: anytype, A: anytype, B: anytype) void {
const N = A.len;
for (0..N) |i| {
for (0..N) |j| {
for (0..N) |k| {
C[i][j] += A[i][k] * B[k][j];
}
}
}
}
fn transposeMatrixMultiply(C: anytype, A: anytype, B: anytype) !void {
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
defer arena.deinit();
const allocator = arena.allocator();
var tmp: [][]f64 = try allocator.alloc([]f64, B.len);
for (tmp) |*row| {
row.* = try allocator.alloc(f64, B.len);
}
for (0..B.len) |i| {
for (0..B.len) |j| {
tmp[i][j] = B[j][i];
}
}
for (0..B.len) |i| {
for (0..B.len) |j| {
for (0..B.len) |k| {
C[i][j] += A[i][k] * tmp[j][k];
}
}
}
}
fn transposeSimdMatrixMultiply(C: anytype, A: anytype, B: anytype) !void {
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
defer arena.deinit();
const allocator = arena.allocator();
var tmp: [][]f64 = try allocator.alloc([]f64, B.len);
for (tmp) |*row| {
row.* = try allocator.alloc(f64, B.len);
}
for (0..B.len) |i| {
for (0..B.len) |j| {
tmp[i][j] = B[j][i];
}
}
const vec_len = 32;
for (0..B.len) |i| {
for (0..B.len) |j| {
var k: usize = 0;
while (k <= B.len - vec_len) : (k += vec_len) {
const u: @Vector(vec_len, f64) = A[i][k..][0..vec_len].*;
const v: @Vector(vec_len, f64) = tmp[j][k..][0..vec_len].*;
C[i][j] += @reduce(.Add, u * v);
}
while (k < B.len) : (k += 1) {
C[i][j] += A[i][k] * tmp[j][k];
}
}
}
}
fn unrollSimdMatrixMultiply(C: anytype, A: anytype, B: anytype) void {
const N = B.len;
const vec_len = 32;
for (C, A) |*C_row, *A_row| {
var j: u32 = 0;
while (j <= N - vec_len) : (j += vec_len) {
for (0..N) |k| {
const u: @Vector(vec_len, f64) = B[k][j..][0..vec_len].*;
const y: @Vector(vec_len, f64) = C_row.*[j..][0..vec_len].*;
const w: @Vector(vec_len, f64) = @splat(vec_len, A_row.*[k]);
const slice: [vec_len]f64 = (u * w) + y;
@memcpy(C_row.*[j .. j + vec_len], &slice);
}
}
while (j < N) : (j += 1) {
for (0..N) |k| {
C_row.*[j] += A_row.*[k] * B[k][j];
}
}
}
}
test "matrix correct" {
var A: [100][100]f64 = undefined;
var B: [100][100]f64 = undefined;
var C: [100][100]f64 = undefined;
var D: [100][100]f64 = undefined;
var prng = std.rand.DefaultPrng.init(blk: {
var seed: u64 = undefined;
try std.os.getrandom(std.mem.asBytes(&seed));
break :blk seed;
});
const rand = prng.random();
for (0..10) |i| {
for (0..10) |j| {
A[i][j] = rand.float(f64);
B[i][j] = rand.float(f64);
C[i][j] = 0;
D[i][j] = 0;
}
}
naiveMatrixMultiply(&C, &A, &B);
// try transposeMatrixMultiply(&D, &A, &B);
// try transposeSimdMatrixMultiply(&D, &A, &B);
// unrollSimdMatrixMultiply(&D, &A, &B);
for (0..100) |i| {
for (0..100) |j| {
// You shouldn't really compare floats this way, but I'm lazy and it's good enough
// for demonstration purposes. For real code, however, see https://floating-point-gui.de/errors/comparison/
// and some of the links on that page.
try expect(math.fabs(C[i][j] - D[i][j]) < 1e-10);
}
}
}
pub fn main() !void {
const N = 1000; // Matrix dimensions: change if you want
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
defer arena.deinit();
const allocator = arena.allocator();
const A = try generateSquareMatrix(N, allocator, true);
const B = try generateSquareMatrix(N, allocator, true);
const C = try generateSquareMatrix(N, allocator, false);
// Below we time the matrix multiplication code (you should use a profiling tool, tho)
// Uncomment the functions you want to run
var timer = try std.time.Timer.start();
naiveMatrixMultiply(C, A, B);
// try transposeMatrixMultiply(C, A, B);
// try transposeSimdMatrixMultiply(C, A, B);
// unrollSimdMatrixMultiply(C, A, B);
var end = @intToFloat(f64, timer.lap());
std.debug.print("{e} ms\n", .{end / 1000000});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment