Skip to content

Instantly share code, notes, and snippets.

@travisstaloch
Last active January 4, 2024 07:54
Show Gist options
  • Save travisstaloch/8bb45b3b6a9502f7f83de466974755af to your computer and use it in GitHub Desktop.
Save travisstaloch/8bb45b3b6a9502f7f83de466974755af to your computer and use it in GitHub Desktop.
nice n-dimensional and generic matrix code
# generates valid zig multiply test case code which may used in matrix.zig
import numpy as np
np.random.seed(42)
# A = np.empty((2, 3, 2))
# A.fill(1)
# A = np.array([
# [
# [1,2,3],
# [4,5,6],
# ],
# [
# [7,8,9],
# [10,11,12],
# ],
# ])
# B = np.array([
# [
# [1,2],
# [3,4],
# [5,6],
# ],
# [
# [7,8],
# [9,10],
# [11,12],
# ],
# ])
# B = np.empty((2, 2, 3))
# B.fill(2)
A = np.random.randint(0, 10, size=(2, 2, 2, 3, 2))
B = np.random.randint(0, 10, size=(2, 2, 2, 2, 3))
# print("A:\n{}, shape={}\nB:\n{}, shape={}".format(A, A.shape, B, B.shape))
C = np.matmul(A, B)
# print("Product C:\n{}, shape={}".format(C, C.shape))
import io
def print_to_string(*args, **kwargs):
output = io.StringIO()
print(*args, file=output, **kwargs)
contents = output.getvalue()
output.close()
return contents
def mls(mat):
s = print_to_string(mat)
x = s.split('\n')
return '\n\\\\'.join(x)
def uwp(tup):
s = print_to_string(tup)
return s.replace("(", "").replace(")", "")
s = """
try testMul(u32, &.{{ {} }},
\\\\{}
, &.{{ {} }}
,
\\\\{}
,
&.{{ {} }},
\\\\{}
);
""".format(uwp(A.shape), mls(A), uwp(B.shape), mls(B), uwp(C.shape), mls(C))
print(s)
//! this lib only does multiplication so far. but can parse numpy output at
//! runtime and comptime and has lots nice helpers including asArray() which
//! allows multi indexing i.e `mat.asArray()[i][j]`. there are lots of working
//! tests for 1d, 2d, 3d, 4d and 5d matrices w/ different shapes.
const std = @import("std");
const Allocator = std.mem.Allocator;
pub fn Matrix(comptime T: type, comptime shape: []const usize) type {
return struct {
const Self = @This();
items: []T,
comptime shape: []const usize = shape,
pub const len = @reduce(.Mul, @as(
@Vector(shape.len, usize),
shape[0..shape.len].*,
));
pub const Array = ArrayFromDim(0);
pub fn init(allocator: Allocator) !Self {
return .{
.items = try allocator.alloc(T, len),
};
}
pub fn initFilled(allocator: Allocator, value: T) !Self {
var result = try init(allocator);
result.fill(value);
return result;
}
pub fn initArray(allocator: Allocator, array: Array) !Self {
const result = try init(allocator);
@memcpy(result.items, @as([*]const T, @ptrCast(&array)));
return result;
}
pub fn initConst(items: *const Array) Self {
return .{
.items = @as([*]T, @constCast(@ptrCast(items)))[0..len],
};
}
fn parseErr(comptime fmt: []const u8, args: anytype) noreturn {
if (@inComptime())
@compileError(std.fmt.comptimePrint(fmt, args))
else {
std.log.err(fmt, args);
@panic("parse error");
}
}
/// parse output from numpy at comptime
pub inline fn parse(
comptime input: []const u8,
) !Self {
comptime {
var items: [len]T = undefined;
return parseBuf(input, &items);
}
}
/// parse output from numpy at runtime into items buffer
pub fn parseBuf(
input: []const u8,
items: *[len]T,
) !Self {
const self = Self{ .items = items };
var i: usize = 0;
var depth: usize = 0;
var ptr = self.items.ptr;
while (i < input.len) : (i += 1) {
switch (input[i]) {
' ', '\n', '\t', '\r' => {},
'[' => depth += 1,
']' => depth -= 1,
'0'...'9', '-', '+' => {
if (depth != shape.len) unreachable;
const rbi = std.mem.indexOfScalarPos(u8, input, i, ']') orelse
parseErr(
"missing closing bracket at position {}",
.{i},
);
var it = std.mem.tokenizeScalar(u8, input[i..rbi], ' ');
while (it.next()) |nr| {
const n = switch (@typeInfo(T)) {
.Int => try std.fmt.parseInt(T, nr, 10),
.Float => try std.fmt.parseFloat(T, nr),
else => unreachable,
};
ptr[0] = n;
ptr += 1;
}
i = rbi;
depth -= 1;
},
else => parseErr(
"unexpected character '{c}' at position {}",
.{ input[i], i },
),
}
}
if (depth != 0) parseErr("eof and missing closing bracket", .{});
if (ptr != self.items.ptr + len)
parseErr("eof and either too many or not enough items", .{});
return self;
}
pub fn deinit(self: Self, allocator: Allocator) void {
if (@sizeOf(T) > 0) allocator.free(self.items);
}
pub fn fill(self: Self, value: T) void {
@memset(self.items, value);
}
pub inline fn asArray(self: Self) *Array {
return @ptrCast(self.items.ptr);
}
pub fn format(self: Self, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
_ = fmt;
_ = options;
for (self.asArray()) |x| {
try writer.print("{any}\n", .{x});
}
}
pub fn dump(self: Self, message: []const u8) void {
std.debug.print("--{s}:", .{message});
for (shape, 0..) |s, i| {
std.debug.print("{s}{}", .{ if (i != 0) "x" else "", s });
}
std.debug.print("--\n", .{});
std.debug.print("{}", .{self});
}
pub fn ArrayFromDim(comptime dim: usize) type {
const s = shape[dim..];
return switch (s[dim..].len) {
1 => [s[0]]T,
2 => [s[0]][s[1]]T,
3 => [s[0]][s[1]][s[2]]T,
4 => [s[0]][s[1]][s[2]][s[3]]T,
5 => [s[0]][s[1]][s[2]][s[3]][s[4]]T,
else => unreachable, // TODO
};
}
pub fn subMatrix(ptr: [*]T, comptime dim: usize) Matrix(T, shape[dim..]) {
const M = Matrix(T, shape[dim..]);
return M{ .items = ptr[0..M.len] };
}
pub const mul = switch (shape.len) {
1 => mul1d,
2 => mul2d,
else => mulNd,
};
fn mul1d(a: Self, b: anytype, dst: anytype) void {
comptime std.debug.assert(a.shape.len == 1);
comptime std.debug.assert(b.shape.len == 2);
comptime std.debug.assert(a.shape[0] == b.shape[1]);
comptime std.debug.assert(b.shape[0] == 1);
dst.items[0] = 0;
for (0..a.shape[0]) |i| {
dst.items[0] += a.items[i] * b.asArray()[0][i];
}
}
fn mul2d(a: Self, b: anytype, dst: anytype) void {
comptime std.debug.assert(a.shape.len == 2);
comptime std.debug.assert(a.shape[1] == b.shape[0]);
for (0..shape[0]) |i| {
dst.asArray()[i] = [1]T{0} ** b.shape[1];
for (0..b.shape[1]) |j| {
for (0..shape[1]) |k| {
dst.asArray()[i][j] +=
a.asArray()[i][k] * b.asArray()[k][j];
}
}
}
}
fn mulNd(a: Self, b: anytype, dst: anytype) void {
comptime std.debug.assert(a.shape.len >= 3);
for (0..dst.shape[0]) |i| {
const asubm = subMatrix(@ptrCast(&a.asArray()[i]), 1);
const bsubm = @TypeOf(b).subMatrix(@ptrCast(&b.asArray()[i]), 1);
const dsubm = @TypeOf(dst).subMatrix(@ptrCast(&dst.asArray()[i]), 1);
switch (shape.len) {
3 => asubm.mul2d(bsubm, dsubm),
else => asubm.mulNd(bsubm, dsubm),
}
}
}
};
}
fn testMul1d(comptime T: type) !void {
const a = Matrix(T, &.{2}).initConst(&.{ 1, 2 });
const b = Matrix(T, &.{ 1, 2 }).initConst(&.{.{ 3, 4 }});
var c = try Matrix(T, &.{1}).init(std.testing.allocator);
defer c.deinit(std.testing.allocator);
a.mul(b, c);
try std.testing.expectEqualSlices(
T,
Matrix(T, &.{1}).initConst(&.{11}).items,
c.items,
);
}
test "1d mul" {
try testMul1d(u8);
try testMul1d(i8);
try testMul1d(f32);
}
fn testMul2d(comptime T: type) !void {
var a = try Matrix(T, &.{ 3, 2 }).init(std.testing.allocator);
defer a.deinit(std.testing.allocator);
a.fill(1);
const b = Matrix(T, &.{ 2, 3 }).initConst(&.{
.{ 2, 2, 2 },
.{ 2, 2, 2 },
});
const C = Matrix(T, &.{ 3, 3 });
var c = try C.init(std.testing.allocator);
defer c.deinit(std.testing.allocator);
a.mul(b, c);
try std.testing.expectEqualSlices(T, C.initConst(&.{
.{ 4, 4, 4 },
.{ 4, 4, 4 },
.{ 4, 4, 4 },
}).items, c.items);
}
test "2d mul" {
try testMul2d(u8);
try testMul2d(i8);
try testMul2d(f32);
}
fn testMul3d(comptime T: type) !void {
var a = try Matrix(T, &.{ 2, 2, 3 }).initArray(std.testing.allocator, .{
.{ .{ 1, 2, 3 }, .{ 4, 5, 6 } },
.{ .{ 7, 8, 9 }, .{ 10, 11, 12 } },
});
defer a.deinit(std.testing.allocator);
var b = try Matrix(T, &.{ 2, 3, 2 }).initArray(std.testing.allocator, .{
.{ .{ 1, 2 }, .{ 3, 4 }, .{ 5, 6 } },
.{ .{ 7, 8 }, .{ 9, 10 }, .{ 11, 12 } },
});
defer b.deinit(std.testing.allocator);
const C = Matrix(T, &.{ 2, 2, 2 });
var c = try C.init(std.testing.allocator);
defer c.deinit(std.testing.allocator);
a.mul(b, c);
try std.testing.expectEqualSlices(T, C.initConst(&.{
.{ .{ 22, 28 }, .{ 49, 64 } },
.{ .{ 220, 244 }, .{ 301, 334 } },
}).items, c.items);
}
test "3d mul" {
try testMul3d(u16);
try testMul3d(i16);
try testMul3d(f32);
}
fn testParse(comptime T: type) !void {
@setEvalBranchQuota(2000);
const A = Matrix(T, &.{2});
const a = try A.parse("[1 2]");
try std.testing.expectEqualSlices(T, &.{ 1, 2 }, a.items);
var abuf: [2]T = undefined;
const a2 = try A.parseBuf("[1 2]", &abuf);
try std.testing.expectEqualSlices(T, &.{ 1, 2 }, a2.items);
const B = Matrix(T, &.{ 2, 2 });
const b = try B.parse("[[1 2] [3 4]]");
try std.testing.expectEqualSlices(T, &.{ 1, 2, 3, 4 }, b.items);
var bbuf: [4]T = undefined;
const b2 = try B.parseBuf("[[1 2] [3 4]]", &bbuf);
try std.testing.expectEqualSlices(T, &.{ 1, 2, 3, 4 }, b2.items);
const C = Matrix(T, &.{ 2, 2, 2 });
const cin =
\\[[[1 2] [3 4]]
\\
\\ [[5 6] [ 7 8]]]
;
const c = try C.parse(cin);
try std.testing.expectEqualSlices(T, &.{ 1, 2, 3, 4, 5, 6, 7, 8 }, c.items);
var cbuf: [8]T = undefined;
const c2 = try C.parseBuf(cin, &cbuf);
try std.testing.expectEqualSlices(T, &.{ 1, 2, 3, 4, 5, 6, 7, 8 }, c2.items);
}
test "parse" {
try testParse(u32);
try testParse(i32);
try testParse(f32);
}
fn testMul(
comptime T: type,
comptime a_shape: []const usize,
comptime a_in: []const u8,
comptime b_shape: []const usize,
comptime b_in: []const u8,
comptime c_shape: []const usize,
comptime c_in: []const u8,
) !void {
@setEvalBranchQuota(10_000);
const A = Matrix(T, a_shape);
const a = try A.parse(a_in);
const B = Matrix(T, b_shape);
const b = try B.parse(b_in);
const C = Matrix(T, c_shape);
var c = try C.init(std.testing.allocator);
defer c.deinit(std.testing.allocator);
a.mul(b, c);
const expected = try C.parse(c_in);
try std.testing.expectEqualSlices(T, expected.items, c.items);
}
fn testMuls(comptime T: type) !void {
// 3d
try testMul(T, &.{ 3, 3, 2 },
\\[[[6 3]
\\ [7 4]
\\ [6 9]]
\\
\\ [[2 6]
\\ [7 4]
\\ [3 7]]
\\
\\ [[7 2]
\\ [5 4]
\\ [1 7]]]
, &.{ 3, 2, 4 },
\\[[[5 1 4 0]
\\ [9 5 8 0]]
\\
\\ [[9 2 6 3]
\\ [8 2 4 2]]
\\
\\ [[6 4 8 6]
\\ [1 3 8 1]]]
, &.{ 3, 3, 4 },
\\[[[ 57 21 48 0]
\\ [ 71 27 60 0]
\\ [111 51 96 0]]
\\
\\ [[ 66 16 36 18]
\\ [ 95 22 58 29]
\\ [ 83 20 46 23]]
\\
\\ [[ 44 34 72 44]
\\ [ 34 32 72 34]
\\ [ 13 25 64 13]]]
);
try testMul(T, &.{ 1, 3, 2 },
\\[[[6 3]
\\ [7 4]
\\ [6 9]]]
\\
, &.{ 1, 2, 4 },
\\[[[2 6 7 4]
\\ [3 7 7 2]]]
\\
, &.{ 1, 3, 4 },
\\[[[ 21 57 63 30]
\\ [ 26 70 77 36]
\\ [ 39 99 105 42]]]
\\
);
try testMul(T, &.{ 1, 3, 1 },
\\[[[6]
\\ [3]
\\ [7]]]
\\
, &.{ 1, 1, 3 },
\\[[[4 6 9]]]
\\
, &.{ 1, 3, 3 },
\\[[[24 36 54]
\\ [12 18 27]
\\ [28 42 63]]]
\\
);
// 4d
try testMul(T, &.{ 1, 1, 3, 1 },
\\[[[[6]
\\ [3]
\\ [7]]]]
\\
, &.{ 1, 1, 1, 3 },
\\[[[[4 6 9]]]]
\\
, &.{ 1, 1, 3, 3 },
\\[[[[24 36 54]
\\ [12 18 27]
\\ [28 42 63]]]]
\\
);
try testMul(u32, &.{ 2, 2, 3, 2 },
\\[[[[6 3]
\\ [7 4]
\\ [6 9]]
\\
\\ [[2 6]
\\ [7 4]
\\ [3 7]]]
\\
\\
\\ [[[7 2]
\\ [5 4]
\\ [1 7]]
\\
\\ [[5 1]
\\ [4 0]
\\ [9 5]]]]
\\
, &.{ 2, 2, 2, 3 },
\\[[[[8 0 9]
\\ [2 6 3]]
\\
\\ [[8 2 4]
\\ [2 6 4]]]
\\
\\
\\ [[[8 6 1]
\\ [3 8 1]]
\\
\\ [[9 8 9]
\\ [4 1 3]]]]
\\
, &.{ 2, 2, 3, 3 },
\\[[[[ 54 18 63]
\\ [ 64 24 75]
\\ [ 66 54 81]]
\\
\\ [[ 28 40 32]
\\ [ 64 38 44]
\\ [ 38 48 40]]]
\\
\\
\\ [[[ 62 58 9]
\\ [ 52 62 9]
\\ [ 29 62 8]]
\\
\\ [[ 49 41 48]
\\ [ 36 32 36]
\\ [101 77 96]]]]
\\
);
// 5d
try testMul(u32, &.{ 2, 2, 2, 3, 2 },
\\[[[[[6 3]
\\ [7 4]
\\ [6 9]]
\\
\\ [[2 6]
\\ [7 4]
\\ [3 7]]]
\\
\\
\\ [[[7 2]
\\ [5 4]
\\ [1 7]]
\\
\\ [[5 1]
\\ [4 0]
\\ [9 5]]]]
\\
\\
\\
\\ [[[[8 0]
\\ [9 2]
\\ [6 3]]
\\
\\ [[8 2]
\\ [4 2]
\\ [6 4]]]
\\
\\
\\ [[[8 6]
\\ [1 3]
\\ [8 1]]
\\
\\ [[9 8]
\\ [9 4]
\\ [1 3]]]]]
\\
, &.{ 2, 2, 2, 2, 3 },
\\[[[[[6 7 2]
\\ [0 3 1]]
\\
\\ [[7 3 1]
\\ [5 5 9]]]
\\
\\
\\ [[[3 5 1]
\\ [9 1 9]]
\\
\\ [[3 7 6]
\\ [8 7 4]]]]
\\
\\
\\
\\ [[[[1 4 7]
\\ [9 8 8]]
\\
\\ [[0 8 6]
\\ [8 7 0]]]
\\
\\
\\ [[[7 7 2]
\\ [0 7 2]]
\\
\\ [[2 0 4]
\\ [9 6 9]]]]]
\\
, &.{ 2, 2, 2, 3, 3 },
\\[[[[[ 36 51 15]
\\ [ 42 61 18]
\\ [ 36 69 21]]
\\
\\ [[ 44 36 56]
\\ [ 69 41 43]
\\ [ 56 44 66]]]
\\
\\
\\ [[[ 39 37 25]
\\ [ 51 29 41]
\\ [ 66 12 64]]
\\
\\ [[ 23 42 34]
\\ [ 12 28 24]
\\ [ 67 98 74]]]]
\\
\\
\\
\\ [[[[ 8 32 56]
\\ [ 27 52 79]
\\ [ 33 48 66]]
\\
\\ [[ 16 78 48]
\\ [ 16 46 24]
\\ [ 32 76 36]]]
\\
\\
\\ [[[ 56 98 28]
\\ [ 7 28 8]
\\ [ 56 63 18]]
\\
\\ [[ 90 48 108]
\\ [ 54 24 72]
\\ [ 29 18 31]]]]]
\\
);
}
test {
try testMuls(u32);
try testMuls(i32);
try testMuls(f32);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment