-
-
Save qedawkins/ee0ca928634b5533b591ce804fa5e080 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// func.func @main(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %out: tensor<128x128xf32>) -> tensor<128x128xf32> { | |
// %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out: tensor<128x128xf32>) -> tensor<128x128xf32> | |
// return %0 : tensor<128x128xf32> | |
// } | |
// func.func @main() { | |
// %c0 = arith.constant 0 : index | |
// %binding0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf32>> | |
// %binding1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf32>> | |
// %binding2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> | |
// %arg0 = flow.dispatch.tensor.load %binding0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf32>> -> tensor<128x128xf32> | |
// %arg1 = flow.dispatch.tensor.load %binding1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf32>> -> tensor<128x128xf32> | |
// %arg2 = flow.dispatch.tensor.load %binding2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> -> tensor<128x128xf32> | |
// %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32> | |
// flow.dispatch.tensor.store %0, %binding2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> | |
// return | |
// } | |
module attributes { transform.with_named_sequence } { | |
transform.named_sequence @cleanup(%target: !transform.any_op {transform.readonly}) { | |
transform.apply_patterns to %target { | |
transform.apply_patterns.scf.for_loop_canonicalization | |
transform.apply_patterns.canonicalization | |
} : !transform.any_op | |
transform.iree.apply_licm %target : !transform.any_op | |
transform.apply_cse to %target : !transform.any_op | |
transform.yield | |
} | |
transform.named_sequence @__transform_main(%func: !transform.any_op) { | |
// Step 1. Find the matmul | |
// =========================================================================== | |
%matmul = transform.structured.match ops{["linalg.matmul"]} in %func : (!transform.any_op) -> !transform.any_op | |
// Step 2. Tile the matmul K loop | |
// =========================================================================== | |
%reduced_matmul, %for = transform.structured.tile_using_for %matmul [0, 0, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) | |
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> () | |
// Step 3. Promote the matmul operands. This allows tiling the | |
// =========================================================================== | |
%lhs_copy = transform.iree.copy_tensor_operand %reduced_matmul [0] : (!transform.any_op) -> (!transform.any_op) | |
%rhs_copy = transform.iree.copy_tensor_operand %reduced_matmul [1] : (!transform.any_op) -> (!transform.any_op) | |
// Step 4. Tile the matmul and copies to threads | |
// =========================================================================== | |
%thread_matmul, %matmul_forall = | |
transform.structured.tile_using_forall %reduced_matmul tile_sizes [16, 16] | |
( mapping = [#gpu.thread<y>, #gpu.thread<x>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) | |
%thread_lhs_copy, %lhs_copy_forall = | |
transform.structured.tile_using_forall %lhs_copy tile_sizes [2, 4] | |
( mapping = [#gpu.thread<y>, #gpu.thread<x>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) | |
%thread_rhs_copy, %rhs_copy_forall = | |
transform.structured.tile_using_forall %rhs_copy tile_sizes [2, 4] | |
( mapping = [#gpu.thread<y>, #gpu.thread<x>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) | |
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> () | |
transform.apply_patterns to %func { | |
transform.apply_patterns.tensor.merge_consecutive_insert_extract_slice | |
} : !transform.any_op | |
// Step 5. Fuse the forall ops | |
// =========================================================================== | |
%matmul_forall_2 = transform.iree.fuse_forall %lhs_copy_forall into %matmul_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op) | |
%matmul_forall_3 = transform.iree.fuse_forall %rhs_copy_forall into %matmul_forall_2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op) | |
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> () | |
// Step 6. Hoist the combined forall out of the serial loop and perform other LICM. | |
transform.apply_patterns to %func { | |
transform.apply_patterns.iree.hoist_forall_from_for | |
} : !transform.any_op | |
%for_loop = transform.structured.match ops{["scf.for"]} in %func : (!transform.any_op) -> !transform.any_op | |
transform.loop.hoist_loop_invariant_subsets %for_loop : !transform.any_op | |
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> () | |
// Step 6. Decompose shuffles and vectorize | |
// =========================================================================== | |
transform.apply_patterns to %func { | |
transform.apply_patterns.iree.lower_shuffle_tensor | |
} : !transform.any_op | |
%func2 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> !transform.any_op | |
%for_loop2 = transform.structured.match ops{["scf.for"]} in %func2 : (!transform.any_op) -> !transform.any_op | |
transform.loop.hoist_loop_invariant_subsets %for_loop2 : !transform.any_op | |
transform.include @cleanup failures(propagate) (%func2) : (!transform.any_op) -> () | |
// Step 7. Bufferize | |
// =========================================================================== | |
transform.iree.eliminate_empty_tensors %func2 : (!transform.any_op) -> () | |
%func3 = transform.iree.bufferize { target_gpu } %func2 : (!transform.any_op) -> !transform.any_op | |
transform.include @cleanup failures(propagate) (%func3) : (!transform.any_op) -> () | |
// Step 8. Map forall ops to threads. | |
// =========================================================================== | |
transform.iree.map_nested_forall_to_gpu_threads %func3 workgroup_dims = [8, 8, 1] : (!transform.any_op) -> () | |
transform.apply_patterns to %func3 { | |
transform.apply_patterns.memref.fold_memref_alias_ops | |
} : !transform.any_op | |
transform.include @cleanup failures(propagate) (%func3) : (!transform.any_op) -> () | |
// Step 9. Late cleanup | |
// =========================================================================== | |
transform.iree.hoist_static_alloc %func3 : (!transform.any_op) -> () | |
transform.memref.erase_dead_alloc_and_stores %func3 : (!transform.any_op) -> () | |
transform.yield | |
} | |
} // module |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment