{{ message }}

Instantly share code, notes, and snippets.

# gngdb/reduce.py

Last active Oct 13, 2018
Trying to matmul reduce in PyTorch faster.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import torch from functools import reduce import math def is_power(n): # https://stackoverflow.com/a/29480710/6937913 n = n/2 if n == 2: return True elif n > 2: return is_power(n) else: return False def functools_reduce(x): return reduce(torch.matmul, x) #X = x #for Y in x[1:]: # X = torch.matmul(X,Y) #return X def recursive_reduce(x): N = len(x) assert is_power(N) # only going to work with powe of 2s # concatenate everything into one tensor x = torch.cat([tensor.unsqueeze(0) for tensor in x], 0) return recursive_reduce_with_tensor(x) def recursive_reduce_with_tensor(x): # split adjacent elements into separate tensors N, M, _ = x.size() while N > 1: x = x.view(N//2, 2, M, M).permute(1,0,2,3) x = torch.matmul(x, x) #x = torch.matmul(x[:-1:2], x[1::2]) # equivalent but slower N, M, _ = x.size() return x.view(M,M) if __name__ == '__main__': M = 1000 X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)] f_X = functools_reduce(X_list) error = torch.abs(f_X - recursive_reduce(X_list)) assert error.max() < 1e-3, (error.mean(), error, f_X) # not exploding assert f_X.mean() < 100., f_X assert math.sqrt(f_X.var().item()) < 100., f_X # or vanishing assert math.sqrt(f_X.var().item()) > 1e-3, math.sqrt(f_X.var().item()) import timeit print("functools (reduce 16 matrices):") setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)]" print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]" print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) print("recursive (reduce 16 matrices):") setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(16)]" print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]" print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=1000)) setup="from __main__ import recursive_reduce_with_tensor as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(16)]; X = torch.cat([x.unsqueeze(0) for x in X_list], 0)" print(" GPU (cat before): ", timeit.timeit("_ = reduce_function(X)", setup=setup, number=1000)) print("functools (reduce 128 matrices):") setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(128)]" print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) setup="from __main__ import functools_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]" print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) print("recursive (reduce 128 matrices):") setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M)/math.sqrt(M) for i in range(128)]" print(" CPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) setup="from __main__ import recursive_reduce as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]" print(" GPU: ", timeit.timeit("_ = reduce_function(X_list)", setup=setup, number=100)) setup="from __main__ import recursive_reduce_with_tensor as reduce_function; import torch,math;M = 100; X_list = [torch.randn(M,M).to('cuda')/math.sqrt(M) for i in range(128)]; X = torch.cat([x.unsqueeze(0) for x in X_list], 0)" print(" GPU (cat before): ", timeit.timeit("_ = reduce_function(X)", setup=setup, number=100))

### gngdb commented Oct 11, 2018 • edited

Results:

``````functools (reduce 16 matrices):
CPU:  0.5009437309927307
GPU:  0.12856827297946438
recursive (reduce 16 matrices):
CPU:  0.9862206479883753
GPU:  0.3362074689939618
GPU (cat before):  0.20793896398390643
functools (reduce 128 matrices):
CPU:  0.448671635997016
GPU:  0.13765212500584312
recursive (reduce 128 matrices):
CPU:  0.5517097190022469
GPU:  0.1064921960060019
GPU (cat before):  0.046707449975656345
``````