Skip to content

Instantly share code, notes, and snippets.

@ngimel
Created May 1, 2022 18:50
Show Gist options
  • Save ngimel/a32881a37687e82759a4d7d313080ec0 to your computer and use it in GitHub Desktop.
Save ngimel/a32881a37687e82759a4d7d313080ec0 to your computer and use it in GitHub Desktop.
import torch
import triton
import triton.language as tl
#@triton.jit
def mul_sum_kernel(
output_ptr, input_ptr0, input_ptr1,
si00, si01, si02, si03,
si10, si11, si12, si13,
so0, so2, so3, sz1, sz2, sz3,
BLOCK_SIZE: tl.constexpr
):
inner_idx = tl.program_id(0)
outer_idx = tl.program_id(1)
inp_ptr0 = input_ptr0 + outer_idx * si00
inp_ptr1 = input_ptr1 + outer_idx * si10
out_ptr = output_ptr + outer_idx * so0
offsets = tl.arange(0, BLOCK_SIZE) + inner_idx * BLOCK_SIZE
mask = offsets < sz2 * sz3
i2 = offsets//sz3
i3 = offsets - i2 * sz3
offsets0 = i3 * si03 + i2 * si02
offsets1 = i3 * si13 + i2 * si12
accum = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
for i in range(sz1):
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than number of valid values
vals0 = tl.load(inp_ptr0+offsets0, mask = mask).to(tl.float32)
vals1 = tl.load(inp_ptr1+offsets1, mask = mask).to(tl.float32)
accum += vals0 * vals1
inp_ptr0 += si01
inp_ptr1 += si11
#accum = accum.to(tl.bfloat16)
o_offsets = i3 * so3 + i2 * so2
tl.store(out_ptr+o_offsets, accum, mask = mask)
# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
def mul_sum(x0, x1):
x1 = x1.expand_as(x0)
s0, s1, s2, s3 = x0.shape
si00, si01, si02, si03 = x0.stride()
si10, si11, si12, si13 = x1.stride()
BLOCK_SIZE = 1024
num_warps = 4
grid = lambda meta: (triton.cdiv(s2*s3, meta['BLOCK_SIZE']),s0)
# Allocate output
y = torch.empty(s0, s2, s3, device=x0.device, dtype=x0.dtype)
so0, so2, so3 = y.stride()
jitted_fn = triton.jit(mul_sum_kernel)
jitted_fn[grid](
y,
x0,
x1,
si00, si01, si02, si03,
si10, si11, si12, si13,
so0, so2, so3,
s1, s2, s3,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
# for v in jitted_fn.bin_cache.values():
# print(v.asm["ptx"])
return y
torch.manual_seed(0)
for s in (16,15):
dtype=torch.bfloat16
x0 = torch.randn(32,2,256,s*s, device='cuda', dtype=dtype)
x1 = torch.randn(32,2,256,1*1,device="cuda", dtype=dtype)
for _ in range(1):
y_triton = mul_sum(x0, x1)
y_torch = (x0*x1).sum(1)
y_torch = (x0.float()*x1.float()).sum(1).to(x0.dtype)
print(y_torch.dtype)
print((y_triton-y_torch).abs().max())
#assert torch.allclose(y_triton, y_torch), (y_triton[0][0], y_torch[0][0])
# def t(input1, input2):
# return((input1*input2).sum(1))
# scripted = torch.jit.script(t)
# with torch.jit.fuser("fuser2"):
# for _ in range(10):
# out = scripted(x0, x1)
#assert torch.allclose(out, y_torch)#, (y_triton, y_torch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment