Created
August 29, 2023 17:24
-
-
Save davidberard98/0f6d1e7f147d5fe9e6c71876414f47c4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def dense_to_jagged_triton( | |
inverse_offsets_ptr, offsets_ptr, dense_ptr, out_ptr0, xnumel, XBLOCK: tl.constexpr | |
): | |
# xnumel = 33106688 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x1 = xindex // 256 | |
x0 = xindex % 256 | |
x2 = xindex | |
# tmp0 = tl.load(inverse_offsets_ptr + (x1), xmask, eviction_policy="evict_last") | |
tmp0 = tl.load(inverse_offsets_ptr + (x1), xmask) | |
tl.device_assert( | |
((0 <= tmp0) & (tmp0 < 1025)) | ~xmask, "index out of bounds: 0 <= tmp0 < 1025" | |
) | |
tmp1 = tl.load(offsets_ptr + (tmp0), xmask) | |
tmp2 = x1 | |
tmp3 = tmp2 - tmp1 | |
tmp4 = tmp3 | |
tmp5 = tl.full([1], 260, tl.int32) | |
tmp6 = tmp4 < tmp5 | |
tmp7 = tl.load( | |
dense_ptr + (x0 + (256 * tmp3) + (66560 * tmp0)), tmp6 & xmask, other=0 | |
).to(tl.float32) | |
tmp8 = tl.where(tmp6, tmp7, 0.0) | |
# tmp8 = tl.load(dense_ptr + (x0 + (256 * tmp3) + (66560 * tmp0)), xmask, other=0) | |
tl.store(out_ptr0 + x2, tmp8, xmask) | |
def dense_to_jagged( | |
dense, | |
offsets, | |
inverse_offsets, | |
jagged_total_length, | |
print_ptx=False, | |
): | |
assert dense.shape[-1] == 256 | |
output = torch.empty( | |
(jagged_total_length, dense.shape[-1]), | |
dtype=dense.dtype, | |
device=dense.device, | |
) | |
BLOCK_SIZE = 1024 | |
num_warps = 4 | |
grid = (triton.cdiv(output.numel(), BLOCK_SIZE),) | |
res = dense_to_jagged_triton[grid]( | |
inverse_offsets, | |
offsets, | |
dense, | |
output, | |
output.numel(), | |
BLOCK_SIZE, | |
num_warps=num_warps, | |
) | |
if print_ptx: | |
print(" ~~~~~~~~~~~~~~~~") | |
print(res.asm.keys()) | |
# print(res.asm["ttir"]) | |
print(res.asm["ttgir"]) | |
# print(res.asm["llir"]) | |
# print(res.asm["ptx"]) | |
return output | |
def generate_offsets( | |
batch_size: int, | |
max_seq_len: int, | |
load_factor: float, | |
offsets_dtype: torch.dtype, | |
spread_radius: float, | |
) -> torch.Tensor: | |
import random | |
assert 0 <= load_factor <= 1 | |
assert 0 <= spread_radius <= 1 | |
if load_factor < 1: | |
spread = int(max_seq_len * spread_radius) | |
mean = int(max_seq_len * load_factor) | |
lengths = [ | |
mean + random.randint(-spread, spread + 1) for _ in range(batch_size) | |
] | |
lengths = [max(min(L, max_seq_len), 0) for L in lengths] | |
else: | |
lengths = [max_seq_len] * batch_size | |
offsets = [0] | |
for length in lengths: | |
offsets.append(offsets[-1] + length) | |
return torch.tensor(offsets, dtype=offsets_dtype) | |
BATCH_SIZE = 1024 | |
MAX_SEQ_LEN = 260 | |
EMBEDDING_DIM = 256 | |
dense = torch.rand((BATCH_SIZE, MAX_SEQ_LEN, EMBEDDING_DIM), device='cuda', dtype=torch.float16) | |
offsets = generate_offsets(BATCH_SIZE, MAX_SEQ_LEN, 0.3, torch.int32, 0.1) | |
jagged_lengths = offsets[1:] - offsets[:-1] | |
inverse_offsets = torch.zeros((offsets[-1].item(),), dtype=torch.int32) | |
idx = 0 | |
for i, cnt in enumerate(jagged_lengths): | |
for x in range(cnt.item()): | |
inverse_offsets[idx] = i | |
idx += 1 | |
jagged_total_length = offsets[-1].item() | |
offsets = offsets.to('cuda') | |
inverse_offsets = inverse_offsets.to('cuda') | |
def run_fn(): | |
dense_to_jagged(dense, offsets, inverse_offsets, jagged_total_length) | |
dense_to_jagged(dense, offsets, inverse_offsets, jagged_total_length, True) | |
ms, min_ms, max_ms = triton.testing.do_bench(run_fn, quantiles=[0.5, 0.2, 0.8]) | |
print(ms) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment