Created
May 7, 2023 05:42
-
-
Save svaniksharma/9ad2fa148254ac74b02940326090b18d to your computer and use it in GitHub Desktop.
Optimizing Matrix Multiplication in Zig
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const mem = std.mem; | |
const meta = std.meta; | |
const math = std.math; | |
const Vector = meta.Vector; | |
const expect = std.testing.expect; | |
fn generateSquareMatrix(N: usize, allocator: mem.Allocator, gen_rand: bool) ![][]f64 { | |
var matrix: [][]f64 = undefined; | |
matrix = try allocator.alloc([]f64, N); | |
for (matrix) |*row| { | |
row.* = try allocator.alloc(f64, N); | |
std.mem.set(f64, row.*, 0); | |
} | |
if (gen_rand) { | |
var prng = std.rand.DefaultPrng.init(blk: { | |
var seed: u64 = undefined; | |
try std.os.getrandom(std.mem.asBytes(&seed)); | |
break :blk seed; | |
}); | |
const rand = prng.random(); | |
for (0..N) |i| { | |
for (0..N) |j| { | |
matrix[i][j] = rand.float(f64); | |
} | |
} | |
} | |
return matrix; | |
} | |
fn naiveMatrixMultiply(C: anytype, A: anytype, B: anytype) void { | |
const N = A.len; | |
for (0..N) |i| { | |
for (0..N) |j| { | |
for (0..N) |k| { | |
C[i][j] += A[i][k] * B[k][j]; | |
} | |
} | |
} | |
} | |
fn transposeMatrixMultiply(C: anytype, A: anytype, B: anytype) !void { | |
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); | |
defer arena.deinit(); | |
const allocator = arena.allocator(); | |
var tmp: [][]f64 = try allocator.alloc([]f64, B.len); | |
for (tmp) |*row| { | |
row.* = try allocator.alloc(f64, B.len); | |
} | |
for (0..B.len) |i| { | |
for (0..B.len) |j| { | |
tmp[i][j] = B[j][i]; | |
} | |
} | |
for (0..B.len) |i| { | |
for (0..B.len) |j| { | |
for (0..B.len) |k| { | |
C[i][j] += A[i][k] * tmp[j][k]; | |
} | |
} | |
} | |
} | |
fn transposeSimdMatrixMultiply(C: anytype, A: anytype, B: anytype) !void { | |
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); | |
defer arena.deinit(); | |
const allocator = arena.allocator(); | |
var tmp: [][]f64 = try allocator.alloc([]f64, B.len); | |
for (tmp) |*row| { | |
row.* = try allocator.alloc(f64, B.len); | |
} | |
for (0..B.len) |i| { | |
for (0..B.len) |j| { | |
tmp[i][j] = B[j][i]; | |
} | |
} | |
const vec_len = 32; | |
for (0..B.len) |i| { | |
for (0..B.len) |j| { | |
var k: usize = 0; | |
while (k <= B.len - vec_len) : (k += vec_len) { | |
const u: @Vector(vec_len, f64) = A[i][k..][0..vec_len].*; | |
const v: @Vector(vec_len, f64) = tmp[j][k..][0..vec_len].*; | |
C[i][j] += @reduce(.Add, u * v); | |
} | |
while (k < B.len) : (k += 1) { | |
C[i][j] += A[i][k] * tmp[j][k]; | |
} | |
} | |
} | |
} | |
fn unrollSimdMatrixMultiply(C: anytype, A: anytype, B: anytype) void { | |
const N = B.len; | |
const vec_len = 32; | |
for (C, A) |*C_row, *A_row| { | |
var j: u32 = 0; | |
while (j <= N - vec_len) : (j += vec_len) { | |
for (0..N) |k| { | |
const u: @Vector(vec_len, f64) = B[k][j..][0..vec_len].*; | |
const y: @Vector(vec_len, f64) = C_row.*[j..][0..vec_len].*; | |
const w: @Vector(vec_len, f64) = @splat(vec_len, A_row.*[k]); | |
const slice: [vec_len]f64 = (u * w) + y; | |
@memcpy(C_row.*[j .. j + vec_len], &slice); | |
} | |
} | |
while (j < N) : (j += 1) { | |
for (0..N) |k| { | |
C_row.*[j] += A_row.*[k] * B[k][j]; | |
} | |
} | |
} | |
} | |
test "matrix correct" { | |
var A: [100][100]f64 = undefined; | |
var B: [100][100]f64 = undefined; | |
var C: [100][100]f64 = undefined; | |
var D: [100][100]f64 = undefined; | |
var prng = std.rand.DefaultPrng.init(blk: { | |
var seed: u64 = undefined; | |
try std.os.getrandom(std.mem.asBytes(&seed)); | |
break :blk seed; | |
}); | |
const rand = prng.random(); | |
for (0..10) |i| { | |
for (0..10) |j| { | |
A[i][j] = rand.float(f64); | |
B[i][j] = rand.float(f64); | |
C[i][j] = 0; | |
D[i][j] = 0; | |
} | |
} | |
naiveMatrixMultiply(&C, &A, &B); | |
// try transposeMatrixMultiply(&D, &A, &B); | |
// try transposeSimdMatrixMultiply(&D, &A, &B); | |
// unrollSimdMatrixMultiply(&D, &A, &B); | |
for (0..100) |i| { | |
for (0..100) |j| { | |
// You shouldn't really compare floats this way, but I'm lazy and it's good enough | |
// for demonstration purposes. For real code, however, see https://floating-point-gui.de/errors/comparison/ | |
// and some of the links on that page. | |
try expect(math.fabs(C[i][j] - D[i][j]) < 1e-10); | |
} | |
} | |
} | |
pub fn main() !void { | |
const N = 1000; // Matrix dimensions: change if you want | |
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); | |
defer arena.deinit(); | |
const allocator = arena.allocator(); | |
const A = try generateSquareMatrix(N, allocator, true); | |
const B = try generateSquareMatrix(N, allocator, true); | |
const C = try generateSquareMatrix(N, allocator, false); | |
// Below we time the matrix multiplication code (you should use a profiling tool, tho) | |
// Uncomment the functions you want to run | |
var timer = try std.time.Timer.start(); | |
naiveMatrixMultiply(C, A, B); | |
// try transposeMatrixMultiply(C, A, B); | |
// try transposeSimdMatrixMultiply(C, A, B); | |
// unrollSimdMatrixMultiply(C, A, B); | |
var end = @intToFloat(f64, timer.lap()); | |
std.debug.print("{e} ms\n", .{end / 1000000}); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment