Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active November 22, 2022 21:10
Show Gist options
  • Save Birch-san/8f3eb99deffdc3541595e46a01605dea to your computer and use it in GitHub Desktop.
Save Birch-san/8f3eb99deffdc3541595e46a01605dea to your computer and use it in GitHub Desktop.
benchmark: batched matmul with scale factor
import torch
from torch import einsum, tensor, matmul, bmm, baddbmm, empty
import time
scale=2
repeats = 10
# both einsum 0s use the same plan, so whichever batch runs first has to pay the price of warmup
# uncomment this to run a warmup before either batch runs, for fairer comparison of batch avg time
# q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
# k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
# start = time.perf_counter()
# (einsum('b i d, b j d -> b i j', q, k) * scale).max().item()
# duration = time.perf_counter()-start
# print('einsum 0 warmup took %.4f seconds' % (duration))
batch_duration = 0
for ix in range(repeats):
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
(einsum('b i d, b j d -> b i j', q, k) * scale).max().item()
duration = time.perf_counter()-start
print('einsum 0 iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of einsum 0 took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
batch_duration = 0
for ix in range(repeats):
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
(einsum('b n m, b m p -> b n p', q, k.transpose(1, 2)) * scale).max().item()
duration = time.perf_counter()-start
print('einsum 0 transposed k iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of einsum 0 transposed k took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
batch_duration = 0
for ix in range(repeats):
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
(matmul(q, k.transpose(1, 2)) * scale).max().item()
duration = time.perf_counter()-start
print('matmul iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of matmul took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
batch_duration = 0
for ix in range(repeats):
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
(bmm(q, k.transpose(1, 2)) * scale).max().item()
duration = time.perf_counter()-start
print('bmm iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of bmm took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
e = empty((1, 1, 1), device='mps')
batch_duration = 0
for ix in range(repeats):
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
baddbmm(e, q, k.transpose(1, 2), alpha=scale, beta=0).max().item()
duration = time.perf_counter()-start
print('baddbmm iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of baddbmm took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
@Birch-san
Copy link
Author

1.14.0.dev20221103, MPS

einsum 0 iteration 0 took 0.2207 seconds
einsum 0 iteration 1 took 0.0213 seconds
einsum 0 iteration 2 took 0.0214 seconds
einsum 0 iteration 3 took 0.0216 seconds
einsum 0 iteration 4 took 0.0215 seconds
einsum 0 iteration 5 took 0.0218 seconds
einsum 0 iteration 6 took 0.0214 seconds
einsum 0 iteration 7 took 0.0217 seconds
einsum 0 iteration 8 took 0.0215 seconds
einsum 0 iteration 9 took 0.0216 seconds
10 iterations of einsum 0 took 0.4145 seconds; avg 0.0415 secs
einsum 0 transposed k iteration 0 took 0.0216 seconds
einsum 0 transposed k iteration 1 took 0.0217 seconds
einsum 0 transposed k iteration 2 took 0.0214 seconds
einsum 0 transposed k iteration 3 took 0.0216 seconds
einsum 0 transposed k iteration 4 took 0.0216 seconds
einsum 0 transposed k iteration 5 took 0.0216 seconds
einsum 0 transposed k iteration 6 took 0.0216 seconds
einsum 0 transposed k iteration 7 took 0.0217 seconds
einsum 0 transposed k iteration 8 took 0.0215 seconds
einsum 0 transposed k iteration 9 took 0.0216 seconds
10 iterations of einsum 0 transposed k took 0.2159 seconds; avg 0.0216 secs
matmul iteration 0 took 0.0200 seconds
matmul iteration 1 took 0.0164 seconds
matmul iteration 2 took 0.0162 seconds
matmul iteration 3 took 0.0159 seconds
matmul iteration 4 took 0.0161 seconds
matmul iteration 5 took 0.0158 seconds
matmul iteration 6 took 0.0158 seconds
matmul iteration 7 took 0.0162 seconds
matmul iteration 8 took 0.0158 seconds
matmul iteration 9 took 0.0159 seconds
10 iterations of matmul took 0.1642 seconds; avg 0.0164 secs
bmm iteration 0 took 0.0159 seconds
bmm iteration 1 took 0.0154 seconds
bmm iteration 2 took 0.0153 seconds
bmm iteration 3 took 0.0157 seconds
bmm iteration 4 took 0.0152 seconds
bmm iteration 5 took 0.0157 seconds
bmm iteration 6 took 0.0153 seconds
bmm iteration 7 took 0.0154 seconds
bmm iteration 8 took 0.0158 seconds
bmm iteration 9 took 0.0154 seconds
10 iterations of bmm took 0.1551 seconds; avg 0.0155 secs
baddbmm iteration 0 took 0.0539 seconds
baddbmm iteration 1 took 0.0110 seconds
baddbmm iteration 2 took 0.0108 seconds
baddbmm iteration 3 took 0.0107 seconds
baddbmm iteration 4 took 0.0108 seconds
baddbmm iteration 5 took 0.0108 seconds
baddbmm iteration 6 took 0.0108 seconds
baddbmm iteration 7 took 0.0107 seconds
baddbmm iteration 8 took 0.0108 seconds
baddbmm iteration 9 took 0.0112 seconds
10 iterations of baddbmm took 0.1513 seconds; avg 0.0151 secs

comparing best iteration (.0213s vs .0107s):
baddbmm is 99% faster than einsum 0?

@Birch-san
Copy link
Author

pytorch 1.12.1

einsum 0 iteration 0 took 0.1395 seconds
einsum 0 iteration 1 took 0.0335 seconds
einsum 0 iteration 2 took 0.0154 seconds
einsum 0 iteration 3 took 0.0152 seconds
einsum 0 iteration 4 took 0.0157 seconds
einsum 0 iteration 5 took 0.0152 seconds
einsum 0 iteration 6 took 0.0153 seconds
einsum 0 iteration 7 took 0.0153 seconds
einsum 0 iteration 8 took 0.0151 seconds
einsum 0 iteration 9 took 0.0152 seconds
10 iterations of einsum 0 took 0.2952 seconds; avg 0.0295 secs
einsum 0 transposed k iteration 0 took 0.0153 seconds
einsum 0 transposed k iteration 1 took 0.0152 seconds
einsum 0 transposed k iteration 2 took 0.0154 seconds
einsum 0 transposed k iteration 3 took 0.0150 seconds
einsum 0 transposed k iteration 4 took 0.0152 seconds
einsum 0 transposed k iteration 5 took 0.0154 seconds
einsum 0 transposed k iteration 6 took 0.0155 seconds
einsum 0 transposed k iteration 7 took 0.0155 seconds
einsum 0 transposed k iteration 8 took 0.0151 seconds
einsum 0 transposed k iteration 9 took 0.0156 seconds
10 iterations of einsum 0 transposed k took 0.1532 seconds; avg 0.0153 secs
matmul iteration 0 took 0.0183 seconds
matmul iteration 1 took 0.0159 seconds
matmul iteration 2 took 0.0158 seconds
matmul iteration 3 took 0.0159 seconds
matmul iteration 4 took 0.0155 seconds
matmul iteration 5 took 0.0155 seconds
matmul iteration 6 took 0.0158 seconds
matmul iteration 7 took 0.0153 seconds
matmul iteration 8 took 0.0158 seconds
matmul iteration 9 took 0.0153 seconds
10 iterations of matmul took 0.1591 seconds; avg 0.0159 secs
bmm iteration 0 took 0.0154 seconds
bmm iteration 1 took 0.0151 seconds
bmm iteration 2 took 0.0152 seconds
bmm iteration 3 took 0.0153 seconds
bmm iteration 4 took 0.0149 seconds
bmm iteration 5 took 0.0153 seconds
bmm iteration 6 took 0.0150 seconds
bmm iteration 7 took 0.0152 seconds
bmm iteration 8 took 0.0150 seconds
bmm iteration 9 took 0.0152 seconds
10 iterations of bmm took 0.1516 seconds; avg 0.0152 secs
baddbmm iteration 0 took 0.0132 seconds
baddbmm iteration 1 took 0.0104 seconds
baddbmm iteration 2 took 0.0105 seconds
baddbmm iteration 3 took 0.0107 seconds
baddbmm iteration 4 took 0.0108 seconds
baddbmm iteration 5 took 0.0105 seconds
baddbmm iteration 6 took 0.0105 seconds
baddbmm iteration 7 took 0.0106 seconds
baddbmm iteration 8 took 0.0105 seconds
baddbmm iteration 9 took 0.0106 seconds
10 iterations of baddbmm took 0.1082 seconds; avg 0.0108 secs

@Birch-san
Copy link
Author

Birch-san commented Nov 5, 2022

8 Heun steps

1.12.1

einsum:
9.688710416987306

baddbmm:
9.598911916022189

1.14.0.dev20221103
einsum:
10.383701542014023

baddbmm
9.281007582991151

@damian0815
Copy link

thanks for this. similar results here (base M1 16GB)
torch 1.12.1:

einsum 0 iteration 0 took 0.1881 seconds
einsum 0 iteration 1 took 0.0805 seconds
einsum 0 iteration 2 took 0.0803 seconds
einsum 0 iteration 3 took 0.0799 seconds
einsum 0 iteration 4 took 0.0802 seconds
einsum 0 iteration 5 took 0.0831 seconds
einsum 0 iteration 6 took 0.0817 seconds
einsum 0 iteration 7 took 0.0814 seconds
einsum 0 iteration 8 took 0.0810 seconds
einsum 0 iteration 9 took 0.0805 seconds
10 iterations of einsum 0 took 0.9165 seconds; avg 0.0917 secs
einsum 0 transposed k iteration 0 took 0.0811 seconds
einsum 0 transposed k iteration 1 took 0.0810 seconds
einsum 0 transposed k iteration 2 took 0.0796 seconds
einsum 0 transposed k iteration 3 took 0.0803 seconds
einsum 0 transposed k iteration 4 took 0.0803 seconds
einsum 0 transposed k iteration 5 took 0.0809 seconds
einsum 0 transposed k iteration 6 took 0.0808 seconds
einsum 0 transposed k iteration 7 took 0.0801 seconds
einsum 0 transposed k iteration 8 took 0.0803 seconds
einsum 0 transposed k iteration 9 took 0.0805 seconds
10 iterations of einsum 0 transposed k took 0.8051 seconds; avg 0.0805 secs
matmul iteration 0 took 0.0996 seconds
matmul iteration 1 took 0.0830 seconds
matmul iteration 2 took 0.0823 seconds
matmul iteration 3 took 0.0844 seconds
matmul iteration 4 took 0.0856 seconds
matmul iteration 5 took 0.0834 seconds
matmul iteration 6 took 0.0823 seconds
matmul iteration 7 took 0.0826 seconds
matmul iteration 8 took 0.0828 seconds
matmul iteration 9 took 0.0821 seconds
10 iterations of matmul took 0.8483 seconds; avg 0.0848 secs
bmm iteration 0 took 0.0794 seconds
bmm iteration 1 took 0.0796 seconds
bmm iteration 2 took 0.0798 seconds
bmm iteration 3 took 0.0812 seconds
bmm iteration 4 took 0.0800 seconds
bmm iteration 5 took 0.0805 seconds
bmm iteration 6 took 0.0894 seconds
bmm iteration 7 took 0.0887 seconds
bmm iteration 8 took 0.0802 seconds
bmm iteration 9 took 0.0799 seconds
10 iterations of bmm took 0.8188 seconds; avg 0.0819 secs
baddbmm iteration 0 took 0.0535 seconds
baddbmm iteration 1 took 0.0456 seconds
baddbmm iteration 2 took 0.0457 seconds
baddbmm iteration 3 took 0.0456 seconds
baddbmm iteration 4 took 0.0454 seconds
baddbmm iteration 5 took 0.0451 seconds
baddbmm iteration 6 took 0.0458 seconds
baddbmm iteration 7 took 0.0458 seconds
baddbmm iteration 8 took 0.0457 seconds
baddbmm iteration 9 took 0.0453 seconds
10 iterations of baddbmm took 0.4635 seconds; avg 0.0463 secs

torch 1.14.0.dev20221121:

einsum 0 iteration 0 took 0.6219 seconds
einsum 0 iteration 1 took 0.1123 seconds
einsum 0 iteration 2 took 0.1147 seconds
einsum 0 iteration 3 took 0.1134 seconds
einsum 0 iteration 4 took 0.1126 seconds
einsum 0 iteration 5 took 0.1126 seconds
einsum 0 iteration 6 took 0.1122 seconds
einsum 0 iteration 7 took 0.1137 seconds
einsum 0 iteration 8 took 0.1137 seconds
einsum 0 iteration 9 took 0.1126 seconds
10 iterations of einsum 0 took 1.6395 seconds; avg 0.1640 secs
einsum 0 transposed k iteration 0 took 0.1133 seconds
einsum 0 transposed k iteration 1 took 0.1125 seconds
einsum 0 transposed k iteration 2 took 0.1123 seconds
einsum 0 transposed k iteration 3 took 0.1128 seconds
einsum 0 transposed k iteration 4 took 0.1128 seconds
einsum 0 transposed k iteration 5 took 0.1127 seconds
einsum 0 transposed k iteration 6 took 0.1129 seconds
einsum 0 transposed k iteration 7 took 0.1128 seconds
einsum 0 transposed k iteration 8 took 0.1131 seconds
einsum 0 transposed k iteration 9 took 0.1126 seconds
10 iterations of einsum 0 transposed k took 1.1277 seconds; avg 0.1128 secs
matmul iteration 0 took 0.0979 seconds
matmul iteration 1 took 0.0800 seconds
matmul iteration 2 took 0.0808 seconds
matmul iteration 3 took 0.0812 seconds
matmul iteration 4 took 0.0794 seconds
matmul iteration 5 took 0.0792 seconds
matmul iteration 6 took 0.0794 seconds
matmul iteration 7 took 0.0796 seconds
matmul iteration 8 took 0.0790 seconds
matmul iteration 9 took 0.0791 seconds
10 iterations of matmul took 0.8157 seconds; avg 0.0816 secs
bmm iteration 0 took 0.0765 seconds
bmm iteration 1 took 0.0956 seconds
bmm iteration 2 took 0.0757 seconds
bmm iteration 3 took 0.0763 seconds
bmm iteration 4 took 0.0773 seconds
bmm iteration 5 took 0.0764 seconds
bmm iteration 6 took 0.0768 seconds
bmm iteration 7 took 0.0768 seconds
bmm iteration 8 took 0.0767 seconds
bmm iteration 9 took 0.0763 seconds
10 iterations of bmm took 0.7844 seconds; avg 0.0784 secs
baddbmm iteration 0 took 0.4508 seconds
baddbmm iteration 1 took 0.0462 seconds
baddbmm iteration 2 took 0.0442 seconds
baddbmm iteration 3 took 0.0445 seconds
baddbmm iteration 4 took 0.0442 seconds
baddbmm iteration 5 took 0.0453 seconds
baddbmm iteration 6 took 0.0450 seconds
baddbmm iteration 7 took 0.0452 seconds
baddbmm iteration 8 took 0.0448 seconds
baddbmm iteration 9 took 0.0453 seconds
10 iterations of baddbmm took 0.8555 seconds; avg 0.0855 secs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment