-
-
Save bertmaher/e2c899fc6ff91a79e93583eb9f8c059d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// -----// IR Dump Before TritonGPUOptimizeDotOperands (tritongpu-optimize-dot-operands) ('builtin.module' operation) //----- // | |
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> | |
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> | |
#loc = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0) | |
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> | |
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> | |
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> | |
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { | |
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":39:0)) attributes {noinline = false} { | |
%c2_i32 = arith.constant 2 : i32 loc(#loc1) | |
%cst = arith.constant dense<16> : tensor<128x16xi32, #blocked> loc(#loc1) | |
%c16_i32 = arith.constant 16 : i32 loc(#loc1) | |
%c256_i32 = arith.constant 256 : i32 loc(#loc1) | |
%c128_i32 = arith.constant 128 : i32 loc(#loc1) | |
%c8_i32 = arith.constant 8 : i32 loc(#loc1) | |
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #blocked> loc(#loc1) | |
%cst_1 = arith.constant dense<0.000000e+00> : tensor<16x256xf32, #blocked1> loc(#loc1) | |
%c0_i32 = arith.constant 0 : i32 loc(#loc1) | |
%c1_i32 = arith.constant 1 : i32 loc(#loc1) | |
%c127_i32 = arith.constant 127 : i32 loc(#loc1) | |
%c255_i32 = arith.constant 255 : i32 loc(#loc1) | |
%c15_i32 = arith.constant 15 : i32 loc(#loc1) | |
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1) | |
%0 = tt.get_program_id x : i32 loc(#loc2) | |
%1 = arith.addi %arg3, %c127_i32 : i32 loc(#loc56) | |
%2 = arith.divsi %1, %c128_i32 : i32 loc(#loc57) | |
%3 = arith.addi %arg4, %c255_i32 : i32 loc(#loc58) | |
%4 = arith.divsi %3, %c256_i32 : i32 loc(#loc59) | |
%5 = arith.muli %4, %c8_i32 : i32 loc(#loc7) | |
%6 = arith.divsi %0, %5 : i32 loc(#loc8) | |
%7 = arith.muli %6, %c8_i32 : i32 loc(#loc9) | |
%8 = arith.subi %2, %7 : i32 loc(#loc10) | |
%9 = arith.minsi %8, %c8_i32 : i32 loc(#loc11) | |
%10 = arith.remsi %0, %9 : i32 loc(#loc12) | |
%11 = arith.addi %7, %10 : i32 loc(#loc13) | |
%12 = arith.remsi %0, %5 : i32 loc(#loc14) | |
%13 = arith.divsi %12, %9 : i32 loc(#loc15) | |
%14 = arith.muli %11, %c128_i32 : i32 loc(#loc16) | |
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc17) | |
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc17) | |
%17 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc18) | |
%18 = tt.splat %14 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18) | |
%19 = arith.addi %17, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc18) | |
%20 = arith.addi %18, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18) | |
%21 = tt.splat %arg3 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc19) | |
%22 = arith.remsi %19, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> loc(#loc19) | |
%23 = arith.muli %13, %c256_i32 : i32 loc(#loc20) | |
%24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc21) | |
%25 = tt.splat %23 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22) | |
%26 = arith.addi %25, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22) | |
%27 = tt.splat %arg4 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc23) | |
%28 = arith.remsi %26, %27 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc23) | |
%29 = tt.expand_dims %22 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc24) | |
%30 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked> loc(#loc25) | |
%31 = arith.muli %29, %30 : tensor<128x1xi32, #blocked> loc(#loc25) | |
%32 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> loc(#loc26) | |
%33 = tt.expand_dims %32 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> loc(#loc26) | |
%34 = tt.broadcast %31 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> loc(#loc27) | |
%35 = tt.broadcast %33 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> loc(#loc27) | |
%36 = arith.addi %34, %35 : tensor<128x16xi32, #blocked> loc(#loc27) | |
%37 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x16x!tt.ptr<f32>, #blocked> loc(#loc28) | |
%38 = tt.addptr %37, %36 : tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<128x16xi32, #blocked> loc(#loc28) | |
%39 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc29) | |
%40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc29) | |
%41 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc30) | |
%42 = arith.muli %40, %41 : tensor<16x1xi32, #blocked1> loc(#loc30) | |
%43 = tt.expand_dims %28 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> loc(#loc31) | |
%44 = tt.broadcast %42 : tensor<16x1xi32, #blocked1> -> tensor<16x256xi32, #blocked1> loc(#loc32) | |
%45 = tt.broadcast %43 : tensor<1x256xi32, #blocked1> -> tensor<16x256xi32, #blocked1> loc(#loc32) | |
%46 = arith.addi %44, %45 : tensor<16x256xi32, #blocked1> loc(#loc32) | |
%47 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<16x256x!tt.ptr<f32>, #blocked1> loc(#loc33) | |
%48 = tt.addptr %47, %46 : tensor<16x256x!tt.ptr<f32>, #blocked1>, tensor<16x256xi32, #blocked1> loc(#loc33) | |
%49 = arith.addi %arg5, %c15_i32 : i32 loc(#loc60) | |
%50 = arith.divsi %49, %c16_i32 : i32 loc(#loc61) | |
%51 = arith.muli %arg7, %c16_i32 : i32 loc(#loc35) | |
%52 = tt.splat %51 : i32 -> tensor<16x256xi32, #blocked1> loc(#loc36) | |
%53 = triton_gpu.local_alloc : () -> !tt.memdesc<2x128x16xf32, #shared, mutable> loc(#loc37) | |
%54 = triton_gpu.local_alloc : () -> !tt.memdesc<2x16x256xf32, #shared1, mutable> loc(#loc38) | |
%55 = arith.cmpi sgt, %50, %c0_i32 : i32 loc(#loc39) | |
%56 = tt.splat %arg5 : i32 -> tensor<1x16xi32, #blocked> loc(#loc40) | |
%57 = arith.cmpi slt, %33, %56 : tensor<1x16xi32, #blocked> loc(#loc40) | |
%58 = tt.broadcast %57 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked> loc(#loc37) | |
%59 = triton_gpu.memdesc_subview %53[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, mutable> -> !tt.memdesc<128x16xf32, #shared, mutable> loc(#loc37) | |
%60 = tt.splat %55 : i1 -> tensor<128x16xi1, #blocked> loc(#loc39) | |
%61 = arith.andi %60, %58 : tensor<128x16xi1, #blocked> loc(#loc39) | |
%62 = triton_gpu.async_copy_global_to_local %38, %59 mask %61 other %cst_0 : tensor<128x16x!tt.ptr<f32>, #blocked> -> <128x16xf32, #shared, mutable> loc(#loc37) | |
%63 = triton_gpu.async_commit_group %62 loc(#loc37) | |
%64 = tt.splat %arg5 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41) | |
%65 = arith.cmpi slt, %40, %64 : tensor<16x1xi32, #blocked1> loc(#loc41) | |
%66 = tt.broadcast %65 : tensor<16x1xi1, #blocked1> -> tensor<16x256xi1, #blocked1> loc(#loc38) | |
%67 = triton_gpu.memdesc_subview %54[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, mutable> -> !tt.memdesc<16x256xf32, #shared1, mutable> loc(#loc38) | |
%68 = tt.splat %55 : i1 -> tensor<16x256xi1, #blocked1> loc(#loc39) | |
%69 = arith.andi %68, %66 : tensor<16x256xi1, #blocked1> loc(#loc39) | |
%70 = triton_gpu.async_copy_global_to_local %48, %67 mask %69 other %cst_1 : tensor<16x256x!tt.ptr<f32>, #blocked1> -> <16x256xf32, #shared1, mutable> loc(#loc38) | |
%71 = triton_gpu.async_commit_group %70 loc(#loc38) | |
%72 = arith.cmpi sgt, %50, %c1_i32 : i32 loc(#loc39) | |
%73 = tt.addptr %38, %cst : tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<128x16xi32, #blocked> loc(#loc42) | |
%74 = tt.addptr %48, %52 : tensor<16x256x!tt.ptr<f32>, #blocked1>, tensor<16x256xi32, #blocked1> loc(#loc36) | |
%75 = arith.subi %arg5, %c16_i32 : i32 loc(#loc43) | |
%76 = tt.splat %75 : i32 -> tensor<1x16xi32, #blocked> loc(#loc40) | |
%77 = arith.cmpi slt, %33, %76 : tensor<1x16xi32, #blocked> loc(#loc40) | |
%78 = tt.broadcast %77 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked> loc(#loc37) | |
%79 = triton_gpu.memdesc_subview %53[%c1_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, mutable> -> !tt.memdesc<128x16xf32, #shared, mutable> loc(#loc37) | |
%80 = tt.splat %72 : i1 -> tensor<128x16xi1, #blocked> loc(#loc39) | |
%81 = arith.andi %80, %78 : tensor<128x16xi1, #blocked> loc(#loc39) | |
%82 = triton_gpu.async_copy_global_to_local %73, %79 mask %81 other %cst_0 : tensor<128x16x!tt.ptr<f32>, #blocked> -> <128x16xf32, #shared, mutable> loc(#loc37) | |
%83 = triton_gpu.async_commit_group %82 loc(#loc37) | |
%84 = tt.splat %75 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41) | |
%85 = arith.cmpi slt, %40, %84 : tensor<16x1xi32, #blocked1> loc(#loc41) | |
%86 = tt.broadcast %85 : tensor<16x1xi1, #blocked1> -> tensor<16x256xi1, #blocked1> loc(#loc38) | |
%87 = triton_gpu.memdesc_subview %54[%c1_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, mutable> -> !tt.memdesc<16x256xf32, #shared1, mutable> loc(#loc38) | |
%88 = tt.splat %72 : i1 -> tensor<16x256xi1, #blocked1> loc(#loc39) | |
%89 = arith.andi %88, %86 : tensor<16x256xi1, #blocked1> loc(#loc39) | |
%90 = triton_gpu.async_copy_global_to_local %74, %87 mask %89 other %cst_1 : tensor<16x256x!tt.ptr<f32>, #blocked1> -> <16x256xf32, #shared1, mutable> loc(#loc38) | |
%91 = triton_gpu.async_commit_group %90 loc(#loc38) | |
%92 = triton_gpu.memdesc_subview %53[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, mutable> -> !tt.memdesc<128x16xf32, #shared, mutable> loc(#loc37) | |
%93 = triton_gpu.async_wait %71 {num = 2 : i32} loc(#loc37) | |
%94 = triton_gpu.memdesc_subview %54[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, mutable> -> !tt.memdesc<16x256xf32, #shared1, mutable> loc(#loc38) | |
%c0_i32_3 = arith.constant 0 : i32 loc(#loc37) | |
%c0_i32_4 = arith.constant 0 : i32 loc(#loc37) | |
%95 = triton_gpu.memdesc_subview %92[%c0_i32_3, %c0_i32_4] : !tt.memdesc<128x16xf32, #shared, mutable> -> !tt.memdesc<128x8xf32, #shared> loc(#loc37) | |
%96 = triton_gpu.local_load %95 : !tt.memdesc<128x8xf32, #shared> -> tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37) | |
%c0_i32_5 = arith.constant 0 : i32 loc(#loc38) | |
%c0_i32_6 = arith.constant 0 : i32 loc(#loc38) | |
%97 = triton_gpu.memdesc_subview %94[%c0_i32_5, %c0_i32_6] : !tt.memdesc<16x256xf32, #shared1, mutable> -> !tt.memdesc<8x256xf32, #shared1> loc(#loc38) | |
%98 = triton_gpu.local_load %97 : !tt.memdesc<8x256xf32, #shared1> -> tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38) | |
%99:13 = scf.for %arg9 = %c0_i32 to %50 step %c1_i32 iter_args(%arg10 = %cst_2, %arg11 = %73, %arg12 = %74, %arg13 = %c1_i32, %arg14 = %c0_i32, %arg15 = %92, %arg16 = %93, %arg17 = %94, %arg18 = %93, %arg19 = %83, %arg20 = %91, %arg21 = %96, %arg22 = %98) -> (tensor<128x256xf32, #mma>, tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<16x256x!tt.ptr<f32>, #blocked1>, i32, i32, !tt.memdesc<128x16xf32, #shared, mutable>, !triton_gpu.async.token, !tt.memdesc<16x256xf32, #shared1, mutable>, !triton_gpu.async.token, !triton_gpu.async.token, !triton_gpu.async.token, tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>) : i32 { | |
%118 = arith.subi %50, %c2_i32 : i32 loc(#loc39) | |
%119 = arith.cmpi slt, %arg9, %118 : i32 loc(#loc39) | |
%120 = triton_gpu.local_load %arg15 : !tt.memdesc<128x16xf32, #shared, mutable> -> tensor<128x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37) | |
%121 = triton_gpu.local_load %arg17 : !tt.memdesc<16x256xf32, #shared1, mutable> -> tensor<16x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38) | |
%122 = tt.dot %120, %121, %arg10, inputPrecision = tf32 : tensor<128x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<128x256xf32, #mma> loc(#loc44) | |
%c0_i32_7 = arith.constant 0 : i32 loc(#loc37) | |
%c8_i32_8 = arith.constant 8 : i32 loc(#loc37) | |
%123 = triton_gpu.memdesc_subview %arg15[%c0_i32_7, %c8_i32_8] : !tt.memdesc<128x16xf32, #shared, mutable> -> !tt.memdesc<128x8xf32, #shared> loc(#loc37) | |
%124 = triton_gpu.local_load %123 : !tt.memdesc<128x8xf32, #shared> -> tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37) | |
%c8_i32_9 = arith.constant 8 : i32 loc(#loc38) | |
%c0_i32_10 = arith.constant 0 : i32 loc(#loc38) | |
%125 = triton_gpu.memdesc_subview %arg17[%c8_i32_9, %c0_i32_10] : !tt.memdesc<16x256xf32, #shared1, mutable> -> !tt.memdesc<8x256xf32, #shared1> loc(#loc38) | |
%126 = triton_gpu.local_load %125 : !tt.memdesc<8x256xf32, #shared1> -> tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38) | |
%127 = tt.dot %arg21, %arg22, %arg10, inputPrecision = tf32 : tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<128x256xf32, #mma> loc(#loc44) | |
%128 = tt.dot %124, %126, %127, inputPrecision = tf32 : tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<128x256xf32, #mma> loc(#loc44) | |
%129 = tt.addptr %arg11, %cst : tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<128x16xi32, #blocked> loc(#loc42) | |
%130 = tt.addptr %arg12, %52 : tensor<16x256x!tt.ptr<f32>, #blocked1>, tensor<16x256xi32, #blocked1> loc(#loc36) | |
%131 = arith.addi %arg13, %c1_i32 : i32 loc(#loc39) | |
%132 = arith.cmpi slt, %131, %c2_i32 : i32 loc(#loc39) | |
%133 = arith.select %132, %131, %c0_i32 : i32 loc(#loc39) | |
%134 = arith.addi %arg9, %c2_i32 : i32 loc(#loc39) | |
%135 = arith.muli %134, %c16_i32 : i32 loc(#loc45) | |
%136 = arith.subi %arg5, %135 : i32 loc(#loc43) | |
%137 = tt.splat %136 : i32 -> tensor<1x16xi32, #blocked> loc(#loc40) | |
%138 = arith.cmpi slt, %33, %137 : tensor<1x16xi32, #blocked> loc(#loc40) | |
%139 = tt.broadcast %138 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked> loc(#loc37) | |
%140 = triton_gpu.memdesc_subview %53[%133, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, mutable> -> !tt.memdesc<128x16xf32, #shared, mutable> loc(#loc37) | |
%141 = tt.splat %119 : i1 -> tensor<128x16xi1, #blocked> loc(#loc39) | |
%142 = arith.andi %141, %139 : tensor<128x16xi1, #blocked> loc(#loc39) | |
%143 = triton_gpu.async_copy_global_to_local %129, %140 mask %142 other %cst_0 : tensor<128x16x!tt.ptr<f32>, #blocked> -> <128x16xf32, #shared, mutable> loc(#loc37) | |
%144 = triton_gpu.async_commit_group %143 loc(#loc37) | |
%145 = tt.splat %136 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41) | |
%146 = arith.cmpi slt, %40, %145 : tensor<16x1xi32, #blocked1> loc(#loc41) | |
%147 = tt.broadcast %146 : tensor<16x1xi1, #blocked1> -> tensor<16x256xi1, #blocked1> loc(#loc38) | |
%148 = triton_gpu.memdesc_subview %54[%133, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, mutable> -> !tt.memdesc<16x256xf32, #shared1, mutable> loc(#loc38) | |
%149 = tt.splat %119 : i1 -> tensor<16x256xi1, #blocked1> loc(#loc39) | |
%150 = arith.andi %149, %147 : tensor<16x256xi1, #blocked1> loc(#loc39) | |
%151 = triton_gpu.async_copy_global_to_local %130, %148 mask %150 other %cst_1 : tensor<16x256x!tt.ptr<f32>, #blocked1> -> <16x256xf32, #shared1, mutable> loc(#loc38) | |
%152 = triton_gpu.async_commit_group %151 loc(#loc38) | |
%153 = arith.addi %arg14, %c1_i32 : i32 loc(#loc39) | |
%154 = arith.cmpi slt, %153, %c2_i32 : i32 loc(#loc39) | |
%155 = arith.select %154, %153, %c0_i32 : i32 loc(#loc39) | |
%156 = triton_gpu.memdesc_subview %53[%155, %c0_i32, %c0_i32] : !tt.memdesc<2x128x16xf32, #shared, mutable> -> !tt.memdesc<128x16xf32, #shared, mutable> loc(#loc37) | |
%157 = triton_gpu.async_wait %arg20 {num = 2 : i32} loc(#loc37) | |
%158 = triton_gpu.memdesc_subview %54[%155, %c0_i32, %c0_i32] : !tt.memdesc<2x16x256xf32, #shared1, mutable> -> !tt.memdesc<16x256xf32, #shared1, mutable> loc(#loc38) | |
%c0_i32_11 = arith.constant 0 : i32 loc(#loc37) | |
%c0_i32_12 = arith.constant 0 : i32 loc(#loc37) | |
%159 = triton_gpu.memdesc_subview %156[%c0_i32_11, %c0_i32_12] : !tt.memdesc<128x16xf32, #shared, mutable> -> !tt.memdesc<128x8xf32, #shared> loc(#loc37) | |
%160 = triton_gpu.local_load %159 : !tt.memdesc<128x8xf32, #shared> -> tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> loc(#loc37) | |
%c0_i32_13 = arith.constant 0 : i32 loc(#loc38) | |
%c0_i32_14 = arith.constant 0 : i32 loc(#loc38) | |
%161 = triton_gpu.memdesc_subview %158[%c0_i32_13, %c0_i32_14] : !tt.memdesc<16x256xf32, #shared1, mutable> -> !tt.memdesc<8x256xf32, #shared1> loc(#loc38) | |
%162 = triton_gpu.local_load %161 : !tt.memdesc<8x256xf32, #shared1> -> tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc38) | |
scf.yield %128, %129, %130, %133, %155, %156, %157, %158, %157, %144, %152, %160, %162 : tensor<128x256xf32, #mma>, tensor<128x16x!tt.ptr<f32>, #blocked>, tensor<16x256x!tt.ptr<f32>, #blocked1>, i32, i32, !tt.memdesc<128x16xf32, #shared, mutable>, !triton_gpu.async.token, !tt.memdesc<16x256xf32, #shared1, mutable>, !triton_gpu.async.token, !triton_gpu.async.token, !triton_gpu.async.token, tensor<128x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>, tensor<8x256xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> loc(#loc39) | |
} loc(#loc39) | |
%100 = triton_gpu.async_wait {num = 0 : i32} loc(#loc39) | |
triton_gpu.local_dealloc %53 : !tt.memdesc<2x128x16xf32, #shared, mutable> loc(#loc39) | |
triton_gpu.local_dealloc %54 : !tt.memdesc<2x16x256xf32, #shared1, mutable> loc(#loc39) | |
%101 = tt.expand_dims %20 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc46) | |
%102 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc47) | |
%103 = arith.muli %102, %101 : tensor<128x1xi32, #blocked1> loc(#loc47) | |
%104 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<128x1x!tt.ptr<f32>, #blocked1> loc(#loc48) | |
%105 = tt.addptr %104, %103 : tensor<128x1x!tt.ptr<f32>, #blocked1>, tensor<128x1xi32, #blocked1> loc(#loc48) | |
%106 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> loc(#loc49) | |
%107 = tt.broadcast %105 : tensor<128x1x!tt.ptr<f32>, #blocked1> -> tensor<128x256x!tt.ptr<f32>, #blocked1> loc(#loc50) | |
%108 = tt.broadcast %106 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> loc(#loc50) | |
%109 = tt.addptr %107, %108 : tensor<128x256x!tt.ptr<f32>, #blocked1>, tensor<128x256xi32, #blocked1> loc(#loc50) | |
%110 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc51) | |
%111 = arith.cmpi slt, %101, %110 : tensor<128x1xi32, #blocked1> loc(#loc51) | |
%112 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked1> loc(#loc52) | |
%113 = arith.cmpi slt, %106, %112 : tensor<1x256xi32, #blocked1> loc(#loc52) | |
%114 = tt.broadcast %111 : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> loc(#loc53) | |
%115 = tt.broadcast %113 : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> loc(#loc53) | |
%116 = arith.andi %114, %115 : tensor<128x256xi1, #blocked1> loc(#loc53) | |
%117 = triton_gpu.convert_layout %99#0 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked1> loc(#loc54) | |
tt.store %109, %117, %116 : tensor<128x256x!tt.ptr<f32>, #blocked1> loc(#loc54) | |
tt.return loc(#loc55) | |
} loc(#loc) | |
} loc(#loc) | |
#loc1 = loc(unknown) | |
#loc2 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":62:24) | |
#loc3 = loc("/data/users/bertrand/triton/python/triton/language/standard.py":44:22) | |
#loc4 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":63:27) | |
#loc5 = loc("/data/users/bertrand/triton/python/triton/language/standard.py":44:28) | |
#loc6 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":64:27) | |
#loc7 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":65:38) | |
#loc8 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":66:22) | |
#loc9 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":67:29) | |
#loc10 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":68:35) | |
#loc11 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":68:48) | |
#loc12 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":69:33) | |
#loc13 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":69:27) | |
#loc14 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":70:19) | |
#loc15 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":70:40) | |
#loc16 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":79:23) | |
#loc17 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":79:51) | |
#loc18 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":79:38) | |
#loc19 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":79:68) | |
#loc20 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":80:23) | |
#loc21 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":80:51) | |
#loc22 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":80:38) | |
#loc23 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":80:68) | |
#loc24 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":82:30) | |
#loc25 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":82:41) | |
#loc26 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":82:60) | |
#loc27 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":82:53) | |
#loc28 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":82:22) | |
#loc29 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":83:29) | |
#loc30 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":83:40) | |
#loc31 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":83:60) | |
#loc32 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":83:52) | |
#loc33 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":83:22) | |
#loc34 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":91:33) | |
#loc35 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":101:33) | |
#loc36 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":101:18) | |
#loc37 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":94:20) | |
#loc38 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":95:20) | |
#loc39 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":91:22) | |
#loc40 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":94:51) | |
#loc41 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":95:51) | |
#loc42 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":100:18) | |
#loc43 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":94:55) | |
#loc44 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":97:35) | |
#loc45 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":94:59) | |
#loc46 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":112:41) | |
#loc47 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":112:33) | |
#loc48 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":112:21) | |
#loc49 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":112:72) | |
#loc50 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":112:52) | |
#loc51 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":113:33) | |
#loc52 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":113:58) | |
#loc53 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":113:39) | |
#loc54 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":114:21) | |
#loc55 = loc("/data/users/bertrand/sandbox/tf32_gemm/triton_kernel.py":114:4) | |
#loc56 = loc(callsite(#loc3 at #loc4)) | |
#loc57 = loc(callsite(#loc5 at #loc4)) | |
#loc58 = loc(callsite(#loc3 at #loc6)) | |
#loc59 = loc(callsite(#loc5 at #loc6)) | |
#loc60 = loc(callsite(#loc3 at #loc34)) | |
#loc61 = loc(callsite(#loc5 at #loc34)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment