Skip to content

Instantly share code, notes, and snippets.

@ngimel
Created April 20, 2022 23:41
Show Gist options
  • Save ngimel/899c335bc49539df04e3a9119a83d5fc to your computer and use it in GitHub Desktop.
Save ngimel/899c335bc49539df04e3a9119a83d5fc to your computer and use it in GitHub Desktop.
import torch
import triton
import triton.language as tl
from itertools import product
@triton.jit
def copy_kernel(
output_ptr, input_ptr,
bs, size_inp_0, size_inp_1,
batch_stride_inp, stride_inp_0, stride_inp_1,
batch_stride_out, stride_out_0, stride_out_1,
BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr
):
pid = tl.program_id(0)
pid_batch = tl.program_id(1)
grid_m = (size_inp_0 + BLOCK_M - 1) // BLOCK_M
grid_k = (size_inp_1 + BLOCK_K - 1) // BLOCK_K
pid_m = pid // grid_k
pid_k = pid - pid_m * grid_k
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
A = input_ptr + batch_stride_inp*pid_batch + (rm[:, None] * stride_inp_0 + rk[None, :] * stride_inp_1)
B = output_ptr + batch_stride_out*pid_batch + (rk[None, :] * stride_out_0 + rm[:, None] * stride_out_1)
mask = (rm < size_inp_0)[:, None] & (rk < size_inp_1)[None, :]
mask_store = (rk < size_inp_1)[None, :] & (rm < size_inp_0)[:, None]
a = tl.load(A, mask=mask, other=0)
tl.store(B, a, mask=mask_store)
def copy(dst, src):
if dst.ndim == 2:
dst = dst.unsqueeze(0)
if src.ndim == 2:
src = src.unsqueeze(0)
bsz, sz0, sz1 = src.shape
bsd, sd0, sd1 = dst.stride()
bss, ss0, ss1 = src.stride()
BLOCK_M = 32
BLOCK_K = 32
grid = lambda meta: (triton.cdiv(sz0, BLOCK_M) * triton.cdiv(sz1, BLOCK_K), bsz)
copy_kernel[grid](dst, src, bsz, sz0, sz1, bsd, ss0, ss1, bss, sd0, sd1, BLOCK_M=BLOCK_M, BLOCK_K=BLOCK_K)
def test_transpose(x):
res = x.transpose(-1, -2).contiguous()
out = torch.empty_like(res)
copy(out, x)
#print(torch.allclose(out, res))
dtypes = (torch.float, torch.half)
sizes = ((16*16*16, 16*16*16), (16, 256, 512*128), (16, 512*128, 256))
for s, dtype in product(sizes, dtypes):
x = torch.randn(s, device="cuda", dtype=dtype)
test_transpose(x)
torch.cuda.synchronize()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment