Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active Oct 13, 2018
Embed
What would you like to do?
Trying to matmul reduce in PyTorch faster.
import torch
from functools import reduce
import math
def is_power(n):
# https://stackoverflow.com/a/29480710/6937913
n = n/2
if n == 2:
return True
elif n > 2:
return is_power(n)
else:
return False
def functools_reduce(x):
return reduce(torch.matmul, x)
#X = x[0]
#for Y in x[1:]:
# X = torch.matmul(X,Y)
#return X
def recursive_reduce(x):
N = len(x)
assert is_power(N) # only going to work with powe of 2s
# concatenate everything into one tensor
x = torch.cat([tensor.unsqueeze(0) for tensor in x], 0)
return recursive_reduce_with_tensor(x)
def recursive_reduce_with_tensor(x):
# split adjacent elements into separate tensors
N, M, _ = x.size()
while N > 1:
x = x.view(N//2, 2, M, M).permute(1,0,2,3)
x = torch.matmul(x[0], x[1])
#x = torch.matmul(x[:-1:2], x[1::2]) # equivalent but slower
N, M, _ = x.size()
return x.view(M,M)
if __name__ == '__main__':
M = 1000
X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)]
f_X = functools_reduce(X_list)
error = torch.abs(f_X - recursive_reduce(X_list))
assert error.max() < 1e-3, (error.mean(), error, f_X)
# not exploding
assert f_X.mean() < 100., f_X
assert math.sqrt(f_X.var().item()) < 100., f_X
# or vanishing
assert math.sqrt(f_X.var().item()) > 1e-3, math.sqrt(f_X.var().item())
import timeit
print("functools (reduce 16 matrices):")
setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)]"
print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000))
setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]"
print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000))
print("recursive (reduce 16 matrices):")
setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)]"
print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000))
setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]"
print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000))
setup="from __main__ import recursive_reduce_with_tensor as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]; X = torch.cat([x.unsqueeze(0) for x in X_list], 0)"
print(" GPU (cat before): ", timeit.timeit("_ = reduce_function(X)", setup=setup, number=1000))
print("functools (reduce 128 matrices):")
setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(128)]"
print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100))
setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]"
print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100))
print("recursive (reduce 128 matrices):")
setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(128)]"
print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100))
setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]"
print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100))
setup="from __main__ import recursive_reduce_with_tensor as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]; X = torch.cat([x.unsqueeze(0) for x in X_list], 0)"
print(" GPU (cat before): ", timeit.timeit("_ = reduce_function(X)", setup=setup, number=100))
@gngdb
Copy link
Author

gngdb commented Oct 11, 2018

Results:

functools (reduce 16 matrices):                                                                                                              
  CPU:  0.5009437309927307                                                                                                                   
  GPU:  0.12856827297946438                                                                                                                  
recursive (reduce 16 matrices):                                                                                                              
  CPU:  0.9862206479883753                                                                                                                   
  GPU:  0.3362074689939618                                                                                                                   
  GPU (cat before):  0.20793896398390643                                                                                                     
functools (reduce 128 matrices):                                                                                                             
  CPU:  0.448671635997016
  GPU:  0.13765212500584312                                                                                                                  
recursive (reduce 128 matrices):
  CPU:  0.5517097190022469
  GPU:  0.1064921960060019
  GPU (cat before):  0.046707449975656345

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment