-
-
Save qedawkins/334c6bce944c6b860066ca873e1388d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// func.func @main() { | |
// %c0 = arith.constant 0 : index | |
// %binding0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x64xf32>> | |
// %binding1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x64xf32>> | |
// %binding2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readwrite:tensor<64x64xf32>> | |
// %arg0 = flow.dispatch.tensor.load %binding0, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x64xf32>> -> tensor<64x64xf32> | |
// %arg1 = flow.dispatch.tensor.load %binding1, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<64x64xf32>> -> tensor<64x64xf32> | |
// %arg2 = flow.dispatch.tensor.load %binding2, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<64x64xf32>> -> tensor<64x64xf32> | |
// %0 = linalg.matmul ins(%arg0, %arg1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%arg2: tensor<64x64xf32>) -> tensor<64x64xf32> | |
// flow.dispatch.tensor.store %0, %binding2, offsets = [0, 0], sizes = [64, 64], strides = [1, 1] : tensor<64x64xf32> -> !flow.dispatch.tensor<readwrite:tensor<64x64xf32>> | |
// return | |
// } | |
#layout_32x32x8 = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32> | |
module attributes { transform.with_named_sequence } { | |
transform.named_sequence @cleanup(%target: !transform.any_op {transform.readonly}) { | |
transform.apply_patterns to %target { | |
transform.apply_patterns.scf.for_loop_canonicalization | |
transform.apply_patterns.canonicalization | |
} : !transform.any_op | |
transform.iree.apply_licm %target : !transform.any_op | |
transform.apply_cse to %target : !transform.any_op | |
transform.yield | |
} | |
transform.named_sequence @__transform_main(%func: !transform.any_op) { | |
// Step 1. Find the matmul | |
// =========================================================================== | |
%matmul = transform.structured.match ops{["linalg.matmul"]} in %func : (!transform.any_op) -> !transform.any_op | |
// Step 2. Tile the matmul K loop | |
// =========================================================================== | |
%reduced_matmul, %for = transform.structured.tile_using_for %matmul [0, 0, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) | |
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> () | |
// Step 3. Promote the matmul operands. This allows tiling the | |
// =========================================================================== | |
%lhs_copy = transform.iree.copy_tensor_operand %reduced_matmul [0] : (!transform.any_op) -> (!transform.any_op) | |
%rhs_copy = transform.iree.copy_tensor_operand %reduced_matmul [1] : (!transform.any_op) -> (!transform.any_op) | |
// Step 4. Tile the matmul and copies to threads | |
// =========================================================================== | |
%thread_matmul, %matmul_forall = | |
transform.structured.tile_using_forall %reduced_matmul tile_sizes [32, 32] | |
( mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) | |
%thread_lhs_copy, %lhs_copy_forall = | |
transform.structured.tile_using_forall %lhs_copy tile_sizes [1, 2] | |
( mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) | |
%thread_rhs_copy, %rhs_copy_forall = | |
transform.structured.tile_using_forall %rhs_copy tile_sizes [1, 2] | |
( mapping = [#gpu.thread<linear_dim_0>, #gpu.thread<linear_dim_1>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) | |
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> () | |
transform.apply_patterns to %func { | |
transform.apply_patterns.tensor.merge_consecutive_insert_extract_slice | |
} : !transform.any_op | |
// Step 4.5. Flatten the thread indices of the forall ops. | |
// =========================================================================== | |
%flat_lhs_forall = transform.iree.flatten_forall_mapping %lhs_copy_forall : (!transform.any_op) -> !transform.any_op | |
%flat_rhs_forall = transform.iree.flatten_forall_mapping %rhs_copy_forall : (!transform.any_op) -> !transform.any_op | |
%flat_matmul_forall = transform.iree.flatten_forall_mapping %matmul_forall : (!transform.any_op) -> !transform.any_op | |
// Step 5. Fuse the forall ops | |
// =========================================================================== | |
%matmul_forall_2, %fused_lhs = transform.iree.fuse_thread_with_warp_forall | |
%flat_lhs_forall into %flat_matmul_forall {subgroup_size = 64} : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) | |
%matmul_forall_3, %fused_rhs = transform.iree.fuse_thread_with_warp_forall | |
%flat_rhs_forall into %matmul_forall_2 {subgroup_size = 64} : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) | |
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> () | |
// Step 6. Hoist the combined forall out of the serial loop and perform other LICM. | |
transform.apply_patterns to %func { | |
transform.apply_patterns.iree.hoist_forall_from_for | |
} : !transform.any_op | |
%for_loop = transform.structured.match ops{["scf.for"]} in %func : (!transform.any_op) -> !transform.any_op | |
transform.loop.hoist_loop_invariant_subsets %for_loop : !transform.any_op | |
transform.include @cleanup failures(propagate) (%func) : (!transform.any_op) -> () | |
// Step 6. Decompose shuffles and vectorize | |
// =========================================================================== | |
transform.apply_patterns to %func { | |
transform.apply_patterns.iree.lower_shuffle_tensor | |
} : !transform.any_op | |
%func2 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> !transform.any_op | |
%for_loop2 = transform.structured.match ops{["scf.for"]} in %func2 : (!transform.any_op) -> !transform.any_op | |
transform.loop.hoist_loop_invariant_subsets %for_loop2 : !transform.any_op | |
transform.apply_patterns to %func2 { | |
transform.apply_patterns.vector.fold_arith_extension | |
} : !transform.any_op | |
transform.include @cleanup failures(propagate) (%func2) : (!transform.any_op) -> () | |
// Step 7. Bufferize | |
// =========================================================================== | |
transform.iree.eliminate_empty_tensors %func2 : (!transform.any_op) -> () | |
%func3 = transform.iree.bufferize { target_gpu } %func2 : (!transform.any_op) -> !transform.any_op | |
transform.include @cleanup failures(propagate) (%func3) : (!transform.any_op) -> () | |
// Step 8. Distribute to MFMA instructions | |
// =========================================================================== | |
%contract = transform.structured.match ops{["vector.contract"]} in %func3 : (!transform.any_op) -> !transform.any_op | |
%layout32x32x8 = transform.param.constant #layout_32x32x8 -> !transform.any_param | |
transform.iree.set_contraction_layout_attributes %contract, %layout32x32x8 : !transform.any_op, !transform.any_param | |
%func4 = transform.iree.amdgpu_distribute_vectors %func3 : (!transform.any_op) -> !transform.any_op | |
transform.include @cleanup failures(propagate) (%func4) : (!transform.any_op) -> () | |
// Step 9. Map forall ops to threads. | |
// =========================================================================== | |
transform.iree.forall_to_lanes %func4 : (!transform.any_op) -> () | |
transform.iree.map_nested_forall_to_gpu_threads %func4 workgroup_dims = [256, 1, 1] subgroup_size = 64 : (!transform.any_op) -> () | |
transform.apply_patterns to %func4 { | |
transform.apply_patterns.memref.fold_memref_alias_ops | |
} : !transform.any_op | |
transform.include @cleanup failures(propagate) (%func4) : (!transform.any_op) -> () | |
// Step 10. Late cleanup | |
// =========================================================================== | |
transform.iree.hoist_static_alloc %func4 : (!transform.any_op) -> () | |
transform.memref.erase_dead_alloc_and_stores %func4 : (!transform.any_op) -> () | |
transform.yield | |
} | |
} // module |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment