Skip to content

Instantly share code, notes, and snippets.

Last active May 26, 2021 23:45
Show Gist options
  • Save oliver-batchelor/7785f222440ea00154897f60e89d7490 to your computer and use it in GitHub Desktop.
Save oliver-batchelor/7785f222440ea00154897f60e89d7490 to your computer and use it in GitHub Desktop.
from os.path import split
from typing import List
import torch
# import opt_einsum as oe
from torch.nn.modules.activation import ReLU
import torch.utils.benchmark as benchmark
from torch import nn
import math
import nestedtensor
import torch.nn.functional as F
dtype = torch.float16
f1 = 256
f2 = 256
b = 100000
n_instances = 1024
weights = torch.randn(n_instances, f2, f1, dtype=dtype, device=device)
features = torch.randn(b, f1, dtype=dtype, device=device)
inds = torch.randint(0, n_instances, [b], dtype=torch.int64).to(device=device)
def make_net():
# m = nn.Sequential(
# nn.Linear(f1, f2),
# nn.ReLU(),
# nn.Linear(f2, f2),
# nn.ReLU()
# )
m = nn.Linear(f1, f2)
return m
def group_by_indexes(f, inds):
_, sort_inds, counts = torch.unique(inds, return_counts=True, return_inverse=True)
fs = torch.split_with_sizes(f[sort_inds], counts.tolist())
return fs, sort_inds
def ungroup_indexes(fs, sort_inds):
sorted_out =
outputs = torch.empty(sorted_out.shape, dtype=sorted_out.dtype, device=sorted_out.device)
outputs[sort_inds] = sorted_out
return outputs
class SplitModule(nn.Module):
def __init__(self, modules):
super(SplitModule, self).__init__()
self.split = nn.ModuleList(modules)
def forward(self, features, inds):
grouped_features, sort_inds = group_by_indexes(features, inds)
outputs = [m.forward(grouped_features[i]) for i, m in enumerate(self.split)]
return ungroup_indexes(outputs, sort_inds)
linear_only = nn.Linear(f1, f2), device=device)
split_module = SplitModule([make_net() for i in range(0, n_instances)]), device=device)
module = make_net(), device=device)
def bmm_linear(weights, features, inds):
b_w = weights[inds]
return torch.bmm(b_w, features.unsqueeze(2))
def pad_to(features, block_size):
padding = (block_size - features.size(0) % block_size) % block_size
return F.pad(features, (0, 0, 0, padding))
def block_shared_bmm(weights, features, inds, block_size=128):
features = F.pad(features, (0, 0, 0, block_size))
_, sort_inds, counts = torch.unique(inds, return_counts=True, return_inverse=True)
counts = counts.cpu()
num_blocks = torch.ceil(counts / block_size).int()
starts =[torch.tensor([0], dtype=int), counts.cumsum(0)])
ends = num_blocks * block_size + starts[:-1]
block_sets = [features[start:end] for start, end in zip(starts, ends)]
x = [b.shape[0] / block_size for b in block_sets]
blocks =, block_size, features.size(1))
block_weights = weights.repeat_interleave(, dim=0)
torch.bmm(block_weights, blocks.permute(0, 2, 1))
def vmap_linear(weights, features, inds):
def f(w, f):
return, f)
b_w = weights[inds]
return torch.vmap(f)(b_w, features)
def module_nested(m, f, inds):
fs, sort_inds = group_by_indexes(f, inds)
t = nestedtensor.nested_tensor(fs, device=f.device)
return m(t)
setup = 'import __main__ as M'
test_bmm = benchmark.Timer(
stmt='M.bmm_linear(w, f, i)',
globals={'i': inds, 'w':weights, 'f':features})
test_shared_bmm = benchmark.Timer(
stmt='M.block_shared_bmm(w, f, i)',
globals={'i': inds, 'w':weights, 'f':features})
test_module = benchmark.Timer(
globals={'m':module, 'f':features})
test_linear = benchmark.Timer(
globals={'m':linear_only, 'f':features})
test_split = benchmark.Timer(
stmt='m.forward(f, i)',
globals={'i': inds, 'm': split_module, 'f':features})
test_nested = benchmark.Timer(
stmt='M.module_nested(m, f, i)',
globals={'m':module, 'i': inds, 'f':features})
print(f"feature size inputs {f1}, outputs {f2}")
print(f"batch size {b}, instances {n_instances}")
test_linear.timeit(50) # warmup
# print("test_bmm", test_linear.timeit(10))
print("test_linear", test_linear.timeit(50))
print("test_shared_bmm", test_shared_bmm.timeit(50))
# print("test_nested", test_nested.timeit(10))
print("test_module", test_module.timeit(50))
print("test_split", test_split.timeit(50))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment