Skip to content

Instantly share code, notes, and snippets.

@Jokeren
Created March 1, 2023 06:13
Show Gist options
  • Save Jokeren/a9707277e6623e235e3a3b55d71c1661 to your computer and use it in GitHub Desktop.
Save Jokeren/a9707277e6623e235e3a3b55d71c1661 to your computer and use it in GitHub Desktop.
fp16 mov reproducer
//
// Generated by LLVM NVPTX Back-End
//
.version 8.0
.target sm_80
.address_size 64
// .globl triton__5_0d1d23
.visible .entry triton__5_0d1d23(
.param .u64 triton__5_0d1d23_param_0,
.param .u64 triton__5_0d1d23_param_1,
.param .u32 triton__5_0d1d23_param_2,
.param .u32 triton__5_0d1d23_param_3
)
.maxntid 64, 1, 1
{
.reg .pred %p<20>;
.reg .b16 %h<9>;
.reg .b16 %rs<9>;
.reg .f32 %f<26>;
.reg .b32 %r<16>;
.reg .b64 %rd<14>;
ld.param.u32 %r7, [triton__5_0d1d23_param_3];
setp.lt.s32 %p1, %r7, 1;
@%p1 bra LBB0_3;
ld.param.u32 %r6, [triton__5_0d1d23_param_2];
ld.param.u64 %rd2, [triton__5_0d1d23_param_1];
ld.param.u64 %rd1, [triton__5_0d1d23_param_0];
mov.u32 %r8, %tid.x;
and.b32 %r1, %r8, 63;
mov.u32 %r2, %ctaid.x;
mad.lo.s32 %r3, %r2, 197, %r1;
mov.f32 %f22, 0fFF800000;
mov.u32 %r15, 0;
setp.lt.s32 %p10, %r2, %r6;
mov.f32 %f23, %f22;
mov.f32 %f24, %f22;
mov.f32 %f25, %f22;
LBB0_2:
add.s32 %r10, %r1, %r15;
add.s32 %r11, %r10, 64;
add.s32 %r12, %r10, 128;
add.s32 %r13, %r10, 192;
setp.lt.s32 %p11, %r10, %r7;
setp.lt.s32 %p12, %r11, %r7;
setp.lt.s32 %p13, %r12, %r7;
setp.lt.s32 %p14, %r13, %r7;
add.s32 %r14, %r3, %r15;
mul.wide.s32 %rd13, %r14, 2;
add.s64 %rd5, %rd1, %rd13;
and.pred %p2, %p11, %p10;
and.pred %p3, %p12, %p10;
and.pred %p4, %p13, %p10;
and.pred %p5, %p14, %p10;
@%p2 ld.global.b16 {%rs1}, [ %rd5 + 0];
mov.b16 %h1, %rs1;
@%p3 ld.global.b16 {%rs2}, [ %rd5 + 128];
mov.b16 %h2, %rs2;
@%p4 ld.global.b16 {%rs3}, [ %rd5 + 256];
mov.b16 %h3, %rs3;
@%p5 ld.global.b16 {%rs4}, [ %rd5 + 384];
mov.b16 %h4, %rs4;
cvt.f32.f16 %f10, %h1;
cvt.f32.f16 %f11, %h2;
cvt.f32.f16 %f12, %h3;
cvt.f32.f16 %f13, %h4;
mul.f32 %f14, %f10, 0f3E000000;
mul.f32 %f15, %f11, 0f3E000000;
mul.f32 %f16, %f12, 0f3E000000;
mul.f32 %f17, %f13, 0f3E000000;
setp.lt.f32 %p15, %f22, %f14;
setp.lt.f32 %p16, %f23, %f15;
setp.lt.f32 %p17, %f24, %f16;
setp.lt.f32 %p18, %f25, %f17;
selp.f32 %f18, %f14, %f22, %p15;
selp.f32 %f22, %f18, %f22, %p2;
selp.f32 %f19, %f15, %f23, %p16;
selp.f32 %f23, %f19, %f23, %p3;
selp.f32 %f20, %f16, %f24, %p17;
selp.f32 %f24, %f20, %f24, %p4;
selp.f32 %f21, %f17, %f25, %p18;
selp.f32 %f25, %f21, %f25, %p5;
add.s64 %rd9, %rd2, %rd13;
cvt.rn.f16.f32 %h5, %f22;
cvt.rn.f16.f32 %h6, %f23;
cvt.rn.f16.f32 %h7, %f24;
cvt.rn.f16.f32 %h8, %f25;
mov.b16 %rs5, %h5;
@%p2 st.global.b16 [ %rd9 + 0] , {%rs5};
mov.b16 %rs6, %h6;
@%p3 st.global.b16 [ %rd9 + 128] , {%rs6};
mov.b16 %rs7, %h7;
@%p4 st.global.b16 [ %rd9 + 256] , {%rs7};
mov.b16 %rs8, %h8;
@%p5 st.global.b16 [ %rd9 + 384] , {%rs8};
add.s32 %r15, %r15, 256;
setp.lt.s32 %p19, %r15, %r7;
@%p19 bra LBB0_2;
LBB0_3:
ret;
}

Install

git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build time dependency
pip install -e .
pip uninstall pytorch-triton -y

Expected result (-0.1250)

rm -rf ~/.triton/cache/*
python test.py correct.ptx

Wrong result (-inf)

rm -rf ~/.triton/cache/*
python test.py wrong.ptx
import triton
import triton.language as tl
import torch
import sys
version = sys.argv[1]
torch.manual_seed(123)
s0 = 4
buf11 = -torch.ones((6*s0, 197, 197), device='cuda', dtype=torch.float16)
buf14 = -torch.ones((s0, 6, 197, 197), device='cuda', dtype=torch.float16)
triton__5_xnumel = 1182*s0
kernel = triton.compile(version, shared=0, num_warps=2)
kernel[(4728,1,1,)](buf11.data_ptr(), buf14.data_ptr(), triton__5_xnumel, 197)
print(buf14.mean())
//
// Generated by LLVM NVPTX Back-End
//
.version 8.0
.target sm_80
.address_size 64
// .globl triton__5_0d1d23
.visible .entry triton__5_0d1d23(
.param .u64 triton__5_0d1d23_param_0,
.param .u64 triton__5_0d1d23_param_1,
.param .u32 triton__5_0d1d23_param_2,
.param .u32 triton__5_0d1d23_param_3
)
.maxntid 64, 1, 1
{
.reg .pred %p<20>;
.reg .b16 %h<9>;
.reg .b16 %rs<9>;
.reg .b32 %r<28>;
.reg .f32 %f<29>;
.reg .b64 %rd<15>;
ld.param.u32 %r8, [triton__5_0d1d23_param_3];
setp.lt.s32 %p1, %r8, 1;
@%p1 bra $L__BB0_3;
ld.param.u32 %r7, [triton__5_0d1d23_param_2];
ld.param.u64 %rd2, [triton__5_0d1d23_param_1];
ld.param.u64 %rd1, [triton__5_0d1d23_param_0];
mov.u32 %r9, %tid.x;
and.b32 %r1, %r9, 63;
mov.u32 %r2, %ctaid.x;
mul.lo.s32 %r3, %r2, 197;
add.s32 %r4, %r3, %r1;
mov.f32 %f25, 0fFF800000;
mov.u32 %r27, 0;
setp.lt.s32 %p10, %r2, %r7;
mov.f32 %f26, %f25;
mov.f32 %f27, %f25;
mov.f32 %f28, %f25;
$L__BB0_2:
add.s32 %r11, %r1, %r27;
add.s32 %r12, %r11, 64;
add.s32 %r13, %r11, 128;
add.s32 %r14, %r11, 192;
setp.lt.s32 %p11, %r11, %r8;
setp.lt.s32 %p12, %r12, %r8;
setp.lt.s32 %p13, %r13, %r8;
setp.lt.s32 %p14, %r14, %r8;
add.s32 %r15, %r4, %r27;
add.s32 %r16, %r15, 64;
add.s32 %r17, %r15, 128;
add.s32 %r18, %r15, 192;
mul.wide.s32 %rd11, %r15, 2;
add.s64 %rd3, %rd1, %rd11;
mul.wide.s32 %rd12, %r16, 2;
add.s64 %rd4, %rd1, %rd12;
mul.wide.s32 %rd13, %r17, 2;
add.s64 %rd5, %rd1, %rd13;
mul.wide.s32 %rd14, %r18, 2;
add.s64 %rd6, %rd1, %rd14;
and.pred %p2, %p10, %p11;
and.pred %p3, %p10, %p12;
and.pred %p4, %p10, %p13;
and.pred %p5, %p10, %p14;
@%p2 ld.global.b16 { %rs1 }, [ %rd3 + 0 ];
@%p3 ld.global.b16 { %rs2 }, [ %rd4 + 0 ];
@%p4 ld.global.b16 { %rs3 }, [ %rd5 + 0 ];
@%p5 ld.global.b16 { %rs4 }, [ %rd6 + 0 ];
cvt.u32.u16 %r19, %rs1;
shl.b32 %r20, %r19, 16;
cvt.u32.u16 %r21, %rs2;
or.b32 %r22, %r21, %r20;
cvt.u32.u16 %r23, %rs3;
shl.b32 %r24, %r23, 16;
cvt.u32.u16 %r25, %rs4;
or.b32 %r26, %r25, %r24;
mov.b32 {%h1, %h2}, %r26;
mov.b32 {%h3, %h4}, %r22;
cvt.f32.f16 %f13, %h1;
cvt.f32.f16 %f14, %h2;
cvt.f32.f16 %f15, %h3;
cvt.f32.f16 %f16, %h4;
mul.f32 %f17, %f16, 0f3E000000;
mul.f32 %f18, %f15, 0f3E000000;
mul.f32 %f19, %f14, 0f3E000000;
mul.f32 %f20, %f13, 0f3E000000;
setp.lt.f32 %p15, %f28, %f20;
setp.lt.f32 %p16, %f27, %f19;
setp.lt.f32 %p17, %f26, %f18;
setp.lt.f32 %p18, %f25, %f17;
selp.f32 %f21, %f17, %f25, %p18;
selp.f32 %f25, %f21, %f25, %p2;
selp.f32 %f22, %f18, %f26, %p17;
selp.f32 %f26, %f22, %f26, %p3;
selp.f32 %f23, %f19, %f27, %p16;
selp.f32 %f27, %f23, %f27, %p4;
selp.f32 %f24, %f20, %f28, %p15;
selp.f32 %f28, %f24, %f28, %p5;
add.s64 %rd7, %rd2, %rd11;
add.s64 %rd8, %rd2, %rd12;
add.s64 %rd9, %rd2, %rd13;
add.s64 %rd10, %rd2, %rd14;
cvt.rn.f16.f32 %h5, %f25;
cvt.rn.f16.f32 %h6, %f26;
cvt.rn.f16.f32 %h7, %f27;
cvt.rn.f16.f32 %h8, %f28;
mov.b16 %rs5, %h5;
@%p2 st.global.b16 [ %rd7 + 0 ], { %rs5 };
mov.b16 %rs6, %h6;
@%p3 st.global.b16 [ %rd8 + 0 ], { %rs6 };
mov.b16 %rs7, %h7;
@%p4 st.global.b16 [ %rd9 + 0 ], { %rs7 };
mov.b16 %rs8, %h8;
@%p5 st.global.b16 [ %rd10 + 0 ], { %rs8 };
add.s32 %r27, %r27, 256;
setp.lt.s32 %p19, %r27, %r8;
@%p19 bra $L__BB0_2;
$L__BB0_3:
ret;
}
@Jokeren
Copy link
Author

Jokeren commented Mar 1, 2023

Semantic:

@triton.jit
def triton__5(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp4 = (tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")).to(tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (197*x0)), rmask & xmask).to(tl.float32)
        tmp1 = 0.125
        tmp2 = tmp0 * tmp1
        tmp3 = tmp2.to(tl.float32)
        tmp5 = _tmp4 < tmp3
        _tmp4 = tl.where(rmask & xmask & (_tmp4 < tmp3), tmp3, _tmp4)
        tl.store(out_ptr2 + (r1 + (197*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask)

The for loop runs only for one iteration, so result = input * 0.125.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment