Skip to content

Instantly share code, notes, and snippets.

@qedawkins
Last active April 24, 2024 17:01
Show Gist options
  • Save qedawkins/ee0ca928634b5533b591ce804fa5e080 to your computer and use it in GitHub Desktop.
Save qedawkins/ee0ca928634b5533b591ce804fa5e080 to your computer and use it in GitHub Desktop.
// func.func @main(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %out: tensor<128x128xf32>) -> tensor<128x128xf32> {
// %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out: tensor<128x128xf32>) -> tensor<128x128xf32>
// return %0 : tensor<128x128xf32>
// }
// func.func @main() {
// %c0 = arith.constant 0 : index
// %binding0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf32>>
// %binding1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x128xf32>>
// %binding2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
// %arg0 = flow.dispatch.tensor.load %binding0, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf32>> -> tensor<128x128xf32>
// %arg1 = flow.dispatch.tensor.load %binding1, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x128xf32>> -> tensor<128x128xf32>
// %arg2 = flow.dispatch.tensor.load %binding2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<128x128xf32>> -> tensor<128x128xf32>
// %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
// flow.dispatch.tensor.store %0, %binding2, offsets = [0, 0], sizes = [128, 128], strides = [1, 1] : tensor<128x128xf32> -> !flow.dispatch.tensor<readwrite:tensor<128x128xf32>>
// return
// }
module attributes { transform.with_named_sequence } {
transform.named_sequence @cleanup(%target: !transform.any_op {transform.readonly}) {
transform.apply_patterns to %target {
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.iree.apply_licm %target : !transform.any_op
transform.apply_cse to %target : !transform.any_op
transform.yield
}
transform.named_sequence @__transform_main(%func: !transform.any_op) {
// Step 1. Find the matmul
// ===========================================================================
%matmul = transform.structured.match ops{["linalg.matmul"]} in %func : (!transform.any_op) -> !transform.any_op
// Step 2. Tile the matmul K loop
// ===========================================================================
%reduced_matmul, %for = transform.structured.tile_using_for %matmul [0, 0, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> ()
// Step 3. Promote the matmul operands. This allows tiling the
// ===========================================================================
%lhs_copy = transform.iree.copy_tensor_operand %reduced_matmul [0] : (!transform.any_op) -> (!transform.any_op)
%rhs_copy = transform.iree.copy_tensor_operand %reduced_matmul [1] : (!transform.any_op) -> (!transform.any_op)
// Step 4. Tile the matmul and copies to threads
// ===========================================================================
%thread_matmul, %matmul_forall =
transform.structured.tile_using_forall %reduced_matmul tile_sizes [16, 16]
( mapping = [#gpu.thread<y>, #gpu.thread<x>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%thread_lhs_copy, %lhs_copy_forall =
transform.structured.tile_using_forall %lhs_copy tile_sizes [2, 4]
( mapping = [#gpu.thread<y>, #gpu.thread<x>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%thread_rhs_copy, %rhs_copy_forall =
transform.structured.tile_using_forall %rhs_copy tile_sizes [2, 4]
( mapping = [#gpu.thread<y>, #gpu.thread<x>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> ()
transform.apply_patterns to %func {
transform.apply_patterns.tensor.merge_consecutive_insert_extract_slice
} : !transform.any_op
// Step 5. Fuse the forall ops
// ===========================================================================
%matmul_forall_2 = transform.iree.fuse_forall %lhs_copy_forall into %matmul_forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op)
%matmul_forall_3 = transform.iree.fuse_forall %rhs_copy_forall into %matmul_forall_2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op)
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> ()
// Step 6. Hoist the combined forall out of the serial loop and perform other LICM.
transform.apply_patterns to %func {
transform.apply_patterns.iree.hoist_forall_from_for
} : !transform.any_op
%for_loop = transform.structured.match ops{["scf.for"]} in %func : (!transform.any_op) -> !transform.any_op
transform.loop.hoist_loop_invariant_subsets %for_loop : !transform.any_op
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> ()
// Step 6. Decompose shuffles and vectorize
// ===========================================================================
transform.apply_patterns to %func {
transform.apply_patterns.iree.lower_shuffle_tensor
} : !transform.any_op
%func2 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> !transform.any_op
%for_loop2 = transform.structured.match ops{["scf.for"]} in %func2 : (!transform.any_op) -> !transform.any_op
transform.loop.hoist_loop_invariant_subsets %for_loop2 : !transform.any_op
transform.include @cleanup failures(propagate) (%func2) : (!transform.any_op) -> ()
// Step 7. Bufferize
// ===========================================================================
transform.iree.eliminate_empty_tensors %func2 : (!transform.any_op) -> ()
%func3 = transform.iree.bufferize { target_gpu } %func2 : (!transform.any_op) -> !transform.any_op
transform.include @cleanup failures(propagate) (%func3) : (!transform.any_op) -> ()
// Step 8. Map forall ops to threads.
// ===========================================================================
transform.iree.map_nested_forall_to_gpu_threads %func3 workgroup_dims = [8, 8, 1] : (!transform.any_op) -> ()
transform.apply_patterns to %func3 {
transform.apply_patterns.memref.fold_memref_alias_ops
} : !transform.any_op
transform.include @cleanup failures(propagate) (%func3) : (!transform.any_op) -> ()
// Step 9. Late cleanup
// ===========================================================================
transform.iree.hoist_static_alloc %func3 : (!transform.any_op) -> ()
transform.memref.erase_dead_alloc_and_stores %func3 : (!transform.any_op) -> ()
transform.yield
}
} // module
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment