Skip to content

Instantly share code, notes, and snippets.

@VerdagonModular
Created November 1, 2025 04:54
Show Gist options
  • Save VerdagonModular/093557f5c0fd424ab44d7e8ab5db7858 to your computer and use it in GitHub Desktop.
Save VerdagonModular/093557f5c0fd424ab44d7e8ab5db7858 to your computer and use it in GitHub Desktop.
# ./bazelw run //KGEN/tools/mojo -- run run_staging_fma.mojo
from collections import List
from utils import Variant
from os import abort
from benchmark import keep
from benchmark.memory import clobber_memory
from time import perf_counter_ns
from math import tanh
# options for elementwise fusion:
# - alias is_elementwise?
# - see if can return a Optional[impl of trait ElementwiseMatrixOp]
# for itself?
# rust iterators can already enable this generally.
# what is it we're doing special?
# we can do it when recursion is involved i suppose?
# we can do it on an arbitrary tree?
# can't rust do it for an arbitrary tree? wasnt the difference that we
# could rearrange things at compile time?
# in rust, can we do a .map()
trait MatrixTrait(Copyable):
alias rows: Int
alias cols: Int
@always_inline
fn __getitem__(ref self, i: Int, j: Int) -> Float32:
...
@fieldwise_init
struct Matrix[Rows: Int, Cols: Int](MatrixTrait, Movable, Copyable):
alias rows: Int = Rows
alias cols: Int = Cols
var values: List[Float32]
fn __init__(out self):
self.values = List[Float32]()
for _ in range(0, Rows):
for _ in range(0, Cols):
self.values.append(0)
fn __init__[OtherRows: Int, OtherCols: Int](
out self,
deinit other: Matrix[OtherRows, OtherCols]
):
constrained[Rows == OtherRows]()
constrained[Cols == OtherCols]()
self.values = other.values^
fn copy(self) -> Self:
return Self(self.values.copy())
@always_inline
fn __getitem__(ref self, i: Int, j: Int) -> Float32:
return self.values[i * Self.cols + j]
trait MatrixOp:
fn source_index_if_map(self) -> Optional[Int]:
...
fn get_result_rows(self) -> Int:
...
fn get_result_cols(self) -> Int:
...
@staticmethod
@always_inline
fn execute_comptime[
# TODO: better syntax for this
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`],
//,
nodes: List[MatrixOpVariant],
self_: Self
](
inputs: VariadicPack[_, _, MatrixTrait, *InputTypes]
) -> Matrix[
self_.get_result_rows(),
self_.get_result_cols()
]:
...
# TODO: This leaves extra nodes in the out_nodes list, let's avoid that
fn optimize(
self,
in_nodes: List[MatrixOpVariant],
mut out_nodes: List[MatrixOpVariant],
node_index: Int,
rotate_matrix_multiply_chains: Bool,
fuse_matrix_multiply_map: Bool
) -> Int:
...
# TODO: detect when theres a MatrixMap containing a MatrixMul, and then
# fuse them into a MatrixMulMap.
@fieldwise_init
struct MatrixMap[
F: fn[
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`]
](
VariadicPack[_, _, MatrixTrait, *InputTypes],
Int,
Int,
Float32
)->Float32
](MatrixOp, Copyable, Movable):
var source_index: Int
var result_rows: Int
var result_cols: Int
fn source_index_if_map(self) -> Optional[Int]:
return self.source_index
fn get_result_rows(self) -> Int:
return self.result_rows
fn get_result_cols(self) -> Int:
return self.result_cols
@staticmethod
# @always_inline TODO: Had to remove this because it thought there was a cycle
fn execute_comptime[
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //,
nodes: List[MatrixOpVariant],
self_: Self
](
inputs: VariadicPack[_, _, MatrixTrait, *InputTypes]
) -> Matrix[
self_.get_result_rows(),
self_.get_result_cols()
]:
alias source_node = nodes[self_.source_index]
alias source_rows = MatrixOp_runtime_dispatch_get_result_rows(source_node)
alias source_cols = MatrixOp_runtime_dispatch_get_result_cols(source_node)
var source_val: Matrix[source_rows, source_cols] =
MatrixOp_comptime_dispatch_execute_comptime_old[nodes, source_node](inputs)
var result = Matrix[self_.get_result_rows(), self_.get_result_cols()]()
# TODO: skip this zero initing
for _ in range(source_rows * source_cols):
result.values.append(0.0)
for i in range(source_rows):
for j in range(source_cols):
result.values[i * source_cols + j] = F(
inputs,
i,
j,
source_val.values[i * source_cols + j]
)
return result^
# TODO: This leaves extra nodes in the out_nodes list, let's avoid that
fn optimize(
self,
in_nodes: List[MatrixOpVariant],
mut out_nodes: List[MatrixOpVariant],
node_index: Int,
rotate_matrix_multiply_chains: Bool,
fuse_matrix_multiply_map: Bool
) -> Int:
var node = in_nodes[node_index].copy()
var source_if_is_map = MatrixOp_runtime_dispatch_source_index_if_map(node)
var source_index = source_if_is_map.value()
var map_source_index =
MatrixOp_runtime_dispatch_optimize_old(
in_nodes, out_nodes, source_index, rotate_matrix_multiply_chains, fuse_matrix_multiply_map)
var map_source_node = out_nodes[map_source_index].copy()
if fuse_matrix_multiply_map:
if map_source_node.isa[MatrixMul]():
var mul_node = map_source_node[MatrixMul].copy()
mul_node.elementwise_tails.append(map_source_index)
out_nodes[map_source_index] = mul_node^
return map_source_index
out_nodes.append(
MatrixMap[F](
map_source_index,
MatrixOp_runtime_dispatch_get_result_rows(node),
MatrixOp_runtime_dispatch_get_result_cols(node)))
return len(out_nodes) - 1
@fieldwise_init
struct MatrixMul(MatrixOp, Copyable, Movable):
var left_index: Int
var right_index: Int
var elementwise_tails: List[Int]
var result_rows: Int
var result_cols: Int
fn source_index_if_map(self) -> Optional[Int]:
return None
fn get_result_rows(self) -> Int:
return self.result_rows
fn get_result_cols(self) -> Int:
return self.result_cols
@staticmethod
@always_inline
fn execute_comptime[
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //,
nodes: List[MatrixOpVariant],
self_: Self
](
inputs: VariadicPack[_, _, MatrixTrait, *InputTypes]
) -> Matrix[
self_.get_result_rows(),
self_.get_result_cols()
]:
alias left_node = nodes[self_.left_index]
alias right_node = nodes[self_.right_index]
alias left_rows = MatrixOp_runtime_dispatch_get_result_rows(left_node)
alias left_cols = MatrixOp_runtime_dispatch_get_result_cols(left_node)
alias right_rows = MatrixOp_runtime_dispatch_get_result_rows(right_node)
alias right_cols = MatrixOp_runtime_dispatch_get_result_cols(right_node)
var left_val: Matrix[left_rows, left_cols] =
MatrixOp_comptime_dispatch_execute_comptime_old[nodes, left_node](inputs)
var right_val_type_not_corrected: Matrix[right_rows, right_cols] =
MatrixOp_comptime_dispatch_execute_comptime_old[nodes, right_node](inputs)
constrained[left_cols == right_rows]()
alias left_cols_right_rows = left_cols
var result = Matrix[self_.get_result_rows(), self_.get_result_cols()]()
# TODO: skip this zero initing
for _ in range(left_rows * right_cols):
result.values.append(0.0)
for i in range(left_rows):
for j in range(right_cols):
var sum = Float32(0.0)
for k in range(left_cols_right_rows):
sum += left_val.values[i * left_cols_right_rows + k] * right_val_type_not_corrected.values[k * right_cols + j]
@parameter
for z in self_.elementwise_tails:
alias elementwise_tails_node = nodes[z]
alias elementwise_tails_node_elementwised_opt =
MatrixOp_comptime_dispatch_get_elementwise_scalar_op[nodes, elementwise_tails_node]()
@parameter
if elementwise_tails_node_elementwised_opt:
alias elementwise_tails_node_elementwised = elementwise_tails_node_elementwised_opt.value()
sum = ScalarOp_comptime_dispatch_execute_comptime[nodes, elementwise_tails_node_elementwised](inputs, i, j, sum)
result.values[i * right_cols + j] = sum
return result^
# TODO: This leaves extra nodes in the out_nodes list, let's avoid that
fn optimize(
self,
in_nodes: List[MatrixOpVariant],
mut out_nodes: List[MatrixOpVariant],
node_index: Int,
rotate_matrix_multiply_chains: Bool,
fuse_matrix_multiply_map: Bool
) -> Int:
var node = in_nodes[node_index].copy()
var source_if_is_map = MatrixOp_runtime_dispatch_source_index_if_map(node)
# TODO: reordered and switched a lot of these refs to vars to avoid some
# memory unsafety that was actually caught in the interpreter. we could
# really use actual memory safety.
var mul_node = node[MatrixMul].copy()
var mul_left_index = MatrixOp_runtime_dispatch_optimize_old(in_nodes, out_nodes, mul_node.left_index, rotate_matrix_multiply_chains, fuse_matrix_multiply_map)
var mul_right_index = MatrixOp_runtime_dispatch_optimize_old(in_nodes, out_nodes, mul_node.right_index, rotate_matrix_multiply_chains, fuse_matrix_multiply_map)
var mul_left_node = out_nodes[mul_left_index].copy()
var mul_right_node = out_nodes[mul_right_index].copy()
if rotate_matrix_multiply_chains:
if mul_left_node.isa[MatrixMul]():
var mul_left_mul_node = mul_left_node[MatrixMul].copy()
var a_index = mul_left_mul_node.left_index
var a_node = out_nodes[a_index].copy()
var b_index = mul_left_mul_node.right_index
var b_node = out_nodes[b_index].copy()
var c_index = mul_right_index
var c_node = mul_right_node.copy()
# If A is 10 × 30, B is 30 × 5, and C is 5 × 60, then:
# - Computing (AB)C would need (10×30×5) + (10×5×60) = 4500 ops
# - Computing A(BC) would need (30×5×60) + (10×30×60) = 27000 ops
var ab_c_total_computations =
MatrixOp_runtime_dispatch_get_result_rows(a_node) * MatrixOp_runtime_dispatch_get_result_cols(a_node) * MatrixOp_runtime_dispatch_get_result_cols(b_node) +
MatrixOp_runtime_dispatch_get_result_rows(a_node) * MatrixOp_runtime_dispatch_get_result_cols(b_node) * MatrixOp_runtime_dispatch_get_result_cols(c_node)
var a_bc_total_computations =
MatrixOp_runtime_dispatch_get_result_rows(b_node) * MatrixOp_runtime_dispatch_get_result_cols(b_node) * MatrixOp_runtime_dispatch_get_result_cols(c_node) +
MatrixOp_runtime_dispatch_get_result_rows(a_node) * MatrixOp_runtime_dispatch_get_result_cols(a_node) * MatrixOp_runtime_dispatch_get_result_cols(c_node)
if ab_c_total_computations <= a_bc_total_computations:
# Keep current structure, already does (a * b) * c
pass
else:
# Reorder/rotate!
out_nodes.append(MatrixMul(b_index, c_index, List[Int](), MatrixOp_runtime_dispatch_get_result_rows(b_node), MatrixOp_runtime_dispatch_get_result_cols(c_node)))
var intermediate_index = len(out_nodes) - 1
out_nodes.append(MatrixMul(a_index, intermediate_index, List[Int](), MatrixOp_runtime_dispatch_get_result_rows(a_node), MatrixOp_runtime_dispatch_get_result_cols(c_node)))
var result_index = len(out_nodes) - 1
return result_index
# TODO: the above reorders (AB)C to A(BC), we need similar handling
# for making A(BC) to (AB)C
# Keep existing structure
out_nodes.append(
MatrixMul(
mul_left_index, mul_right_index, List[Int](),
mul_node.get_result_rows(), mul_node.get_result_cols()))
var result_index = len(out_nodes) - 1
return result_index
# Tried putting the vector itself in here, was less separable.
# Rule of thumb?: Have AST refer to inputs somewhere else / symbolically
@fieldwise_init
struct MatrixLit(MatrixOp, Copyable, Movable):
var input_index: Int
var rows: Int
var cols: Int
fn source_index_if_map(self) -> Optional[Int]:
return None
fn get_result_rows(self) -> Int:
return self.rows
fn get_result_cols(self) -> Int:
return self.cols
@staticmethod
@always_inline
fn execute_comptime[
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //,
nodes: List[MatrixOpVariant],
self_: Self
](
inputs: VariadicPack[_, _, MatrixTrait, *InputTypes]
) -> Matrix[
self_.get_result_rows(),
self_.get_result_cols()
]:
ref input_ref = inputs[self_.input_index]
# I don't think we can avoid this rebind, *something* needs to check
# that the input_index'th element matches the shape we expect.
return rebind[Matrix[
self_.get_result_rows(),
self_.get_result_cols()
]](input_ref).copy()
# TODO: This leaves extra nodes in the out_nodes list, let's avoid that
fn optimize(
self,
in_nodes: List[MatrixOpVariant],
mut out_nodes: List[MatrixOpVariant],
node_index: Int,
rotate_matrix_multiply_chains: Bool,
fuse_matrix_multiply_map: Bool
) -> Int:
var node = in_nodes[node_index].copy()
var source_if_is_map = MatrixOp_runtime_dispatch_source_index_if_map(node)
var lit_node = node[MatrixLit].copy()
out_nodes.append(lit_node.copy())
return len(out_nodes) - 1
@fieldwise_init
struct CompileResult:
var nodes: List[MatrixOpVariant]
var root_index: Int
fn compile(
rotate_matrix_multiply_chains: Bool,
fuse_matrix_multiply_map: Bool
) -> CompileResult:
var nodes = List[MatrixOpVariant]()
# L1 data cache: 65KB, L2 cache: 4MB
# Using larger matrices to exceed cache and see fusion effects
# If A is 200 × 100, B is 100 × 200, and C is 200 × 100, then:
# - Computing (AB)C would need (200×100×200) + (200×200×100) = 100M + 100M = 200M ops
# - Computing A(BC) would need (100×200×100) + (200×100×100) = 10M + 10M = 20M ops
# So A(BC) is 10x faster.
# Matrix sizes: A=200×100×4=400KB, B=100×200×4=400KB, C=200×100×4=400KB, total=1.2MB
alias a_rows = 200
alias a_cols = 100
alias b_rows = 100
alias b_cols = 200
alias c_rows = 200
alias c_cols = 100
nodes.append(MatrixLit(0, a_rows, a_cols))
nodes.append(MatrixLit(1, b_rows, b_cols))
nodes.append(MatrixLit(2, c_rows, c_cols))
nodes.append(MatrixMul(0, 1, List[Int](), a_rows, b_cols))
nodes.append(MatrixMul(3, 2, List[Int](), a_rows, c_cols))
# TODO: had an infinite recursion interpreter crash because i put
# 5 here instead of 4. no symptoms, just crash, impossible to figure
# out without firing up the debugger.
nodes.append(MatrixMap[residual](4, a_rows, c_cols))
nodes.append(MatrixMap[gelu](5, a_rows, c_cols))
var optimized_nodes = List[MatrixOpVariant]()
var root_index =
MatrixOp_runtime_dispatch_optimize_old(
nodes,
optimized_nodes,
len(nodes) - 1,
rotate_matrix_multiply_chains,
fuse_matrix_multiply_map)
return CompileResult(optimized_nodes^, root_index)
fn describe_node(node: MatrixOpVariant) -> String:
"""Return a human-readable description of a node."""
if node.isa[MatrixLit]():
var lit = node[MatrixLit].copy()
return "MatrixLit(" + String(lit.get_result_rows()) + "x" + String(lit.get_result_cols()) + ")"
elif node.isa[MatrixMul]():
var mul = node[MatrixMul].copy()
var tails = " +tails[" + String(len(mul.elementwise_tails)) + "]" if len(mul.elementwise_tails) > 0 else ""
return "MatrixMul(left=" + String(mul.left_index) + ", right=" + String(mul.right_index) + tails + ")"
elif node.isa[MatrixMap[residual]]():
var map_node = node[MatrixMap[residual]].copy()
return "MatrixMap[residual](source=" + String(map_node.source_index) + ")"
elif node.isa[MatrixMap[gelu]]():
var map_node = node[MatrixMap[gelu]].copy()
return "MatrixMap[gelu](source=" + String(map_node.source_index) + ")"
else:
return "Unknown"
fn assert_graph_equals(actual: CompileResult, expected_description: String, test_name: String):
"""Compare actual graph against expected description."""
print("\n--- Checking:", test_name, "---")
# Build actual description
var actual_desc = ""
for i in range(len(actual.nodes)):
if i > 0:
actual_desc += " | "
actual_desc += String(i) + ":" + describe_node(actual.nodes[i])
actual_desc += " | root=" + String(actual.root_index)
print("Expected:", expected_description)
print("Actual: ", actual_desc)
if actual_desc != expected_description:
print("❌ MISMATCH!")
abort("Graph mismatch for " + test_name)
else:
print("✓ Match!")
fn main():
# Shows how, even though we dont know the graph until runtime,
# we can still do some optimization at runtime to turn one runtime
# graph into another runtime.
# TODO: use this instead of the below chunk, currently causes behavior
# differences (causes us to hit our panic for "Unknown Op")
# benchmark_runtime_fma_fused()
alias a_rows = 200
alias a_cols = 100
alias b_rows = 100
alias b_cols = 200
alias c_rows = 200
alias c_cols = 100
var matrix_a = Matrix[a_rows, a_cols]()
var matrix_b = Matrix[b_rows, b_cols]()
var matrix_c = Matrix[c_rows, c_cols]()
var matrix_d = Matrix[a_rows, c_cols]() # Fourth matrix for elementwise addition
var result_rt_unopt = compile(False, False)
assert_graph_equals(
result_rt_unopt,
"0:MatrixLit(200x100) | 1:MatrixLit(100x200) | 2:MatrixMul(left=0, right=1) | 3:MatrixLit(200x100) | 4:MatrixMul(left=2, right=3) | 5:MatrixMap[residual](source=4) | 6:MatrixMap[gelu](source=5) | root=6",
"Unoptimized"
)
var result_rt_rotate_only = compile(True, False)
assert_graph_equals(
result_rt_rotate_only,
"0:MatrixLit(200x100) | 1:MatrixLit(100x200) | 2:MatrixMul(left=0, right=1) | 3:MatrixLit(200x100) | 4:MatrixMul(left=1, right=3) | 5:MatrixMul(left=0, right=4) | 6:MatrixMap[residual](source=5) | 7:MatrixMap[gelu](source=6) | root=7",
"Rotate only"
)
var result_rt_fuse_only = compile(False, True)
assert_graph_equals(
result_rt_fuse_only,
"0:MatrixLit(200x100) | 1:MatrixLit(100x200) | 2:MatrixMul(left=0, right=1) | 3:MatrixLit(200x100) | 4:MatrixMul(left=2, right=3 +tails[2]) | root=4",
"Fuse only"
)
var result_rt_both = compile(True, True)
assert_graph_equals(
result_rt_both,
"0:MatrixLit(200x100) | 1:MatrixLit(100x200) | 2:MatrixMul(left=0, right=1) | 3:MatrixLit(200x100) | 4:MatrixMul(left=1, right=3) | 5:MatrixMul(left=0, right=4 +tails[2]) | root=5",
"Rotate + Fuse"
)
print("Starting...")
alias result_ct_unopt = compile(False, False)
alias root_node_ct_unopt = result_ct_unopt.nodes[result_ct_unopt.root_index]
var elapsed_ns_ct_unopt = 0
for _ in range(0, 4000):
clobber_memory() # TODO: I don't actually know what clobber_memory does
var start_ct_unopt = perf_counter_ns()
var result = MatrixOp_comptime_dispatch_execute_comptime_old_variadic[result_ct_unopt.nodes, root_node_ct_unopt](matrix_a, matrix_b, matrix_c, matrix_d)
keep(len(result.values) > 0)
var end_ct_unopt = perf_counter_ns()
elapsed_ns_ct_unopt += (end_ct_unopt - start_ct_unopt)
var elapsed_ms_ct_unopt = Float64(elapsed_ns_ct_unopt) / 1_000_000.0
print("Comptime (rotate=False, fuse=False): ", elapsed_ms_ct_unopt, " ms")
alias result_ct_rotate = compile(True, False)
alias root_node_ct_rotate = result_ct_rotate.nodes[result_ct_rotate.root_index]
var elapsed_ns_ct_rotate = 0
for _ in range(0, 4000):
clobber_memory() # TODO: I don't actually know what clobber_memory does
var start_ct_rotate = perf_counter_ns()
var result = MatrixOp_comptime_dispatch_execute_comptime_old_variadic[result_ct_rotate.nodes, root_node_ct_rotate](matrix_a, matrix_b, matrix_c, matrix_d)
keep(len(result.values) > 0)
var end_ct_rotate = perf_counter_ns()
elapsed_ns_ct_rotate += (end_ct_rotate - start_ct_rotate)
var elapsed_ms_ct_rotate = Float64(elapsed_ns_ct_rotate) / 1_000_000.0
print("Comptime (rotate=True, fuse=False): ", elapsed_ms_ct_rotate, " ms")
alias result_ct_fuse = compile(False, True)
alias root_node_ct_fuse = result_ct_fuse.nodes[result_ct_fuse.root_index]
var elapsed_ns_ct_fuse = 0
for _ in range(0, 4000):
clobber_memory() # TODO: I don't actually know what clobber_memory does
var start_ct_fuse = perf_counter_ns()
var result = MatrixOp_comptime_dispatch_execute_comptime_old_variadic[result_ct_fuse.nodes, root_node_ct_fuse](matrix_a, matrix_b, matrix_c, matrix_d)
keep(len(result.values) > 0)
var end_ct_fuse = perf_counter_ns()
elapsed_ns_ct_fuse += (end_ct_fuse - start_ct_fuse)
var elapsed_ms_ct_fuse = Float64(elapsed_ns_ct_fuse) / 1_000_000.0
print("Comptime (rotate=False, fuse=True): ", elapsed_ms_ct_fuse, " ms")
alias result_ct_both = compile(True, True)
alias root_node_ct_both = result_ct_both.nodes[result_ct_both.root_index]
var elapsed_ns_ct_both = 0
for _ in range(0, 4000):
clobber_memory() # TODO: I don't actually know what clobber_memory does
var start_ct_both = perf_counter_ns()
var result = MatrixOp_comptime_dispatch_execute_comptime_old_variadic[result_ct_both.nodes, root_node_ct_both](matrix_a, matrix_b, matrix_c, matrix_d)
keep(len(result.values) > 0)
var end_ct_both = perf_counter_ns()
elapsed_ns_ct_both += (end_ct_both - start_ct_both)
var elapsed_ms_ct_both = Float64(elapsed_ns_ct_both) / 1_000_000.0
print("Comptime (rotate=True, fuse=True): ", elapsed_ms_ct_both, " ms")
# Shows doing some fusion for a custom op with a runtime graph.
# expected: really slow
# Shows doing some fusion for a custom op with a comptime graph.
# expected: actually inlines, super fast
print("Done!")
fn residual[
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`]
](
inputs: VariadicPack[_, _, MatrixTrait, *InputTypes],
i: Int,
j: Int,
x: Float32
) -> Float32:
# TODO: really confusing error here when i forgot to pass in the fourth
# matrix. lots of crazy rebind problems in the elaborator.
return x + inputs[3][i, j]
fn gelu[
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`]
](
inputs: VariadicPack[_, _, MatrixTrait, *InputTypes],
i: Int,
j: Int,
x: Float32
) -> Float32:
# Just apply GELU activation
alias sqrt_2_over_pi = 0.7978845608028654
alias coeff = 0.044715
var x_cubed = x * x * x
var tanh_arg = sqrt_2_over_pi * (x + coeff * x_cubed)
# Simple tanh approximation
var tanh_val: Float32
if abs(tanh_arg) < 3.0:
var x2 = tanh_arg * tanh_arg
tanh_val = tanh_arg * (27.0 + x2) / (27.0 + 9.0 * x2)
else:
tanh_val = Float32(1.0) if tanh_arg > 0 else Float32(-1.0)
return x * 0.5 * (1.0 + tanh_val)
# ===== BEGIN ScalarOp code
trait ScalarOp:
@staticmethod
@always_inline
fn execute_comptime[
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`]
](
inputs: VariadicPack[_, _, MatrixTrait, *InputTypes],
i: Int,
j: Int,
x: Float32
) -> Float32:
...
@fieldwise_init
struct ScalarOpFunc[
F: fn[
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`]
](
VariadicPack[_, _, MatrixTrait, *InputTypes],
Int,
Int,
Float32
)->Float32
](ScalarOp, Copyable, Movable):
@staticmethod
@always_inline
fn execute_comptime[
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`]
](
inputs: VariadicPack[_, _, MatrixTrait, *InputTypes],
i: Int,
j: Int,
x: Float32
) -> Float32:
return F(inputs, i, j, x)
alias ScalarOpVariant = Variant[ScalarOpFunc[residual], ScalarOpFunc[gelu]]
fn ScalarOp_comptime_dispatch_execute_comptime[
InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //,
nodes: List[MatrixOpVariant],
scalar_op: ScalarOpVariant
](
inputs: VariadicPack[_, _, MatrixTrait, *InputTypes],
i: Int,
j: Int,
x: Float32
) -> Float32:
@parameter
if scalar_op.isa[ScalarOpFunc[residual]]():
return scalar_op[ScalarOpFunc[residual]].execute_comptime(inputs, i, j, x)
elif scalar_op.isa[ScalarOpFunc[gelu]]():
return scalar_op[ScalarOpFunc[gelu]].execute_comptime(inputs, i, j, x)
else:
abort("Unknown scalar op")
return 0.0
# ===== END ScalarOp code
# ===== BEGIN MatrixOp existential code
alias MatrixOpVariant = Variant[MatrixLit, MatrixMul, MatrixMap[residual], MatrixMap[gelu]]
# TODO: look into a constrained Variant to maybe make this general
fn MatrixOp_dispatch_runtime[
R: AnyType, //,
func: fn[T: MatrixOp](x: T) capturing -> R
](node: MatrixOpVariant) -> R:
if node.isa[MatrixLit]():
return func(node[MatrixLit])
elif node.isa[MatrixMul]():
return func(node[MatrixMul])
elif node.isa[MatrixMap[residual]]():
return func(node[MatrixMap[residual]])
elif node.isa[MatrixMap[gelu]]():
return func(node[MatrixMap[gelu]])
else:
abort("Unknown op")
while True:
pass
# TODO: look into a constrained Variant to maybe make this general
@always_inline
fn MatrixOp_dispatch_comptime[
R: AnyType, //,
func: fn[T: MatrixOp, //, node: T]() capturing -> R,
node: MatrixOpVariant
]() -> R:
@parameter
if node.isa[MatrixLit]():
return func[node[MatrixLit]]()
elif node.isa[MatrixMul]():
return func[node[MatrixMul]]()
elif node.isa[MatrixMap[residual]]():
return func[node[MatrixMap[residual]]]()
elif node.isa[MatrixMap[gelu]]():
return func[node[MatrixMap[gelu]]]()
else:
abort("Unknown op")
while True:
pass
fn MatrixOp_runtime_dispatch_source_index_if_map(
node: MatrixOpVariant,
) -> Optional[Int]:
@parameter
fn source_index_if_map[T: MatrixOp](node: T) -> Optional[Int]:
return node.source_index_if_map()
return MatrixOp_dispatch_runtime[R=Optional[Int], source_index_if_map](node)
fn MatrixOp_runtime_dispatch_get_result_rows(
node: MatrixOpVariant,
) -> Int:
@parameter
fn get_result_rows[T: MatrixOp](node: T) -> Int:
return node.get_result_rows()
return MatrixOp_dispatch_runtime[R=Int, get_result_rows](node)
fn MatrixOp_runtime_dispatch_get_result_cols(
node: MatrixOpVariant,
) -> Int:
@parameter
fn get_result_cols[T: MatrixOp](node: T) -> Int:
return node.get_result_cols()
return MatrixOp_dispatch_runtime[R=Int, get_result_cols](node)
@always_inline
fn MatrixOp_comptime_dispatch_get_elementwise_scalar_op[
nodes: List[MatrixOpVariant],
node: MatrixOpVariant
]() -> Optional[ScalarOpVariant]:
@parameter
if node.isa[MatrixMap[residual]]():
var scalar_op = ScalarOpFunc[residual]()
var variant = ScalarOpVariant(scalar_op^)
return Optional[ScalarOpVariant](variant^)
elif node.isa[MatrixMap[gelu]]():
var scalar_op = ScalarOpFunc[gelu]()
var variant = ScalarOpVariant(scalar_op^)
return Optional[ScalarOpVariant](variant^)
else:
return None
@always_inline
fn MatrixOp_comptime_dispatch_execute_comptime[
T: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //,
nodes: List[MatrixOpVariant],
node: MatrixOpVariant
](
*inputs: *T
) -> Matrix[
MatrixOp_runtime_dispatch_get_result_rows(node),
MatrixOp_runtime_dispatch_get_result_cols(node)
]:
alias Result = Matrix[
MatrixOp_runtime_dispatch_get_result_rows(node),
MatrixOp_runtime_dispatch_get_result_cols(node)
]
@parameter
fn execute_comptime[T: MatrixOp, //, specific_node: T]() -> Result:
# TODO: remove the nodes= and self_=
var result = specific_node.execute_comptime[nodes=nodes, self_=specific_node](inputs)
# TODO: find some way to convince it these are the same
return Matrix[
MatrixOp_runtime_dispatch_get_result_rows(node),
MatrixOp_runtime_dispatch_get_result_cols(node)
](result^)
return MatrixOp_dispatch_comptime[R=Result, execute_comptime, node]()
# TODO: Use `MatrixOp_comptime_dispatch_execute_comptime` instead, but that
# currently causes a compiler crash
@always_inline
fn MatrixOp_comptime_dispatch_execute_comptime_old[
T: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //,
nodes: List[MatrixOpVariant],
node: MatrixOpVariant,
](
inputs: VariadicPack[_, _, MatrixTrait, *T],
) -> Matrix[
MatrixOp_runtime_dispatch_get_result_rows(node),
MatrixOp_runtime_dispatch_get_result_cols(node)
]:
@parameter
if node.isa[MatrixLit]():
alias lit_node = node[MatrixLit]
var result = lit_node.execute_comptime[nodes, lit_node](inputs)
# TODO: find some way to convince it these are the same
return Matrix[
MatrixOp_runtime_dispatch_get_result_rows(node),
MatrixOp_runtime_dispatch_get_result_cols(node)
](result^)
elif node.isa[MatrixMul]():
alias mul_node = node[MatrixMul]
var result = mul_node.execute_comptime[nodes, mul_node](inputs)
# TODO: find some way to convince it these are the same
return Matrix[
MatrixOp_runtime_dispatch_get_result_rows(node),
MatrixOp_runtime_dispatch_get_result_cols(node)
](result^)
elif node.isa[MatrixMap[residual]]():
alias map_node = node[MatrixMap[residual]]
# TODO: had to put nodes= and self_= because a bug, it was getting
# confused about parameter binding order
var result = map_node.execute_comptime[nodes=nodes, self_=map_node](inputs)
# TODO: find some way to convince it these are the same
return Matrix[
MatrixOp_runtime_dispatch_get_result_rows(node),
MatrixOp_runtime_dispatch_get_result_cols(node)
](result^)
elif node.isa[MatrixMap[gelu]]():
alias map_node = node[MatrixMap[gelu]]
var result = map_node.execute_comptime[nodes=nodes, self_=map_node](inputs)
return Matrix[
MatrixOp_runtime_dispatch_get_result_rows(node),
MatrixOp_runtime_dispatch_get_result_cols(node)
](result^)
else:
abort("Unknown op")
while True:
pass
# TODO: Get rid of this
@always_inline
fn MatrixOp_comptime_dispatch_execute_comptime_old_variadic[
T: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //,
nodes: List[MatrixOpVariant],
node: MatrixOpVariant,
](
*inputs: *T,
) -> Matrix[
MatrixOp_runtime_dispatch_get_result_rows(node),
MatrixOp_runtime_dispatch_get_result_cols(node)
]:
return MatrixOp_comptime_dispatch_execute_comptime_old[nodes, node](inputs)
fn MatrixOp_runtime_dispatch_optimize_old(
in_nodes: List[MatrixOpVariant],
mut out_nodes: List[MatrixOpVariant],
node_index: Int,
rotate_matrix_multiply_chains: Bool,
fuse_matrix_multiply_map: Bool
) -> Int:
var node = in_nodes[node_index]
if node.isa[MatrixLit]():
var lit_node = node[MatrixLit].copy()
return lit_node.optimize(
in_nodes,
out_nodes,
node_index,
rotate_matrix_multiply_chains,
fuse_matrix_multiply_map)
elif node.isa[MatrixMul]():
var mul_node = node[MatrixMul].copy()
return mul_node.optimize(
in_nodes,
out_nodes,
node_index,
rotate_matrix_multiply_chains,
fuse_matrix_multiply_map)
elif node.isa[MatrixMap[residual]]():
var map_node = node[MatrixMap[residual]].copy()
# TODO: had to put nodes= and self_= because a bug, it was getting
# confused about parameter binding order
return map_node.optimize(
in_nodes,
out_nodes,
node_index,
rotate_matrix_multiply_chains,
fuse_matrix_multiply_map)
elif node.isa[MatrixMap[gelu]]():
var map_node = node[MatrixMap[gelu]].copy()
# TODO: had to put nodes= and self_= because a bug, it was getting
# confused about parameter binding order
return map_node.optimize(
in_nodes,
out_nodes,
node_index,
rotate_matrix_multiply_chains,
fuse_matrix_multiply_map)
else:
abort("Unknown op")
while True:
pass
# ===== END MatrixOp existential code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment