Skip to content

Instantly share code, notes, and snippets.

@fxmarty
Created July 17, 2024 14:49
Show Gist options
  • Save fxmarty/1aff830cdd57aa650412f34bd4076b3b to your computer and use it in GitHub Desktop.
Save fxmarty/1aff830cdd57aa650412f34bd4076b3b to your computer and use it in GitHub Desktop.
profile quanto
import torch
import torch.nn as nn
from optimum.quanto import Calibration, freeze, qint4, qint8, quantize, qfloat8, qfloat8_e4m3fn
from torch.profiler import ProfilerActivity, profile
M_SHAPE = 4096
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(M_SHAPE, 4096, bias=False)
def forward(self, x):
relu_inp = self.lin1(x)
return relu_inp
def keyword_to_itype(k):
return {"none": None, "int4": qint4, "int8": qint8, "float8": qfloat8, "float8_e4m3fn": qfloat8_e4m3fn}[k]
model = MyModel().to(torch.float16)
model = model.eval()
device = "cuda"
seed = 42
weights = "float8_e4m3fn"
activations = "none"
torch.manual_seed(seed)
device = torch.device("cuda")
model = model.to(device)
original_weight = model.lin1.weight.data.clone()
print("Float model")
weights = keyword_to_itype(weights)
activations = keyword_to_itype(activations)
print("------ QUANTIZING")
quantize(model, weights=weights, activations=activations)
print("------ FREEZING")
freeze(model)
print(f"Quantized model (w: {weights}, a: {activations})")
print("--------- INFERENCE")
inp = torch.rand(1, M_SHAPE, dtype=torch.float16).to(device)
def run_linear_marlin(inp, weight):
workspace = weight._workspace
scale = weight._scale
input_flat = inp.view(-1, inp.shape[-1])
out = torch.ops.quanto_ext.fp8_marlin(
input_flat,
b_q_weight=weight._data,
b_scales=scale.to(input_flat.dtype),
workspace=weight._workspace,
num_bits=8,
size_m=input_flat.shape[0],
size_n=scale.shape[1],
size_k=input_flat.shape[1],
)
return out.reshape(inp.shape[:-1] + (scale.shape[1],))
def run_native_linear(inp, weight):
return torch.nn.functional.linear(inp, weight)
with torch.no_grad():
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
for _ in range(10):
res = model(inp)
res = run_linear_marlin(inp, model.lin1.weight)
res = run_native_linear(inp, original_weight)
prof.export_chrome_trace("trace.json")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment