Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created August 29, 2023 17:24
Show Gist options
  • Save davidberard98/0f6d1e7f147d5fe9e6c71876414f47c4 to your computer and use it in GitHub Desktop.
Save davidberard98/0f6d1e7f147d5fe9e6c71876414f47c4 to your computer and use it in GitHub Desktop.
import torch
import triton
import triton.language as tl
@triton.jit
def dense_to_jagged_triton(
inverse_offsets_ptr, offsets_ptr, dense_ptr, out_ptr0, xnumel, XBLOCK: tl.constexpr
):
# xnumel = 33106688
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = xindex // 256
x0 = xindex % 256
x2 = xindex
# tmp0 = tl.load(inverse_offsets_ptr + (x1), xmask, eviction_policy="evict_last")
tmp0 = tl.load(inverse_offsets_ptr + (x1), xmask)
tl.device_assert(
((0 <= tmp0) & (tmp0 < 1025)) | ~xmask, "index out of bounds: 0 <= tmp0 < 1025"
)
tmp1 = tl.load(offsets_ptr + (tmp0), xmask)
tmp2 = x1
tmp3 = tmp2 - tmp1
tmp4 = tmp3
tmp5 = tl.full([1], 260, tl.int32)
tmp6 = tmp4 < tmp5
tmp7 = tl.load(
dense_ptr + (x0 + (256 * tmp3) + (66560 * tmp0)), tmp6 & xmask, other=0
).to(tl.float32)
tmp8 = tl.where(tmp6, tmp7, 0.0)
# tmp8 = tl.load(dense_ptr + (x0 + (256 * tmp3) + (66560 * tmp0)), xmask, other=0)
tl.store(out_ptr0 + x2, tmp8, xmask)
def dense_to_jagged(
dense,
offsets,
inverse_offsets,
jagged_total_length,
print_ptx=False,
):
assert dense.shape[-1] == 256
output = torch.empty(
(jagged_total_length, dense.shape[-1]),
dtype=dense.dtype,
device=dense.device,
)
BLOCK_SIZE = 1024
num_warps = 4
grid = (triton.cdiv(output.numel(), BLOCK_SIZE),)
res = dense_to_jagged_triton[grid](
inverse_offsets,
offsets,
dense,
output,
output.numel(),
BLOCK_SIZE,
num_warps=num_warps,
)
if print_ptx:
print(" ~~~~~~~~~~~~~~~~")
print(res.asm.keys())
# print(res.asm["ttir"])
print(res.asm["ttgir"])
# print(res.asm["llir"])
# print(res.asm["ptx"])
return output
def generate_offsets(
batch_size: int,
max_seq_len: int,
load_factor: float,
offsets_dtype: torch.dtype,
spread_radius: float,
) -> torch.Tensor:
import random
assert 0 <= load_factor <= 1
assert 0 <= spread_radius <= 1
if load_factor < 1:
spread = int(max_seq_len * spread_radius)
mean = int(max_seq_len * load_factor)
lengths = [
mean + random.randint(-spread, spread + 1) for _ in range(batch_size)
]
lengths = [max(min(L, max_seq_len), 0) for L in lengths]
else:
lengths = [max_seq_len] * batch_size
offsets = [0]
for length in lengths:
offsets.append(offsets[-1] + length)
return torch.tensor(offsets, dtype=offsets_dtype)
BATCH_SIZE = 1024
MAX_SEQ_LEN = 260
EMBEDDING_DIM = 256
dense = torch.rand((BATCH_SIZE, MAX_SEQ_LEN, EMBEDDING_DIM), device='cuda', dtype=torch.float16)
offsets = generate_offsets(BATCH_SIZE, MAX_SEQ_LEN, 0.3, torch.int32, 0.1)
jagged_lengths = offsets[1:] - offsets[:-1]
inverse_offsets = torch.zeros((offsets[-1].item(),), dtype=torch.int32)
idx = 0
for i, cnt in enumerate(jagged_lengths):
for x in range(cnt.item()):
inverse_offsets[idx] = i
idx += 1
jagged_total_length = offsets[-1].item()
offsets = offsets.to('cuda')
inverse_offsets = inverse_offsets.to('cuda')
def run_fn():
dense_to_jagged(dense, offsets, inverse_offsets, jagged_total_length)
dense_to_jagged(dense, offsets, inverse_offsets, jagged_total_length, True)
ms, min_ms, max_ms = triton.testing.do_bench(run_fn, quantiles=[0.5, 0.2, 0.8])
print(ms)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment