Skip to content

Instantly share code, notes, and snippets.

@euclaise
Created February 9, 2024 00:08
Show Gist options
  • Save euclaise/fa0c56e4054936e69f29a651139adeaa to your computer and use it in GitHub Desktop.
Save euclaise/fa0c56e4054936e69f29a651139adeaa to your computer and use it in GitHub Desktop.
Prefix-sum scan in PyTorch
import torch
from torch.nn import functional as F
import math
from typing import Callable
def split(xs):
xs = [x.view(x.shape[0], x.shape[-1]//2, 2) for x in xs]
return [x[: , :, 0] for x in xs], [x[:, :, 1] for x in xs]
def merge1(l, r):
B, H = l.shape
return torch.stack((l, r), dim=-1).view(B, H*2)
def merge(ls, rs):
return [merge1(l, r) for l, r in zip(ls, rs)]
# https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
def pscan_(xs, d, log2n, op):
xl, xr = split(xs)
xs = op(xl, xr)
if d == log2n - 1:
root = [torch.zeros_like(x) for x in xs]
else:
root = pscan_(xs, d+1, log2n, op)
return merge(root, op(root, xl))
@torch.compile
def pscan(xs: torch.Tensor, op: Callable, dim: int):
xs = [x.transpose(dim, -1) for x in xs]
orig_shape = [x.shape for x in xs]
xs = [x.reshape(-1, x.shape[-1]) for x in xs]
N = xs[0].shape[-1]
log2n = math.ceil(math.log2(N))
next_pow2 = 2 ** log2n
xs = [F.pad(x, (0, next_pow2 - N)) for x in xs]
xs = op(pscan_(xs, 0, log2n, op), xs)
xs = [x[:, :N] for x in xs]
xs = [x.reshape(orig_shape[i]).transpose(dim, -1) for i, x in enumerate(xs)]
return xs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment