Skip to content

Instantly share code, notes, and snippets.

@bertmaher
Created May 23, 2024 19:50
Show Gist options
  • Save bertmaher/8d6083176e7901819299ce0749435695 to your computer and use it in GitHub Desktop.
Save bertmaher/8d6083176e7901819299ce0749435695 to your computer and use it in GitHub Desktop.
// -----// IR Dump Before TritonGPUOptimizeDotOperands (tritongpu-optimize-dot-operands) ('builtin.module' operation) //----- //
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
#loc = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0)
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0)) attributes {noinline = false} {
%c2_i32 = arith.constant 2 : i32 loc(#loc1)
%cst = arith.constant dense<16> : tensor<128x16xi32, #blocked> loc(#loc1)
%c16_i32 = arith.constant 16 : i32 loc(#loc1)
%c256_i32 = arith.constant 256 : i32 loc(#loc1)
%c128_i32 = arith.constant 128 : i32 loc(#loc1)
%c8_i32 = arith.constant 8 : i32 loc(#loc1)
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #blocked> loc(#loc1)
%cst_1 = arith.constant dense<0.000000e+00> : tensor<16x256xf32, #blocked1> loc(#loc1)
%c0_i32 = arith.constant 0 : i32 loc(#loc1)
%c1_i32 = arith.constant 1 : i32 loc(#loc1)
%c127_i32 = arith.constant 127 : i32 loc(#loc1)
%c255_i32 = arith.constant 255 : i32 loc(#loc1)
%c15_i32 = arith.constant 15 : i32 loc(#loc1)
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1)
%0 = tt.get_program_id x : i32 loc(#loc2)
%1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc56)
%2 = arith.divsi %1, %c128_i32 : i32 loc(#loc57)
%3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc58)
%4 = arith.divsi %3, %c256_i32 : i32 loc(#loc59)
%5 = arith.muli %4, %c8_i32 : i32 loc(#loc7)
%6 = arith.divsi %0, %5 : i32 loc(#loc8)
%7 = arith.muli %6, %c8_i32 : i32 loc(#loc9)
%8 = arith.subi %2, %7 : i32 loc(#loc10)
%9 = arith.minsi %8, %c8_i32 : i32 loc(#loc11)
%10 = arith.remsi %0, %9 : i32 loc(#loc12)
%11 = arith.addi %7, %10 : i32 loc(#loc13)
%12 = arith.remsi %0, %5 : i32 loc(#loc14)
%13 = arith.divsi %12, %9 : i32 loc(#loc15)
%14 = arith.muli %11, %c128_i32 : i32 loc(#loc16)
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc17)
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc17)
%17 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc18)
%18 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18)
%19 = arith.addi %17, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc18)
%20 = arith.addi %18, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18)
%21 = tt.splat %arg3 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc19)
%22 = arith.remsi %19, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc19)
%23 = arith.muli %13, %c256_i32 : i32 loc(#loc20)
%24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc21)
%25 = tt.splat %23 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22)
%26 = arith.addi %25, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22)
%27 = tt.splat %arg4 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc23)
%28 = arith.remsi %26, %27 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc23)
%29 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc24)
%30 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked> loc(#loc25)
%31 = arith.muli %29, %30 : tensor<128x1xi32, #blocked> loc(#loc25)
%32 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> loc(#loc26)
%33 = tt.expand_dims %32 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> loc(#loc26)
%34 = tt.broadcast %31 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> loc(#loc27)
%35 = tt.broadcast %33 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> loc(#loc27)
%36 = arith.addi %34, %35 : tensor<128x16xi32, #blocked> loc(#loc27)
%37 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x16x!tt.ptr<f32>, #blocked> loc(#loc28)
%38 = tt.addptr %37, %36 : tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<128x16xi32, #blocked> loc(#loc28)
%39 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc29)
%40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc29)
%41 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc30)
%42 = arith.muli %40, %41 : tensor<16x1xi32, #blocked1> loc(#loc30)
%43 = tt.expand_dims %28 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> loc(#loc31)
%44 = tt.broadcast %42 : tensor<16x1xi32, #blocked1> -> tensor<16x256xi32, #blocked1> loc(#loc32)
%45 = tt.broadcast %43 : tensor<1x256xi32, #blocked1> -> tensor<16x256xi32, #blocked1> loc(#loc32)
%46 = arith.addi %44, %45 : tensor<16x256xi32, #blocked1> loc(#loc32)
%47 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<16x256x!tt.ptr<f32>, #blocked1> loc(#loc33)
%48 = tt.addptr %47, %46 : tensor<16x256x!tt.ptr<f32>, #blocked1>, tensor<16x256xi32, #blocked1> loc(#loc33)
%49 = arith.addi %arg5, %c15_i32 : i32 loc(#loc60)
%50 = arith.divsi %49, %c16_i32 : i32 loc(#loc61)
%51 = arith.muli %arg7, %c16_i32 : i32 loc(#loc35)
%52 = tt.splat %51 : i32 -> tensor<16x256xi32, #blocked1> loc(#loc36)
%53 = triton_gpu.local_alloc : () -> !tt.memdesc<2x128x16xf32, #shared, mutable> loc(#loc37)
%54 = triton_gpu.local_alloc : () -> !tt.memdesc<2x16x256xf32, #shared1, mutable> loc(#loc38)
%55 = arith.cmpi sgt, %50, %c0_i32 : i32 loc(#loc39)
%56 = tt.splat %arg5 : i32 -> tensor<1x16xi32, #blocked> loc(#loc40)
%57 = arith.cmpi slt, %33, %56 : tensor<1x16xi32, #blocked> loc(#loc40)
%58 = tt.broadcast %57 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked> loc(#loc37)
%59 = triton_gpu.memdesc_subview %53[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, mutable> -> !tt.memdesc<128x16xf32, #shared, mutable> loc(#loc37)
%60 = tt.splat %55 : i1 -> tensor<128x16xi1, #blocked> loc(#loc39)
%61 = arith.andi %60, %58 : tensor<128x16xi1, #blocked> loc(#loc39)
%62 = triton_gpu.async_copy_global_to_local %38, %59 mask %61 other %cst_0 : tensor<128x16x!tt.ptr<f32>, #blocked> -> <128x16xf32, #shared, mutable> loc(#loc37)
%63 = triton_gpu.async_commit_group %62 loc(#loc37)
%64 = tt.splat %arg5 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41)
%65 = arith.cmpi slt, %40, %64 : tensor<16x1xi32, #blocked1> loc(#loc41)
%66 = tt.broadcast %65 : tensor<16x1xi1, #blocked1> -> tensor<16x256xi1, #blocked1> loc(#loc38)
%67 = triton_gpu.memdesc_subview %54[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, mutable> -> !tt.memdesc<16x256xf32, #shared1, mutable> loc(#loc38)
%68 = tt.splat %55 : i1 -> tensor<16x256xi1, #blocked1> loc(#loc39)
%69 = arith.andi %68, %66 : tensor<16x256xi1, #blocked1> loc(#loc39)
%70 = triton_gpu.async_copy_global_to_local %48, %67 mask %69 other %cst_1 : tensor<16x256x!tt.ptr<f32>, #blocked1> -> <16x256xf32, #shared1, mutable> loc(#loc38)
%71 = triton_gpu.async_commit_group %70 loc(#loc38)
%72 = arith.cmpi sgt, %50, %c1_i32 : i32 loc(#loc39)
%73 = tt.addptr %38, %cst : tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<128x16xi32, #blocked> loc(#loc42)
%74 = tt.addptr %48, %52 : tensor<16x256x!tt.ptr<f32>, #blocked1>, tensor<16x256xi32, #blocked1> loc(#loc36)
%75 = arith.subi %arg5, %c16_i32 : i32 loc(#loc43)
%76 = tt.splat %75 : i32 -> tensor<1x16xi32, #blocked> loc(#loc40)
%77 = arith.cmpi slt, %33, %76 : tensor<1x16xi32, #blocked> loc(#loc40)
%78 = tt.broadcast %77 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked> loc(#loc37)
%79 = triton_gpu.memdesc_subview %53[%c1_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, mutable> -> !tt.memdesc<128x16xf32, #shared, mutable> loc(#loc37)
%80 = tt.splat %72 : i1 -> tensor<128x16xi1, #blocked> loc(#loc39)
%81 = arith.andi %80, %78 : tensor<128x16xi1, #blocked> loc(#loc39)
%82 = triton_gpu.async_copy_global_to_local %73, %79 mask %81 other %cst_0 : tensor<128x16x!tt.ptr<f32>, #blocked> -> <128x16xf32, #shared, mutable> loc(#loc37)
%83 = triton_gpu.async_commit_group %82 loc(#loc37)
%84 = tt.splat %75 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41)
%85 = arith.cmpi slt, %40, %84 : tensor<16x1xi32, #blocked1> loc(#loc41)
%86 = tt.broadcast %85 : tensor<16x1xi1, #blocked1> -> tensor<16x256xi1, #blocked1> loc(#loc38)
%87 = triton_gpu.memdesc_subview %54[%c1_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, mutable> -> !tt.memdesc<16x256xf32, #shared1, mutable> loc(#loc38)
%88 = tt.splat %72 : i1 -> tensor<16x256xi1, #blocked1> loc(#loc39)
%89 = arith.andi %88, %86 : tensor<16x256xi1, #blocked1> loc(#loc39)
%90 = triton_gpu.async_copy_global_to_local %74, %87 mask %89 other %cst_1 : tensor<16x256x!tt.ptr<f32>, #blocked1> -> <16x256xf32, #shared1, mutable> loc(#loc38)
%91 = triton_gpu.async_commit_group %90 loc(#loc38)
%92 = triton_gpu.memdesc_subview %53[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, mutable> -> !tt.memdesc<128x16xf32, #shared, mutable> loc(#loc37)
%93 = triton_gpu.async_wait %71 {num = 2 : i32} loc(#loc37)
%94 = triton_gpu.memdesc_subview %54[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, mutable> -> !tt.memdesc<16x256xf32, #shared1, mutable> loc(#loc38)
%c0_i32_3 = arith.constant 0 : i32 loc(#loc37)
%c0_i32_4 = arith.constant 0 : i32 loc(#loc37)
%95 = triton_gpu.memdesc_subview %92[%c0_i32_3, %c0_i32_4] : !tt.memdesc<128x16xf32, #shared, mutable> -> !tt.memdesc<128x8xf32, #shared> loc(#loc37)
%96 = triton_gpu.local_load %95 : !tt.memdesc<128x8xf32, #shared> -> tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37)
%c0_i32_5 = arith.constant 0 : i32 loc(#loc38)
%c0_i32_6 = arith.constant 0 : i32 loc(#loc38)
%97 = triton_gpu.memdesc_subview %94[%c0_i32_5, %c0_i32_6] : !tt.memdesc<16x256xf32, #shared1, mutable> -> !tt.memdesc<8x256xf32, #shared1> loc(#loc38)
%98 = triton_gpu.local_load %97 : !tt.memdesc<8x256xf32, #shared1> -> tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38)
%99:13 = scf.for %arg9 = %c0_i32 to %50 step %c1_i32 iter_args(%arg10 = %cst_2, %arg11 = %73, %arg12 = %74, %arg13 = %c1_i32, %arg14 = %c0_i32, %arg15 = %92, %arg16 = %93, %arg17 = %94, %arg18 = %93, %arg19 = %83, %arg20 = %91, %arg21 = %96, %arg22 = %98) -> (tensor<128x256xf32, #mma>, tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<16x256x!tt.ptr<f32>, #blocked1>, i32, i32, !tt.memdesc<128x16xf32, #shared, mutable>, !triton_gpu.async.token, !tt.memdesc<16x256xf32, #shared1, mutable>, !triton_gpu.async.token, !triton_gpu.async.token, !triton_gpu.async.token, tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>) : i32 {
%118 = arith.subi %50, %c2_i32 : i32 loc(#loc39)
%119 = arith.cmpi slt, %arg9, %118 : i32 loc(#loc39)
%120 = triton_gpu.local_load %arg15 : !tt.memdesc<128x16xf32, #shared, mutable> -> tensor<128x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37)
%121 = triton_gpu.local_load %arg17 : !tt.memdesc<16x256xf32, #shared1, mutable> -> tensor<16x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38)
%122 = tt.dot %120, %121, %arg10, inputPrecision = tf32 : tensor<128x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<128x256xf32, #mma> loc(#loc44)
%c0_i32_7 = arith.constant 0 : i32 loc(#loc37)
%c8_i32_8 = arith.constant 8 : i32 loc(#loc37)
%123 = triton_gpu.memdesc_subview %arg15[%c0_i32_7, %c8_i32_8] : !tt.memdesc<128x16xf32, #shared, mutable> -> !tt.memdesc<128x8xf32, #shared> loc(#loc37)
%124 = triton_gpu.local_load %123 : !tt.memdesc<128x8xf32, #shared> -> tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37)
%c8_i32_9 = arith.constant 8 : i32 loc(#loc38)
%c0_i32_10 = arith.constant 0 : i32 loc(#loc38)
%125 = triton_gpu.memdesc_subview %arg17[%c8_i32_9, %c0_i32_10] : !tt.memdesc<16x256xf32, #shared1, mutable> -> !tt.memdesc<8x256xf32, #shared1> loc(#loc38)
%126 = triton_gpu.local_load %125 : !tt.memdesc<8x256xf32, #shared1> -> tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38)
%127 = tt.dot %arg21, %arg22, %arg10, inputPrecision = tf32 : tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<128x256xf32, #mma> loc(#loc44)
%128 = tt.addptr %arg11, %cst : tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<128x16xi32, #blocked> loc(#loc42)
%129 = tt.addptr %arg12, %52 : tensor<16x256x!tt.ptr<f32>, #blocked1>, tensor<16x256xi32, #blocked1> loc(#loc36)
%130 = arith.addi %arg13, %c1_i32 : i32 loc(#loc39)
%131 = arith.cmpi slt, %130, %c2_i32 : i32 loc(#loc39)
%132 = arith.select %131, %130, %c0_i32 : i32 loc(#loc39)
%133 = arith.addi %arg9, %c2_i32 : i32 loc(#loc39)
%134 = arith.muli %133, %c16_i32 : i32 loc(#loc45)
%135 = arith.subi %arg5, %134 : i32 loc(#loc43)
%136 = tt.splat %135 : i32 -> tensor<1x16xi32, #blocked> loc(#loc40)
%137 = arith.cmpi slt, %33, %136 : tensor<1x16xi32, #blocked> loc(#loc40)
%138 = tt.broadcast %137 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked> loc(#loc37)
%139 = triton_gpu.memdesc_subview %53[%132, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, mutable> -> !tt.memdesc<128x16xf32, #shared, mutable> loc(#loc37)
%140 = tt.splat %119 : i1 -> tensor<128x16xi1, #blocked> loc(#loc39)
%141 = arith.andi %140, %138 : tensor<128x16xi1, #blocked> loc(#loc39)
%142 = triton_gpu.async_copy_global_to_local %128, %139 mask %141 other %cst_0 : tensor<128x16x!tt.ptr<f32>, #blocked> -> <128x16xf32, #shared, mutable> loc(#loc37)
%143 = triton_gpu.async_commit_group %142 loc(#loc37)
%144 = tt.splat %135 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41)
%145 = arith.cmpi slt, %40, %144 : tensor<16x1xi32, #blocked1> loc(#loc41)
%146 = tt.broadcast %145 : tensor<16x1xi1, #blocked1> -> tensor<16x256xi1, #blocked1> loc(#loc38)
%147 = triton_gpu.memdesc_subview %54[%132, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, mutable> -> !tt.memdesc<16x256xf32, #shared1, mutable> loc(#loc38)
%148 = tt.splat %119 : i1 -> tensor<16x256xi1, #blocked1> loc(#loc39)
%149 = arith.andi %148, %146 : tensor<16x256xi1, #blocked1> loc(#loc39)
%150 = triton_gpu.async_copy_global_to_local %129, %147 mask %149 other %cst_1 : tensor<16x256x!tt.ptr<f32>, #blocked1> -> <16x256xf32, #shared1, mutable> loc(#loc38)
%151 = triton_gpu.async_commit_group %150 loc(#loc38)
%152 = arith.addi %arg14, %c1_i32 : i32 loc(#loc39)
%153 = arith.cmpi slt, %152, %c2_i32 : i32 loc(#loc39)
%154 = arith.select %153, %152, %c0_i32 : i32 loc(#loc39)
%155 = triton_gpu.memdesc_subview %53[%154, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, mutable> -> !tt.memdesc<128x16xf32, #shared, mutable> loc(#loc37)
%156 = triton_gpu.async_wait %arg20 {num = 2 : i32} loc(#loc37)
%157 = triton_gpu.memdesc_subview %54[%154, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, mutable> -> !tt.memdesc<16x256xf32, #shared1, mutable> loc(#loc38)
%c0_i32_11 = arith.constant 0 : i32 loc(#loc37)
%c0_i32_12 = arith.constant 0 : i32 loc(#loc37)
%158 = triton_gpu.memdesc_subview %155[%c0_i32_11, %c0_i32_12] : !tt.memdesc<128x16xf32, #shared, mutable> -> !tt.memdesc<128x8xf32, #shared> loc(#loc37)
%159 = triton_gpu.local_load %158 : !tt.memdesc<128x8xf32, #shared> -> tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37)
%c0_i32_13 = arith.constant 0 : i32 loc(#loc38)
%c0_i32_14 = arith.constant 0 : i32 loc(#loc38)
%160 = triton_gpu.memdesc_subview %157[%c0_i32_13, %c0_i32_14] : !tt.memdesc<16x256xf32, #shared1, mutable> -> !tt.memdesc<8x256xf32, #shared1> loc(#loc38)
%161 = triton_gpu.local_load %160 : !tt.memdesc<8x256xf32, #shared1> -> tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38)
%162 = tt.dot %124, %126, %127, inputPrecision = tf32 : tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<128x256xf32, #mma> loc(#loc44)
scf.yield %162, %128, %129, %132, %154, %155, %156, %157, %156, %143, %151, %159, %161 : tensor<128x256xf32, #mma>, tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<16x256x!tt.ptr<f32>, #blocked1>, i32, i32, !tt.memdesc<128x16xf32, #shared, mutable>, !triton_gpu.async.token, !tt.memdesc<16x256xf32, #shared1, mutable>, !triton_gpu.async.token, !triton_gpu.async.token, !triton_gpu.async.token, tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc39)
} loc(#loc39)
%100 = triton_gpu.async_wait {num = 0 : i32} loc(#loc39)
triton_gpu.local_dealloc %53 : !tt.memdesc<2x128x16xf32, #shared, mutable> loc(#loc39)
triton_gpu.local_dealloc %54 : !tt.memdesc<2x16x256xf32, #shared1, mutable> loc(#loc39)
%101 = tt.expand_dims %20 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc46)
%102 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc47)
%103 = arith.muli %102, %101 : tensor<128x1xi32, #blocked1> loc(#loc47)
%104 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<128x1x!tt.ptr<f32>, #blocked1> loc(#loc48)
%105 = tt.addptr %104, %103 : tensor<128x1x!tt.ptr<f32>, #blocked1>, tensor<128x1xi32, #blocked1> loc(#loc48)
%106 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> loc(#loc49)
%107 = tt.broadcast %105 : tensor<128x1x!tt.ptr<f32>, #blocked1> -> tensor<128x256x!tt.ptr<f32>, #blocked1> loc(#loc50)
%108 = tt.broadcast %106 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> loc(#loc50)
%109 = tt.addptr %107, %108 : tensor<128x256x!tt.ptr<f32>, #blocked1>, tensor<128x256xi32, #blocked1> loc(#loc50)
%110 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc51)
%111 = arith.cmpi slt, %101, %110 : tensor<128x1xi32, #blocked1> loc(#loc51)
%112 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked1> loc(#loc52)
%113 = arith.cmpi slt, %106, %112 : tensor<1x256xi32, #blocked1> loc(#loc52)
%114 = tt.broadcast %111 : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> loc(#loc53)
%115 = tt.broadcast %113 : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> loc(#loc53)
%116 = arith.andi %114, %115 : tensor<128x256xi1, #blocked1> loc(#loc53)
%117 = triton_gpu.convert_layout %99#0 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked1> loc(#loc54)
tt.store %109, %117, %116 : tensor<128x256x!tt.ptr<f32>, #blocked1> loc(#loc54)
tt.return loc(#loc55)
} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":62:24)
#loc3 = loc("/data/users/bertrand/triton/python/triton/language/standard.py":44:22)
#loc4 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":63:27)
#loc5 = loc("/data/users/bertrand/triton/python/triton/language/standard.py":44:28)
#loc6 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":64:27)
#loc7 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":65:38)
#loc8 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":66:22)
#loc9 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":67:29)
#loc10 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":68:35)
#loc11 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":68:48)
#loc12 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":69:33)
#loc13 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":69:27)
#loc14 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":70:19)
#loc15 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":70:40)
#loc16 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":79:23)
#loc17 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":79:51)
#loc18 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":79:38)
#loc19 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":79:68)
#loc20 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":80:23)
#loc21 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":80:51)
#loc22 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":80:38)
#loc23 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":80:68)
#loc24 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":82:30)
#loc25 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":82:41)
#loc26 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":82:60)
#loc27 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":82:53)
#loc28 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":82:22)
#loc29 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":83:29)
#loc30 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":83:40)
#loc31 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":83:60)
#loc32 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":83:52)
#loc33 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":83:22)
#loc34 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":91:33)
#loc35 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":101:33)
#loc36 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":101:18)
#loc37 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":94:20)
#loc38 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":95:20)
#loc39 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":91:22)
#loc40 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":94:51)
#loc41 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":95:51)
#loc42 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":100:18)
#loc43 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":94:55)
#loc44 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":97:35)
#loc45 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":94:59)
#loc46 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":112:41)
#loc47 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":112:33)
#loc48 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":112:21)
#loc49 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":112:72)
#loc50 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":112:52)
#loc51 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":113:33)
#loc52 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":113:58)
#loc53 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":113:39)
#loc54 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":114:21)
#loc55 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":114:4)
#loc56 = loc(callsite(#loc3 at #loc4))
#loc57 = loc(callsite(#loc5 at #loc4))
#loc58 = loc(callsite(#loc3 at #loc6))
#loc59 = loc(callsite(#loc5 at #loc6))
#loc60 = loc(callsite(#loc3 at #loc34))
#loc61 = loc(callsite(#loc5 at #loc34))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment