Skip to content

Instantly share code, notes, and snippets.

@oliver-batchelor
Last active May 17, 2021 06:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oliver-batchelor/06a47d20f55faf5cab76f5210e876f11 to your computer and use it in GitHub Desktop.
Save oliver-batchelor/06a47d20f55faf5cab76f5210e876f11 to your computer and use it in GitHub Desktop.
from os.path import split
import torch
import opt_einsum as oe
import torch.utils.benchmark as benchmark
from torch import nn
import torch.nn.functional as F
f1 = 256
f2 = 256
b = 100000
splits = 2
n_instances = splits * splits * splits
weights = torch.randn(n_instances, f2, f1).cuda()
features = torch.randn(b, f1).cuda()
inds = torch.randint(0, n_instances, [b]).cuda()
module = nn.Linear(f1, f2)
module.cuda()
masked_module = nn.Linear(f1 * n_instances, f2)
masked_module.cuda()
split_modules = nn.ModuleList([nn.Linear(f1, f2) for i in range(0, n_instances)])
split_modules.cuda()
def einsum_linear(i, w, f):
i = F.one_hot(i, w.shape[0]).to(dtype=torch.float32)
return oe.contract("x n d, b x, b d -> b n", w, i, f)
def bmm_linear(i, w, f):
b_w = w[i]
return torch.bmm(b_w, f.unsqueeze(2))
def split_linear(i, modules, f):
i, inds = torch.sort(i)
_, counts = torch.unique_consecutive(i, return_counts=True)
fs = torch.split_with_sizes(f[inds], tuple(counts))
sorted_out = torch.cat([m.forward(x) for m, x in zip(modules, fs)])
outputs = sorted_out.new(sorted_out.shape)
outputs[inds] = sorted_out
return outputs
def test_linear(m, f):
return m.forward(f)
def masked_linear(i, m, f):
sparse_f = f.new_zeros(f.shape[0], n_instances, f.shape[1])
sparse_f[:, i] = f
sparse_f = sparse_f.view(f.shape[0], -1)
return m.forward(sparse_f)
t0 = benchmark.Timer(
stmt='einsum_linear(i, w, f)',
setup='from __main__ import einsum_linear',
globals={'i': inds, 'w':weights, 'f':features})
t1 = benchmark.Timer(
stmt='bmm_linear(i, w, f)',
setup='from __main__ import bmm_linear',
globals={'i': inds, 'w':weights, 'f':features})
t2 = benchmark.Timer(
stmt='test_linear(m, f)',
setup='from __main__ import test_linear',
globals={'m':module, 'f':features})
t3 = benchmark.Timer(
stmt='split_linear(i, m, f)',
setup='from __main__ import split_linear',
globals={'i': inds, 'm':split_modules, 'f':features})
print(f"feature size inputs {f1}, outputs {f2}")
print(f"batch size {b}, instances {n_instances}")
# print(t0.timeit(10))
print(t3.timeit(10))
# print(t1.timeit(10))
print(t2.timeit(10))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment