Skip to content

Instantly share code, notes, and snippets.

@mobicham
Created June 14, 2024 11:50
Show Gist options
  • Save mobicham/ab439330b4fb9c6f1d4086e54e4142c0 to your computer and use it in GitHub Desktop.
Save mobicham/ab439330b4fb9c6f1d4086e54e4142c0 to your computer and use it in GitHub Desktop.
#torch.__version__: 2.5.0.dev20240613+cu121
###########################################################################################
import torch, time
import numpy as np
dtype = torch.float16
def eval_time(fct, params, iters=10000):
t = []
for _ in range(iters):
t1 = time.time();
_ = fct(**params)
torch.cuda.synchronize()
t2 = time.time();
t.append(t2-t1)
return np.mean(t[-iters//2:]) #with warm-up
matmul = lambda a, b: torch.matmul(a, b.T)
batch_size = 1
shapes = [
(batch_size, 2048, 2048),
(batch_size, 2048, 4096),
(batch_size, 4096, 2048),
(batch_size, 4096, 4096),
(batch_size, 4096, 4096*2),
(batch_size, 4096*2, 4096),
(batch_size, 4096*2, 4096*2),
#(batch_size, 4096*3, 4096*3),
#(batch_size, 4096*4, 4096*4),
]
for b, K, N in shapes:
x = torch.randn((b, K), device='cuda', dtype=dtype).contiguous()/10.
W = torch.randn((N, K), device='cuda', dtype=dtype).contiguous()/10.
W2 = W.clone()
assert W.is_contiguous()
assert x.is_contiguous()
assert W2.is_contiguous()
W_time = eval_time(matmul, {'a':x, 'b':W})
Wq_time = eval_time(matmul, {'a':x, 'b':W2})
print('----------------------------------------------------------------------')
print("Shape:", str(b) + 'x' + str(K) + ' , ' + str(K) + 'x' + str(N))
print('processed vs random |', 'speed-up', str(np.round(W_time/Wq_time, 6)) + 'x')
print()
########################################################################################################
#GPU: 4090
# ----------------------------------------------------------------------
# Shape: 1x2048 , 2048x2048
# processed vs random | speed-up 0.999612x
# ----------------------------------------------------------------------
# Shape: 1x2048 , 2048x4096
# processed vs random | speed-up 0.977499x
# ----------------------------------------------------------------------
# Shape: 1x4096 , 4096x2048
# processed vs random | speed-up 0.996997x
# ----------------------------------------------------------------------
# Shape: 1x4096 , 4096x4096
# processed vs random | speed-up 1.008123x
# ----------------------------------------------------------------------
# Shape: 1x4096 , 4096x8192
# processed vs random | speed-up 2.438954x
# ----------------------------------------------------------------------
# Shape: 1x8192 , 8192x4096
# processed vs random | speed-up 2.494266x
# ----------------------------------------------------------------------
# Shape: 1x8192 , 8192x8192
# processed vs random | speed-up 0.99959x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment