Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Last active November 11, 2021 02:52
Show Gist options
  • Save AranKomat/be50d1bcee38411681f7218d2b81dede to your computer and use it in GitHub Desktop.
Save AranKomat/be50d1bcee38411681f7218d2b81dede to your computer and use it in GitHub Desktop.
Log-linear version of cumsum and cumprod
from functools import partial
import torch
def _const(example, val):
return torch.tensor(val, dtype=example.dtype)
def pad(x, axis, side):
shape = list(x.size())
if axis == -1:
axis = len(shape) - 1
length = shape[axis]
x = x.unsqueeze(axis+1)
if side == 'right':
x = torch.cat([x, torch.zeros_like(x)], axis+1)
else:
x = torch.cat([torch.zeros_like(x), x], axis+1)
shape[axis] = 2*length
return x.reshape(shape)
def slice_in_dim(operand, start_index, limit_index, stride: int = 1, axis: int = 0):
"""Convenience wrapper around slice applying to only one dimension."""
# translate `None`
len_axis = operand.shape[axis]
start_index_int = int(start_index) if start_index is not None else 0
limit_index_int = int(limit_index) if limit_index is not None else len_axis
# translate negative indices
if start_index_int < 0:
start_index_int = start_index_int + len_axis
if limit_index_int < 0:
limit_index_int = limit_index_int + len_axis
axis = int(axis)
return operand.transpose(axis, -1)[..., start_index_int:limit_index_int:stride].transpose(axis, -1)
def _prescan_power_of_two(x, axis, op, unit):
n = x.shape[axis]
assert n != 0 and n & (n - 1) == 0, "n must be a power of 2"
# Upsweep
xs = []
for d in range(0, n.bit_length() - 1):
x1 = slice_in_dim(x, 0, None, stride=2, axis=axis)
xs.append(x1)
x2 = slice_in_dim(x, 1, None, stride=2, axis=axis)
x = op(x1, x2)
total = x
# Downsweep
x = torch.full_like(total, unit)
for w in reversed(xs):
x1 = pad(x, axis=axis, side='right')
x2 = pad(x, axis=axis, side='left')
w = pad(w, axis=axis, side='left')
x = x1 + op(x2, w)
return x, total
def _parallel_prefix_scan(x, axis, op, unit):
n = x.shape[axis]
if n == 0:
return x
# Pads to the next largest power of two
nbits = n.bit_length()
if n == (1 << (nbits - 1)):
nbits -= 1
shape = list(x.size())
shape[axis] = (1 << nbits) - n
padding = x.new_zeros(shape).fill_(unit)
x = torch.cat([x, padding], axis)
x, total = _prescan_power_of_two(x, axis, op, unit)
return torch.cat([slice_in_dim(x, 1, n, axis=axis), total], axis)
def cumsum(x, dim=-1):
def add(y, z): return y + z
return _parallel_prefix_scan(x, axis=dim, op=add, unit=0)
def cumprod(x, dim=-1):
def mult(y, z): return y * z
return _parallel_prefix_scan(x, axis=dim, op=mult, unit=1)
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
a = torch.empty(512, 8000, dtype=torch.float32).to('cuda')
cumsum(a, dim=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment