Skip to content

Instantly share code, notes, and snippets.

@benvanik
Last active September 14, 2022 14:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benvanik/6a906037e08963a5d81485f4a5a3ad4e to your computer and use it in GitHub Desktop.
Save benvanik/6a906037e08963a5d81485f4a5a3ad4e to your computer and use it in GitHub Desktop.
vmvx notes
vm.import @vmvx.add.2d.f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.add.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.and.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.div.2d.f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.divs.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.divu.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.mul.2d.f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.mul.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.or.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.shl.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.shrs.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.shru.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.sub.2d.f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.sub.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.xor.2d.i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.abs.2d.f32(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.ceil.2d.f32(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.ctlz.2d.i32(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.exp.2d.f32(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.floor.2d.f32(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.log.2d.f32(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.neg.2d.f32(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.rsqrt.2d.f32(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.copy.2d.x8(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.copy.2d.x16(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.copy.2d.x32(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.copy.2d.x64(%in_buffer : !vm.buffer, %in_offset : i64, %in_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %sizes : tuple<i64, i64>)
vm.import @vmvx.fill.2d.x32(%fill_value : i32, %out_buffer : !vm.buffer, %out_offset : i64, %out_row_stride : i64, %size_m : i64, %size_n : i64)
vm.import @vmvx.matmul.f32f32f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_row_stride : i64, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_row_stride : i64, %out_buffer : !vm.buffer, %out_offset : i64, %out_row_stride : i64, %m : i64, %n : i64, %k : i64, %flags : i32)
vm.import @vmvx.matmul.i8i8i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_row_stride : i64, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_row_stride : i64, %out_buffer : !vm.buffer, %out_offset : i64, %out_row_stride : i64, %m : i64, %n : i64, %k : i64, %flags : i32)
vm.import @vmvx.mmt4d.f32f32f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_row_stride : i64, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_row_stride : i64, %out_buffer : !vm.buffer, %out_offset : i64, %out_row_stride : i64, %m : i64, %n : i64, %k : i64, %m0 : i32, %n0 : i32, %k0 : i32, %flags : i32)
vm.import @vmvx.mmt4d.i8i8i32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_row_stride : i64, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_row_stride : i64, %out_buffer : !vm.buffer, %out_offset : i64, %out_row_stride : i64, %m : i64, %n : i64, %k : i64, %m0 : i32, %n0 : i32, %k0 : i32, %flags : i32)
// -----// IR Dump Before Canonicalizer (canonicalize) //----- //
func.func @mmt4d_384x384x512_4x1x4(%arg0: tensor<96x384x4x1xf32>, %arg1: tensor<128x384x4x1xf32>, %arg2: tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32> {
%0 = linalg.mmt4d ins(%arg0, %arg1 : tensor<96x384x4x1xf32>, tensor<128x384x4x1xf32>) outs(%arg2 : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
return %0 : tensor<96x128x4x4xf32>
}
// -----// IR Dump After OutlineDispatchRegions (iree-flow-outline-dispatch-regions) //----- //
module {
flow.executable private @mmt4d_384x384x512_4x1x4_dispatch_0 {
flow.executable.export public @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1 workgroups(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%arg0: !flow.dispatch.tensor<readonly:96x384x4x1xf32>, %arg1: !flow.dispatch.tensor<readonly:128x384x4x1xf32>, %arg2: !flow.dispatch.tensor<writeonly:96x128x4x4xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%0 = flow.dispatch.tensor.load %arg0, offsets = [0, 0, 0, 0], sizes = [96, 384, 4, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:96x384x4x1xf32> -> tensor<96x384x4x1xf32>
%1 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0, 0], sizes = [128, 384, 4, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:128x384x4x1xf32> -> tensor<128x384x4x1xf32>
%2 = linalg.init_tensor [96, 128, 4, 4] : tensor<96x128x4x4xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
%4 = linalg.mmt4d ins(%0, %1 : tensor<96x384x4x1xf32>, tensor<128x384x4x1xf32>) outs(%3 : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
flow.dispatch.tensor.store %4, %arg2, offsets = [0, 0, 0, 0], sizes = [96, 128, 4, 4], strides = [1, 1, 1, 1] : tensor<96x128x4x4xf32> -> !flow.dispatch.tensor<writeonly:96x128x4x4xf32>
return
}
}
}
func.func @mmt4d_384x384x512_4x1x4(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c96 = arith.constant 96 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<96x384x4x1xf32>
%1 = hal.tensor.import %arg1 : !hal.buffer_view -> tensor<128x384x4x1xf32>
%2 = flow.dispatch @mmt4d_384x384x512_4x1x4_dispatch_0::@mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1[%c96, %c128, %c1, %c4, %c4, %c1](%0, %1) : (tensor<96x384x4x1xf32>, tensor<128x384x4x1xf32>) -> tensor<96x128x4x4xf32>
%3 = hal.tensor.export %2 : tensor<96x128x4x4xf32> -> !hal.buffer_view
return %3 : !hal.buffer_view
}
}
// -----// IR Dump Before RemoveSingleIterationLoop (iree-codegen-remove-single-iteration-loop) //----- //
func.func @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1() {
%c96 = arith.constant 96 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<96x384x4x1xf32>
memref.assume_alignment %0, 64 : memref<96x384x4x1xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384x4x1xf32>
memref.assume_alignment %1, 64 : memref<128x384x4x1xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<96x128x4x4xf32>
memref.assume_alignment %2, 64 : memref<96x128x4x4xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_y]
scf.for %arg0 = %3 to %c96 step %4 {
%5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%6 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
scf.for %arg1 = %5 to %c128 step %6 {
%7 = memref.subview %2[%arg0, %arg1, 0, 0] [32, 64, 4, 4] [1, 1, 1, 1] : memref<96x128x4x4xf32> to memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
%8 = memref.subview %0[%arg0, 0, 0, 0] [32, 384, 4, 1] [1, 1, 1, 1] : memref<96x384x4x1xf32> to memref<32x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>
%9 = memref.subview %1[%arg1, 0, 0, 0] [64, 384, 4, 1] [1, 1, 1, 1] : memref<128x384x4x1xf32> to memref<64x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>
linalg.fill ins(%cst : f32) outs(%7 : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>)
linalg.mmt4d {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 64, 0, 0, 0, 0]]>} ins(%8, %9 : memref<32x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>, memref<64x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>) outs(%7 : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>)
}
}
return
}
// -----// IR Dump After RemoveSingleIterationLoop (iree-codegen-remove-single-iteration-loop) //----- //
func.func @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<96x384x4x1xf32>
memref.assume_alignment %0, 64 : memref<96x384x4x1xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384x4x1xf32>
memref.assume_alignment %1, 64 : memref<128x384x4x1xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<96x128x4x4xf32>
memref.assume_alignment %2, 64 : memref<96x128x4x4xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%5 = memref.subview %2[%3, %4, 0, 0] [32, 64, 4, 4] [1, 1, 1, 1] : memref<96x128x4x4xf32> to memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
%6 = memref.subview %0[%3, 0, 0, 0] [32, 384, 4, 1] [1, 1, 1, 1] : memref<96x384x4x1xf32> to memref<32x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>
%7 = memref.subview %1[%4, 0, 0, 0] [64, 384, 4, 1] [1, 1, 1, 1] : memref<128x384x4x1xf32> to memref<64x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>
linalg.fill ins(%cst : f32) outs(%5 : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>)
linalg.mmt4d {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 64, 0, 0, 0, 0]]>} ins(%6, %7 : memref<32x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>, memref<64x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>) outs(%5 : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>)
return
}
// -----// IR Dump Before ResolveBufferDescriptors (iree-vmvx-resolve-buffer-descriptors) //----- //
func.func @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1() {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c64 = arith.constant 64 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<96x384x4x1xf32>
memref.assume_alignment %0, 64 : memref<96x384x4x1xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384x4x1xf32>
memref.assume_alignment %1, 64 : memref<128x384x4x1xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<96x128x4x4xf32>
memref.assume_alignment %2, 64 : memref<96x128x4x4xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%5 = memref.subview %2[%3, %4, 0, 0] [32, 64, 4, 4] [1, 1, 1, 1] : memref<96x128x4x4xf32> to memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
%base_buffer, %offset, %sizes:4, %strides:4 = vmvx.get_buffer_descriptor %5 : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>> -> !util.buffer, index, index, index, index, index, index, index, index, index
%6 = memref.subview %0[%3, 0, 0, 0] [32, 384, 4, 1] [1, 1, 1, 1] : memref<96x384x4x1xf32> to memref<32x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>
%base_buffer_0, %offset_1, %sizes_2:4, %strides_3:4 = vmvx.get_buffer_descriptor %6 : memref<32x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>> -> !util.buffer, index, index, index, index, index, index, index, index, index
%7 = memref.subview %1[%4, 0, 0, 0] [64, 384, 4, 1] [1, 1, 1, 1] : memref<128x384x4x1xf32> to memref<64x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>
%base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = vmvx.get_buffer_descriptor %7 : memref<64x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>> -> !util.buffer, index, index, index, index, index, index, index, index, index
scf.for %arg0 = %c0 to %c32 step %c1 {
scf.for %arg1 = %c0 to %c64 step %c1 {
scf.for %arg2 = %c0 to %c4 step %c1 {
scf.for %arg3 = %c0 to %c4 step %c1 {
memref.store %cst, %5[%arg0, %arg1, %arg2, %arg3] : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
}
}
}
}
vmvx.mmt4d lhs(%base_buffer_0 offset %offset_1 row_stride %strides_3#0 : !util.buffer) rhs(%base_buffer_4 offset %offset_5 row_stride %strides_7#0 : !util.buffer) out(%base_buffer offset %offset row_stride %strides#0 : !util.buffer) mnk(%sizes_2#0, %sizes_6#0, %sizes_6#1) tile_mnk(%sizes_2#2, %sizes_6#2, %sizes_6#3) flags(1) : (f32, f32, f32)
return
}
// -----// IR Dump After ForOpCanonicalization (iree-codegen-canonicalize-scf-for) //----- //
func.func @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) {
%c1024 = arith.constant 1024 : index
%c65536 = arith.constant 65536 : index
%c2 = arith.constant 2 : index
%c16 = arith.constant 16 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c1536 = arith.constant 1536 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c384 = arith.constant 384 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = util.list.get %arg2[%c2] : !util.list<!util.buffer>
%1 = arith.index_cast %arg3 : i32 to index
%2 = arith.index_cast %arg4 : i32 to index
%3 = arith.muli %2, %c32 : index
%4 = arith.muli %1, %c64 : index
%5 = arith.muli %3, %c2048 : index
%6 = arith.muli %4, %c16 : index
%7 = arith.addi %5, %6 : index
%8 = util.list.get %arg2[%c0] : !util.list<!util.buffer>
%9 = arith.muli %3, %c1536 : index
%10 = util.list.get %arg2[%c1] : !util.list<!util.buffer>
%11 = arith.muli %4, %c1536 : index
scf.for %arg12 = %c0 to %c32 step %c1 {
scf.for %arg13 = %c0 to %c64 step %c1 {
scf.for %arg14 = %c0 to %c4 step %c1 {
scf.for %arg15 = %c0 to %c4 step %c1 {
%12 = arith.muli %arg14, %c4 : index
%13 = arith.addi %12, %arg15 : index
%14 = arith.muli %arg12, %c2048 : index
%15 = arith.addi %13, %14 : index
%16 = arith.muli %arg13, %c16 : index
%17 = arith.addi %15, %16 : index
%18 = arith.muli %2, %c65536 : index
%19 = arith.addi %17, %18 : index
%20 = arith.muli %1, %c1024 : index
%21 = arith.addi %19, %20 : index
%buffer_size = util.buffer.size %0 : !util.buffer
%22 = arith.muli %21, %c4 : index
util.buffer.store %cst, %0[%22] : f32 -> !util.buffer{%buffer_size}
}
}
}
}
vmvx.mmt4d lhs(%8 offset %9 row_stride %c1536 : !util.buffer) rhs(%10 offset %11 row_stride %c1536 : !util.buffer) out(%0 offset %7 row_stride %c2048 : !util.buffer) mnk(%c32, %c64, %c384) tile_mnk(%c4, %c4, %c1) flags(1) : (f32, f32, f32)
return
}
// -----// IR Dump Before mlir::iree_compiler::IREE::HAL::SerializeTargetExecutablesPass (iree-hal-serialize-target-executables) //----- //
hal.executable private @mmt4d_384x384x512_4x1x4_dispatch_0 {
hal.executable.variant public @vmvx_bytecode_fb, target = <"vmvx", "vmvx-bytecode-fb"> {
hal.executable.export public @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {translation_info = #iree_codegen.translation_info<VMVXDefault workload_per_wg = [64, 32]>} {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
%c64 = arith.constant 64 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = arith.cmpi sle, %arg1, %c0 : index
%1 = arith.subi %c0, %arg1 : index
%2 = arith.subi %arg1, %c1 : index
%3 = arith.select %0, %1, %2 : index
%4 = arith.divsi %3, %c32 : index
%5 = arith.subi %c0, %4 : index
%6 = arith.addi %4, %c1 : index
%7 = arith.select %0, %5, %6 : index
%8 = arith.cmpi sle, %arg2, %c0 : index
%9 = arith.subi %c0, %arg2 : index
%10 = arith.subi %arg2, %c1 : index
%11 = arith.select %8, %9, %10 : index
%12 = arith.divsi %11, %c64 : index
%13 = arith.subi %c0, %12 : index
%14 = arith.addi %12, %c1 : index
%15 = arith.select %8, %13, %14 : index
hal.return %15, %7, %c1 : index, index, index
}
builtin.module attributes {vm.toplevel} {
vm.module public @module {
vm.import @vmvx.mmt4d.f32f32f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_row_stride : i64, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_row_stride : i64, %out_buffer : !vm.buffer, %out_offset : i64, %out_row_stride : i64, %m : i64, %n : i64, %k : i64, %m0 : i32, %n0 : i32, %k0 : i32, %flags : i32) attributes {sym_visibility = "private"}
vm.func private @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%arg0: !vm.buffer, %arg1: !vm.buffer, %arg2: !vm.list<!vm.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) {
%c98304 = vm.const.i32 98304
%c49152 = vm.const.i32 49152
%c1024 = vm.const.i32 1024
%c65536 = vm.const.i32 65536
%c384 = vm.const.i64 384
%c64 = vm.const.i64 64
%c32 = vm.const.i64 32
%c2048 = vm.const.i64 2048
%c1536 = vm.const.i64 1536
%c2 = vm.const.i32 2
%c16 = vm.const.i32 16
%c2048_0 = vm.const.i32 2048
%zero = vm.const.i32.zero
%c4 = vm.const.i32 4
%c32_1 = vm.const.i32 32
%c64_2 = vm.const.i32 64
%c1 = vm.const.i32 1
%zero_3 = vm.const.f32.zero
%buffer = vm.list.get.ref %arg2, %c2 : (!vm.list<!vm.buffer>, i32) -> !vm.buffer
%0 = vm.mul.i32 %arg4, %c65536 : i32
%1 = vm.mul.i32 %arg3, %c1024 : i32
%2 = vm.add.i32 %0, %1 : i32
%buffer_4 = vm.list.get.ref %arg2, %zero : (!vm.list<!vm.buffer>, i32) -> !vm.buffer
%3 = vm.mul.i32 %arg4, %c49152 : i32
%buffer_5 = vm.list.get.ref %arg2, %c1 : (!vm.list<!vm.buffer>, i32) -> !vm.buffer
%4 = vm.mul.i32 %arg3, %c98304 : i32
vm.br ^bb1(%zero : i32)
^bb1(%5: i32): // 2 preds: ^bb0, ^bb11
%slt = vm.cmp.lt.i32.s %5, %c32_1 : i32
vm.cond_br %slt, ^bb2, ^bb12
^bb2: // pred: ^bb1
%6 = vm.mul.i32 %5, %c2048_0 : i32
vm.br ^bb3(%zero : i32)
^bb3(%7: i32): // 2 preds: ^bb2, ^bb10
%slt_6 = vm.cmp.lt.i32.s %7, %c64_2 : i32
vm.cond_br %slt_6, ^bb4, ^bb11
^bb4: // pred: ^bb3
%8 = vm.mul.i32 %7, %c16 : i32
vm.br ^bb5(%zero : i32)
^bb5(%9: i32): // 2 preds: ^bb4, ^bb9
%slt_7 = vm.cmp.lt.i32.s %9, %c4 : i32
vm.cond_br %slt_7, ^bb6, ^bb10
^bb6: // pred: ^bb5
%10 = vm.mul.i32 %9, %c4 : i32
vm.br ^bb7(%zero : i32)
^bb7(%11: i32): // 2 preds: ^bb6, ^bb8
%slt_8 = vm.cmp.lt.i32.s %11, %c4 : i32
vm.cond_br %slt_8, ^bb8, ^bb9
^bb8: // pred: ^bb7
%12 = vm.add.i32 %10, %11 : i32
%13 = vm.add.i32 %12, %6 : i32
%14 = vm.add.i32 %13, %8 : i32
%15 = vm.add.i32 %14, %0 : i32
%16 = vm.add.i32 %15, %1 : i32
%17 = vm.mul.i32 %16, %c4 : i32
%18 = vm.ext.i32.i64.u %17 : i32 -> i64
vm.buffer.store.f32 %zero_3, %buffer[%18] : f32 -> !vm.buffer
%19 = vm.add.i32 %11, %c1 : i32
vm.br ^bb7(%19 : i32)
^bb9: // pred: ^bb7
%20 = vm.add.i32 %9, %c1 : i32
vm.br ^bb5(%20 : i32)
^bb10: // pred: ^bb5
%21 = vm.add.i32 %7, %c1 : i32
vm.br ^bb3(%21 : i32)
^bb11: // pred: ^bb3
%22 = vm.add.i32 %5, %c1 : i32
vm.br ^bb1(%22 : i32)
^bb12: // pred: ^bb1
%23 = vm.ext.i32.i64.s %3 : i32 -> i64
%24 = vm.ext.i32.i64.s %4 : i32 -> i64
%25 = vm.ext.i32.i64.s %2 : i32 -> i64
vm.call @vmvx.mmt4d.f32f32f32(%buffer_4, %23, %c1536, %buffer_5, %24, %c1536, %buffer, %25, %c2048, %c32, %c64, %c384, %c4, %c4, %c1, %c1) : (!vm.buffer, i64, i64, !vm.buffer, i64, i64, !vm.buffer, i64, i64, i64, i64, i64, i32, i32, i32, i32) -> ()
vm.return
}
vm.export @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1
}
}
}
}
func.func @mmt4d_384x384x512_4x1x4(%lhs: tensor<96x384x4x1xf32>, %rhs: tensor<128x384x4x1xf32>, %dst: tensor<96x128x4x4xf32>, %fused: tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32> {
%0 = linalg.mmt4d ins(%lhs, %rhs : tensor<96x384x4x1xf32>, tensor<128x384x4x1xf32>) outs(%dst : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
%1 = mhlo.minimum %0, %fused : tensor<96x128x4x4xf32>
return %1 : tensor<96x128x4x4xf32>
}
// -----// IR Dump Before RemoveSingleIterationLoop (iree-codegen-remove-single-iteration-loop) //----- //
func.func @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1() {
%c96 = arith.constant 96 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<96x384x4x1xf32>
memref.assume_alignment %0, 64 : memref<96x384x4x1xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384x4x1xf32>
memref.assume_alignment %1, 64 : memref<128x384x4x1xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<96x128x4x4xf32>
memref.assume_alignment %2, 64 : memref<96x128x4x4xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : memref<96x128x4x4xf32>
memref.assume_alignment %3, 64 : memref<96x128x4x4xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_count_y]
scf.for %arg0 = %4 to %c96 step %5 {
%6 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%7 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
scf.for %arg1 = %6 to %c128 step %7 {
%8 = memref.subview %3[%arg0, %arg1, 0, 0] [32, 64, 4, 4] [1, 1, 1, 1] : memref<96x128x4x4xf32> to memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
%9 = memref.subview %0[%arg0, 0, 0, 0] [32, 384, 4, 1] [1, 1, 1, 1] : memref<96x384x4x1xf32> to memref<32x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>
%10 = memref.subview %1[%arg1, 0, 0, 0] [64, 384, 4, 1] [1, 1, 1, 1] : memref<128x384x4x1xf32> to memref<64x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>
linalg.fill ins(%cst : f32) outs(%8 : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>)
linalg.mmt4d {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 64, 0, 0, 0, 0]]>} ins(%9, %10 : memref<32x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>, memref<64x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>) outs(%8 : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>)
%11 = memref.subview %2[%arg0, %arg1, 0, 0] [32, 64, 4, 4] [1, 1, 1, 1] : memref<96x128x4x4xf32> to memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%11 : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>) outs(%8 : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>) {
^bb0(%arg2: f32, %arg3: f32):
%12 = arith.minf %arg3, %arg2 : f32
linalg.yield %12 : f32
}
}
}
return
}
// -----// IR Dump After LinalgLowerToLoops (convert-linalg-to-loops) //----- //
func.func @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1() {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c64 = arith.constant 64 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<96x384x4x1xf32>
memref.assume_alignment %0, 64 : memref<96x384x4x1xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384x4x1xf32>
memref.assume_alignment %1, 64 : memref<128x384x4x1xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : memref<96x128x4x4xf32>
memref.assume_alignment %2, 64 : memref<96x128x4x4xf32>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : memref<96x128x4x4xf32>
memref.assume_alignment %3, 64 : memref<96x128x4x4xf32>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%6 = memref.subview %3[%4, %5, 0, 0] [32, 64, 4, 4] [1, 1, 1, 1] : memref<96x128x4x4xf32> to memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
%base_buffer, %offset, %sizes:4, %strides:4 = vmvx.get_buffer_descriptor %6 : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>> -> !util.buffer, index, index, index, index, index, index, index, index, index
%7 = memref.subview %0[%4, 0, 0, 0] [32, 384, 4, 1] [1, 1, 1, 1] : memref<96x384x4x1xf32> to memref<32x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>
%base_buffer_0, %offset_1, %sizes_2:4, %strides_3:4 = vmvx.get_buffer_descriptor %7 : memref<32x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>> -> !util.buffer, index, index, index, index, index, index, index, index, index
%8 = memref.subview %1[%5, 0, 0, 0] [64, 384, 4, 1] [1, 1, 1, 1] : memref<128x384x4x1xf32> to memref<64x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>>
%base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = vmvx.get_buffer_descriptor %8 : memref<64x384x4x1xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1536 + s0 + d1 * 4 + d2 + d3)>> -> !util.buffer, index, index, index, index, index, index, index, index, index
scf.for %arg0 = %c0 to %c32 step %c1 {
scf.for %arg1 = %c0 to %c64 step %c1 {
scf.for %arg2 = %c0 to %c4 step %c1 {
scf.for %arg3 = %c0 to %c4 step %c1 {
memref.store %cst, %6[%arg0, %arg1, %arg2, %arg3] : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
}
}
}
}
vmvx.mmt4d lhs(%base_buffer_0 offset %offset_1 row_stride %strides_3#0 : !util.buffer) rhs(%base_buffer_4 offset %offset_5 row_stride %strides_7#0 : !util.buffer) out(%base_buffer offset %offset row_stride %strides#0 : !util.buffer) mnk(%sizes_2#0, %sizes_6#0, %sizes_6#1) tile_mnk(%sizes_2#2, %sizes_6#2, %sizes_6#3) flags(1) : (f32, f32, f32)
%9 = memref.subview %2[%4, %5, 0, 0] [32, 64, 4, 4] [1, 1, 1, 1] : memref<96x128x4x4xf32> to memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
scf.for %arg0 = %c0 to %c32 step %c1 {
scf.for %arg1 = %c0 to %c64 step %c1 {
scf.for %arg2 = %c0 to %c4 step %c1 {
scf.for %arg3 = %c0 to %c4 step %c1 {
%10 = memref.load %9[%arg0, %arg1, %arg2, %arg3] : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
%11 = memref.load %6[%arg0, %arg1, %arg2, %arg3] : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
%12 = arith.minf %11, %10 : f32
memref.store %12, %6[%arg0, %arg1, %arg2, %arg3] : memref<32x64x4x4xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2048 + s0 + d1 * 16 + d2 * 4 + d3)>>
}
}
}
}
return
}
// -----// IR Dump Before SCFToControlFlow (convert-scf-to-cf) //----- //
func.func @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) {
%c1024 = arith.constant 1024 : index
%c65536 = arith.constant 65536 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c16 = arith.constant 16 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c1536 = arith.constant 1536 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c384 = arith.constant 384 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = util.list.get %arg2[%c2] : !util.list<!util.buffer>
%1 = util.list.get %arg2[%c3] : !util.list<!util.buffer>
%2 = arith.index_cast %arg3 : i32 to index
%3 = arith.index_cast %arg4 : i32 to index
%4 = arith.muli %3, %c32 : index
%5 = arith.muli %2, %c64 : index
%6 = arith.muli %4, %c2048 : index
%7 = arith.muli %5, %c16 : index
%8 = arith.addi %6, %7 : index
%9 = util.list.get %arg2[%c0] : !util.list<!util.buffer>
%10 = arith.muli %4, %c1536 : index
%11 = util.list.get %arg2[%c1] : !util.list<!util.buffer>
%12 = arith.muli %5, %c1536 : index
%13 = arith.muli %3, %c65536 : index
%14 = arith.muli %2, %c1024 : index
%buffer_size = util.buffer.size %1 : !util.buffer
scf.for %arg12 = %c0 to %c32 step %c1 {
%15 = arith.muli %arg12, %c2048 : index
scf.for %arg13 = %c0 to %c64 step %c1 {
%16 = arith.muli %arg13, %c16 : index
scf.for %arg14 = %c0 to %c4 step %c1 {
%17 = arith.muli %arg14, %c4 : index
scf.for %arg15 = %c0 to %c4 step %c1 {
%18 = arith.addi %17, %arg15 : index
%19 = arith.addi %18, %15 : index
%20 = arith.addi %19, %16 : index
%21 = arith.addi %20, %13 : index
%22 = arith.addi %21, %14 : index
%23 = arith.muli %22, %c4 : index
util.buffer.store %cst, %1[%23] : f32 -> !util.buffer{%buffer_size}
}
}
}
}
vmvx.mmt4d lhs(%9 offset %10 row_stride %c1536 : !util.buffer) rhs(%11 offset %12 row_stride %c1536 : !util.buffer) out(%1 offset %8 row_stride %c2048 : !util.buffer) mnk(%c32, %c64, %c384) tile_mnk(%c4, %c4, %c1) flags(1) : (f32, f32, f32)
%buffer_size_0 = util.buffer.size %0 : !util.buffer
scf.for %arg12 = %c0 to %c32 step %c1 {
%15 = arith.muli %arg12, %c2048 : index
scf.for %arg13 = %c0 to %c64 step %c1 {
%16 = arith.muli %arg13, %c16 : index
scf.for %arg14 = %c0 to %c4 step %c1 {
%17 = arith.muli %arg14, %c4 : index
scf.for %arg15 = %c0 to %c4 step %c1 {
%18 = arith.addi %17, %arg15 : index
%19 = arith.addi %18, %15 : index
%20 = arith.addi %19, %16 : index
%21 = arith.addi %20, %13 : index
%22 = arith.addi %21, %14 : index
%23 = arith.muli %22, %c4 : index
%24 = util.buffer.load %0[%23] : !util.buffer{%buffer_size_0} -> f32
%25 = util.buffer.load %1[%23] : !util.buffer{%buffer_size} -> f32
%26 = arith.minf %25, %24 : f32
util.buffer.store %26, %1[%23] : f32 -> !util.buffer{%buffer_size}
}
}
}
}
return
}
// -----// IR Dump Before Canonicalizer (canonicalize) ('func.func' operation: @mmt4d_384x384x512_4x1x4) //----- //
module {
func.func @mmt4d_384x384x512_4x1x4(%arg0: tensor<96x384x4x1xf32>, %arg1: tensor<128x384x4x1xf32>, %arg2: tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32> {
%0 = linalg.mmt4d ins(%arg0, %arg1 : tensor<96x384x4x1xf32>, tensor<128x384x4x1xf32>) outs(%arg2 : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
return %0 : tensor<96x128x4x4xf32>
}
}
// -----// IR Dump After DispatchLinalgOnTensors (iree-flow-dispatch-linalg-on-tensors-pass) ('func.func' operation: @mmt4d_384x384x512_4x1x4) //----- //
module {
func.func @mmt4d_384x384x512_4x1x4(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c96 = arith.constant 96 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<96x384x4x1xf32>
%1 = hal.tensor.import %arg1 : !hal.buffer_view -> tensor<128x384x4x1xf32>
%2 = hal.tensor.import %arg2 : !hal.buffer_view -> tensor<96x128x4x4xf32>
%3 = flow.dispatch.workgroups[%c96, %c128, %c1, %c4, %c4, %c1](%0, %1, %2) : (tensor<96x384x4x1xf32>, tensor<128x384x4x1xf32>, tensor<96x128x4x4xf32>) -> %2 =
(%arg3: !flow.dispatch.tensor<readonly:96x384x4x1xf32>, %arg4: !flow.dispatch.tensor<readonly:128x384x4x1xf32>, %arg5: !flow.dispatch.tensor<readwrite:96x128x4x4xf32>) {
%5 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [96, 384, 4, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:96x384x4x1xf32> -> tensor<96x384x4x1xf32>
%6 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [128, 384, 4, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:128x384x4x1xf32> -> tensor<128x384x4x1xf32>
%7 = flow.dispatch.tensor.load %arg5, offsets = [0, 0, 0, 0], sizes = [96, 128, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:96x128x4x4xf32> -> tensor<96x128x4x4xf32>
%8 = linalg.mmt4d ins(%5, %6 : tensor<96x384x4x1xf32>, tensor<128x384x4x1xf32>) outs(%7 : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
flow.dispatch.tensor.store %8, %arg5, offsets = [0, 0, 0, 0], sizes = [96, 128, 4, 4], strides = [1, 1, 1, 1] : tensor<96x128x4x4xf32> -> !flow.dispatch.tensor<readwrite:96x128x4x4xf32>
flow.return
} count(%arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5, %arg6, %arg7, %arg8
flow.return %x, %y, %z : index, index, index
}
%4 = hal.tensor.export %3 : tensor<96x128x4x4xf32> -> !hal.buffer_view
return %4 : !hal.buffer_view
}
}
// -----// IR Dump Before CSE (cse) ('func.func' operation: @mmt4d_384x384x512_4x1x4) //----- //
module {
stream.executable private @mmt4d_384x384x512_4x1x4_dispatch_0 {
stream.executable.export public @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1 workgroups(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding) {
%c0 = arith.constant 0 : index
%0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:96x384x4x1xf32>
%1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:128x384x4x1xf32>
%2 = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:96x128x4x4xf32>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [96, 384, 4, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:96x384x4x1xf32> -> tensor<96x384x4x1xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [128, 384, 4, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:128x384x4x1xf32> -> tensor<128x384x4x1xf32>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [96, 128, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:96x128x4x4xf32> -> tensor<96x128x4x4xf32>
%6 = linalg.mmt4d ins(%3, %4 : tensor<96x384x4x1xf32>, tensor<128x384x4x1xf32>) outs(%5 : tensor<96x128x4x4xf32>) -> tensor<96x128x4x4xf32>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0, 0], sizes = [96, 128, 4, 4], strides = [1, 1, 1, 1] : tensor<96x128x4x4xf32> -> !flow.dispatch.tensor<readwrite:96x128x4x4xf32>
return
}
}
}
func.func @mmt4d_384x384x512_4x1x4(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c589824 = arith.constant 589824 : index
%c786432 = arith.constant 786432 : index
%c96 = arith.constant 96 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c553648160_i32 = arith.constant 553648160 : i32
%c1_i32 = arith.constant 1 : i32
%c384 = arith.constant 384 : index
hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("tensor") shape([%c96, %c384, %c4, %c1]) type(%c553648160_i32) encoding(%c1_i32)
%0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<96x384x4x1xf32> in !stream.resource<external>{%c589824}
hal.buffer_view.assert<%arg1 : !hal.buffer_view> message("tensor") shape([%c128, %c384, %c4, %c1]) type(%c553648160_i32) encoding(%c1_i32)
%1 = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<128x384x4x1xf32> in !stream.resource<external>{%c786432}
hal.buffer_view.assert<%arg2 : !hal.buffer_view> message("tensor") shape([%c96, %c128, %c4, %c4]) type(%c553648160_i32) encoding(%c1_i32)
%2 = stream.tensor.import %arg2 : !hal.buffer_view -> tensor<96x128x4x4xf32> in !stream.resource<external>{%c786432}
%results, %result_timepoint = stream.async.execute with(%0 as %arg3: !stream.resource<external>{%c589824}, %1 as %arg4: !stream.resource<external>{%c786432}, %2 as %arg5: !stream.resource<external>{%c786432}) -> %2{%c786432} {
%5 = stream.async.dispatch @mmt4d_384x384x512_4x1x4_dispatch_0::@mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1[%c96, %c128, %c1, %c4, %c4, %c1](%arg3, %arg4, %arg5) : (!stream.resource<external>{%c589824}, !stream.resource<external>{%c786432}, !stream.resource<external>{%c786432}) -> %arg5{%c786432}
stream.yield %5 : !stream.resource<external>{%c786432}
} => !stream.timepoint
%3 = stream.timepoint.await %result_timepoint => %results : !stream.resource<external>{%c786432}
%4 = stream.tensor.export %3 : tensor<96x128x4x4xf32> in !stream.resource<external>{%c786432} -> !hal.buffer_view
return %4 : !hal.buffer_view
}
}
// -----// IR Dump Before InlineExecutables (iree-hal-inline-executables) ('builtin.module' operation) //----- //
#device_target_vmvx_inline = #hal.device.target<"vmvx-inline", {executable_targets = [#hal.executable.target<"vmvx-inline", "vmvx-ir">]}>
#executable_target_vmvx_ir = #hal.executable.target<"vmvx-inline", "vmvx-ir">
#map0 = affine_map<()[s0] -> (s0 ceildiv 32)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 64)>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
#translation = #iree_codegen.translation_info<VMVXDefault workload_per_wg = [64, 32]>
module attributes {hal.device.targets = [#device_target_vmvx_inline]} {
hal.executable private @mmt4d_384x384x512_4x1x4_dispatch_0 {
hal.executable.variant public @vmvx_ir, target = #executable_target_vmvx_ir {
hal.executable.export public @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1 ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
%c1 = arith.constant 1 : index
%0 = affine.apply #map0()[%arg1]
%1 = affine.apply #map1()[%arg2]
hal.return %1, %0, %c1 : index, index, index
}
builtin.module {
func.func @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) {
%c2 = arith.constant 2 : index
%c1536 = arith.constant 1536 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c384 = arith.constant 384 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c2048 = arith.constant 2048 : index
%0 = arith.index_cast %arg3 : i32 to index
%1 = arith.index_cast %arg4 : i32 to index
%2 = arith.muli %1, %c32 : index
%3 = arith.muli %0, %c64 : index
%4 = util.list.get %arg2[%c0] : !util.list<!util.buffer>
%5 = arith.muli %2, %c1536 : index
%6 = util.list.get %arg2[%c1] : !util.list<!util.buffer>
%7 = arith.muli %3, %c1536 : index
%8 = util.list.get %arg2[%c2] : !util.list<!util.buffer>
%9 = arith.muli %2, %c2048 : index
%10 = arith.muli %3, %c16 : index
%11 = arith.addi %9, %10 : index
vmvx.mmt4d lhs(%4 offset %5 row_stride %c1536 : !util.buffer) rhs(%6 offset %7 row_stride %c1536 : !util.buffer) out(%8 offset %11 row_stride %c2048 : !util.buffer) mnk(%c32, %c64, %c384) tile_mnk(%c4, %c4, %c1) flags(1) : (f32, f32, f32)
return
}
}
}
}
func.func @mmt4d_384x384x512_4x1x4(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c0 = arith.constant 0 : index
%c589824 = arith.constant 589824 : index
%c786432 = arith.constant 786432 : index
%c96 = arith.constant 96 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c553648160_i32 = arith.constant 553648160 : i32
%c1_i32 = arith.constant 1 : i32
%c384 = arith.constant 384 : index
hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("tensor") shape([%c96, %c384, %c4, %c1]) type(%c553648160_i32) encoding(%c1_i32)
%0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<96x384x4x1xf32> in !stream.resource<external>{%c589824}
hal.buffer_view.assert<%arg1 : !hal.buffer_view> message("tensor") shape([%c128, %c384, %c4, %c1]) type(%c553648160_i32) encoding(%c1_i32)
%1 = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<128x384x4x1xf32> in !stream.resource<external>{%c786432}
hal.buffer_view.assert<%arg2 : !hal.buffer_view> message("tensor") shape([%c96, %c128, %c4, %c4]) type(%c553648160_i32) encoding(%c1_i32)
%2 = stream.tensor.import %arg2 : !hal.buffer_view -> tensor<96x128x4x4xf32> in !stream.resource<external>{%c786432}
%3 = stream.cmd.execute with(%0 as %arg3: !stream.resource<external>{%c589824}, %1 as %arg4: !stream.resource<external>{%c786432}, %2 as %arg5: !stream.resource<external>{%c786432}) {
stream.cmd.dispatch @mmt4d_384x384x512_4x1x4_dispatch_0::@mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1[%c96, %c128, %c1, %c4, %c4, %c1] {
ro %arg3[%c0 for %c589824] : !stream.resource<external>{%c589824},
ro %arg4[%c0 for %c786432] : !stream.resource<external>{%c786432},
rw %arg5[%c0 for %c786432] : !stream.resource<external>{%c786432}
} attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]}
} => !stream.timepoint
%4 = stream.timepoint.await %3 => %2 : !stream.resource<external>{%c786432}
%5 = stream.tensor.export %4 : tensor<96x128x4x4xf32> in !stream.resource<external>{%c786432} -> !hal.buffer_view
return %5 : !hal.buffer_view
}
}
// -----// IR Dump Before PropagateSubranges (iree-util-propagate-subranges) ('builtin.module' operation) //----- //
#device_target_vmvx_inline = #hal.device.target<"vmvx-inline", {executable_targets = [#hal.executable.target<"vmvx-inline", "vmvx-ir">]}>
#map0 = affine_map<()[s0] -> (s0 ceildiv 32)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 64)>
module attributes {hal.device.targets = [#device_target_vmvx_inline]} {
func.func private @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.buffer, %arg3: !util.buffer, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: index, %arg11: index, %arg12: index) {
%c2048 = arith.constant 2048 : index
%c16 = arith.constant 16 : index
%c1 = arith.constant 1 : index
%c384 = arith.constant 384 : index
%c64 = arith.constant 64 : index
%c32 = arith.constant 32 : index
%c4 = arith.constant 4 : index
%c1536 = arith.constant 1536 : index
%0 = arith.muli %arg5, %c32 : index
%1 = arith.muli %arg4, %c64 : index
%2 = arith.muli %0, %c1536 : index
%3 = arith.muli %1, %c1536 : index
%4 = arith.muli %0, %c2048 : index
%5 = arith.muli %1, %c16 : index
%6 = arith.addi %4, %5 : index
vmvx.mmt4d lhs(%arg1 offset %2 row_stride %c1536 : !util.buffer) rhs(%arg2 offset %3 row_stride %c1536 : !util.buffer) out(%arg3 offset %6 row_stride %c2048 : !util.buffer) mnk(%c32, %c64, %c384) tile_mnk(%c4, %c4, %c1) flags(1) : (f32, f32, f32)
return
}
func.func private @__dispatch_mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: !util.buffer, %arg7: !util.buffer, %arg8: !util.buffer, %arg9: index, %arg10: index, %arg11: index, %arg12: index, %arg13: index, %arg14: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = affine.apply #map0()[%arg0]
%1 = affine.apply #map1()[%arg1]
%2 = util.null : !util.buffer
%buffer_size = util.buffer.size %arg6 : !util.buffer
%buffer_span = util.buffer.subspan %arg6[%arg9] : !util.buffer{%buffer_size} -> !util.buffer{%arg12}
%buffer_size_0 = util.buffer.size %arg7 : !util.buffer
%buffer_span_1 = util.buffer.subspan %arg7[%arg10] : !util.buffer{%buffer_size_0} -> !util.buffer{%arg13}
%buffer_size_2 = util.buffer.size %arg8 : !util.buffer
%buffer_span_3 = util.buffer.subspan %arg8[%arg11] : !util.buffer{%buffer_size_2} -> !util.buffer{%arg14}
scf.for %arg15 = %c0 to %0 step %c1 {
scf.for %arg16 = %c0 to %1 step %c1 {
func.call @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%2, %buffer_span, %buffer_span_1, %buffer_span_3, %arg16, %arg15, %c0, %c1, %c1, %c1, %1, %0, %c1) : (!util.buffer, !util.buffer, !util.buffer, !util.buffer, index, index, index, index, index, index, index, index, index) -> ()
}
}
return
}
func.func @mmt4d_384x384x512_4x1x4(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c0 = arith.constant 0 : index
%c589824 = arith.constant 589824 : index
%c786432 = arith.constant 786432 : index
%c96 = arith.constant 96 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c553648160_i32 = arith.constant 553648160 : i32
%c1_i32 = arith.constant 1 : i32
%c384 = arith.constant 384 : index
hal_inline.buffer_view.assert<%arg0 : !hal.buffer_view> message("tensor") shape([%c96, %c384, %c4, %c1]) type(%c553648160_i32) encoding(%c1_i32)
%buffer = hal_inline.buffer_view.buffer<%arg0 : !hal.buffer_view> : !hal.buffer
hal_inline.buffer_view.assert<%arg1 : !hal.buffer_view> message("tensor") shape([%c128, %c384, %c4, %c1]) type(%c553648160_i32) encoding(%c1_i32)
%buffer_0 = hal_inline.buffer_view.buffer<%arg1 : !hal.buffer_view> : !hal.buffer
hal_inline.buffer_view.assert<%arg2 : !hal.buffer_view> message("tensor") shape([%c96, %c128, %c4, %c4]) type(%c553648160_i32) encoding(%c1_i32)
%buffer_1 = hal_inline.buffer_view.buffer<%arg2 : !hal.buffer_view> : !hal.buffer
%storage = hal_inline.buffer.storage<%buffer : !hal.buffer> : !util.buffer
%length = hal_inline.buffer.length<%buffer : !hal.buffer> : index
%storage_2 = hal_inline.buffer.storage<%buffer_0 : !hal.buffer> : !util.buffer
%length_3 = hal_inline.buffer.length<%buffer_0 : !hal.buffer> : index
%storage_4 = hal_inline.buffer.storage<%buffer_1 : !hal.buffer> : !util.buffer
%length_5 = hal_inline.buffer.length<%buffer_1 : !hal.buffer> : index
call @__dispatch_mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%c96, %c128, %c1, %c4, %c4, %c1, %storage, %storage_2, %storage_4, %c0, %c0, %c0, %c589824, %c786432, %c786432) : (index, index, index, index, index, index, !util.buffer, !util.buffer, !util.buffer, index, index, index, index, index, index) -> ()
%c0_i64 = arith.constant 0 : i64
%c96_6 = arith.constant 96 : index
%c128_7 = arith.constant 128 : index
%c4_8 = arith.constant 4 : index
%c4_9 = arith.constant 4 : index
%c1_i32_10 = arith.constant 1 : i32
%c553648160_i32_11 = arith.constant 553648160 : i32
%view = hal_inline.buffer_view.create buffer(%buffer_1 : !hal.buffer) shape([%c96_6, %c128_7, %c4_8, %c4_9]) type(%c553648160_i32_11) encoding(%c1_i32_10) : !hal.buffer_view
return %view : !hal.buffer_view
}
}
// -----// IR Dump Before SCFToControlFlow (convert-scf-to-cf) ('func.func' operation: @mmt4d_384x384x512_4x1x4) //----- //
#device_target_vmvx_inline = #hal.device.target<"vmvx-inline", {executable_targets = [#hal.executable.target<"vmvx-inline", "vmvx-ir">]}>
module attributes {hal.device.targets = [#device_target_vmvx_inline]} {
func.func private @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%arg0: !util.buffer, %arg1: index, %arg2: index, %arg3: index, %arg4: !util.buffer, %arg5: index, %arg6: index, %arg7: index, %arg8: !util.buffer, %arg9: index, %arg10: index, %arg11: index, %arg12: !util.buffer, %arg13: index, %arg14: index, %arg15: index, %arg16: index, %arg17: index, %arg18: index, %arg19: index, %arg20: index, %arg21: index, %arg22: index, %arg23: index, %arg24: index) {
%c1536 = arith.constant 1536 : index
%c4 = arith.constant 4 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c384 = arith.constant 384 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c2048 = arith.constant 2048 : index
%buffer_span = util.buffer.subspan %arg4[%arg6] : !util.buffer{%arg5} -> !util.buffer{%arg7}
%buffer_span_0 = util.buffer.subspan %arg8[%arg10] : !util.buffer{%arg9} -> !util.buffer{%arg11}
%buffer_span_1 = util.buffer.subspan %arg12[%arg14] : !util.buffer{%arg13} -> !util.buffer{%arg15}
%0 = arith.muli %arg17, %c32 : index
%1 = arith.muli %arg16, %c64 : index
%2 = arith.muli %0, %c1536 : index
%3 = arith.muli %1, %c1536 : index
%4 = arith.muli %0, %c2048 : index
%5 = arith.muli %1, %c16 : index
%6 = arith.addi %4, %5 : index
vmvx.mmt4d lhs(%buffer_span offset %2 row_stride %c1536 : !util.buffer) rhs(%buffer_span_0 offset %3 row_stride %c1536 : !util.buffer) out(%buffer_span_1 offset %6 row_stride %c2048 : !util.buffer) mnk(%c32, %c64, %c384) tile_mnk(%c4, %c4, %c1) flags(1) : (f32, f32, f32)
return
}
func.func @mmt4d_384x384x512_4x1x4(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c0 = arith.constant 0 : index
%c589824 = arith.constant 589824 : index
%c786432 = arith.constant 786432 : index
%c96 = arith.constant 96 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c553648160_i32 = arith.constant 553648160 : i32
%c1_i32 = arith.constant 1 : i32
%c384 = arith.constant 384 : index
hal_inline.buffer_view.assert<%arg0 : !hal.buffer_view> message("tensor") shape([%c96, %c384, %c4, %c1]) type(%c553648160_i32) encoding(%c1_i32)
%buffer = hal_inline.buffer_view.buffer<%arg0 : !hal.buffer_view> : !hal.buffer
hal_inline.buffer_view.assert<%arg1 : !hal.buffer_view> message("tensor") shape([%c128, %c384, %c4, %c1]) type(%c553648160_i32) encoding(%c1_i32)
%buffer_0 = hal_inline.buffer_view.buffer<%arg1 : !hal.buffer_view> : !hal.buffer
hal_inline.buffer_view.assert<%arg2 : !hal.buffer_view> message("tensor") shape([%c96, %c128, %c4, %c4]) type(%c553648160_i32) encoding(%c1_i32)
%buffer_1 = hal_inline.buffer_view.buffer<%arg2 : !hal.buffer_view> : !hal.buffer
%storage = hal_inline.buffer.storage<%buffer : !hal.buffer> : !util.buffer
%storage_2 = hal_inline.buffer.storage<%buffer_0 : !hal.buffer> : !util.buffer
%storage_3 = hal_inline.buffer.storage<%buffer_1 : !hal.buffer> : !util.buffer
%buffer_size = util.buffer.size %storage : !util.buffer
%buffer_size_4 = util.buffer.size %storage_2 : !util.buffer
%buffer_size_5 = util.buffer.size %storage_3 : !util.buffer
%0 = util.null : !util.buffer
%1 = arith.muli %c3, %c2 : index
%buffer_size_6 = util.buffer.size %0 : !util.buffer
scf.for %arg3 = %c0 to %1 step %c1 {
%2 = arith.remsi %arg3, %c2 : index
%3 = arith.divsi %arg3, %c2 : index
func.call @mmt4d_384x384x512_4x1x4_dispatch_0_mmt4d_96x128x384x4x4x1(%0, %buffer_size_6, %c0, %buffer_size_6, %storage, %buffer_size, %c0, %c589824, %storage_2, %buffer_size_4, %c0, %c786432, %storage_3, %buffer_size_5, %c0, %c786432, %2, %3, %c0, %c1, %c1, %c1, %c2, %c3, %c1) : (!util.buffer, index, index, index, !util.buffer, index, index, index, !util.buffer, index, index, index, !util.buffer, index, index, index, index, index, index, index, index, index, index, index, index) -> ()
}
%view = hal_inline.buffer_view.create buffer(%buffer_1 : !hal.buffer) shape([%c96, %c128, %c4, %c4]) type(%c553648160_i32) encoding(%c1_i32) : !hal.buffer_view
return %view : !hal.buffer_view
}
}
// todo: fold subspan into vmvx ops via op interface
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment