Skip to content

Instantly share code, notes, and snippets.

@ityonemo
Last active August 15, 2021 22:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ityonemo/725d293b49ee334393e564026a8a2927 to your computer and use it in GitHub Desktop.
Save ityonemo/725d293b49ee334393e564026a8a2927 to your computer and use it in GitHub Desktop.
generic dot product zig
const std = @import("std");
fn childOf(comptime T: type) type {
return switch (@typeInfo(T)) {
.Vector => |vector| vector.child,
.Array => |array| array.child,
.Pointer => |pointer| {
switch (pointer.size) {
.Slice => {},
.One => {
switch (@typeInfo(pointer.child)) {
.Array => |array| { return array.child; },
else => @compileError("tried to do a dot product on an unsupported type"),
}
},
else => @compileError("tried to do a dot product on an unsupported type"),
}
return pointer.child;
},
else => @compileError("tried to do a dot product on an unsupported type")
};
}
fn dot(a: anytype, b: anytype) childOf(@TypeOf(a)) {
switch (@typeInfo(@TypeOf(a))) {
.Vector => |vector| {
var sum: vector.child = 0;
const len = vector.len;
const prod = a * b;
comptime var index = 0;
inline while (index < len) : (index += 1) { sum += prod[index]; }
return sum;
},
.Array => |array| {
std.debug.assert(@typeInfo(@TypeOf(b)).Array.len == array.len);
var sum: array.child = 0;
const len = array.len;
comptime var index = 0;
inline while (index < len) : (index += 1) { sum += a[index] * b[index]; }
return sum;
},
.Pointer => |pointer| {
if (a.len != b.len) {@panic("you dot producted two things of different length!");}
var sum: childOf(@TypeOf(a)) = 0;
for (a) | number, index | { sum += number * b[index]; }
return sum;
},
else => unreachable,
}
}
test "dot product works for SIMD vectors" {
const V = std.meta.Vector(2, u8);
var x = V{1, 2};
var y = V{3, 4};
try std.testing.expectEqual(dot(x, y), 11);
}
test "dot product works for Arrays" {
var x = [_]u8{1, 2};
var y = [_]u8{3, 4};
try std.testing.expectEqual(dot(x, y), 11);
}
test "dot product works for Slices" {
var x: []const u8 = ([_]u8{1, 2})[0..];
var y: []const u8 = ([_]u8{3, 4})[0..];
try std.testing.expectEqual(dot(x, y), 11);
}
test "dot product works for Array pointers" {
var x = ([_]u8{1, 2})[0..];
var y = ([_]u8{3, 4})[0..];
try std.testing.expectEqual(dot(x, y), 11);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment