Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Created January 25, 2024 03:07
Show Gist options
  • Save HDCharles/3c00ec8a02b58983195bb2a3a3015c55 to your computer and use it in GitHub Desktop.
Save HDCharles/3c00ec8a02b58983195bb2a3a3015c55 to your computer and use it in GitHub Desktop.
compare bitsandbytes with torchao
######################################################################
# Comparing Torchao #
# and BitsandBytes #
######################################################################
# Set up Your Environment
# --------------------------------
#
# First, let's configure your environment. This guide requires you to use CUDA 12.1.
# We have run this tutorial on an A100-PG509-200 power limited to 330.00 W. If you
# are using a different hardware, you might see different performance numbers.
#
#
# .. code-block:: bash
#
# > conda create -n myenv python=3.10
# > pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
# > pip install git+https://github.com/pytorch-labs/ao.git
#
#
# This was run on an A100-PG509-200 power limited to 330.00 W
import bitsandbytes
import torch
from torch.utils.benchmark import Timer
from torchao.quantization import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
)
torch._inductor.config.use_mixed_mm = True
@torch.no_grad()
def benchmark(f, *args, **kwargs):
for _ in range(3):
f(*args, **kwargs)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
t0 = Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
res = t0.blocked_autorange()
return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}
# i, j, k
shapes = [
(78400, 1280, 3840, "SAM"),
(78400, 1280, 1280, "SAM"),
(65536, 1280, 5120, "SAM"),
(65536, 5120, 1280, "SAM"),
(65536, 1280, 3840, "SAM"),
(65536, 1280, 1280, "SAM"),
(1, 4096, 4096, "LLAMA"),
(1, 4096, 11008, "LLAMA"),
(1, 11008, 4096, "LLAMA"),
(1, 4096, 12288, "LLAMA"),
(1, 4096, 32000, "LLAMA"),
]
for i,j,k,m in shapes:
bias = (m == "SAM")
res={}
image = torch.randn(i, j, device='cuda').to(torch.bfloat16)
if m == "SAM":
lin=torch.nn.Linear(j,k, bias=bias).to(torch.bfloat16).cuda()
change_linear_weights_to_int8_dqtensors(lin)
lin_c = torch.compile(lin, mode='max-autotune')
res["ao-int8dq-c"] = benchmark(lin_c, image)
del lin, lin_c
image=image.to(torch.float16)
lin = bitsandbytes.nn.Linear8bitLt(j,k, bias=bias, has_fp16_weights=False).to(0).cuda()
res["bb-int8"] = benchmark(lin, image)
del lin
if m == "LLAMA":
lin=torch.nn.Linear(j,k, bias=bias).to(torch.bfloat16).cuda()
change_linear_weights_to_int4_woqtensors(lin, groupsize=64)
res["ao-int4wo"] = benchmark(lin, image)
del lin
lin=torch.nn.Linear(j,k, bias=bias).to(torch.bfloat16).cuda()
change_linear_weights_to_int8_woqtensors(lin)
lin_c = torch.compile(lin, mode='max-autotune')
res["ao-int8wo-c"] = benchmark(lin_c, image)
del lin, lin_c
image=image.to(torch.float16)
lin = bitsandbytes.nn.Linear4bit(j,k, bias=bias, device='cuda').cuda()
res["bb-int4"] = benchmark(lin, image)
del lin
image=image.to(torch.bfloat16)
lin=torch.nn.Linear(j,k, bias=bias).to(torch.bfloat16).cuda()
res["bf16"] = benchmark(lin, image)
lin_c = torch.compile(lin, mode='max-autotune')
res["bf16-c"] = benchmark(lin_c, image)
del lin, lin_c
perf = "perf:"
mem = "mem: "
for key, res_data in res.items():
perf += f"{key}({res_data['time']:0.2f}ms) "
mem += f"{key}({res_data['memory']:0.2f}GB) "
print(f"for shape {i,j,k}, with bias, from model {m}\n{perf}\n{mem}")
@HDCharles
Copy link
Author

Raw Output

for shape (78400, 1280, 3840), with bias, from model SAM
perf:ao-int8dq-c(3.78ms) bb-int8(6.17ms) bf16(3.55ms) bf16-c(3.77ms)
mem: ao-int8dq-c(1.07GB) bb-int8(2.32GB) bf16(0.83GB) bf16-c(1.08GB)
for shape (78400, 1280, 1280), with bias, from model SAM
perf:ao-int8dq-c(1.66ms) bb-int8(2.98ms) bf16(1.22ms) bf16-c(1.43ms)
mem: ao-int8dq-c(1.02GB) bb-int8(1.12GB) bf16(0.43GB) bf16-c(0.83GB)
for shape (65536, 1280, 5120), with bias, from model SAM
perf:ao-int8dq-c(6.45ms) bb-int8(6.36ms) bf16(3.98ms) bf16-c(4.17ms)
mem: ao-int8dq-c(1.47GB) bb-int8(2.47GB) bf16(0.89GB) bf16-c(1.48GB)
for shape (65536, 5120, 1280), with bias, from model SAM
perf:ao-int8dq-c(4.00ms) bb-int8(7.16ms) bf16(3.66ms) bf16-c(4.41ms)
mem: ao-int8dq-c(1.72GB) bb-int8(2.23GB) bf16(0.91GB) bf16-c(1.58GB)
for shape (65536, 1280, 3840), with bias, from model SAM
perf:ao-int8dq-c(3.45ms) bb-int8(5.20ms) bf16(2.97ms) bf16-c(3.12ms)
mem: ao-int8dq-c(1.41GB) bb-int8(2.00GB) bf16(0.75GB) bf16-c(1.42GB)
for shape (65536, 1280, 1280), with bias, from model SAM
perf:ao-int8dq-c(1.50ms) bb-int8(2.56ms) bf16(1.02ms) bf16-c(1.24ms)
mem: ao-int8dq-c(1.08GB) bb-int8(1.00GB) bf16(0.42GB) bf16-c(1.09GB)
for shape (1, 4096, 4096), with bias, from model LLAMA
perf:ao-int4wo(0.07ms) ao-int8wo-c(0.12ms) bb-int4(0.11ms) bf16(0.02ms) bf16-c(0.10ms)
mem: ao-int4wo(0.09GB) ao-int8wo-c(0.77GB) bb-int4(0.11GB) bf16(0.14GB) bf16-c(0.80GB)
for shape (1, 4096, 11008), with bias, from model LLAMA
perf:ao-int4wo(0.07ms) ao-int8wo-c(0.13ms) bb-int4(0.10ms) bf16(0.06ms) bf16-c(0.10ms)
mem: ao-int4wo(0.15GB) ao-int8wo-c(0.85GB) bb-int4(0.20GB) bf16(0.27GB) bf16-c(0.94GB)
for shape (1, 11008, 4096), with bias, from model LLAMA
perf:ao-int4wo(0.08ms) ao-int8wo-c(0.15ms) bb-int4(0.10ms) bf16(0.06ms) bf16-c(0.10ms)
mem: ao-int4wo(0.29GB) ao-int8wo-c(0.98GB) bb-int4(0.34GB) bf16(0.41GB) bf16-c(1.07GB)
for shape (1, 4096, 12288), with bias, from model LLAMA
perf:ao-int4wo(0.07ms) ao-int8wo-c(0.12ms) bb-int4(0.10ms) bf16(0.07ms) bf16-c(0.11ms)
mem: ao-int4wo(0.43GB) ao-int8wo-c(1.12GB) bb-int4(0.48GB) bf16(0.56GB) bf16-c(1.22GB)
for shape (1, 4096, 32000), with bias, from model LLAMA
perf:ao-int4wo(0.07ms) ao-int8wo-c(0.22ms) bb-int4(0.13ms) bf16(0.17ms) bf16-c(0.17ms)
mem: ao-int4wo(0.62GB) ao-int8wo-c(1.35GB) bb-int4(0.76GB) bf16(0.95GB) bf16-c(1.62GB)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment