Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active May 1, 2024 18:23
Show Gist options
  • Save Chillee/42e4635c59760a74cb3b4ba7ea5ad9f8 to your computer and use it in GitHub Desktop.
Save Chillee/42e4635c59760a74cb3b4ba7ea5ad9f8 to your computer and use it in GitHub Desktop.
Strangely, Matrix Multiplications Run Faster When Given "Predictable" Data!
import torch
torch.set_default_device('cuda')
from triton.testing import do_bench
from collections import defaultdict
from functools import partial
import random
random.seed(0)
def get_flops(A, B):
ms = do_bench(lambda: torch.mm(A, B))
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
return (1e3/ms) * flops
M = 8192
N = 8192
K = 8192
def get_tensors(f):
A = f(M, K, dtype=torch.bfloat16)
B = f(N, K, dtype=torch.bfloat16).t()
return A, B
def one_bit_random(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype)
return x
def sparse(*shape, dtype=torch.bfloat16):
x = torch.randn(*args, **kwargs)
x = torch.where(x < 0, 0, x)
return x
original_setups = [
("randn", torch.randn),
("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)),
("sparse", sparse),
("one bit", one_bit_random),
("rand", torch.rand),
("zeros", torch.zeros),
]
results = defaultdict(list)
setups = list(original_setups)
ITERS = 10
for _ in range(ITERS):
random.shuffle(setups)
for name, f in setups:
results[name].append(get_flops(*get_tensors(f)))
def median(x):
x = sorted(x)
if len(x) % 2 == 0:
return (x[len(x)//2] + x[(len(x) - 1)//2])/2
else:
return x[len(x)//2]
for name, _ in original_setups:
print(f"{name}: {median(results[name])/1e12}")
import torch
torch.set_default_device('cuda')
from triton.testing import do_bench
from collections import defaultdict
from functools import partial
import random
import subprocess
random.seed(0)
def set_gpu_limits(ref_sm_clock=1810, power_limit=330):
subprocess.check_output([
"sudo",
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
])
subprocess.check_output([
"sudo",
"nvidia-smi",
"-i",
"0",
f"-pl={power_limit}",
])
def get_flops(A, B):
ms = do_bench(lambda: torch.mm(A, B))
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
return (1e3/ms) * flops
M = 8192
N = 8192
K = 8192
def get_tensors(f):
A = f(M, K, dtype=torch.bfloat16)
B = f(N, K, dtype=torch.bfloat16).t()
return A, B
def one_bit_random(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype)
return x
def sparse(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = torch.where(torch.rand_like(x) > 0.1, 0, x)
return x
def checkerboard(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = torch.where((torch.arange(shape[0]).view(1, -1) - torch.arange(shape[1]).view(-1, 1)) % 2 == 0, x, 0)
return x
def ternary(*shape, dtype=torch.bfloat16):
x = torch.randint(low=-1, high=2, size=shape, dtype=torch.bfloat16)
return x
original_setups = [
# ("zeros", torch.zeros),
("randn", torch.randn),
# ("checkerboard", checkerboard),
# ("sparse", sparse),
# ("rand", torch.rand),
# ("ternary", ternary),
# ("one bit", one_bit_random),
# ("all_pi", lambda *shape, dtype: torch.full(shape, fill_value=3.1415926535897932384626, dtype=dtype)),
# ("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)),
]
def get_results(clocks, power):
set_gpu_limits(clocks, power)
results = defaultdict(list)
setups = list(original_setups)
ITERS = 10
for _ in range(ITERS):
random.shuffle(setups)
for name, f in setups:
results[name].append(get_flops(*get_tensors(f)))
def median(x):
x = sorted(x)
if len(x) % 2 == 0:
return (x[len(x)//2] + x[(len(x) - 1)//2])/2
else:
return x[len(x)//2]
# for name, _ in original_setups:
# print(f"{name}: {median(results[name])/1e12}")
# print(median(results['zeros']) / median(results["randn"]))
return median(results['randn'])
start_clocks = 1980 # H100
for power in reversed([150, 200, 250, 300, 350, 400, 450, 500]):
max_clocks = 1980 # H100
start_flops = get_results(max_clocks, power)
for clocks in range(start_clocks, 200, -100):
# print(power, clocks)
cur_flops = get_results(clocks, power)
if cur_flops < start_flops * 0.9:
print("Done: ", power, clocks)
start_clocks = clocks
break
@Chillee
Copy link
Author

Chillee commented Apr 28, 2024

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment