Skip to content

Instantly share code, notes, and snippets.

@antiagainst
Last active February 22, 2024 04:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save antiagainst/665c8236414d211c4050b4e57b523ab2 to your computer and use it in GitHub Desktop.
Save antiagainst/665c8236414d211c4050b4e57b523ab2 to your computer and use it in GitHub Desktop.
module {
llvm.mlir.global external @__dynamic_shared_memory__() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
llvm.mlir.global private @__shared_memory___0() {addr_space = 3 : i32, alignment = 2 : i64} : !llvm.array<8 x array<40 x f16>>
llvm.mlir.global private @__shared_memory__() {addr_space = 3 : i32, alignment = 2 : i64} : !llvm.array<32 x array<16 x f16>>
llvm.func @matmul_accumulate_512x128xf16_times_128x512xf16_into_512x512xf32_for_LLVMGPUVectorDistribute_32_32_8_64_1_1_dispatch_0_matmul_512x512x128_f16xf16xf32(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias, llvm.readonly}, %arg1: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias, llvm.readonly}, %arg2: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) {
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = llvm.mlir.constant(3 : i64) : i64
%2 = llvm.mlir.constant(2 : i64) : i64
%3 = llvm.mlir.constant(1 : i64) : i64
%4 = llvm.mlir.constant(63 : index) : i64
%5 = llvm.mlir.constant(512 : index) : i64
%6 = llvm.mlir.constant(0 : i64) : i64
%7 = llvm.mlir.constant(dense<0.000000e+00> : vector<1x4x1x1x4xf32>) : !llvm.array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>
%8 = llvm.mlir.constant(dense<0.000000e+00> : vector<1x4xf32>) : !llvm.array<1 x vector<4xf32>>
%9 = llvm.mlir.constant(8 : index) : i64
%10 = llvm.mlir.constant(128 : index) : i64
%11 = llvm.mlir.constant(0 : index) : i64
%12 = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1x4x1x1x4xf32>) : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%13 = llvm.mlir.constant(1 : index) : i64
%14 = llvm.mlir.constant(2 : index) : i64
%15 = llvm.mlir.constant(32 : index) : i64
%16 = llvm.mlir.constant(64 : index) : i64
%17 = llvm.mlir.constant(4 : index) : i64
%18 = llvm.mlir.constant(16 : index) : i64
%19 = llvm.mlir.constant(24 : index) : i64
%20 = llvm.mlir.constant(dense<0.000000e+00> : vector<4xf16>) : vector<4xf16>
%21 = llvm.mlir.constant(dense<0.000000e+00> : vector<4x4xf32>) : !llvm.array<4 x vector<4xf32>>
%22 = llvm.mlir.constant(dense<0.000000e+00> : vector<4x1x1x4xf32>) : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%23 = llvm.mlir.constant(3 : index) : i64
%24 = llvm.mlir.constant(9 : index) : i64
%25 = llvm.mlir.constant(10 : index) : i64
%26 = llvm.mlir.constant(11 : index) : i64
%27 = llvm.mlir.constant(17 : index) : i64
%28 = llvm.mlir.constant(18 : index) : i64
%29 = llvm.mlir.constant(19 : index) : i64
%30 = llvm.mlir.constant(25 : index) : i64
%31 = llvm.mlir.constant(26 : index) : i64
%32 = llvm.mlir.constant(27 : index) : i64
%33 = llvm.mlir.constant(-1 : index) : i64
%34 = llvm.mlir.constant(40 : index) : i64
%35 = llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
%36 = llvm.getelementptr %35[0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<0 x i8>
%37 = llvm.getelementptr %36[0, 0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<8 x array<40 x f16>>
%38 = llvm.getelementptr %35[0, 640] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<0 x i8>
%39 = llvm.getelementptr %38[0, 0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<32 x array<16 x f16>>
%40 = rocdl.workitem.id.x : i32
%41 = llvm.sext %40 : i32 to i64
%42 = rocdl.workitem.id.y : i32
%43 = llvm.sext %42 : i32 to i64
%44 = rocdl.workitem.id.z : i32
%45 = llvm.sext %44 : i32 to i64
%46 = llvm.mul %43, %16 : i64
%47 = llvm.add %41, %46 : i64
%48 = llvm.mul %45, %16 : i64
%49 = llvm.add %47, %48 : i64
%50 = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
%51 = llvm.and %50, %4 : i64
%52 = llvm.icmp "eq" %51, %11 : i64
"llvm.intr.assume"(%52) : (i1) -> ()
%53 = llvm.ptrtoint %arg1 : !llvm.ptr<1> to i64
%54 = llvm.and %53, %4 : i64
%55 = llvm.icmp "eq" %54, %11 : i64
"llvm.intr.assume"(%55) : (i1) -> ()
%56 = llvm.ptrtoint %arg2 : !llvm.ptr<1> to i64
%57 = llvm.and %56, %4 : i64
%58 = llvm.icmp "eq" %57, %11 : i64
"llvm.intr.assume"(%58) : (i1) -> ()
%59 = rocdl.workgroup.id.x : i32
%60 = llvm.sext %59 : i32 to i64
%61 = rocdl.workgroup.id.y : i32
%62 = llvm.sext %61 : i32 to i64
%63 = llvm.srem %49, %16 : i64
%64 = llvm.icmp "slt" %63, %11 : i64
%65 = llvm.add %63, %16 : i64
%66 = llvm.select %64, %65, %63 : i1, i64
%67 = llvm.icmp "slt" %66, %11 : i64
%68 = llvm.sub %33, %66 : i64
%69 = llvm.select %67, %68, %66 : i1, i64
%70 = llvm.sdiv %69, %15 : i64
%71 = llvm.sub %33, %70 : i64
%72 = llvm.select %67, %71, %70 : i1, i64
%73 = llvm.srem %49, %15 : i64
%74 = llvm.icmp "slt" %73, %11 : i64
%75 = llvm.add %73, %15 : i64
%76 = llvm.select %74, %75, %73 : i1, i64
%77 = llvm.mul %62, %15 : i64
%78 = llvm.mul %72, %17 : i64
%79 = llvm.add %77, %78 : i64
%80 = llvm.mul %60, %15 : i64
%81 = llvm.add %76, %80 : i64
%82 = llvm.mul %79, %5 : i64
%83 = llvm.add %82, %81 : i64
%84 = llvm.getelementptr %arg2[%83] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%85 = llvm.load %84 : !llvm.ptr<1> -> f32
%86 = llvm.add %79, %13 : i64
%87 = llvm.mul %86, %5 : i64
%88 = llvm.add %87, %81 : i64
%89 = llvm.getelementptr %arg2[%88] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%90 = llvm.load %89 : !llvm.ptr<1> -> f32
%91 = llvm.add %79, %14 : i64
%92 = llvm.mul %91, %5 : i64
%93 = llvm.add %92, %81 : i64
%94 = llvm.getelementptr %arg2[%93] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%95 = llvm.load %94 : !llvm.ptr<1> -> f32
%96 = llvm.add %79, %23 : i64
%97 = llvm.mul %96, %5 : i64
%98 = llvm.add %97, %81 : i64
%99 = llvm.getelementptr %arg2[%98] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%100 = llvm.load %99 : !llvm.ptr<1> -> f32
%101 = llvm.extractvalue %8[0] : !llvm.array<1 x vector<4xf32>>
%102 = llvm.insertelement %85, %101[%6 : i64] : vector<4xf32>
%103 = llvm.insertvalue %102, %8[0] : !llvm.array<1 x vector<4xf32>>
%104 = llvm.insertelement %90, %102[%3 : i64] : vector<4xf32>
%105 = llvm.insertvalue %104, %103[0] : !llvm.array<1 x vector<4xf32>>
%106 = llvm.insertelement %95, %104[%2 : i64] : vector<4xf32>
%107 = llvm.insertvalue %106, %105[0] : !llvm.array<1 x vector<4xf32>>
%108 = llvm.insertelement %100, %106[%1 : i64] : vector<4xf32>
%109 = llvm.insertvalue %108, %107[0] : !llvm.array<1 x vector<4xf32>>
%110 = llvm.insertvalue %109, %12[0, 0, 0, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%111 = llvm.add %79, %9 : i64
%112 = llvm.mul %111, %5 : i64
%113 = llvm.add %112, %81 : i64
%114 = llvm.getelementptr %arg2[%113] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%115 = llvm.load %114 : !llvm.ptr<1> -> f32
%116 = llvm.add %79, %24 : i64
%117 = llvm.mul %116, %5 : i64
%118 = llvm.add %117, %81 : i64
%119 = llvm.getelementptr %arg2[%118] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%120 = llvm.load %119 : !llvm.ptr<1> -> f32
%121 = llvm.add %79, %25 : i64
%122 = llvm.mul %121, %5 : i64
%123 = llvm.add %122, %81 : i64
%124 = llvm.getelementptr %arg2[%123] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%125 = llvm.load %124 : !llvm.ptr<1> -> f32
%126 = llvm.add %79, %26 : i64
%127 = llvm.mul %126, %5 : i64
%128 = llvm.add %127, %81 : i64
%129 = llvm.getelementptr %arg2[%128] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%130 = llvm.load %129 : !llvm.ptr<1> -> f32
%131 = llvm.insertelement %115, %101[%6 : i64] : vector<4xf32>
%132 = llvm.insertvalue %131, %8[0] : !llvm.array<1 x vector<4xf32>>
%133 = llvm.insertelement %120, %131[%3 : i64] : vector<4xf32>
%134 = llvm.insertvalue %133, %132[0] : !llvm.array<1 x vector<4xf32>>
%135 = llvm.insertelement %125, %133[%2 : i64] : vector<4xf32>
%136 = llvm.insertvalue %135, %134[0] : !llvm.array<1 x vector<4xf32>>
%137 = llvm.insertelement %130, %135[%1 : i64] : vector<4xf32>
%138 = llvm.insertvalue %137, %136[0] : !llvm.array<1 x vector<4xf32>>
%139 = llvm.insertvalue %138, %110[0, 0, 1, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%140 = llvm.add %79, %18 : i64
%141 = llvm.mul %140, %5 : i64
%142 = llvm.add %141, %81 : i64
%143 = llvm.getelementptr %arg2[%142] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%144 = llvm.load %143 : !llvm.ptr<1> -> f32
%145 = llvm.add %79, %27 : i64
%146 = llvm.mul %145, %5 : i64
%147 = llvm.add %146, %81 : i64
%148 = llvm.getelementptr %arg2[%147] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%149 = llvm.load %148 : !llvm.ptr<1> -> f32
%150 = llvm.add %79, %28 : i64
%151 = llvm.mul %150, %5 : i64
%152 = llvm.add %151, %81 : i64
%153 = llvm.getelementptr %arg2[%152] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%154 = llvm.load %153 : !llvm.ptr<1> -> f32
%155 = llvm.add %79, %29 : i64
%156 = llvm.mul %155, %5 : i64
%157 = llvm.add %156, %81 : i64
%158 = llvm.getelementptr %arg2[%157] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%159 = llvm.load %158 : !llvm.ptr<1> -> f32
%160 = llvm.insertelement %144, %101[%6 : i64] : vector<4xf32>
%161 = llvm.insertvalue %160, %8[0] : !llvm.array<1 x vector<4xf32>>
%162 = llvm.insertelement %149, %160[%3 : i64] : vector<4xf32>
%163 = llvm.insertvalue %162, %161[0] : !llvm.array<1 x vector<4xf32>>
%164 = llvm.insertelement %154, %162[%2 : i64] : vector<4xf32>
%165 = llvm.insertvalue %164, %163[0] : !llvm.array<1 x vector<4xf32>>
%166 = llvm.insertelement %159, %164[%1 : i64] : vector<4xf32>
%167 = llvm.insertvalue %166, %165[0] : !llvm.array<1 x vector<4xf32>>
%168 = llvm.insertvalue %167, %139[0, 0, 2, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%169 = llvm.add %79, %19 : i64
%170 = llvm.mul %169, %5 : i64
%171 = llvm.add %170, %81 : i64
%172 = llvm.getelementptr %arg2[%171] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%173 = llvm.load %172 : !llvm.ptr<1> -> f32
%174 = llvm.add %79, %30 : i64
%175 = llvm.mul %174, %5 : i64
%176 = llvm.add %175, %81 : i64
%177 = llvm.getelementptr %arg2[%176] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%178 = llvm.load %177 : !llvm.ptr<1> -> f32
%179 = llvm.add %79, %31 : i64
%180 = llvm.mul %179, %5 : i64
%181 = llvm.add %180, %81 : i64
%182 = llvm.getelementptr %arg2[%181] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%183 = llvm.load %182 : !llvm.ptr<1> -> f32
%184 = llvm.add %79, %32 : i64
%185 = llvm.mul %184, %5 : i64
%186 = llvm.add %185, %81 : i64
%187 = llvm.getelementptr %arg2[%186] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%188 = llvm.load %187 : !llvm.ptr<1> -> f32
%189 = llvm.insertelement %173, %101[%6 : i64] : vector<4xf32>
%190 = llvm.insertvalue %189, %8[0] : !llvm.array<1 x vector<4xf32>>
%191 = llvm.insertelement %178, %189[%3 : i64] : vector<4xf32>
%192 = llvm.insertvalue %191, %190[0] : !llvm.array<1 x vector<4xf32>>
%193 = llvm.insertelement %183, %191[%2 : i64] : vector<4xf32>
%194 = llvm.insertvalue %193, %192[0] : !llvm.array<1 x vector<4xf32>>
%195 = llvm.insertelement %188, %193[%1 : i64] : vector<4xf32>
%196 = llvm.insertvalue %195, %194[0] : !llvm.array<1 x vector<4xf32>>
%197 = llvm.insertvalue %196, %168[0, 0, 3, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%198 = llvm.sdiv %69, %14 : i64
%199 = llvm.sub %33, %198 : i64
%200 = llvm.select %67, %199, %198 : i1, i64
%201 = llvm.srem %49, %14 : i64
%202 = llvm.icmp "slt" %201, %11 : i64
%203 = llvm.add %201, %14 : i64
%204 = llvm.select %202, %203, %201 : i1, i64
%205 = llvm.add %200, %77 : i64
%206 = llvm.mul %204, %17 : i64
%207 = llvm.sdiv %69, %9 : i64
%208 = llvm.sub %33, %207 : i64
%209 = llvm.select %67, %208, %207 : i1, i64
%210 = llvm.srem %49, %9 : i64
%211 = llvm.icmp "slt" %210, %11 : i64
%212 = llvm.add %210, %9 : i64
%213 = llvm.select %211, %212, %210 : i1, i64
%214 = llvm.mul %213, %17 : i64
%215 = llvm.add %80, %214 : i64
%216 = llvm.add %78, %13 : i64
%217 = llvm.add %78, %14 : i64
%218 = llvm.add %78, %23 : i64
llvm.br ^bb1(%11, %197 : i64, !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>)
^bb1(%219: i64, %220: !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>): // 2 preds: ^bb0, ^bb2
%221 = llvm.icmp "slt" %219, %10 : i64
llvm.cond_br %221, ^bb2, ^bb3
^bb2: // pred: ^bb1
%222 = llvm.add %206, %219 : i64
%223 = llvm.mul %205, %10 : i64
%224 = llvm.add %223, %222 : i64
%225 = llvm.getelementptr %arg0[%224] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f16
%226 = llvm.load %225 {alignment = 2 : i64} : !llvm.ptr<1> -> vector<4xf16>
%227 = llvm.add %209, %219 : i64
%228 = llvm.mul %227, %5 : i64
%229 = llvm.add %228, %215 : i64
%230 = llvm.getelementptr %arg1[%229] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f16
%231 = llvm.load %230 {alignment = 2 : i64} : !llvm.ptr<1> -> vector<4xf16>
%232 = llvm.mul %200, %18 : i64
%233 = llvm.add %232, %206 : i64
%234 = llvm.getelementptr %39[%233] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
llvm.store %226, %234 {alignment = 2 : i64} : vector<4xf16>, !llvm.ptr<3>
llvm.inline_asm has_side_effects asm_dialect = att "s_waitcnt lgkmcnt(0)\0As_barrier", "" : () -> ()
%235 = llvm.mul %209, %34 : i64
%236 = llvm.add %235, %214 : i64
%237 = llvm.getelementptr %37[%236] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
llvm.store %231, %237 {alignment = 2 : i64} : vector<4xf16>, !llvm.ptr<3>
llvm.inline_asm has_side_effects asm_dialect = att "s_waitcnt lgkmcnt(0)\0As_barrier", "" : () -> ()
%238 = llvm.mul %76, %18 : i64
%239 = llvm.add %238, %78 : i64
%240 = llvm.getelementptr %39[%239] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
%241 = llvm.load %240 {alignment = 2 : i64} : !llvm.ptr<3> -> vector<4xf16>
%242 = llvm.mul %78, %34 : i64
%243 = llvm.add %242, %76 : i64
%244 = llvm.getelementptr %37[%243] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
%245 = llvm.load %244 : !llvm.ptr<3> -> f16
%246 = llvm.mul %216, %34 : i64
%247 = llvm.add %246, %76 : i64
%248 = llvm.getelementptr %37[%247] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
%249 = llvm.load %248 : !llvm.ptr<3> -> f16
%250 = llvm.mul %217, %34 : i64
%251 = llvm.add %250, %76 : i64
%252 = llvm.getelementptr %37[%251] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
%253 = llvm.load %252 : !llvm.ptr<3> -> f16
%254 = llvm.mul %218, %34 : i64
%255 = llvm.add %254, %76 : i64
%256 = llvm.getelementptr %37[%255] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16
%257 = llvm.load %256 : !llvm.ptr<3> -> f16
%258 = llvm.extractelement %241[%6 : i64] : vector<4xf16>
%259 = llvm.insertelement %258, %20[%6 : i64] : vector<4xf16>
%260 = llvm.extractelement %241[%3 : i64] : vector<4xf16>
%261 = llvm.insertelement %260, %259[%3 : i64] : vector<4xf16>
%262 = llvm.extractelement %241[%2 : i64] : vector<4xf16>
%263 = llvm.insertelement %262, %261[%2 : i64] : vector<4xf16>
%264 = llvm.extractelement %241[%1 : i64] : vector<4xf16>
%265 = llvm.insertelement %264, %263[%1 : i64] : vector<4xf16>
%266 = llvm.insertelement %245, %20[%6 : i64] : vector<4xf16>
%267 = llvm.insertelement %249, %266[%3 : i64] : vector<4xf16>
%268 = llvm.insertelement %253, %267[%2 : i64] : vector<4xf16>
%269 = llvm.insertelement %257, %268[%1 : i64] : vector<4xf16>
%270 = llvm.extractvalue %220[0, 0, 0, 0, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%271 = llvm.extractelement %270[%6 : i64] : vector<4xf32>
%272 = llvm.extractvalue %21[0] : !llvm.array<4 x vector<4xf32>>
%273 = llvm.insertelement %271, %272[%6 : i64] : vector<4xf32>
%274 = llvm.insertvalue %273, %21[0] : !llvm.array<4 x vector<4xf32>>
%275 = llvm.extractelement %270[%3 : i64] : vector<4xf32>
%276 = llvm.insertelement %275, %273[%3 : i64] : vector<4xf32>
%277 = llvm.insertvalue %276, %274[0] : !llvm.array<4 x vector<4xf32>>
%278 = llvm.extractelement %270[%2 : i64] : vector<4xf32>
%279 = llvm.insertelement %278, %276[%2 : i64] : vector<4xf32>
%280 = llvm.insertvalue %279, %277[0] : !llvm.array<4 x vector<4xf32>>
%281 = llvm.extractelement %270[%1 : i64] : vector<4xf32>
%282 = llvm.insertelement %281, %279[%1 : i64] : vector<4xf32>
%283 = llvm.insertvalue %282, %280[0] : !llvm.array<4 x vector<4xf32>>
%284 = llvm.extractvalue %220[0, 0, 1, 0, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%285 = llvm.extractelement %284[%6 : i64] : vector<4xf32>
%286 = llvm.extractvalue %21[1] : !llvm.array<4 x vector<4xf32>>
%287 = llvm.insertelement %285, %286[%6 : i64] : vector<4xf32>
%288 = llvm.insertvalue %287, %283[1] : !llvm.array<4 x vector<4xf32>>
%289 = llvm.extractelement %284[%3 : i64] : vector<4xf32>
%290 = llvm.insertelement %289, %287[%3 : i64] : vector<4xf32>
%291 = llvm.insertvalue %290, %288[1] : !llvm.array<4 x vector<4xf32>>
%292 = llvm.extractelement %284[%2 : i64] : vector<4xf32>
%293 = llvm.insertelement %292, %290[%2 : i64] : vector<4xf32>
%294 = llvm.insertvalue %293, %291[1] : !llvm.array<4 x vector<4xf32>>
%295 = llvm.extractelement %284[%1 : i64] : vector<4xf32>
%296 = llvm.insertelement %295, %293[%1 : i64] : vector<4xf32>
%297 = llvm.insertvalue %296, %294[1] : !llvm.array<4 x vector<4xf32>>
%298 = llvm.extractvalue %220[0, 0, 2, 0, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%299 = llvm.extractelement %298[%6 : i64] : vector<4xf32>
%300 = llvm.extractvalue %21[2] : !llvm.array<4 x vector<4xf32>>
%301 = llvm.insertelement %299, %300[%6 : i64] : vector<4xf32>
%302 = llvm.insertvalue %301, %297[2] : !llvm.array<4 x vector<4xf32>>
%303 = llvm.extractelement %298[%3 : i64] : vector<4xf32>
%304 = llvm.insertelement %303, %301[%3 : i64] : vector<4xf32>
%305 = llvm.insertvalue %304, %302[2] : !llvm.array<4 x vector<4xf32>>
%306 = llvm.extractelement %298[%2 : i64] : vector<4xf32>
%307 = llvm.insertelement %306, %304[%2 : i64] : vector<4xf32>
%308 = llvm.insertvalue %307, %305[2] : !llvm.array<4 x vector<4xf32>>
%309 = llvm.extractelement %298[%1 : i64] : vector<4xf32>
%310 = llvm.insertelement %309, %307[%1 : i64] : vector<4xf32>
%311 = llvm.insertvalue %310, %308[2] : !llvm.array<4 x vector<4xf32>>
%312 = llvm.extractvalue %220[0, 0, 3, 0, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%313 = llvm.extractelement %312[%6 : i64] : vector<4xf32>
%314 = llvm.extractvalue %21[3] : !llvm.array<4 x vector<4xf32>>
%315 = llvm.insertelement %313, %314[%6 : i64] : vector<4xf32>
%316 = llvm.insertvalue %315, %311[3] : !llvm.array<4 x vector<4xf32>>
%317 = llvm.extractelement %312[%3 : i64] : vector<4xf32>
%318 = llvm.insertelement %317, %315[%3 : i64] : vector<4xf32>
%319 = llvm.insertvalue %318, %316[3] : !llvm.array<4 x vector<4xf32>>
%320 = llvm.extractelement %312[%2 : i64] : vector<4xf32>
%321 = llvm.insertelement %320, %318[%2 : i64] : vector<4xf32>
%322 = llvm.insertvalue %321, %319[3] : !llvm.array<4 x vector<4xf32>>
%323 = llvm.extractelement %312[%1 : i64] : vector<4xf32>
%324 = llvm.insertelement %323, %321[%1 : i64] : vector<4xf32>
%325 = llvm.insertvalue %324, %322[3] : !llvm.array<4 x vector<4xf32>>
%326 = rocdl.mfma.f32.32x32x8f16 %265, %269, %325, %0, %0, %0 : (vector<4xf16>, vector<4xf16>, !llvm.array<4 x vector<4xf32>>, i32, i32, i32) -> !llvm.array<4 x vector<4xf32>>
%327 = llvm.extractvalue %326[0] : !llvm.array<4 x vector<4xf32>>
%328 = llvm.extractelement %327[%6 : i64] : vector<4xf32>
%329 = llvm.extractvalue %22[0, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%330 = llvm.insertelement %328, %329[%6 : i64] : vector<4xf32>
%331 = llvm.insertvalue %330, %22[0, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%332 = llvm.extractelement %327[%3 : i64] : vector<4xf32>
%333 = llvm.insertelement %332, %330[%3 : i64] : vector<4xf32>
%334 = llvm.insertvalue %333, %331[0, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%335 = llvm.extractelement %327[%2 : i64] : vector<4xf32>
%336 = llvm.insertelement %335, %333[%2 : i64] : vector<4xf32>
%337 = llvm.insertvalue %336, %334[0, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%338 = llvm.extractelement %327[%1 : i64] : vector<4xf32>
%339 = llvm.insertelement %338, %336[%1 : i64] : vector<4xf32>
%340 = llvm.insertvalue %339, %337[0, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%341 = llvm.extractvalue %326[1] : !llvm.array<4 x vector<4xf32>>
%342 = llvm.extractelement %341[%6 : i64] : vector<4xf32>
%343 = llvm.extractvalue %22[1, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%344 = llvm.insertelement %342, %343[%6 : i64] : vector<4xf32>
%345 = llvm.insertvalue %344, %340[1, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%346 = llvm.extractelement %341[%3 : i64] : vector<4xf32>
%347 = llvm.insertelement %346, %344[%3 : i64] : vector<4xf32>
%348 = llvm.insertvalue %347, %345[1, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%349 = llvm.extractelement %341[%2 : i64] : vector<4xf32>
%350 = llvm.insertelement %349, %347[%2 : i64] : vector<4xf32>
%351 = llvm.insertvalue %350, %348[1, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%352 = llvm.extractelement %341[%1 : i64] : vector<4xf32>
%353 = llvm.insertelement %352, %350[%1 : i64] : vector<4xf32>
%354 = llvm.insertvalue %353, %351[1, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%355 = llvm.extractvalue %326[2] : !llvm.array<4 x vector<4xf32>>
%356 = llvm.extractelement %355[%6 : i64] : vector<4xf32>
%357 = llvm.extractvalue %22[2, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%358 = llvm.insertelement %356, %357[%6 : i64] : vector<4xf32>
%359 = llvm.insertvalue %358, %354[2, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%360 = llvm.extractelement %355[%3 : i64] : vector<4xf32>
%361 = llvm.insertelement %360, %358[%3 : i64] : vector<4xf32>
%362 = llvm.insertvalue %361, %359[2, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%363 = llvm.extractelement %355[%2 : i64] : vector<4xf32>
%364 = llvm.insertelement %363, %361[%2 : i64] : vector<4xf32>
%365 = llvm.insertvalue %364, %362[2, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%366 = llvm.extractelement %355[%1 : i64] : vector<4xf32>
%367 = llvm.insertelement %366, %364[%1 : i64] : vector<4xf32>
%368 = llvm.insertvalue %367, %365[2, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%369 = llvm.extractvalue %326[3] : !llvm.array<4 x vector<4xf32>>
%370 = llvm.extractelement %369[%6 : i64] : vector<4xf32>
%371 = llvm.extractvalue %22[3, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%372 = llvm.insertelement %370, %371[%6 : i64] : vector<4xf32>
%373 = llvm.insertvalue %372, %368[3, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%374 = llvm.extractelement %369[%3 : i64] : vector<4xf32>
%375 = llvm.insertelement %374, %372[%3 : i64] : vector<4xf32>
%376 = llvm.insertvalue %375, %373[3, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%377 = llvm.extractelement %369[%2 : i64] : vector<4xf32>
%378 = llvm.insertelement %377, %375[%2 : i64] : vector<4xf32>
%379 = llvm.insertvalue %378, %376[3, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%380 = llvm.extractelement %369[%1 : i64] : vector<4xf32>
%381 = llvm.insertelement %380, %378[%1 : i64] : vector<4xf32>
%382 = llvm.insertvalue %381, %379[3, 0, 0] : !llvm.array<4 x array<1 x array<1 x vector<4xf32>>>>
%383 = llvm.insertvalue %382, %7[0] : !llvm.array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>
%384 = llvm.insertvalue %383, %12[0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%385 = llvm.add %219, %9 : i64
llvm.br ^bb1(%385, %384 : i64, !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>)
^bb3: // pred: ^bb1
%386 = llvm.extractvalue %220[0, 0, 0, 0, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%387 = llvm.extractelement %386[%6 : i64] : vector<4xf32>
llvm.store %387, %84 : f32, !llvm.ptr<1>
%388 = llvm.extractelement %386[%3 : i64] : vector<4xf32>
llvm.store %388, %89 : f32, !llvm.ptr<1>
%389 = llvm.extractelement %386[%2 : i64] : vector<4xf32>
llvm.store %389, %94 : f32, !llvm.ptr<1>
%390 = llvm.extractelement %386[%1 : i64] : vector<4xf32>
llvm.store %390, %99 : f32, !llvm.ptr<1>
%391 = llvm.extractvalue %220[0, 0, 1, 0, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%392 = llvm.extractelement %391[%6 : i64] : vector<4xf32>
llvm.store %392, %114 : f32, !llvm.ptr<1>
%393 = llvm.extractelement %391[%3 : i64] : vector<4xf32>
llvm.store %393, %119 : f32, !llvm.ptr<1>
%394 = llvm.extractelement %391[%2 : i64] : vector<4xf32>
llvm.store %394, %124 : f32, !llvm.ptr<1>
%395 = llvm.extractelement %391[%1 : i64] : vector<4xf32>
llvm.store %395, %129 : f32, !llvm.ptr<1>
%396 = llvm.extractvalue %220[0, 0, 2, 0, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%397 = llvm.extractelement %396[%6 : i64] : vector<4xf32>
llvm.store %397, %143 : f32, !llvm.ptr<1>
%398 = llvm.extractelement %396[%3 : i64] : vector<4xf32>
llvm.store %398, %148 : f32, !llvm.ptr<1>
%399 = llvm.extractelement %396[%2 : i64] : vector<4xf32>
llvm.store %399, %153 : f32, !llvm.ptr<1>
%400 = llvm.extractelement %396[%1 : i64] : vector<4xf32>
llvm.store %400, %158 : f32, !llvm.ptr<1>
%401 = llvm.extractvalue %220[0, 0, 3, 0, 0] : !llvm.array<1 x array<1 x array<4 x array<1 x array<1 x vector<4xf32>>>>>>
%402 = llvm.extractelement %401[%6 : i64] : vector<4xf32>
llvm.store %402, %172 : f32, !llvm.ptr<1>
%403 = llvm.extractelement %401[%3 : i64] : vector<4xf32>
llvm.store %403, %177 : f32, !llvm.ptr<1>
%404 = llvm.extractelement %401[%2 : i64] : vector<4xf32>
llvm.store %404, %182 : f32, !llvm.ptr<1>
%405 = llvm.extractelement %401[%1 : i64] : vector<4xf32>
llvm.store %405, %187 : f32, !llvm.ptr<1>
llvm.return
}
}
// tools/iree-compile --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942
#compilation2 = #iree_codegen.compilation_info<
lowering_config = <tile_sizes = [[32, 32, 8]]>,
translation_info = <LLVMGPUVectorDistribute,
{ pipeline_depth = 0, store_stage = 1, mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mfma_layout<F16_32x32x8_F32>, subgroup_m_count = 1, subgroup_n_count = 1, subgroup_m_tile_count = 1, subgroup_n_tile_count = 1, subgroup_k_tile_count = 1> }>,
workgroup_size = [64, 1, 1]>
func.func @matmul_accumulate_512x128xf16_times_128x512xf16_into_512x512xf32_for_LLVMGPUVectorDistribute_32_32_8_64_1_1(%lhs: tensor<512x128xf16>, %rhs: tensor<128x512xf16>, %acc: tensor<512x512xf32>) -> tensor<512x512xf32> {
%result = linalg.matmul {compilation_info = #compilation2} ins(%lhs, %rhs: tensor<512x128xf16>, tensor<128x512xf16>) outs(%acc: tensor<512x512xf32>) -> tensor<512x512xf32>
return %result: tensor<512x512xf32>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment