Created
November 1, 2025 04:54
-
-
Save VerdagonModular/093557f5c0fd424ab44d7e8ab5db7858 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ./bazelw run //KGEN/tools/mojo -- run run_staging_fma.mojo | |
| from collections import List | |
| from utils import Variant | |
| from os import abort | |
| from benchmark import keep | |
| from benchmark.memory import clobber_memory | |
| from time import perf_counter_ns | |
| from math import tanh | |
| # options for elementwise fusion: | |
| # - alias is_elementwise? | |
| # - see if can return a Optional[impl of trait ElementwiseMatrixOp] | |
| # for itself? | |
| # rust iterators can already enable this generally. | |
| # what is it we're doing special? | |
| # we can do it when recursion is involved i suppose? | |
| # we can do it on an arbitrary tree? | |
| # can't rust do it for an arbitrary tree? wasnt the difference that we | |
| # could rearrange things at compile time? | |
| # in rust, can we do a .map() | |
| trait MatrixTrait(Copyable): | |
| alias rows: Int | |
| alias cols: Int | |
| @always_inline | |
| fn __getitem__(ref self, i: Int, j: Int) -> Float32: | |
| ... | |
| @fieldwise_init | |
| struct Matrix[Rows: Int, Cols: Int](MatrixTrait, Movable, Copyable): | |
| alias rows: Int = Rows | |
| alias cols: Int = Cols | |
| var values: List[Float32] | |
| fn __init__(out self): | |
| self.values = List[Float32]() | |
| for _ in range(0, Rows): | |
| for _ in range(0, Cols): | |
| self.values.append(0) | |
| fn __init__[OtherRows: Int, OtherCols: Int]( | |
| out self, | |
| deinit other: Matrix[OtherRows, OtherCols] | |
| ): | |
| constrained[Rows == OtherRows]() | |
| constrained[Cols == OtherCols]() | |
| self.values = other.values^ | |
| fn copy(self) -> Self: | |
| return Self(self.values.copy()) | |
| @always_inline | |
| fn __getitem__(ref self, i: Int, j: Int) -> Float32: | |
| return self.values[i * Self.cols + j] | |
| trait MatrixOp: | |
| fn source_index_if_map(self) -> Optional[Int]: | |
| ... | |
| fn get_result_rows(self) -> Int: | |
| ... | |
| fn get_result_cols(self) -> Int: | |
| ... | |
| @staticmethod | |
| @always_inline | |
| fn execute_comptime[ | |
| # TODO: better syntax for this | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], | |
| //, | |
| nodes: List[MatrixOpVariant], | |
| self_: Self | |
| ]( | |
| inputs: VariadicPack[_, _, MatrixTrait, *InputTypes] | |
| ) -> Matrix[ | |
| self_.get_result_rows(), | |
| self_.get_result_cols() | |
| ]: | |
| ... | |
| # TODO: This leaves extra nodes in the out_nodes list, let's avoid that | |
| fn optimize( | |
| self, | |
| in_nodes: List[MatrixOpVariant], | |
| mut out_nodes: List[MatrixOpVariant], | |
| node_index: Int, | |
| rotate_matrix_multiply_chains: Bool, | |
| fuse_matrix_multiply_map: Bool | |
| ) -> Int: | |
| ... | |
| # TODO: detect when theres a MatrixMap containing a MatrixMul, and then | |
| # fuse them into a MatrixMulMap. | |
| @fieldwise_init | |
| struct MatrixMap[ | |
| F: fn[ | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`] | |
| ]( | |
| VariadicPack[_, _, MatrixTrait, *InputTypes], | |
| Int, | |
| Int, | |
| Float32 | |
| )->Float32 | |
| ](MatrixOp, Copyable, Movable): | |
| var source_index: Int | |
| var result_rows: Int | |
| var result_cols: Int | |
| fn source_index_if_map(self) -> Optional[Int]: | |
| return self.source_index | |
| fn get_result_rows(self) -> Int: | |
| return self.result_rows | |
| fn get_result_cols(self) -> Int: | |
| return self.result_cols | |
| @staticmethod | |
| # @always_inline TODO: Had to remove this because it thought there was a cycle | |
| fn execute_comptime[ | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //, | |
| nodes: List[MatrixOpVariant], | |
| self_: Self | |
| ]( | |
| inputs: VariadicPack[_, _, MatrixTrait, *InputTypes] | |
| ) -> Matrix[ | |
| self_.get_result_rows(), | |
| self_.get_result_cols() | |
| ]: | |
| alias source_node = nodes[self_.source_index] | |
| alias source_rows = MatrixOp_runtime_dispatch_get_result_rows(source_node) | |
| alias source_cols = MatrixOp_runtime_dispatch_get_result_cols(source_node) | |
| var source_val: Matrix[source_rows, source_cols] = | |
| MatrixOp_comptime_dispatch_execute_comptime_old[nodes, source_node](inputs) | |
| var result = Matrix[self_.get_result_rows(), self_.get_result_cols()]() | |
| # TODO: skip this zero initing | |
| for _ in range(source_rows * source_cols): | |
| result.values.append(0.0) | |
| for i in range(source_rows): | |
| for j in range(source_cols): | |
| result.values[i * source_cols + j] = F( | |
| inputs, | |
| i, | |
| j, | |
| source_val.values[i * source_cols + j] | |
| ) | |
| return result^ | |
| # TODO: This leaves extra nodes in the out_nodes list, let's avoid that | |
| fn optimize( | |
| self, | |
| in_nodes: List[MatrixOpVariant], | |
| mut out_nodes: List[MatrixOpVariant], | |
| node_index: Int, | |
| rotate_matrix_multiply_chains: Bool, | |
| fuse_matrix_multiply_map: Bool | |
| ) -> Int: | |
| var node = in_nodes[node_index].copy() | |
| var source_if_is_map = MatrixOp_runtime_dispatch_source_index_if_map(node) | |
| var source_index = source_if_is_map.value() | |
| var map_source_index = | |
| MatrixOp_runtime_dispatch_optimize_old( | |
| in_nodes, out_nodes, source_index, rotate_matrix_multiply_chains, fuse_matrix_multiply_map) | |
| var map_source_node = out_nodes[map_source_index].copy() | |
| if fuse_matrix_multiply_map: | |
| if map_source_node.isa[MatrixMul](): | |
| var mul_node = map_source_node[MatrixMul].copy() | |
| mul_node.elementwise_tails.append(map_source_index) | |
| out_nodes[map_source_index] = mul_node^ | |
| return map_source_index | |
| out_nodes.append( | |
| MatrixMap[F]( | |
| map_source_index, | |
| MatrixOp_runtime_dispatch_get_result_rows(node), | |
| MatrixOp_runtime_dispatch_get_result_cols(node))) | |
| return len(out_nodes) - 1 | |
| @fieldwise_init | |
| struct MatrixMul(MatrixOp, Copyable, Movable): | |
| var left_index: Int | |
| var right_index: Int | |
| var elementwise_tails: List[Int] | |
| var result_rows: Int | |
| var result_cols: Int | |
| fn source_index_if_map(self) -> Optional[Int]: | |
| return None | |
| fn get_result_rows(self) -> Int: | |
| return self.result_rows | |
| fn get_result_cols(self) -> Int: | |
| return self.result_cols | |
| @staticmethod | |
| @always_inline | |
| fn execute_comptime[ | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //, | |
| nodes: List[MatrixOpVariant], | |
| self_: Self | |
| ]( | |
| inputs: VariadicPack[_, _, MatrixTrait, *InputTypes] | |
| ) -> Matrix[ | |
| self_.get_result_rows(), | |
| self_.get_result_cols() | |
| ]: | |
| alias left_node = nodes[self_.left_index] | |
| alias right_node = nodes[self_.right_index] | |
| alias left_rows = MatrixOp_runtime_dispatch_get_result_rows(left_node) | |
| alias left_cols = MatrixOp_runtime_dispatch_get_result_cols(left_node) | |
| alias right_rows = MatrixOp_runtime_dispatch_get_result_rows(right_node) | |
| alias right_cols = MatrixOp_runtime_dispatch_get_result_cols(right_node) | |
| var left_val: Matrix[left_rows, left_cols] = | |
| MatrixOp_comptime_dispatch_execute_comptime_old[nodes, left_node](inputs) | |
| var right_val_type_not_corrected: Matrix[right_rows, right_cols] = | |
| MatrixOp_comptime_dispatch_execute_comptime_old[nodes, right_node](inputs) | |
| constrained[left_cols == right_rows]() | |
| alias left_cols_right_rows = left_cols | |
| var result = Matrix[self_.get_result_rows(), self_.get_result_cols()]() | |
| # TODO: skip this zero initing | |
| for _ in range(left_rows * right_cols): | |
| result.values.append(0.0) | |
| for i in range(left_rows): | |
| for j in range(right_cols): | |
| var sum = Float32(0.0) | |
| for k in range(left_cols_right_rows): | |
| sum += left_val.values[i * left_cols_right_rows + k] * right_val_type_not_corrected.values[k * right_cols + j] | |
| @parameter | |
| for z in self_.elementwise_tails: | |
| alias elementwise_tails_node = nodes[z] | |
| alias elementwise_tails_node_elementwised_opt = | |
| MatrixOp_comptime_dispatch_get_elementwise_scalar_op[nodes, elementwise_tails_node]() | |
| @parameter | |
| if elementwise_tails_node_elementwised_opt: | |
| alias elementwise_tails_node_elementwised = elementwise_tails_node_elementwised_opt.value() | |
| sum = ScalarOp_comptime_dispatch_execute_comptime[nodes, elementwise_tails_node_elementwised](inputs, i, j, sum) | |
| result.values[i * right_cols + j] = sum | |
| return result^ | |
| # TODO: This leaves extra nodes in the out_nodes list, let's avoid that | |
| fn optimize( | |
| self, | |
| in_nodes: List[MatrixOpVariant], | |
| mut out_nodes: List[MatrixOpVariant], | |
| node_index: Int, | |
| rotate_matrix_multiply_chains: Bool, | |
| fuse_matrix_multiply_map: Bool | |
| ) -> Int: | |
| var node = in_nodes[node_index].copy() | |
| var source_if_is_map = MatrixOp_runtime_dispatch_source_index_if_map(node) | |
| # TODO: reordered and switched a lot of these refs to vars to avoid some | |
| # memory unsafety that was actually caught in the interpreter. we could | |
| # really use actual memory safety. | |
| var mul_node = node[MatrixMul].copy() | |
| var mul_left_index = MatrixOp_runtime_dispatch_optimize_old(in_nodes, out_nodes, mul_node.left_index, rotate_matrix_multiply_chains, fuse_matrix_multiply_map) | |
| var mul_right_index = MatrixOp_runtime_dispatch_optimize_old(in_nodes, out_nodes, mul_node.right_index, rotate_matrix_multiply_chains, fuse_matrix_multiply_map) | |
| var mul_left_node = out_nodes[mul_left_index].copy() | |
| var mul_right_node = out_nodes[mul_right_index].copy() | |
| if rotate_matrix_multiply_chains: | |
| if mul_left_node.isa[MatrixMul](): | |
| var mul_left_mul_node = mul_left_node[MatrixMul].copy() | |
| var a_index = mul_left_mul_node.left_index | |
| var a_node = out_nodes[a_index].copy() | |
| var b_index = mul_left_mul_node.right_index | |
| var b_node = out_nodes[b_index].copy() | |
| var c_index = mul_right_index | |
| var c_node = mul_right_node.copy() | |
| # If A is 10 × 30, B is 30 × 5, and C is 5 × 60, then: | |
| # - Computing (AB)C would need (10×30×5) + (10×5×60) = 4500 ops | |
| # - Computing A(BC) would need (30×5×60) + (10×30×60) = 27000 ops | |
| var ab_c_total_computations = | |
| MatrixOp_runtime_dispatch_get_result_rows(a_node) * MatrixOp_runtime_dispatch_get_result_cols(a_node) * MatrixOp_runtime_dispatch_get_result_cols(b_node) + | |
| MatrixOp_runtime_dispatch_get_result_rows(a_node) * MatrixOp_runtime_dispatch_get_result_cols(b_node) * MatrixOp_runtime_dispatch_get_result_cols(c_node) | |
| var a_bc_total_computations = | |
| MatrixOp_runtime_dispatch_get_result_rows(b_node) * MatrixOp_runtime_dispatch_get_result_cols(b_node) * MatrixOp_runtime_dispatch_get_result_cols(c_node) + | |
| MatrixOp_runtime_dispatch_get_result_rows(a_node) * MatrixOp_runtime_dispatch_get_result_cols(a_node) * MatrixOp_runtime_dispatch_get_result_cols(c_node) | |
| if ab_c_total_computations <= a_bc_total_computations: | |
| # Keep current structure, already does (a * b) * c | |
| pass | |
| else: | |
| # Reorder/rotate! | |
| out_nodes.append(MatrixMul(b_index, c_index, List[Int](), MatrixOp_runtime_dispatch_get_result_rows(b_node), MatrixOp_runtime_dispatch_get_result_cols(c_node))) | |
| var intermediate_index = len(out_nodes) - 1 | |
| out_nodes.append(MatrixMul(a_index, intermediate_index, List[Int](), MatrixOp_runtime_dispatch_get_result_rows(a_node), MatrixOp_runtime_dispatch_get_result_cols(c_node))) | |
| var result_index = len(out_nodes) - 1 | |
| return result_index | |
| # TODO: the above reorders (AB)C to A(BC), we need similar handling | |
| # for making A(BC) to (AB)C | |
| # Keep existing structure | |
| out_nodes.append( | |
| MatrixMul( | |
| mul_left_index, mul_right_index, List[Int](), | |
| mul_node.get_result_rows(), mul_node.get_result_cols())) | |
| var result_index = len(out_nodes) - 1 | |
| return result_index | |
| # Tried putting the vector itself in here, was less separable. | |
| # Rule of thumb?: Have AST refer to inputs somewhere else / symbolically | |
| @fieldwise_init | |
| struct MatrixLit(MatrixOp, Copyable, Movable): | |
| var input_index: Int | |
| var rows: Int | |
| var cols: Int | |
| fn source_index_if_map(self) -> Optional[Int]: | |
| return None | |
| fn get_result_rows(self) -> Int: | |
| return self.rows | |
| fn get_result_cols(self) -> Int: | |
| return self.cols | |
| @staticmethod | |
| @always_inline | |
| fn execute_comptime[ | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //, | |
| nodes: List[MatrixOpVariant], | |
| self_: Self | |
| ]( | |
| inputs: VariadicPack[_, _, MatrixTrait, *InputTypes] | |
| ) -> Matrix[ | |
| self_.get_result_rows(), | |
| self_.get_result_cols() | |
| ]: | |
| ref input_ref = inputs[self_.input_index] | |
| # I don't think we can avoid this rebind, *something* needs to check | |
| # that the input_index'th element matches the shape we expect. | |
| return rebind[Matrix[ | |
| self_.get_result_rows(), | |
| self_.get_result_cols() | |
| ]](input_ref).copy() | |
| # TODO: This leaves extra nodes in the out_nodes list, let's avoid that | |
| fn optimize( | |
| self, | |
| in_nodes: List[MatrixOpVariant], | |
| mut out_nodes: List[MatrixOpVariant], | |
| node_index: Int, | |
| rotate_matrix_multiply_chains: Bool, | |
| fuse_matrix_multiply_map: Bool | |
| ) -> Int: | |
| var node = in_nodes[node_index].copy() | |
| var source_if_is_map = MatrixOp_runtime_dispatch_source_index_if_map(node) | |
| var lit_node = node[MatrixLit].copy() | |
| out_nodes.append(lit_node.copy()) | |
| return len(out_nodes) - 1 | |
| @fieldwise_init | |
| struct CompileResult: | |
| var nodes: List[MatrixOpVariant] | |
| var root_index: Int | |
| fn compile( | |
| rotate_matrix_multiply_chains: Bool, | |
| fuse_matrix_multiply_map: Bool | |
| ) -> CompileResult: | |
| var nodes = List[MatrixOpVariant]() | |
| # L1 data cache: 65KB, L2 cache: 4MB | |
| # Using larger matrices to exceed cache and see fusion effects | |
| # If A is 200 × 100, B is 100 × 200, and C is 200 × 100, then: | |
| # - Computing (AB)C would need (200×100×200) + (200×200×100) = 100M + 100M = 200M ops | |
| # - Computing A(BC) would need (100×200×100) + (200×100×100) = 10M + 10M = 20M ops | |
| # So A(BC) is 10x faster. | |
| # Matrix sizes: A=200×100×4=400KB, B=100×200×4=400KB, C=200×100×4=400KB, total=1.2MB | |
| alias a_rows = 200 | |
| alias a_cols = 100 | |
| alias b_rows = 100 | |
| alias b_cols = 200 | |
| alias c_rows = 200 | |
| alias c_cols = 100 | |
| nodes.append(MatrixLit(0, a_rows, a_cols)) | |
| nodes.append(MatrixLit(1, b_rows, b_cols)) | |
| nodes.append(MatrixLit(2, c_rows, c_cols)) | |
| nodes.append(MatrixMul(0, 1, List[Int](), a_rows, b_cols)) | |
| nodes.append(MatrixMul(3, 2, List[Int](), a_rows, c_cols)) | |
| # TODO: had an infinite recursion interpreter crash because i put | |
| # 5 here instead of 4. no symptoms, just crash, impossible to figure | |
| # out without firing up the debugger. | |
| nodes.append(MatrixMap[residual](4, a_rows, c_cols)) | |
| nodes.append(MatrixMap[gelu](5, a_rows, c_cols)) | |
| var optimized_nodes = List[MatrixOpVariant]() | |
| var root_index = | |
| MatrixOp_runtime_dispatch_optimize_old( | |
| nodes, | |
| optimized_nodes, | |
| len(nodes) - 1, | |
| rotate_matrix_multiply_chains, | |
| fuse_matrix_multiply_map) | |
| return CompileResult(optimized_nodes^, root_index) | |
| fn describe_node(node: MatrixOpVariant) -> String: | |
| """Return a human-readable description of a node.""" | |
| if node.isa[MatrixLit](): | |
| var lit = node[MatrixLit].copy() | |
| return "MatrixLit(" + String(lit.get_result_rows()) + "x" + String(lit.get_result_cols()) + ")" | |
| elif node.isa[MatrixMul](): | |
| var mul = node[MatrixMul].copy() | |
| var tails = " +tails[" + String(len(mul.elementwise_tails)) + "]" if len(mul.elementwise_tails) > 0 else "" | |
| return "MatrixMul(left=" + String(mul.left_index) + ", right=" + String(mul.right_index) + tails + ")" | |
| elif node.isa[MatrixMap[residual]](): | |
| var map_node = node[MatrixMap[residual]].copy() | |
| return "MatrixMap[residual](source=" + String(map_node.source_index) + ")" | |
| elif node.isa[MatrixMap[gelu]](): | |
| var map_node = node[MatrixMap[gelu]].copy() | |
| return "MatrixMap[gelu](source=" + String(map_node.source_index) + ")" | |
| else: | |
| return "Unknown" | |
| fn assert_graph_equals(actual: CompileResult, expected_description: String, test_name: String): | |
| """Compare actual graph against expected description.""" | |
| print("\n--- Checking:", test_name, "---") | |
| # Build actual description | |
| var actual_desc = "" | |
| for i in range(len(actual.nodes)): | |
| if i > 0: | |
| actual_desc += " | " | |
| actual_desc += String(i) + ":" + describe_node(actual.nodes[i]) | |
| actual_desc += " | root=" + String(actual.root_index) | |
| print("Expected:", expected_description) | |
| print("Actual: ", actual_desc) | |
| if actual_desc != expected_description: | |
| print("❌ MISMATCH!") | |
| abort("Graph mismatch for " + test_name) | |
| else: | |
| print("✓ Match!") | |
| fn main(): | |
| # Shows how, even though we dont know the graph until runtime, | |
| # we can still do some optimization at runtime to turn one runtime | |
| # graph into another runtime. | |
| # TODO: use this instead of the below chunk, currently causes behavior | |
| # differences (causes us to hit our panic for "Unknown Op") | |
| # benchmark_runtime_fma_fused() | |
| alias a_rows = 200 | |
| alias a_cols = 100 | |
| alias b_rows = 100 | |
| alias b_cols = 200 | |
| alias c_rows = 200 | |
| alias c_cols = 100 | |
| var matrix_a = Matrix[a_rows, a_cols]() | |
| var matrix_b = Matrix[b_rows, b_cols]() | |
| var matrix_c = Matrix[c_rows, c_cols]() | |
| var matrix_d = Matrix[a_rows, c_cols]() # Fourth matrix for elementwise addition | |
| var result_rt_unopt = compile(False, False) | |
| assert_graph_equals( | |
| result_rt_unopt, | |
| "0:MatrixLit(200x100) | 1:MatrixLit(100x200) | 2:MatrixMul(left=0, right=1) | 3:MatrixLit(200x100) | 4:MatrixMul(left=2, right=3) | 5:MatrixMap[residual](source=4) | 6:MatrixMap[gelu](source=5) | root=6", | |
| "Unoptimized" | |
| ) | |
| var result_rt_rotate_only = compile(True, False) | |
| assert_graph_equals( | |
| result_rt_rotate_only, | |
| "0:MatrixLit(200x100) | 1:MatrixLit(100x200) | 2:MatrixMul(left=0, right=1) | 3:MatrixLit(200x100) | 4:MatrixMul(left=1, right=3) | 5:MatrixMul(left=0, right=4) | 6:MatrixMap[residual](source=5) | 7:MatrixMap[gelu](source=6) | root=7", | |
| "Rotate only" | |
| ) | |
| var result_rt_fuse_only = compile(False, True) | |
| assert_graph_equals( | |
| result_rt_fuse_only, | |
| "0:MatrixLit(200x100) | 1:MatrixLit(100x200) | 2:MatrixMul(left=0, right=1) | 3:MatrixLit(200x100) | 4:MatrixMul(left=2, right=3 +tails[2]) | root=4", | |
| "Fuse only" | |
| ) | |
| var result_rt_both = compile(True, True) | |
| assert_graph_equals( | |
| result_rt_both, | |
| "0:MatrixLit(200x100) | 1:MatrixLit(100x200) | 2:MatrixMul(left=0, right=1) | 3:MatrixLit(200x100) | 4:MatrixMul(left=1, right=3) | 5:MatrixMul(left=0, right=4 +tails[2]) | root=5", | |
| "Rotate + Fuse" | |
| ) | |
| print("Starting...") | |
| alias result_ct_unopt = compile(False, False) | |
| alias root_node_ct_unopt = result_ct_unopt.nodes[result_ct_unopt.root_index] | |
| var elapsed_ns_ct_unopt = 0 | |
| for _ in range(0, 4000): | |
| clobber_memory() # TODO: I don't actually know what clobber_memory does | |
| var start_ct_unopt = perf_counter_ns() | |
| var result = MatrixOp_comptime_dispatch_execute_comptime_old_variadic[result_ct_unopt.nodes, root_node_ct_unopt](matrix_a, matrix_b, matrix_c, matrix_d) | |
| keep(len(result.values) > 0) | |
| var end_ct_unopt = perf_counter_ns() | |
| elapsed_ns_ct_unopt += (end_ct_unopt - start_ct_unopt) | |
| var elapsed_ms_ct_unopt = Float64(elapsed_ns_ct_unopt) / 1_000_000.0 | |
| print("Comptime (rotate=False, fuse=False): ", elapsed_ms_ct_unopt, " ms") | |
| alias result_ct_rotate = compile(True, False) | |
| alias root_node_ct_rotate = result_ct_rotate.nodes[result_ct_rotate.root_index] | |
| var elapsed_ns_ct_rotate = 0 | |
| for _ in range(0, 4000): | |
| clobber_memory() # TODO: I don't actually know what clobber_memory does | |
| var start_ct_rotate = perf_counter_ns() | |
| var result = MatrixOp_comptime_dispatch_execute_comptime_old_variadic[result_ct_rotate.nodes, root_node_ct_rotate](matrix_a, matrix_b, matrix_c, matrix_d) | |
| keep(len(result.values) > 0) | |
| var end_ct_rotate = perf_counter_ns() | |
| elapsed_ns_ct_rotate += (end_ct_rotate - start_ct_rotate) | |
| var elapsed_ms_ct_rotate = Float64(elapsed_ns_ct_rotate) / 1_000_000.0 | |
| print("Comptime (rotate=True, fuse=False): ", elapsed_ms_ct_rotate, " ms") | |
| alias result_ct_fuse = compile(False, True) | |
| alias root_node_ct_fuse = result_ct_fuse.nodes[result_ct_fuse.root_index] | |
| var elapsed_ns_ct_fuse = 0 | |
| for _ in range(0, 4000): | |
| clobber_memory() # TODO: I don't actually know what clobber_memory does | |
| var start_ct_fuse = perf_counter_ns() | |
| var result = MatrixOp_comptime_dispatch_execute_comptime_old_variadic[result_ct_fuse.nodes, root_node_ct_fuse](matrix_a, matrix_b, matrix_c, matrix_d) | |
| keep(len(result.values) > 0) | |
| var end_ct_fuse = perf_counter_ns() | |
| elapsed_ns_ct_fuse += (end_ct_fuse - start_ct_fuse) | |
| var elapsed_ms_ct_fuse = Float64(elapsed_ns_ct_fuse) / 1_000_000.0 | |
| print("Comptime (rotate=False, fuse=True): ", elapsed_ms_ct_fuse, " ms") | |
| alias result_ct_both = compile(True, True) | |
| alias root_node_ct_both = result_ct_both.nodes[result_ct_both.root_index] | |
| var elapsed_ns_ct_both = 0 | |
| for _ in range(0, 4000): | |
| clobber_memory() # TODO: I don't actually know what clobber_memory does | |
| var start_ct_both = perf_counter_ns() | |
| var result = MatrixOp_comptime_dispatch_execute_comptime_old_variadic[result_ct_both.nodes, root_node_ct_both](matrix_a, matrix_b, matrix_c, matrix_d) | |
| keep(len(result.values) > 0) | |
| var end_ct_both = perf_counter_ns() | |
| elapsed_ns_ct_both += (end_ct_both - start_ct_both) | |
| var elapsed_ms_ct_both = Float64(elapsed_ns_ct_both) / 1_000_000.0 | |
| print("Comptime (rotate=True, fuse=True): ", elapsed_ms_ct_both, " ms") | |
| # Shows doing some fusion for a custom op with a runtime graph. | |
| # expected: really slow | |
| # Shows doing some fusion for a custom op with a comptime graph. | |
| # expected: actually inlines, super fast | |
| print("Done!") | |
| fn residual[ | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`] | |
| ]( | |
| inputs: VariadicPack[_, _, MatrixTrait, *InputTypes], | |
| i: Int, | |
| j: Int, | |
| x: Float32 | |
| ) -> Float32: | |
| # TODO: really confusing error here when i forgot to pass in the fourth | |
| # matrix. lots of crazy rebind problems in the elaborator. | |
| return x + inputs[3][i, j] | |
| fn gelu[ | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`] | |
| ]( | |
| inputs: VariadicPack[_, _, MatrixTrait, *InputTypes], | |
| i: Int, | |
| j: Int, | |
| x: Float32 | |
| ) -> Float32: | |
| # Just apply GELU activation | |
| alias sqrt_2_over_pi = 0.7978845608028654 | |
| alias coeff = 0.044715 | |
| var x_cubed = x * x * x | |
| var tanh_arg = sqrt_2_over_pi * (x + coeff * x_cubed) | |
| # Simple tanh approximation | |
| var tanh_val: Float32 | |
| if abs(tanh_arg) < 3.0: | |
| var x2 = tanh_arg * tanh_arg | |
| tanh_val = tanh_arg * (27.0 + x2) / (27.0 + 9.0 * x2) | |
| else: | |
| tanh_val = Float32(1.0) if tanh_arg > 0 else Float32(-1.0) | |
| return x * 0.5 * (1.0 + tanh_val) | |
| # ===== BEGIN ScalarOp code | |
| trait ScalarOp: | |
| @staticmethod | |
| @always_inline | |
| fn execute_comptime[ | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`] | |
| ]( | |
| inputs: VariadicPack[_, _, MatrixTrait, *InputTypes], | |
| i: Int, | |
| j: Int, | |
| x: Float32 | |
| ) -> Float32: | |
| ... | |
| @fieldwise_init | |
| struct ScalarOpFunc[ | |
| F: fn[ | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`] | |
| ]( | |
| VariadicPack[_, _, MatrixTrait, *InputTypes], | |
| Int, | |
| Int, | |
| Float32 | |
| )->Float32 | |
| ](ScalarOp, Copyable, Movable): | |
| @staticmethod | |
| @always_inline | |
| fn execute_comptime[ | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`] | |
| ]( | |
| inputs: VariadicPack[_, _, MatrixTrait, *InputTypes], | |
| i: Int, | |
| j: Int, | |
| x: Float32 | |
| ) -> Float32: | |
| return F(inputs, i, j, x) | |
| alias ScalarOpVariant = Variant[ScalarOpFunc[residual], ScalarOpFunc[gelu]] | |
| fn ScalarOp_comptime_dispatch_execute_comptime[ | |
| InputTypes: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //, | |
| nodes: List[MatrixOpVariant], | |
| scalar_op: ScalarOpVariant | |
| ]( | |
| inputs: VariadicPack[_, _, MatrixTrait, *InputTypes], | |
| i: Int, | |
| j: Int, | |
| x: Float32 | |
| ) -> Float32: | |
| @parameter | |
| if scalar_op.isa[ScalarOpFunc[residual]](): | |
| return scalar_op[ScalarOpFunc[residual]].execute_comptime(inputs, i, j, x) | |
| elif scalar_op.isa[ScalarOpFunc[gelu]](): | |
| return scalar_op[ScalarOpFunc[gelu]].execute_comptime(inputs, i, j, x) | |
| else: | |
| abort("Unknown scalar op") | |
| return 0.0 | |
| # ===== END ScalarOp code | |
| # ===== BEGIN MatrixOp existential code | |
| alias MatrixOpVariant = Variant[MatrixLit, MatrixMul, MatrixMap[residual], MatrixMap[gelu]] | |
| # TODO: look into a constrained Variant to maybe make this general | |
| fn MatrixOp_dispatch_runtime[ | |
| R: AnyType, //, | |
| func: fn[T: MatrixOp](x: T) capturing -> R | |
| ](node: MatrixOpVariant) -> R: | |
| if node.isa[MatrixLit](): | |
| return func(node[MatrixLit]) | |
| elif node.isa[MatrixMul](): | |
| return func(node[MatrixMul]) | |
| elif node.isa[MatrixMap[residual]](): | |
| return func(node[MatrixMap[residual]]) | |
| elif node.isa[MatrixMap[gelu]](): | |
| return func(node[MatrixMap[gelu]]) | |
| else: | |
| abort("Unknown op") | |
| while True: | |
| pass | |
| # TODO: look into a constrained Variant to maybe make this general | |
| @always_inline | |
| fn MatrixOp_dispatch_comptime[ | |
| R: AnyType, //, | |
| func: fn[T: MatrixOp, //, node: T]() capturing -> R, | |
| node: MatrixOpVariant | |
| ]() -> R: | |
| @parameter | |
| if node.isa[MatrixLit](): | |
| return func[node[MatrixLit]]() | |
| elif node.isa[MatrixMul](): | |
| return func[node[MatrixMul]]() | |
| elif node.isa[MatrixMap[residual]](): | |
| return func[node[MatrixMap[residual]]]() | |
| elif node.isa[MatrixMap[gelu]](): | |
| return func[node[MatrixMap[gelu]]]() | |
| else: | |
| abort("Unknown op") | |
| while True: | |
| pass | |
| fn MatrixOp_runtime_dispatch_source_index_if_map( | |
| node: MatrixOpVariant, | |
| ) -> Optional[Int]: | |
| @parameter | |
| fn source_index_if_map[T: MatrixOp](node: T) -> Optional[Int]: | |
| return node.source_index_if_map() | |
| return MatrixOp_dispatch_runtime[R=Optional[Int], source_index_if_map](node) | |
| fn MatrixOp_runtime_dispatch_get_result_rows( | |
| node: MatrixOpVariant, | |
| ) -> Int: | |
| @parameter | |
| fn get_result_rows[T: MatrixOp](node: T) -> Int: | |
| return node.get_result_rows() | |
| return MatrixOp_dispatch_runtime[R=Int, get_result_rows](node) | |
| fn MatrixOp_runtime_dispatch_get_result_cols( | |
| node: MatrixOpVariant, | |
| ) -> Int: | |
| @parameter | |
| fn get_result_cols[T: MatrixOp](node: T) -> Int: | |
| return node.get_result_cols() | |
| return MatrixOp_dispatch_runtime[R=Int, get_result_cols](node) | |
| @always_inline | |
| fn MatrixOp_comptime_dispatch_get_elementwise_scalar_op[ | |
| nodes: List[MatrixOpVariant], | |
| node: MatrixOpVariant | |
| ]() -> Optional[ScalarOpVariant]: | |
| @parameter | |
| if node.isa[MatrixMap[residual]](): | |
| var scalar_op = ScalarOpFunc[residual]() | |
| var variant = ScalarOpVariant(scalar_op^) | |
| return Optional[ScalarOpVariant](variant^) | |
| elif node.isa[MatrixMap[gelu]](): | |
| var scalar_op = ScalarOpFunc[gelu]() | |
| var variant = ScalarOpVariant(scalar_op^) | |
| return Optional[ScalarOpVariant](variant^) | |
| else: | |
| return None | |
| @always_inline | |
| fn MatrixOp_comptime_dispatch_execute_comptime[ | |
| T: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //, | |
| nodes: List[MatrixOpVariant], | |
| node: MatrixOpVariant | |
| ]( | |
| *inputs: *T | |
| ) -> Matrix[ | |
| MatrixOp_runtime_dispatch_get_result_rows(node), | |
| MatrixOp_runtime_dispatch_get_result_cols(node) | |
| ]: | |
| alias Result = Matrix[ | |
| MatrixOp_runtime_dispatch_get_result_rows(node), | |
| MatrixOp_runtime_dispatch_get_result_cols(node) | |
| ] | |
| @parameter | |
| fn execute_comptime[T: MatrixOp, //, specific_node: T]() -> Result: | |
| # TODO: remove the nodes= and self_= | |
| var result = specific_node.execute_comptime[nodes=nodes, self_=specific_node](inputs) | |
| # TODO: find some way to convince it these are the same | |
| return Matrix[ | |
| MatrixOp_runtime_dispatch_get_result_rows(node), | |
| MatrixOp_runtime_dispatch_get_result_cols(node) | |
| ](result^) | |
| return MatrixOp_dispatch_comptime[R=Result, execute_comptime, node]() | |
| # TODO: Use `MatrixOp_comptime_dispatch_execute_comptime` instead, but that | |
| # currently causes a compiler crash | |
| @always_inline | |
| fn MatrixOp_comptime_dispatch_execute_comptime_old[ | |
| T: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //, | |
| nodes: List[MatrixOpVariant], | |
| node: MatrixOpVariant, | |
| ]( | |
| inputs: VariadicPack[_, _, MatrixTrait, *T], | |
| ) -> Matrix[ | |
| MatrixOp_runtime_dispatch_get_result_rows(node), | |
| MatrixOp_runtime_dispatch_get_result_cols(node) | |
| ]: | |
| @parameter | |
| if node.isa[MatrixLit](): | |
| alias lit_node = node[MatrixLit] | |
| var result = lit_node.execute_comptime[nodes, lit_node](inputs) | |
| # TODO: find some way to convince it these are the same | |
| return Matrix[ | |
| MatrixOp_runtime_dispatch_get_result_rows(node), | |
| MatrixOp_runtime_dispatch_get_result_cols(node) | |
| ](result^) | |
| elif node.isa[MatrixMul](): | |
| alias mul_node = node[MatrixMul] | |
| var result = mul_node.execute_comptime[nodes, mul_node](inputs) | |
| # TODO: find some way to convince it these are the same | |
| return Matrix[ | |
| MatrixOp_runtime_dispatch_get_result_rows(node), | |
| MatrixOp_runtime_dispatch_get_result_cols(node) | |
| ](result^) | |
| elif node.isa[MatrixMap[residual]](): | |
| alias map_node = node[MatrixMap[residual]] | |
| # TODO: had to put nodes= and self_= because a bug, it was getting | |
| # confused about parameter binding order | |
| var result = map_node.execute_comptime[nodes=nodes, self_=map_node](inputs) | |
| # TODO: find some way to convince it these are the same | |
| return Matrix[ | |
| MatrixOp_runtime_dispatch_get_result_rows(node), | |
| MatrixOp_runtime_dispatch_get_result_cols(node) | |
| ](result^) | |
| elif node.isa[MatrixMap[gelu]](): | |
| alias map_node = node[MatrixMap[gelu]] | |
| var result = map_node.execute_comptime[nodes=nodes, self_=map_node](inputs) | |
| return Matrix[ | |
| MatrixOp_runtime_dispatch_get_result_rows(node), | |
| MatrixOp_runtime_dispatch_get_result_cols(node) | |
| ](result^) | |
| else: | |
| abort("Unknown op") | |
| while True: | |
| pass | |
| # TODO: Get rid of this | |
| @always_inline | |
| fn MatrixOp_comptime_dispatch_execute_comptime_old_variadic[ | |
| T: __mlir_type[`!kgen.variadic<`, MatrixTrait, `>`], //, | |
| nodes: List[MatrixOpVariant], | |
| node: MatrixOpVariant, | |
| ]( | |
| *inputs: *T, | |
| ) -> Matrix[ | |
| MatrixOp_runtime_dispatch_get_result_rows(node), | |
| MatrixOp_runtime_dispatch_get_result_cols(node) | |
| ]: | |
| return MatrixOp_comptime_dispatch_execute_comptime_old[nodes, node](inputs) | |
| fn MatrixOp_runtime_dispatch_optimize_old( | |
| in_nodes: List[MatrixOpVariant], | |
| mut out_nodes: List[MatrixOpVariant], | |
| node_index: Int, | |
| rotate_matrix_multiply_chains: Bool, | |
| fuse_matrix_multiply_map: Bool | |
| ) -> Int: | |
| var node = in_nodes[node_index] | |
| if node.isa[MatrixLit](): | |
| var lit_node = node[MatrixLit].copy() | |
| return lit_node.optimize( | |
| in_nodes, | |
| out_nodes, | |
| node_index, | |
| rotate_matrix_multiply_chains, | |
| fuse_matrix_multiply_map) | |
| elif node.isa[MatrixMul](): | |
| var mul_node = node[MatrixMul].copy() | |
| return mul_node.optimize( | |
| in_nodes, | |
| out_nodes, | |
| node_index, | |
| rotate_matrix_multiply_chains, | |
| fuse_matrix_multiply_map) | |
| elif node.isa[MatrixMap[residual]](): | |
| var map_node = node[MatrixMap[residual]].copy() | |
| # TODO: had to put nodes= and self_= because a bug, it was getting | |
| # confused about parameter binding order | |
| return map_node.optimize( | |
| in_nodes, | |
| out_nodes, | |
| node_index, | |
| rotate_matrix_multiply_chains, | |
| fuse_matrix_multiply_map) | |
| elif node.isa[MatrixMap[gelu]](): | |
| var map_node = node[MatrixMap[gelu]].copy() | |
| # TODO: had to put nodes= and self_= because a bug, it was getting | |
| # confused about parameter binding order | |
| return map_node.optimize( | |
| in_nodes, | |
| out_nodes, | |
| node_index, | |
| rotate_matrix_multiply_chains, | |
| fuse_matrix_multiply_map) | |
| else: | |
| abort("Unknown op") | |
| while True: | |
| pass | |
| # ===== END MatrixOp existential code |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment