Skip to content

Instantly share code, notes, and snippets.

@travisstaloch
Last active July 30, 2024 06:09
Show Gist options
  • Save travisstaloch/c4f5f87dc494fda954188eca9b700da9 to your computer and use it in GitHub Desktop.
Save travisstaloch/c4f5f87dc494fda954188eca9b700da9 to your computer and use it in GitHub Desktop.
A more ergonomic vector type without usingnamespace. If you wanted to create a new Vec3(u32) type for instance, youd to have copy some decls. Don't forget to call check(Vec3(u32)) to make sure you haven't missed any shared methods.
const std = @import("std");
pub fn SharedVecMethods(comptime Self: type) type {
return struct {
const fields = @typeInfo(Self).Struct.fields;
const T = fields[0].type;
comptime {
// check that all fields have the same type
for (fields) |f| std.debug.assert(T == f.type);
}
const Simd = @Vector(fields.len, T);
const Array = [fields.len]T;
// conversion methods
pub fn simd(a: Self) Simd {
return @bitCast(a);
}
pub fn fromSimd(a: Simd) Self {
return @bitCast(a);
}
pub fn array(a: Self) Array {
return @bitCast(a);
}
pub fn fromArray(a: Array) Self {
return @bitCast(a);
}
pub fn initBy(a: T) Self {
return fromArray([1]T{a} ** fields.len);
}
pub const zero: Self = initBy(0);
pub const one: Self = initBy(1);
pub fn negate(a: Self) Self {
return a.mulBy(-1);
}
pub fn add(a: Self, b: Self) Self {
return fromSimd(a.simd() + b.simd());
}
pub fn sub(a: Self, b: Self) Self {
return fromSimd(a.simd() - b.simd());
}
pub fn mul(a: Self, b: Self) Self {
return fromSimd(a.simd() * b.simd());
}
pub fn div(a: Self, b: Self) Self {
return fromSimd(a.simd() / b.simd());
}
pub fn addBy(a: Self, b: T) Self {
return a.add(fromSimd(@splat(b)));
}
pub fn subBy(a: Self, b: T) Self {
return a.sub(fromSimd(@splat(b)));
}
pub fn mulBy(a: Self, b: T) Self {
return a.mul(fromSimd(@splat(b)));
}
pub fn divBy(a: Self, b: T) Self {
return a.div(fromSimd(@splat(b)));
}
pub fn mag(a: Self) T {
return @sqrt(@reduce(.Add, a.mul(a).simd()));
}
pub fn unitize(a: Self) Self {
return a.divBy(a.mag());
}
pub fn dot(a: Self, b: Self) T {
return @reduce(.Add, a.simd() * b.simd());
}
pub fn format(self: Self, comptime fmt_: []const u8, options_: std.fmt.FormatOptions, writer: anytype) !void {
// use {d} specifier by default
const fmt = if (fmt_.len == 0) "d" else fmt_;
// use 2 digits of precision and fill='0' by default
const options = if (fmt_.len != 0) options_ else .{
.precision = 2,
.fill = '0',
};
// construct name=f32x2 for T=f32 and len=2
const name = comptime std.fmt.comptimePrint("{s}x{}{{ ", .{ @typeName(T), fields.len });
_ = try writer.write(name);
for (self.array(), 0..) |ele, i| {
if (i != 0) _ = try writer.write(", ");
if (comptime std.ascii.eqlIgnoreCase(fmt, "x")) _ = try writer.write("0x");
try std.fmt.formatType(ele, fmt, options, writer, 0);
}
_ = try writer.write(" }");
}
};
}
pub fn Vec2(comptime T: type) type {
return extern struct {
x: T,
y: T,
const Self = @This();
fn init(x: T, y: T) Self {
return .{ .x = x, .y = y };
}
// to create a new VecN(T) type, copy all these delcs
pub const S = SharedVecMethods(Self);
pub const simd = S.simd;
pub const fromSimd = S.fromSimd;
pub const array = S.array;
pub const fromArray = S.fromArray;
pub const initBy = S.initBy;
pub const zero = S.zero;
pub const one = S.one;
pub const negate = S.negate;
pub const add = S.add;
pub const sub = S.sub;
pub const mul = S.mul;
pub const div = S.div;
pub const addBy = S.addBy;
pub const subBy = S.subBy;
pub const mulBy = S.mulBy;
pub const divBy = S.divBy;
pub const mag = S.mag;
pub const unitize = S.unitize;
pub const dot = S.dot;
pub const format = S.format;
};
}
pub fn check(comptime V: type) void {
comptime for (std.meta.declarations(SharedVecMethods(Vec2(u0)))) |decl| {
if (!@hasDecl(V, decl.name)) {
@compileError(std.fmt.comptimePrint("{s} is missing shared method '{s}'", .{ @typeName(V), decl.name }));
}
};
}
comptime {
// don't forget to add a check for any new VecN(T) types you create.
// T isn't important and could be any int or float type
check(Vec2(u0));
}
fn testT(comptime T: type) !void {
const Tx2 = Vec2(T);
const a = Tx2.one;
try std.testing.expectEqual(Tx2.initBy(3), a.add(a.mulBy(2)));
try std.testing.expectEqual(Tx2.initBy(2), a.addBy(1));
try std.testing.expectEqual(Tx2.initBy(0), a.subBy(1));
try std.testing.expectEqual(Tx2.initBy(2), a.mulBy(2));
try std.testing.expectEqual(Tx2.initBy(2), a.mulBy(4).divBy(2));
try std.testing.expectEqual(4, a.dot(a.mulBy(2)));
const name = @typeName(T) ++ "x2";
if (@typeInfo(T) == .Float) {
try std.testing.expectEqual(std.math.sqrt2, a.mag());
try std.testing.expectApproxEqAbs(std.math.sqrt1_2, a.unitize().x, std.math.floatEps(T));
try std.testing.expectFmt(name ++ "{ 1.00, 1.00 }", "{}", .{a});
try std.testing.expectFmt(name ++ "{ 1, 1 }", "{d:.0}", .{a});
try std.testing.expectFmt(name ++ "{ 1.0, 1.0 }", "{d:.1}", .{a});
} else {
try std.testing.expectFmt(name ++ "{ 1, 1 }", "{}", .{a});
try std.testing.expectFmt(name ++ "{ 0xA, 0xA }", "{X}", .{Tx2.initBy(10)});
}
}
test {
try testT(f32);
try testT(f64);
try testT(u8);
try testT(i8);
try testT(u32);
try testT(i32);
try testT(u64);
try testT(i64);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment