Skip to content

Instantly share code, notes, and snippets.

@ityonemo
Last active January 4, 2021 13:54
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ityonemo/ca017dc6f9ddfe3fe304446e1ef2baf0 to your computer and use it in GitHub Desktop.
Save ityonemo/ca017dc6f9ddfe3fe304446e1ef2baf0 to your computer and use it in GitHub Desktop.
Strategy for composable Algebra in Zig
const std = @import("std");
pub fn Algebra(comptime fields: []const type) type {
return struct{
pub fn @"<+>"(a: anytype, b: anytype) FieldFor(@TypeOf(a), @TypeOf(b)).Type {
const F: type = FieldFor(@TypeOf(a), @TypeOf(b));
return F.@"<+>"(a, b);
}
pub fn FieldFor(comptime AType: type, comptime BType: type) type {
for (fields) |field| {
if (AType == field.Type) {
return field;
}
} else unreachable;
}
};
}
fn prime_inherit(comptime s: anytype) type {
// probably can do this more dynamically.
return struct {
pub const @"<+'>" = s.@"<+>";
};
}
pub fn Complex(comptime field: type) type {
return struct{
re: field.Type,
im: field.Type,
pub const Type: type = @This();
usingnamespace prime_inherit(field);
pub fn @"<+>"(a: Type, b: Type) Type {
return .{
.re = @"<+'>"(a.re, b.re),
.im = @"<+'>"(a.re, b.im),
};
}
};
}
const std = @import("std");
const Integer = struct{
pub const Type: type = i32;
pub fn @"<+>"(a: i32, b: i32) i32 {
return a + b;
}
};
const Float = struct{
pub const Type: type = f64;
pub fn @"<+>"(a: f64, b: f64) f64 {
return a + b;
}
};
const CFactory = @import("complex.zig").Complex;
const Gaussian = CFactory(Integer);
const Complex = CFactory(Float);
usingnamespace @import("algebra.zig").Algebra(&[_]type{Integer, Gaussian, Float, Complex});
pub fn main() void {}
test "test" {
// integers and gaussians
const one = @intCast(i32, 1);
std.debug.assert(2 == @"<+>"(one, one));
const gauss_one_one: Gaussian = .{.re = one, .im = one};
const gauss_two_two = @"<+>"(gauss_one_one, gauss_one_one);
std.debug.assert(2 == gauss_two_two.re);
std.debug.assert(2 == gauss_two_two.im);
// floats and complex
const onef = @floatCast(f64, 1.0);
std.debug.assert(2.0 == @"<+>"(onef, onef));
const cplx_one_one: Complex = .{.re = 1.0, .im = 1.0};
const cplx_two_two = @"<+>"(cplx_one_one, cplx_one_one);
std.debug.assert(2.0 == cplx_two_two.re);
std.debug.assert(2.0 == cplx_two_two.re);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment