Last active
June 21, 2024 22:41
-
-
Save Chillee/42e4635c59760a74cb3b4ba7ea5ad9f8 to your computer and use it in GitHub Desktop.
Strangely, Matrix Multiplications Run Faster When Given "Predictable" Data!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
torch.set_default_device('cuda') | |
from triton.testing import do_bench | |
from collections import defaultdict | |
from functools import partial | |
import random | |
random.seed(0) | |
def get_flops(A, B): | |
ms = do_bench(lambda: torch.mm(A, B)) | |
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2 | |
return (1e3/ms) * flops | |
M = 8192 | |
N = 8192 | |
K = 8192 | |
def get_tensors(f): | |
A = f(M, K, dtype=torch.bfloat16) | |
B = f(N, K, dtype=torch.bfloat16).t() | |
return A, B | |
def one_bit_random(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype) | |
return x | |
def sparse(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = torch.where(x < 0, 0, x) | |
return x | |
original_setups = [ | |
("randn", torch.randn), | |
("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)), | |
("sparse", sparse), | |
("one bit", one_bit_random), | |
("rand", torch.rand), | |
("zeros", torch.zeros), | |
] | |
results = defaultdict(list) | |
setups = list(original_setups) | |
ITERS = 10 | |
for _ in range(ITERS): | |
random.shuffle(setups) | |
for name, f in setups: | |
results[name].append(get_flops(*get_tensors(f))) | |
def median(x): | |
x = sorted(x) | |
if len(x) % 2 == 0: | |
return (x[len(x)//2] + x[(len(x) - 1)//2])/2 | |
else: | |
return x[len(x)//2] | |
for name, _ in original_setups: | |
print(f"{name}: {median(results[name])/1e12}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
torch.set_default_device('cuda') | |
from triton.testing import do_bench | |
from collections import defaultdict | |
from functools import partial | |
import random | |
import subprocess | |
random.seed(0) | |
def set_gpu_limits(ref_sm_clock=1810, power_limit=330): | |
subprocess.check_output([ | |
"sudo", | |
"nvidia-smi", | |
"-i", | |
"0", | |
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", | |
]) | |
subprocess.check_output([ | |
"sudo", | |
"nvidia-smi", | |
"-i", | |
"0", | |
f"-pl={power_limit}", | |
]) | |
def get_flops(A, B): | |
ms = do_bench(lambda: torch.mm(A, B)) | |
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2 | |
return (1e3/ms) * flops | |
M = 8192 | |
N = 8192 | |
K = 8192 | |
def get_tensors(f): | |
A = f(M, K, dtype=torch.bfloat16) | |
B = f(N, K, dtype=torch.bfloat16).t() | |
return A, B | |
def one_bit_random(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype) | |
return x | |
def sparse(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = torch.where(torch.rand_like(x) > 0.1, 0, x) | |
return x | |
def checkerboard(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = torch.where((torch.arange(shape[0]).view(1, -1) - torch.arange(shape[1]).view(-1, 1)) % 2 == 0, x, 0) | |
return x | |
def ternary(*shape, dtype=torch.bfloat16): | |
x = torch.randint(low=-1, high=2, size=shape, dtype=torch.bfloat16) | |
return x | |
original_setups = [ | |
# ("zeros", torch.zeros), | |
("randn", torch.randn), | |
# ("checkerboard", checkerboard), | |
# ("sparse", sparse), | |
# ("rand", torch.rand), | |
# ("ternary", ternary), | |
# ("one bit", one_bit_random), | |
# ("all_pi", lambda *shape, dtype: torch.full(shape, fill_value=3.1415926535897932384626, dtype=dtype)), | |
# ("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)), | |
] | |
def get_results(clocks, power): | |
set_gpu_limits(clocks, power) | |
results = defaultdict(list) | |
setups = list(original_setups) | |
ITERS = 10 | |
for _ in range(ITERS): | |
random.shuffle(setups) | |
for name, f in setups: | |
results[name].append(get_flops(*get_tensors(f))) | |
def median(x): | |
x = sorted(x) | |
if len(x) % 2 == 0: | |
return (x[len(x)//2] + x[(len(x) - 1)//2])/2 | |
else: | |
return x[len(x)//2] | |
# for name, _ in original_setups: | |
# print(f"{name}: {median(results[name])/1e12}") | |
# print(median(results['zeros']) / median(results["randn"])) | |
return median(results['randn']) | |
start_clocks = 1980 # H100 | |
for power in reversed([150, 200, 250, 300, 350, 400, 450, 500]): | |
max_clocks = 1980 # H100 | |
start_flops = get_results(max_clocks, power) | |
for clocks in range(start_clocks, 200, -100): | |
# print(power, clocks) | |
cur_flops = get_results(clocks, power) | |
if cur_flops < start_flops * 0.9: | |
print("Done: ", power, clocks) | |
start_clocks = clocks | |
break |
Author
Chillee
commented
Apr 28, 2024
![image](https://private-user-images.githubusercontent.com/6355099/326289577-32d77672-fc8e-4151-acb5-1f9b1ee187b4.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjE3MjM2MjcsIm5iZiI6MTcyMTcyMzMyNywicGF0aCI6Ii82MzU1MDk5LzMyNjI4OTU3Ny0zMmQ3NzY3Mi1mYzhlLTQxNTEtYWNiNS0xZjliMWVlMTg3YjQucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI0MDcyMyUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNDA3MjNUMDgyODQ3WiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9ZTBmNWJjOGNiZmU3N2ZkYmRlYTUzZjViMmZlYTE0NjliMTI2ZmI3MGU3YjI0MWNkMzU4NTg0MTEzMWRlNGQ1YiZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QmYWN0b3JfaWQ9MCZrZXlfaWQ9MCZyZXBvX2lkPTAifQ.7bFsoBcMA4fP0HdtV1IUKaxZlXtMCSck8DGSDQXeUOA)
Hi! Thanks for an amazing post! I've run the mm_weird.py
benchmark w/ H100 and I get the following results:
Run 1:
randn: 1024.3100185282144
twos: 803.3550678054524
sparse: 1086.47683669488
one bit: 830.4096678480972
rand: 837.8445385632689
zeros: 810.90379017118
Run 2:
randn: 1020.114801596814
twos: 803.5072206112413
sparse: 1060.8216568964108
one bit: 828.119089454572
rand: 832.9280217508104
zeros: 815.2949820775259
Run 3:
randn: 1015.1157728697485
twos: 808.3138761162128
sparse: 1074.7391939180266
one bit: 835.835139020573
rand: 836.061297508501
zeros: 812.8299565166335
I don't know what's even more estrange, getting +1000 TFLOPs or getting opposite results...
Toni
PD: I changed L28 with x = torch.randn(*shape, dtype=dtype)
@TJ-Solergibert That's indeed quite strange 🤔 In particular, you also see very high FLOPs for randn
compared to zeros
.
And I got the same behavior with 80GB A100…
@TJ-Solergibert Can you show your nvidia-smi
?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment