Skip to content

Instantly share code, notes, and snippets.

@qedawkins
Created May 1, 2024 20:27
Show Gist options
  • Save qedawkins/334c6bce944c6b860066ca873e1388d2 to your computer and use it in GitHub Desktop.
Save qedawkins/334c6bce944c6b860066ca873e1388d2 to your computer and use it in GitHub Desktop.
// func.func @main() {
// %c0 = arith.constant 0 : index
// %binding0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x64xf32>>
// %binding1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x64xf32>>
// %binding2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<64x64xf32>>
// %arg0 = flow.dispatch.tensor.load %binding0, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x64xf32>> -> tensor<64x64xf32>
// %arg1 = flow.dispatch.tensor.load %binding1, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x64xf32>> -> tensor<64x64xf32>
// %arg2 = flow.dispatch.tensor.load %binding2, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<64x64xf32>> -> tensor<64x64xf32>
// %0 = linalg.matmul ins(%arg0, %arg1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%arg2: tensor<64x64xf32>) -> tensor<64x64xf32>
// flow.dispatch.tensor.store %0, %binding2, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : tensor<64x64xf32> -> !flow.dispatch.tensor<readwrite:tensor<64x64xf32>>
// return
// }
#layout_32x32x8 = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
module attributes { transform.with_named_sequence } {
transform.named_sequence @cleanup(%target: !transform.any_op {transform.readonly}) {
transform.apply_patterns to %target {
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.iree.apply_licm %target : !transform.any_op
transform.apply_cse to %target : !transform.any_op
transform.yield
}
transform.named_sequence @__transform_main(%func: !transform.any_op) {
// Step 1. Find the matmul
// ===========================================================================
%matmul = transform.structured.match ops{["linalg.matmul"]} in %func : (!transform.any_op) -> !transform.any_op
// Step 2. Tile the matmul K loop
// ===========================================================================
%reduced_matmul, %for = transform.structured.tile_using_for %matmul [0, 0, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> ()
// Step 3. Promote the matmul operands. This allows tiling the
// ===========================================================================
%lhs_copy = transform.iree.copy_tensor_operand %reduced_matmul [0] : (!transform.any_op) -> (!transform.any_op)
%rhs_copy = transform.iree.copy_tensor_operand %reduced_matmul [1] : (!transform.any_op) -> (!transform.any_op)
// Step 4. Tile the matmul and copies to threads
// ===========================================================================
%thread_matmul, %matmul_forall =
transform.structured.tile_using_forall %reduced_matmul tile_sizes [32, 32]
( mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%thread_lhs_copy, %lhs_copy_forall =
transform.structured.tile_using_forall %lhs_copy tile_sizes [1, 2]
( mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%thread_rhs_copy, %rhs_copy_forall =
transform.structured.tile_using_forall %rhs_copy tile_sizes [1, 2]
( mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> ()
transform.apply_patterns to %func {
transform.apply_patterns.tensor.merge_consecutive_insert_extract_slice
} : !transform.any_op
// Step 4.5. Flatten the thread indices of the forall ops.
// ===========================================================================
%flat_lhs_forall = transform.iree.flatten_forall_mapping %lhs_copy_forall : (!transform.any_op) -> !transform.any_op
%flat_rhs_forall = transform.iree.flatten_forall_mapping %rhs_copy_forall : (!transform.any_op) -> !transform.any_op
%flat_matmul_forall = transform.iree.flatten_forall_mapping %matmul_forall : (!transform.any_op) -> !transform.any_op
// Step 5. Fuse the forall ops
// ===========================================================================
%matmul_forall_2, %fused_lhs = transform.iree.fuse_thread_with_warp_forall
%flat_lhs_forall into %flat_matmul_forall {subgroup_size = 64} : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%matmul_forall_3, %fused_rhs = transform.iree.fuse_thread_with_warp_forall
%flat_rhs_forall into %matmul_forall_2 {subgroup_size = 64} : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> ()
// Step 6. Hoist the combined forall out of the serial loop and perform other LICM.
transform.apply_patterns to %func {
transform.apply_patterns.iree.hoist_forall_from_for
} : !transform.any_op
%for_loop = transform.structured.match ops{["scf.for"]} in %func : (!transform.any_op) -> !transform.any_op
transform.loop.hoist_loop_invariant_subsets %for_loop : !transform.any_op
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> ()
// Step 6. Decompose shuffles and vectorize
// ===========================================================================
transform.apply_patterns to %func {
transform.apply_patterns.iree.lower_shuffle_tensor
} : !transform.any_op
%func2 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> !transform.any_op
%for_loop2 = transform.structured.match ops{["scf.for"]} in %func2 : (!transform.any_op) -> !transform.any_op
transform.loop.hoist_loop_invariant_subsets %for_loop2 : !transform.any_op
transform.apply_patterns to %func2 {
transform.apply_patterns.vector.fold_arith_extension
} : !transform.any_op
transform.include @cleanup failures(propagate) (%func2) : (!transform.any_op) -> ()
// Step 7. Bufferize
// ===========================================================================
transform.iree.eliminate_empty_tensors %func2 : (!transform.any_op) -> ()
%func3 = transform.iree.bufferize { target_gpu } %func2 : (!transform.any_op) -> !transform.any_op
transform.include @cleanup failures(propagate) (%func3) : (!transform.any_op) -> ()
// Step 8. Distribute to MFMA instructions
// ===========================================================================
%contract = transform.structured.match ops{["vector.contract"]} in %func3 : (!transform.any_op) -> !transform.any_op
%layout32x32x8 = transform.param.constant #layout_32x32x8 -> !transform.any_param
transform.iree.set_contraction_layout_attributes %contract, %layout32x32x8 : !transform.any_op, !transform.any_param
%func4 = transform.iree.amdgpu_distribute_vectors %func3 : (!transform.any_op) -> !transform.any_op
transform.include @cleanup failures(propagate) (%func4) : (!transform.any_op) -> ()
// Step 9. Map forall ops to threads.
// ===========================================================================
transform.iree.forall_to_lanes %func4 : (!transform.any_op) -> ()
transform.iree.map_nested_forall_to_gpu_threads %func4 workgroup_dims = [256, 1, 1] subgroup_size = 64 : (!transform.any_op) -> ()
transform.apply_patterns to %func4 {
transform.apply_patterns.memref.fold_memref_alias_ops
} : !transform.any_op
transform.include @cleanup failures(propagate) (%func4) : (!transform.any_op) -> ()
// Step 10. Late cleanup
// ===========================================================================
transform.iree.hoist_static_alloc %func4 : (!transform.any_op) -> ()
transform.memref.erase_dead_alloc_and_stores %func4 : (!transform.any_op) -> ()
transform.yield
}
} // module
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment