Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Created August 13, 2022 18:00
Show Gist options
  • Save younesbelkada/9035e247b066d1cf18682e9e4c21032d to your computer and use it in GitHub Desktop.
Save younesbelkada/9035e247b066d1cf18682e9e4c21032d to your computer and use it in GitHub Desktop.
A minimal script to run `bitsandbytes` int8 inference
import torch
import torch.nn as nn
from bitsandbytes.nn import Linear8bitLt
# Utility function
def get_model_memory_footprint(model):
r"""
Partially copied and inspired from: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
"""
return sum([param.nelement() * param.element_size() for param in model.parameters()])
# Main script
fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
).to(torch.float16)
# Train and save your model!
torch.save(fp16_model.state_dict(), "model.pt")
# Define your int8 model!
int8_model = nn.Sequential(
Linear8bitLt(64, 64, has_fp16_weights=False),
Linear8bitLt(64, 64, has_fp16_weights=False)
)
int8_model.load_state_dict(torch.load("model.pt"))
int8_model = int8_model.to(0) # Quantization happens here
input_ = torch.randn(8, 64, dtype=torch.float16)
hidden_states = int8_model(input_.to(0))
mem_int8 = get_model_memory_footprint(int8_model)
mem_fp16 = get_model_memory_footprint(fp16_model)
print(f"Relative difference: {mem_fp16/mem_int8}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment