Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active May 31, 2024 21:52
Show Gist options
  • Save Chillee/e3089e7a11419c6b85f68de170e0ba0c to your computer and use it in GitHub Desktop.
Save Chillee/e3089e7a11419c6b85f68de170e0ba0c to your computer and use it in GitHub Desktop.
Higher Order Kernel - associative scan
import torch
import torch.nn as nn
from torch._higher_order_ops.associative_scan import associative_scan
from triton.testing import do_bench
torch.set_default_device('cuda')
def combine_fn(i, j):
ia, ib = i
ja, jb = j
return ia * ja, ib * ja + jb
a = torch.randn(1024, 1024 * 10)
b = torch.randn(1024, 1024 * 10)
def baseline(v, u):
A = []
A.append(b[:, 0])
for i in range(1, v.shape[1]):
A.append(a[:, i] * A[i - 1] + b[:, i])
return torch.stack(A, dim=1)
@torch.compile
def compiled_scan(a, b):
return associative_scan(combine_fn, (a, b), dim=-1)[1]
out1 = baseline(a, b)
out2 = compiled_scan(a, b)
print((out1 - out2).abs().max())
print("eager", do_bench(lambda: baseline(a, b)))
print("compiled", do_bench(lambda: compiled_scan(a, b)))
print("two cumprods", do_bench(lambda: [torch.cumprod(a, dim=-1), torch.cumprod(b, dim=-1)]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment