Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created February 6, 2025 16:55
Show Gist options
  • Save pashu123/9848fa6cc6b2b8cdbdad4cfa98dfccfc to your computer and use it in GitHub Desktop.
Save pashu123/9848fa6cc6b2b8cdbdad4cfa98dfccfc to your computer and use it in GitHub Desktop.
func.func @decode_bs1$async_dispatch_21_attention_8x4xDx1x2x64xf8E4M3FNUZ_generic() {
%c127_i32 = arith.constant 127 : i32
%c23_i32 = arith.constant 23 : i32
%cst = arith.constant 1.270000e+02 : f32
%cst_0 = arith.constant -1.270000e+02 : f32
%cst_1 = arith.constant 8.880000e+01 : f32
%cst_2 = arith.constant -8.780000e+01 : f32
%cst_3 = arith.constant 0.166666657 : f32
%cst_4 = arith.constant 0.0416657962 : f32
%cst_5 = arith.constant 0.00833345205 : f32
%cst_6 = arith.constant 0.00139819994 : f32
%cst_7 = arith.constant 1.98756912E-4 : f32
%cst_8 = arith.constant 2.12194442E-4 : f32
%cst_9 = arith.constant -0.693359375 : f32
%cst_10 = arith.constant 1.000000e+00 : f32
%cst_11 = arith.constant 5.000000e-01 : f32
%cst_12 = arith.constant 0.693147182 : f32
%cst_13 = arith.constant dense<0.000000e+00> : vector<f32>
%cst_14 = arith.constant dense<-3.40282347E+38> : vector<f32>
%cst_15 = arith.constant dense<64> : vector<1x2x1x1x1x1xindex>
%cst_16 = arith.constant dense<2.400000e+02> : vector<1x1x1xf32>
%cst_17 = arith.constant dense<-2.400000e+02> : vector<1x1x1xf32>
%cst_18 = arith.constant dense<1.000000e+00> : vector<1x1x1xf32>
%cst_19 = arith.constant dense<0.000000e+00> : vector<1x1x1xf32>
%cst_20 = arith.constant dense<0.000000e+00> : vector<1x2x1x1x1x1xf8E4M3FNUZ>
%cst_21 = arith.constant dense<true> : vector<1x2x1x1x1x1xi1>
%cst_22 = arith.constant 0.00416666688 : f32
%cst_23 = arith.constant 1.44269502 : f32
%cst_24 = arith.constant 0.000000e+00 : f32
%c67117632 = arith.constant 67117632 : index
%c32_i64 = arith.constant 32 : i64
%cst_25 = arith.constant 2.400000e+02 : f32
%c0 = arith.constant 0 : index
%c67108928 = arith.constant 67108928 : index
%c67109184 = arith.constant 67109184 : index
%c0_i32 = arith.constant 0 : i32
%cst_26 = arith.constant dense<2> : vector<1x2x1x1x1x1xi32>
%cst_27 = arith.constant dense<1> : vector<2x1x1x1x1x1xi32>
%cst_28 = arith.constant dense<[0, 32]> : vector<2xi32>
%c16_i32 = arith.constant 16 : i32
%c8_i32 = arith.constant 8 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c1_i32 = arith.constant 1 : i32
%c64_i32 = arith.constant 64 : i32
%c32_i32 = arith.constant 32 : i32
%cst_29 = arith.constant dense<0.000000e+00> : vector<1xf8E4M3FNUZ>
%cst_30 = arith.constant dense<0> : vector<2x1x1xindex>
%cst_31 = arith.constant dense<0.000000e+00> : vector<1x1x1xf8E4M3FNUZ>
%cst_32 = arith.constant dense<0.000000e+00> : vector<1xf32>
%cst_33 = arith.constant dense<0> : vector<1x2x1x1x1x1xindex>
%cst_34 = arith.constant dense<0> : vector<2x1x1x1x1x1xindex>
%cst_35 = arith.constant dense<0> : vector<1x2x1x1x1x1xi16>
%thread_id_x = gpu.thread_id x upper_bound 256
%0 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
%2 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
%3 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32
%4 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(4) : i32
%5 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(5) : i32
%6 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(6) : i32
%7 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(7) : i32
%8 = arith.extui %3 : i32 to i64
%9 = arith.extui %4 : i32 to i64
%10 = arith.shli %9, %c32_i64 : i64
%11 = arith.ori %8, %10 : i64
%12 = arith.index_castui %11 {stream.alignment = 64 : index, stream.values = [1075847616 : index, 1293968512 : index, 1512089408 : index, 1730210304 : index, 1948331200 : index, 2166452096 : index, 2384572992 : index, 2602693888 : index, 2820814784 : index, 3038935680 : index, 3257056576 : index, 3475177472 : index, 3693298368 : index, 3911419264 : index, 4129540160 : index, 4347661056 : index, 4565781952 : index, 4783902848 : index, 5002023744 : index, 5220144640 : index, 5438265536 : index, 5656386432 : index, 5874507328 : index, 6092628224 : index, 6310749120 : index, 6528870016 : index, 6746990912 : index, 6965111808 : index, 7183232704 : index, 7401353600 : index, 7619474496 : index, 7837595392 : index]} : i64 to index
%13 = arith.bitcast %6 : i32 to f32
%14 = arith.index_castui %0 : i32 to index
%15 = arith.index_castui %1 : i32 to index
%16 = arith.index_castui %2 : i32 to index
%17 = arith.index_castui %5 : i32 to index
%18 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%14) flags("ReadOnly|Indirect") : memref<1x32x1x2x64xf8E4M3FNUZ, strided<[4096, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>
memref.assume_alignment %18, 64 : memref<1x32x1x2x64xf8E4M3FNUZ, strided<[4096, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>
%19 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%14) flags("ReadOnly|Indirect") : memref<8x4x1x1x1x2x64xf8E4M3FNUZ, strided<[512, 128, 128, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>
memref.assume_alignment %19, 64 : memref<8x4x1x1x1x2x64xf8E4M3FNUZ, strided<[512, 128, 128, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>
%20 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c67108928) flags("ReadOnly|Indirect") : memref<1x2x64xi16, strided<[128, 64, 1], offset: 33554464>, #gpu.address_space<global>>
memref.assume_alignment %20, 64 : memref<1x2x64xi16, strided<[128, 64, 1], offset: 33554464>, #gpu.address_space<global>>
%21 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c67109184) flags("ReadOnly|Indirect") : memref<1x2x64xi16, strided<[128, 64, 1], offset: 33554592>, #gpu.address_space<global>>
memref.assume_alignment %21, 64 : memref<1x2x64xi16, strided<[128, 64, 1], offset: 33554592>, #gpu.address_space<global>>
%22 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%12) flags(ReadOnly) : memref<f32, strided<[], offset: ?>, #gpu.address_space<global>>
memref.assume_alignment %22, 64 : memref<f32, strided<[], offset: ?>, #gpu.address_space<global>>
%23 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%17) flags(Indirect) : memref<8x4x1x128xf8E4M3FNUZ, strided<[512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>
memref.assume_alignment %23, 64 : memref<8x4x1x128xf8E4M3FNUZ, strided<[512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>
%24 = arith.divui %7, %c32_i32 : i32
%25 = arith.index_castui %24 : i32 to index
%26 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c67117632) flags("ReadOnly|Indirect") : memref<1x?x32xf8E4M3FNUZ, strided<[?, 32, 1], offset: 67117632>, #gpu.address_space<global>>{%25}
memref.assume_alignment %26, 64 : memref<1x?x32xf8E4M3FNUZ, strided<[?, 32, 1], offset: 67117632>, #gpu.address_space<global>>
%27 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%15) flags("ReadOnly|Indirect") : memref<8x4x?x32x1x2x64xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>{%25}
memref.assume_alignment %27, 1 : memref<8x4x?x32x1x2x64xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>
%28 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%16) flags("ReadOnly|Indirect") : memref<8x4x128x?x32xf8E4M3FNUZ, strided<[?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>{%25}
memref.assume_alignment %28, 1 : memref<8x4x128x?x32xf8E4M3FNUZ, strided<[?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>
%workgroup_id_x = hal.interface.workgroup.id[0] upper_bound 32 : index
%workgroup_id_y = hal.interface.workgroup.id[1] upper_bound 4 : index
%workgroup_id_z = hal.interface.workgroup.id[2] upper_bound 8 : index
%29 = arith.index_castui %thread_id_x : index to i32
%30 = arith.remui %29, %c64_i32 : i32
%31 = arith.divui %30, %c32_i32 : i32
%32 = arith.index_castui %31 : i32 to index
%33 = arith.remui %29, %c32_i32 : i32
%34 = arith.index_castui %33 : i32 to index
%35 = vector.load %19[%workgroup_id_z, %workgroup_id_y, %c0, %c0, %c0, %32, %34] : memref<8x4x1x1x1x2x64xf8E4M3FNUZ, strided<[512, 128, 128, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>, vector<1xf8E4M3FNUZ>
%36 = vector.broadcast %35 : vector<1xf8E4M3FNUZ> to vector<1x1xf8E4M3FNUZ>
%37 = vector.insert_strided_slice %36, %cst_20 {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x1xf8E4M3FNUZ> into vector<1x2x1x1x1x1xf8E4M3FNUZ>
%38 = arith.addi %33, %c32_i32 : i32
%39 = arith.index_castui %38 : i32 to index
%40 = vector.load %19[%workgroup_id_z, %workgroup_id_y, %c0, %c0, %c0, %32, %39] : memref<8x4x1x1x1x2x64xf8E4M3FNUZ, strided<[512, 128, 128, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>, vector<1xf8E4M3FNUZ>
%41 = vector.broadcast %40 : vector<1xf8E4M3FNUZ> to vector<1x1xf8E4M3FNUZ>
%42 = vector.insert_strided_slice %41, %37 {offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]} : vector<1x1xf8E4M3FNUZ> into vector<1x2x1x1x1x1xf8E4M3FNUZ>
%43 = vector.load %20[%c0, %32, %34] : memref<1x2x64xi16, strided<[128, 64, 1], offset: 33554464>, #gpu.address_space<global>>, vector<1xi16>
%44 = vector.broadcast %43 : vector<1xi16> to vector<1x1xi16>
%45 = vector.insert_strided_slice %44, %cst_35 {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x1xi16> into vector<1x2x1x1x1x1xi16>
%46 = vector.load %20[%c0, %32, %39] : memref<1x2x64xi16, strided<[128, 64, 1], offset: 33554464>, #gpu.address_space<global>>, vector<1xi16>
%47 = vector.broadcast %46 : vector<1xi16> to vector<1x1xi16>
%48 = vector.insert_strided_slice %47, %45 {offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]} : vector<1x1xi16> into vector<1x2x1x1x1x1xi16>
%49 = arith.bitcast %48 : vector<1x2x1x1x1x1xi16> to vector<1x2x1x1x1x1xbf16>
%50 = arith.extf %49 : vector<1x2x1x1x1x1xbf16> to vector<1x2x1x1x1x1xf32>
%51 = vector.load %21[%c0, %32, %34] : memref<1x2x64xi16, strided<[128, 64, 1], offset: 33554592>, #gpu.address_space<global>>, vector<1xi16>
%52 = vector.broadcast %51 : vector<1xi16> to vector<1x1xi16>
%53 = vector.insert_strided_slice %52, %cst_35 {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x1xi16> into vector<1x2x1x1x1x1xi16>
%54 = vector.load %21[%c0, %32, %39] : memref<1x2x64xi16, strided<[128, 64, 1], offset: 33554592>, #gpu.address_space<global>>, vector<1xi16>
%55 = vector.broadcast %54 : vector<1xi16> to vector<1x1xi16>
%56 = vector.insert_strided_slice %55, %53 {offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]} : vector<1x1xi16> into vector<1x2x1x1x1x1xi16>
%57 = arith.bitcast %56 : vector<1x2x1x1x1x1xi16> to vector<1x2x1x1x1x1xbf16>
%58 = arith.extf %57 : vector<1x2x1x1x1x1xbf16> to vector<1x2x1x1x1x1xf32>
%59 = vector.splat %32 : vector<1x1xindex>
%60 = vector.insert %59, %cst_34 [0, 0, 0, 0] : vector<1x1xindex> into vector<2x1x1x1x1x1xindex>
%61 = vector.insert %59, %60 [1, 0, 0, 0] : vector<1x1xindex> into vector<2x1x1x1x1x1xindex>
%62 = vector.splat %34 : vector<2xindex>
%63 = arith.index_castui %62 : vector<2xindex> to vector<2xi32>
%64 = arith.addi %63, %cst_28 : vector<2xi32>
%65 = arith.index_castui %64 : vector<2xi32> to vector<2xindex>
%66 = vector.extract %65[0] : index from vector<2xindex>
%67 = vector.insert %66, %cst_30 [0, 0, 0] : index into vector<2x1x1xindex>
%68 = vector.extract %65[1] : index from vector<2xindex>
%69 = vector.insert %68, %67 [1, 0, 0] : index into vector<2x1x1xindex>
%70 = arith.index_castui %workgroup_id_z : index to i32
%71 = arith.muli %70, %c4_i32 overflow<nsw> : i32
%72 = arith.index_castui %workgroup_id_y : index to i32
%73 = arith.addi %71, %72 : i32
%74 = arith.index_castui %73 : i32 to index
%75 = arith.index_castui %61 : vector<2x1x1x1x1x1xindex> to vector<2x1x1x1x1x1xi32>
%76 = arith.subi %cst_27, %75 : vector<2x1x1x1x1x1xi32>
%77 = arith.index_castui %76 : vector<2x1x1x1x1x1xi32> to vector<2x1x1x1x1x1xindex>
%78 = vector.transpose %77, [1, 0, 3, 2, 5, 4] : vector<2x1x1x1x1x1xindex> to vector<1x2x1x1x1x1xindex>
%79 = vector.splat %74 : vector<1x2x1x1x1x1xindex>
%80 = arith.index_castui %79 : vector<1x2x1x1x1x1xindex> to vector<1x2x1x1x1x1xi32>
%81 = arith.muli %80, %cst_26 : vector<1x2x1x1x1x1xi32>
%82 = arith.index_castui %81 : vector<1x2x1x1x1x1xi32> to vector<1x2x1x1x1x1xindex>
%83 = arith.addi %78, %82 : vector<1x2x1x1x1x1xindex>
%84 = arith.muli %83, %cst_15 : vector<1x2x1x1x1x1xindex>
%85 = vector.extract %69[0, 0] : vector<1xindex> from vector<2x1x1xindex>
%86 = vector.broadcast %85 : vector<1xindex> to vector<1x1xindex>
%87 = vector.insert %86, %cst_33 [0, 0, 0, 0] : vector<1x1xindex> into vector<1x2x1x1x1x1xindex>
%88 = vector.extract %69[1, 0] : vector<1xindex> from vector<2x1x1xindex>
%89 = vector.broadcast %88 : vector<1xindex> to vector<1x1xindex>
%90 = vector.insert %89, %87 [0, 1, 0, 0] : vector<1x1xindex> into vector<1x2x1x1x1x1xindex>
%91 = arith.addi %90, %84 : vector<1x2x1x1x1x1xindex>
%92 = vector.gather %18[%c0, %c0, %c0, %c0, %c0] [%91], %cst_21, %cst_20 : memref<1x32x1x2x64xf8E4M3FNUZ, strided<[4096, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>, vector<1x2x1x1x1x1xindex>, vector<1x2x1x1x1x1xi1>, vector<1x2x1x1x1x1xf8E4M3FNUZ> into vector<1x2x1x1x1x1xf8E4M3FNUZ>
%93 = arith.negf %92 : vector<1x2x1x1x1x1xf8E4M3FNUZ>
%94 = arith.cmpi eq, %76, %cst_27 : vector<2x1x1x1x1x1xi32>
%95 = vector.transpose %94, [1, 0, 3, 2, 5, 4] : vector<2x1x1x1x1x1xi1> to vector<1x2x1x1x1x1xi1>
%96 = arith.select %95, %93, %92 : vector<1x2x1x1x1x1xi1>, vector<1x2x1x1x1x1xf8E4M3FNUZ>
%97 = arith.truncf %58 : vector<1x2x1x1x1x1xf32> to vector<1x2x1x1x1x1xf8E4M3FNUZ>
%98 = arith.mulf %96, %97 : vector<1x2x1x1x1x1xf8E4M3FNUZ>
%99 = arith.truncf %50 : vector<1x2x1x1x1x1xf32> to vector<1x2x1x1x1x1xf8E4M3FNUZ>
%100 = arith.mulf %42, %99 : vector<1x2x1x1x1x1xf8E4M3FNUZ>
%101 = arith.addf %100, %98 : vector<1x2x1x1x1x1xf8E4M3FNUZ>
%102 = arith.muli %24, %c32_i32 overflow<nsw> : i32
%103 = arith.mulf %13, %cst_23 : f32
cf.br ^bb1(%c0_i32, %cst_14, %cst_13, %cst_19 : i32, vector<f32>, vector<f32>, vector<1x1x1xf32>)
^bb1(%104: i32, %105: vector<f32>, %106: vector<f32>, %107: vector<1x1x1xf32>): // 2 preds: ^bb0, ^bb5
%108 = arith.cmpi slt, %104, %102 : i32
cf.cond_br %108, ^bb2, ^bb6
^bb2: // pred: ^bb1
%109 = arith.divui %104, %c32_i32 : i32
%110 = arith.index_castui %109 : i32 to index
%111 = arith.remui %104, %c32_i32 : i32
%112 = arith.index_castui %111 : i32 to index
%113 = vector.load %27[%workgroup_id_z, %workgroup_id_y, %110, %112, %c0, %32, %34] : memref<8x4x?x32x1x2x64xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>, vector<1xf8E4M3FNUZ>
%114 = vector.broadcast %113 : vector<1xf8E4M3FNUZ> to vector<1x1xf8E4M3FNUZ>
%115 = vector.insert_strided_slice %114, %cst_20 {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x1xf8E4M3FNUZ> into vector<1x2x1x1x1x1xf8E4M3FNUZ>
%116 = vector.load %27[%workgroup_id_z, %workgroup_id_y, %110, %112, %c0, %32, %39] : memref<8x4x?x32x1x2x64xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 128, 64, 1], offset: ?>, #gpu.address_space<global>>, vector<1xf8E4M3FNUZ>
%117 = vector.broadcast %116 : vector<1xf8E4M3FNUZ> to vector<1x1xf8E4M3FNUZ>
%118 = vector.insert_strided_slice %117, %115 {offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]} : vector<1x1xf8E4M3FNUZ> into vector<1x2x1x1x1x1xf8E4M3FNUZ>
%119 = arith.extf %101 : vector<1x2x1x1x1x1xf8E4M3FNUZ> to vector<1x2x1x1x1x1xf32>
%120 = arith.extf %118 : vector<1x2x1x1x1x1xf8E4M3FNUZ> to vector<1x2x1x1x1x1xf32>
%121 = vector.extract %119[0, 0, 0, 0, 0] : vector<1xf32> from vector<1x2x1x1x1x1xf32>
%122 = vector.extract %120[0, 0, 0, 0, 0] : vector<1xf32> from vector<1x2x1x1x1x1xf32>
%123 = arith.mulf %121, %122 : vector<1xf32>
%124 = vector.extract %123[0] : f32 from vector<1xf32>
%125 = arith.addf %124, %cst_24 : f32
%126 = vector.extract %119[0, 1, 0, 0, 0] : vector<1xf32> from vector<1x2x1x1x1x1xf32>
%127 = vector.extract %120[0, 1, 0, 0, 0] : vector<1xf32> from vector<1x2x1x1x1x1xf32>
%128 = arith.mulf %126, %127 : vector<1xf32>
%129 = vector.extract %128[0] : f32 from vector<1xf32>
%130 = arith.addf %129, %125 : f32
%shuffleResult, %valid = gpu.shuffle xor %130, %c32_i32, %c64_i32 : f32
%131 = arith.addf %130, %shuffleResult : f32
%shuffleResult_36, %valid_37 = gpu.shuffle xor %131, %c1_i32, %c64_i32 : f32
%132 = arith.addf %131, %shuffleResult_36 : f32
%shuffleResult_38, %valid_39 = gpu.shuffle xor %132, %c2_i32, %c64_i32 : f32
%133 = arith.addf %132, %shuffleResult_38 : f32
%shuffleResult_40, %valid_41 = gpu.shuffle xor %133, %c4_i32, %c64_i32 : f32
%134 = arith.addf %133, %shuffleResult_40 : f32
%shuffleResult_42, %valid_43 = gpu.shuffle xor %134, %c8_i32, %c64_i32 : f32
%135 = arith.addf %134, %shuffleResult_42 : f32
%shuffleResult_44, %valid_45 = gpu.shuffle xor %135, %c16_i32, %c64_i32 : f32
%136 = arith.addf %135, %shuffleResult_44 : f32
%137 = vector.insert %136, %cst_32 [0] : f32 into vector<1xf32>
%138 = arith.addf %137, %cst_32 : vector<1xf32>
%139 = vector.extract %138[0] : f32 from vector<1xf32>
%140 = arith.mulf %103, %139 : f32
%141 = arith.addf %140, %cst_22 : f32
%142 = vector.load %26[%c0, %110, %112] : memref<1x?x32xf8E4M3FNUZ, strided<[?, 32, 1], offset: 67117632>, #gpu.address_space<global>>, vector<f8E4M3FNUZ>
%143 = vector.extract %142[] : f8E4M3FNUZ from vector<f8E4M3FNUZ>
%144 = arith.extf %143 : f8E4M3FNUZ to f32
%145 = arith.mulf %144, %cst_23 : f32
%146 = arith.addf %141, %145 : f32
%147 = vector.extract %105[] : f32 from vector<f32>
%148 = arith.maximumf %146, %147 : f32
%149 = vector.splat %148 : vector<f32>
%150 = arith.subf %147, %148 : f32
%151 = arith.mulf %150, %cst_12 : f32
%152 = arith.cmpf uge, %151, %cst_2 : f32
%153 = arith.select %152, %151, %cst_2 : f32
%154 = arith.cmpf ule, %153, %cst_1 : f32
%155 = arith.select %154, %153, %cst_1 : f32
%156 = math.fma %155, %cst_23, %cst_11 : f32
%157 = math.floor %156 : f32
%158 = arith.cmpf uge, %157, %cst_0 : f32
%159 = arith.select %158, %157, %cst_0 : f32
%160 = arith.cmpf ule, %159, %cst : f32
%161 = arith.select %160, %159, %cst : f32
%162 = math.fma %cst_9, %161, %155 : f32
%163 = math.fma %cst_8, %161, %162 : f32
%164 = math.fma %163, %cst_7, %cst_6 : f32
%165 = math.fma %164, %163, %cst_5 : f32
%166 = math.fma %165, %163, %cst_4 : f32
%167 = math.fma %166, %163, %cst_3 : f32
%168 = math.fma %167, %163, %cst_11 : f32
%169 = arith.mulf %163, %163 : f32
%170 = math.fma %168, %169, %163 : f32
%171 = arith.addf %170, %cst_10 : f32
%172 = arith.fptosi %161 : f32 to i32
%173 = arith.addi %172, %c127_i32 : i32
%174 = arith.shli %173, %c23_i32 : i32
%175 = arith.bitcast %174 : i32 to f32
%176 = arith.mulf %171, %175 : f32
%177 = vector.extract %106[] : f32 from vector<f32>
%178 = arith.mulf %176, %177 : f32
%179 = arith.subf %146, %148 : f32
%180 = arith.mulf %179, %cst_12 : f32
%181 = arith.cmpf uge, %180, %cst_2 : f32
%182 = arith.select %181, %180, %cst_2 : f32
%183 = arith.cmpf ule, %182, %cst_1 : f32
%184 = arith.select %183, %182, %cst_1 : f32
%185 = math.fma %184, %cst_23, %cst_11 : f32
%186 = math.floor %185 : f32
%187 = arith.cmpf uge, %186, %cst_0 : f32
%188 = arith.select %187, %186, %cst_0 : f32
%189 = arith.cmpf ule, %188, %cst : f32
%190 = arith.select %189, %188, %cst : f32
%191 = math.fma %cst_9, %190, %184 : f32
%192 = math.fma %cst_8, %190, %191 : f32
%193 = math.fma %192, %cst_7, %cst_6 : f32
%194 = math.fma %193, %192, %cst_5 : f32
%195 = math.fma %194, %192, %cst_4 : f32
%196 = math.fma %195, %192, %cst_3 : f32
%197 = math.fma %196, %192, %cst_11 : f32
%198 = arith.mulf %192, %192 : f32
%199 = math.fma %197, %198, %192 : f32
%200 = arith.addf %199, %cst_10 : f32
%201 = arith.fptosi %190 : f32 to i32
%202 = arith.addi %201, %c127_i32 : i32
%203 = arith.shli %202, %c23_i32 : i32
%204 = arith.bitcast %203 : i32 to f32
%205 = arith.mulf %200, %204 : f32
%206 = arith.addf %205, %178 : f32
%207 = vector.splat %206 : vector<f32>
%208 = arith.minimumf %205, %cst_25 : f32
%209 = arith.truncf %208 : f32 to f8E4M3FNUZ
%210 = vector.splat %209 : vector<f8E4M3FNUZ>
%211 = vector.splat %176 : vector<1x1x1xf32>
%212 = arith.mulf %211, %107 : vector<1x1x1xf32>
%213 = arith.divui %29, %c64_i32 : i32
%214 = arith.index_castui %workgroup_id_x : index to i32
%215 = arith.muli %214, %c4_i32 overflow<nsw> : i32
%216 = arith.addi %213, %215 : i32
%217 = arith.index_castui %216 : i32 to index
cf.br ^bb3(%c0_i32, %cst_29 : i32, vector<1xf8E4M3FNUZ>)
^bb3(%218: i32, %219: vector<1xf8E4M3FNUZ>): // 2 preds: ^bb2, ^bb4
%220 = arith.cmpi slt, %218, %c1_i32 : i32
cf.cond_br %220, ^bb4, ^bb5
^bb4: // pred: ^bb3
%221 = memref.load %28[%workgroup_id_z, %workgroup_id_y, %217, %110, %112] : memref<8x4x128x?x32xf8E4M3FNUZ, strided<[?, ?, ?, 32, 1], offset: ?>, #gpu.address_space<global>>
%222 = vector.insertelement %221, %219[%c0 : index] : vector<1xf8E4M3FNUZ>
%223 = arith.addi %218, %c1_i32 : i32
cf.br ^bb3(%223, %222 : i32, vector<1xf8E4M3FNUZ>)
^bb5: // pred: ^bb3
%224 = vector.insert_strided_slice %219, %cst_31 {offsets = [0, 0, 0], strides = [1]} : vector<1xf8E4M3FNUZ> into vector<1x1x1xf8E4M3FNUZ>
%225 = arith.extf %210 : vector<f8E4M3FNUZ> to vector<f32>
%226 = vector.extract %225[] : f32 from vector<f32>
%227 = vector.splat %226 : vector<1x1x1xf32>
%228 = arith.extf %224 : vector<1x1x1xf8E4M3FNUZ> to vector<1x1x1xf32>
%229 = arith.mulf %227, %228 : vector<1x1x1xf32>
%230 = arith.addf %229, %212 : vector<1x1x1xf32>
%231 = arith.addi %104, %c1_i32 : i32
cf.br ^bb1(%231, %149, %207, %230 : i32, vector<f32>, vector<f32>, vector<1x1x1xf32>)
^bb6: // pred: ^bb1
%232 = vector.extract %106[] : f32 from vector<f32>
%233 = vector.splat %232 : vector<1x1x1xf32>
%234 = arith.divf %cst_18, %233 : vector<1x1x1xf32>
%235 = arith.mulf %234, %107 : vector<1x1x1xf32>
%236 = vector.load %22[] : memref<f32, strided<[], offset: ?>, #gpu.address_space<global>>, vector<f32>
%237 = vector.extract %236[] : f32 from vector<f32>
%238 = vector.splat %237 : vector<1x1x1xf32>
%239 = arith.divf %235, %238 : vector<1x1x1xf32>
%240 = arith.cmpf ult, %239, %cst_17 : vector<1x1x1xf32>
%241 = arith.select %240, %cst_17, %239 : vector<1x1x1xi1>, vector<1x1x1xf32>
%242 = arith.cmpf ugt, %241, %cst_16 : vector<1x1x1xf32>
%243 = arith.select %242, %cst_16, %241 : vector<1x1x1xi1>, vector<1x1x1xf32>
%244 = arith.truncf %243 : vector<1x1x1xf32> to vector<1x1x1xf8E4M3FNUZ>
%245 = arith.divui %29, %c64_i32 : i32
%246 = vector.extract %244[0, 0] : vector<1xf8E4M3FNUZ> from vector<1x1x1xf8E4M3FNUZ>
%247 = arith.index_castui %workgroup_id_x : index to i32
%248 = arith.muli %247, %c4_i32 overflow<nsw> : i32
%249 = arith.addi %245, %248 : i32
%250 = arith.index_castui %249 : i32 to index
vector.store %246, %23[%workgroup_id_z, %workgroup_id_y, %c0, %250] : memref<8x4x1x128xf8E4M3FNUZ, strided<[512, 128, 128, 1], offset: ?>, #gpu.address_space<global>>, vector<1xf8E4M3FNUZ>
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment