Created
January 25, 2024 03:07
-
-
Save HDCharles/3c00ec8a02b58983195bb2a3a3015c55 to your computer and use it in GitHub Desktop.
compare bitsandbytes with torchao
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
###################################################################### | |
# Comparing Torchao # | |
# and BitsandBytes # | |
###################################################################### | |
# Set up Your Environment | |
# -------------------------------- | |
# | |
# First, let's configure your environment. This guide requires you to use CUDA 12.1. | |
# We have run this tutorial on an A100-PG509-200 power limited to 330.00 W. If you | |
# are using a different hardware, you might see different performance numbers. | |
# | |
# | |
# .. code-block:: bash | |
# | |
# > conda create -n myenv python=3.10 | |
# > pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 | |
# > pip install git+https://github.com/pytorch-labs/ao.git | |
# | |
# | |
# This was run on an A100-PG509-200 power limited to 330.00 W | |
import bitsandbytes | |
import torch | |
from torch.utils.benchmark import Timer | |
from torchao.quantization import ( | |
change_linear_weights_to_int4_woqtensors, | |
change_linear_weights_to_int8_dqtensors, | |
change_linear_weights_to_int8_woqtensors, | |
) | |
torch._inductor.config.use_mixed_mm = True | |
@torch.no_grad() | |
def benchmark(f, *args, **kwargs): | |
for _ in range(3): | |
f(*args, **kwargs) | |
torch.cuda.synchronize() | |
torch.cuda.reset_peak_memory_stats() | |
t0 = Timer( | |
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} | |
) | |
res = t0.blocked_autorange() | |
return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9} | |
# i, j, k | |
shapes = [ | |
(78400, 1280, 3840, "SAM"), | |
(78400, 1280, 1280, "SAM"), | |
(65536, 1280, 5120, "SAM"), | |
(65536, 5120, 1280, "SAM"), | |
(65536, 1280, 3840, "SAM"), | |
(65536, 1280, 1280, "SAM"), | |
(1, 4096, 4096, "LLAMA"), | |
(1, 4096, 11008, "LLAMA"), | |
(1, 11008, 4096, "LLAMA"), | |
(1, 4096, 12288, "LLAMA"), | |
(1, 4096, 32000, "LLAMA"), | |
] | |
for i,j,k,m in shapes: | |
bias = (m == "SAM") | |
res={} | |
image = torch.randn(i, j, device='cuda').to(torch.bfloat16) | |
if m == "SAM": | |
lin=torch.nn.Linear(j,k, bias=bias).to(torch.bfloat16).cuda() | |
change_linear_weights_to_int8_dqtensors(lin) | |
lin_c = torch.compile(lin, mode='max-autotune') | |
res["ao-int8dq-c"] = benchmark(lin_c, image) | |
del lin, lin_c | |
image=image.to(torch.float16) | |
lin = bitsandbytes.nn.Linear8bitLt(j,k, bias=bias, has_fp16_weights=False).to(0).cuda() | |
res["bb-int8"] = benchmark(lin, image) | |
del lin | |
if m == "LLAMA": | |
lin=torch.nn.Linear(j,k, bias=bias).to(torch.bfloat16).cuda() | |
change_linear_weights_to_int4_woqtensors(lin, groupsize=64) | |
res["ao-int4wo"] = benchmark(lin, image) | |
del lin | |
lin=torch.nn.Linear(j,k, bias=bias).to(torch.bfloat16).cuda() | |
change_linear_weights_to_int8_woqtensors(lin) | |
lin_c = torch.compile(lin, mode='max-autotune') | |
res["ao-int8wo-c"] = benchmark(lin_c, image) | |
del lin, lin_c | |
image=image.to(torch.float16) | |
lin = bitsandbytes.nn.Linear4bit(j,k, bias=bias, device='cuda').cuda() | |
res["bb-int4"] = benchmark(lin, image) | |
del lin | |
image=image.to(torch.bfloat16) | |
lin=torch.nn.Linear(j,k, bias=bias).to(torch.bfloat16).cuda() | |
res["bf16"] = benchmark(lin, image) | |
lin_c = torch.compile(lin, mode='max-autotune') | |
res["bf16-c"] = benchmark(lin_c, image) | |
del lin, lin_c | |
perf = "perf:" | |
mem = "mem: " | |
for key, res_data in res.items(): | |
perf += f"{key}({res_data['time']:0.2f}ms) " | |
mem += f"{key}({res_data['memory']:0.2f}GB) " | |
print(f"for shape {i,j,k}, with bias, from model {m}\n{perf}\n{mem}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Raw Output
for shape (78400, 1280, 3840), with bias, from model SAM
perf:ao-int8dq-c(3.78ms) bb-int8(6.17ms) bf16(3.55ms) bf16-c(3.77ms)
mem: ao-int8dq-c(1.07GB) bb-int8(2.32GB) bf16(0.83GB) bf16-c(1.08GB)
for shape (78400, 1280, 1280), with bias, from model SAM
perf:ao-int8dq-c(1.66ms) bb-int8(2.98ms) bf16(1.22ms) bf16-c(1.43ms)
mem: ao-int8dq-c(1.02GB) bb-int8(1.12GB) bf16(0.43GB) bf16-c(0.83GB)
for shape (65536, 1280, 5120), with bias, from model SAM
perf:ao-int8dq-c(6.45ms) bb-int8(6.36ms) bf16(3.98ms) bf16-c(4.17ms)
mem: ao-int8dq-c(1.47GB) bb-int8(2.47GB) bf16(0.89GB) bf16-c(1.48GB)
for shape (65536, 5120, 1280), with bias, from model SAM
perf:ao-int8dq-c(4.00ms) bb-int8(7.16ms) bf16(3.66ms) bf16-c(4.41ms)
mem: ao-int8dq-c(1.72GB) bb-int8(2.23GB) bf16(0.91GB) bf16-c(1.58GB)
for shape (65536, 1280, 3840), with bias, from model SAM
perf:ao-int8dq-c(3.45ms) bb-int8(5.20ms) bf16(2.97ms) bf16-c(3.12ms)
mem: ao-int8dq-c(1.41GB) bb-int8(2.00GB) bf16(0.75GB) bf16-c(1.42GB)
for shape (65536, 1280, 1280), with bias, from model SAM
perf:ao-int8dq-c(1.50ms) bb-int8(2.56ms) bf16(1.02ms) bf16-c(1.24ms)
mem: ao-int8dq-c(1.08GB) bb-int8(1.00GB) bf16(0.42GB) bf16-c(1.09GB)
for shape (1, 4096, 4096), with bias, from model LLAMA
perf:ao-int4wo(0.07ms) ao-int8wo-c(0.12ms) bb-int4(0.11ms) bf16(0.02ms) bf16-c(0.10ms)
mem: ao-int4wo(0.09GB) ao-int8wo-c(0.77GB) bb-int4(0.11GB) bf16(0.14GB) bf16-c(0.80GB)
for shape (1, 4096, 11008), with bias, from model LLAMA
perf:ao-int4wo(0.07ms) ao-int8wo-c(0.13ms) bb-int4(0.10ms) bf16(0.06ms) bf16-c(0.10ms)
mem: ao-int4wo(0.15GB) ao-int8wo-c(0.85GB) bb-int4(0.20GB) bf16(0.27GB) bf16-c(0.94GB)
for shape (1, 11008, 4096), with bias, from model LLAMA
perf:ao-int4wo(0.08ms) ao-int8wo-c(0.15ms) bb-int4(0.10ms) bf16(0.06ms) bf16-c(0.10ms)
mem: ao-int4wo(0.29GB) ao-int8wo-c(0.98GB) bb-int4(0.34GB) bf16(0.41GB) bf16-c(1.07GB)
for shape (1, 4096, 12288), with bias, from model LLAMA
perf:ao-int4wo(0.07ms) ao-int8wo-c(0.12ms) bb-int4(0.10ms) bf16(0.07ms) bf16-c(0.11ms)
mem: ao-int4wo(0.43GB) ao-int8wo-c(1.12GB) bb-int4(0.48GB) bf16(0.56GB) bf16-c(1.22GB)
for shape (1, 4096, 32000), with bias, from model LLAMA
perf:ao-int4wo(0.07ms) ao-int8wo-c(0.22ms) bb-int4(0.13ms) bf16(0.17ms) bf16-c(0.17ms)
mem: ao-int4wo(0.62GB) ao-int8wo-c(1.35GB) bb-int4(0.76GB) bf16(0.95GB) bf16-c(1.62GB)