# gngdb/reduce.py

Last active Oct 13, 2018
Trying to matmul reduce in PyTorch faster.
 import torch from functools import reduce import math def is_power(n): # https://stackoverflow.com/a/29480710/6937913 n = n/2 if n == 2: return True elif n > 2: return is_power(n) else: return False def functools_reduce(x): return reduce(torch.matmul, x) #X = x #for Y in x[1:]: # X = torch.matmul(X,Y) #return X def recursive_reduce(x): N = len(x) assert is_power(N) # only going to work with powe of 2s # concatenate everything into one tensor x = torch.cat([tensor.unsqueeze(0) for tensor in x], 0) return recursive_reduce_with_tensor(x) def recursive_reduce_with_tensor(x): # split adjacent elements into separate tensors N, M, _ = x.size() while N > 1: x = x.view(N//2, 2, M, M).permute(1,0,2,3) x = torch.matmul(x, x) #x = torch.matmul(x[:-1:2], x[1::2]) # equivalent but slower N, M, _ = x.size() return x.view(M,M) if __name__ == '__main__': M = 1000 X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)] f_X = functools_reduce(X_list) error = torch.abs(f_X - recursive_reduce(X_list)) assert error.max() < 1e-3, (error.mean(), error, f_X) # not exploding assert f_X.mean() < 100., f_X assert math.sqrt(f_X.var().item()) < 100., f_X # or vanishing assert math.sqrt(f_X.var().item()) > 1e-3, math.sqrt(f_X.var().item()) import timeit print("functools (reduce 16 matrices):") setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)]" print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]" print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) print("recursive (reduce 16 matrices):") setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)]" print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]" print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) setup="from __main__ import recursive_reduce_with_tensor as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]; X = torch.cat([x.unsqueeze(0) for x in X_list], 0)" print(" GPU (cat before): ", timeit.timeit("_ = reduce_function(X)", setup=setup, number=1000)) print("functools (reduce 128 matrices):") setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(128)]" print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]" print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) print("recursive (reduce 128 matrices):") setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(128)]" print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]" print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) setup="from __main__ import recursive_reduce_with_tensor as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]; X = torch.cat([x.unsqueeze(0) for x in X_list], 0)" print(" GPU (cat before): ", timeit.timeit("_ = reduce_function(X)", setup=setup, number=100))

Results:

``````functools (reduce 16 matrices):
CPU:  0.5009437309927307
GPU:  0.12856827297946438
recursive (reduce 16 matrices):
CPU:  0.9862206479883753
GPU:  0.3362074689939618
GPU (cat before):  0.20793896398390643
functools (reduce 128 matrices):
CPU:  0.448671635997016
GPU:  0.13765212500584312
recursive (reduce 128 matrices):
CPU:  0.5517097190022469
GPU:  0.1064921960060019
GPU (cat before):  0.046707449975656345
``````