Last active
August 15, 2021 22:32
-
-
Save ityonemo/725d293b49ee334393e564026a8a2927 to your computer and use it in GitHub Desktop.
generic dot product zig
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
fn childOf(comptime T: type) type { | |
return switch (@typeInfo(T)) { | |
.Vector => |vector| vector.child, | |
.Array => |array| array.child, | |
.Pointer => |pointer| { | |
switch (pointer.size) { | |
.Slice => {}, | |
.One => { | |
switch (@typeInfo(pointer.child)) { | |
.Array => |array| { return array.child; }, | |
else => @compileError("tried to do a dot product on an unsupported type"), | |
} | |
}, | |
else => @compileError("tried to do a dot product on an unsupported type"), | |
} | |
return pointer.child; | |
}, | |
else => @compileError("tried to do a dot product on an unsupported type") | |
}; | |
} | |
fn dot(a: anytype, b: anytype) childOf(@TypeOf(a)) { | |
switch (@typeInfo(@TypeOf(a))) { | |
.Vector => |vector| { | |
var sum: vector.child = 0; | |
const len = vector.len; | |
const prod = a * b; | |
comptime var index = 0; | |
inline while (index < len) : (index += 1) { sum += prod[index]; } | |
return sum; | |
}, | |
.Array => |array| { | |
std.debug.assert(@typeInfo(@TypeOf(b)).Array.len == array.len); | |
var sum: array.child = 0; | |
const len = array.len; | |
comptime var index = 0; | |
inline while (index < len) : (index += 1) { sum += a[index] * b[index]; } | |
return sum; | |
}, | |
.Pointer => |pointer| { | |
if (a.len != b.len) {@panic("you dot producted two things of different length!");} | |
var sum: childOf(@TypeOf(a)) = 0; | |
for (a) | number, index | { sum += number * b[index]; } | |
return sum; | |
}, | |
else => unreachable, | |
} | |
} | |
test "dot product works for SIMD vectors" { | |
const V = std.meta.Vector(2, u8); | |
var x = V{1, 2}; | |
var y = V{3, 4}; | |
try std.testing.expectEqual(dot(x, y), 11); | |
} | |
test "dot product works for Arrays" { | |
var x = [_]u8{1, 2}; | |
var y = [_]u8{3, 4}; | |
try std.testing.expectEqual(dot(x, y), 11); | |
} | |
test "dot product works for Slices" { | |
var x: []const u8 = ([_]u8{1, 2})[0..]; | |
var y: []const u8 = ([_]u8{3, 4})[0..]; | |
try std.testing.expectEqual(dot(x, y), 11); | |
} | |
test "dot product works for Array pointers" { | |
var x = ([_]u8{1, 2})[0..]; | |
var y = ([_]u8{3, 4})[0..]; | |
try std.testing.expectEqual(dot(x, y), 11); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment