Last active
October 13, 2018 12:47
-
-
Save gngdb/70fce4f27cdaeeb3f8f18cf9929e60d3 to your computer and use it in GitHub Desktop.
Trying to matmul reduce in PyTorch faster.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from functools import reduce | |
import math | |
def is_power(n): | |
# https://stackoverflow.com/a/29480710/6937913 | |
n = n/2 | |
if n == 2: | |
return True | |
elif n > 2: | |
return is_power(n) | |
else: | |
return False | |
def functools_reduce(x): | |
return reduce(torch.matmul, x) | |
#X = x[0] | |
#for Y in x[1:]: | |
# X = torch.matmul(X,Y) | |
#return X | |
def recursive_reduce(x): | |
N = len(x) | |
assert is_power(N) # only going to work with powe of 2s | |
# concatenate everything into one tensor | |
x = torch.cat([tensor.unsqueeze(0) for tensor in x], 0) | |
return recursive_reduce_with_tensor(x) | |
def recursive_reduce_with_tensor(x): | |
# split adjacent elements into separate tensors | |
N, M, _ = x.size() | |
while N > 1: | |
x = x.view(N//2, 2, M, M).permute(1,0,2,3) | |
x = torch.matmul(x[0], x[1]) | |
#x = torch.matmul(x[:-1:2], x[1::2]) # equivalent but slower | |
N, M, _ = x.size() | |
return x.view(M,M) | |
if __name__ == '__main__': | |
M = 1000 | |
X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)] | |
f_X = functools_reduce(X_list) | |
error = torch.abs(f_X - recursive_reduce(X_list)) | |
assert error.max() < 1e-3, (error.mean(), error, f_X) | |
# not exploding | |
assert f_X.mean() < 100., f_X | |
assert math.sqrt(f_X.var().item()) < 100., f_X | |
# or vanishing | |
assert math.sqrt(f_X.var().item()) > 1e-3, math.sqrt(f_X.var().item()) | |
import timeit | |
print("functools (reduce 16 matrices):") | |
setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)]" | |
print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) | |
setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]" | |
print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) | |
print("recursive (reduce 16 matrices):") | |
setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)]" | |
print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) | |
setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]" | |
print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) | |
setup="from __main__ import recursive_reduce_with_tensor as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]; X = torch.cat([x.unsqueeze(0) for x in X_list], 0)" | |
print(" GPU (cat before): ", timeit.timeit("_ = reduce_function(X)", setup=setup, number=1000)) | |
print("functools (reduce 128 matrices):") | |
setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(128)]" | |
print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) | |
setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]" | |
print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) | |
print("recursive (reduce 128 matrices):") | |
setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(128)]" | |
print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) | |
setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]" | |
print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) | |
setup="from __main__ import recursive_reduce_with_tensor as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]; X = torch.cat([x.unsqueeze(0) for x in X_list], 0)" | |
print(" GPU (cat before): ", timeit.timeit("_ = reduce_function(X)", setup=setup, number=100)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Results: