Skip to content

Instantly share code, notes, and snippets.

@micahswitzer
Last active December 25, 2021 15:35
Implementation of the visitor pattern in Zig
const std = @import("std");
fn VisitorSet(comptime Argument: type, comptime Return: type, comptime types: anytype) type {
return struct {
const VisitorType = Visitor(Argument, Return, types);
const VisiteeType = Visitee(Argument, Return, types);
const VtableType = VisitorType.VtableType;
};
}
fn Visitee(comptime Argument: type, comptime Return: type, comptime types: anytype) type {
return struct {
ptr: usize,
acceptPtr: fn (usize, *const VisitorType, Argument) Return,
const Self = @This();
const VisitorType = Visitor(Argument, Return, types);
pub fn accept(self: *Self, visitor: *const VisitorType, arg: Argument) Return {
return self.acceptPtr(self.ptr, visitor, arg);
}
pub fn create(obj: anytype) Self {
return Self{
.ptr = @ptrToInt(obj),
.acceptPtr = Dispatcher(Argument, Return, std.meta.Child(@TypeOf(obj)), types).dispatch,
};
}
};
}
fn VisitorVtable(comptime Argument: type, comptime Return: type, comptime types: anytype) type {
const fields = std.meta.fields(@TypeOf(types));
var type_list: [fields.len]type = undefined;
inline for (fields) |f, i| {
if (f.field_type != type)
@compileError("types must be a tuple of types");
type_list[i] = @field(types, f.name);
}
var new_fields: [fields.len]std.builtin.TypeInfo.StructField = undefined;
inline for (type_list) |T, i| {
const FType = fn (*T, Argument) Return;
new_fields[i] = .{
.name = "visit" ++ @typeName(T),
.field_type = FType,
.default_value = null,
.is_comptime = false,
.alignment = @alignOf(FType),
};
}
return @Type(std.builtin.TypeInfo{ .Struct = .{
.layout = .Auto,
.fields = &new_fields,
.decls = &[_]std.builtin.TypeInfo.Declaration{},
.is_tuple = false,
} });
}
fn Visitor(comptime Argument: type, comptime Return: type, comptime types: anytype) type {
return struct {
pub const VtableType = VisitorVtable(Argument, Return, types);
const Self = @This();
ptr: usize,
vtable: *const VtableType,
pub fn fromVtable(vtable: *const VtableType) Self {
return .{
.ptr = undefined,
.vtable = vtable,
};
}
};
}
fn Dispatcher(comptime Argument: type, comptime Return: type, comptime T: type, comptime types: anytype) type {
return struct {
const VisitorType = Visitor(Argument, Return, types);
pub fn dispatch(obj: usize, visitor: *const VisitorType, arg: Argument) Return {
var impl = @intToPtr(*T, obj);
return @field(visitor.vtable, "visit" ++ @typeName(T))(impl, arg);
}
};
}
pub fn main() void {
const B = struct {
value: isize,
const Self = @This();
fn visitee(self: *Self, comptime VisiteeType: type) VisiteeType {
return VisiteeType.create(self);
}
};
const C = struct {
value: isize,
const Self = @This();
fn visitee(self: *Self, comptime VisiteeType: type) VisiteeType {
return VisiteeType.create(self);
}
};
const A = struct {
op: enum { Add, Subtract, Multiply },
left: Union,
right: Union,
const Union = union(enum) {
b: B,
c: C,
fn getVisitee(self: *Union, comptime VisiteeType: type) VisiteeType {
return switch (self.*) {
.c => |*c| c.visitee(VisiteeType),
.b => |*b| b.visitee(VisiteeType),
};
}
};
const Self = @This();
fn visitee(self: *Self, comptime VisiteeType: type) VisiteeType {
return VisiteeType.create(self);
}
};
var writer = std.io.getStdOut().writer();
const Argument = *@TypeOf(writer);
const Return = isize;
const V = VisitorSet(Argument, Return, .{ A, B, C });
const VisitorImpl = struct {
fn visitA(obj: *A, w: Argument) Return {
w.print("Visiting A! left = {?}, right = {?}, op = {s}\n", .{obj.left, obj.right, @tagName(obj.op)}) catch unreachable;
const leftV = obj.left.getVisitee(V.VisiteeType).accept(&Self.visitor, w);
const rightV = obj.right.getVisitee(V.VisiteeType).accept(&Self.visitor, w);
return switch (obj.op) {
.Add => leftV + rightV,
.Subtract => leftV - rightV,
.Multiply => leftV * rightV,
};
}
fn visitB(obj: *B, w: Argument) Return {
w.print("Visiting B! b = {}\n", .{obj.value}) catch unreachable;
return obj.value;
}
fn visitC(obj: *C, w: Argument) Return {
w.print("Visiting B! c = {}\n", .{obj.value}) catch unreachable;
return obj.value;
}
const Self = @This();
const vtable = V.VtableType{
.visitA = visitA,
.visitB = visitB,
.visitC = visitC,
};
const visitor = V.VisitorType.fromVtable(&vtable);
};
var b = B{ .value = 2 };
var c = C{ .value = 3 };
var a = A{ .left = .{ .b = b }, .right = .{ .c = c }, .op = .Subtract };
var visitees = [_]V.VisiteeType{ a.visitee(V.VisiteeType) };
for (visitees) |*visitee| {
const res = visitee.accept(&VisitorImpl.visitor, &writer);
writer.print("got val {}\n", .{res} ) catch unreachable;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment