Skip to content

Instantly share code, notes, and snippets.

@Lawrencium77
Last active December 29, 2023 09:30
Show Gist options
  • Save Lawrencium77/5a21217086ca00876b369e05baa6cd43 to your computer and use it in GitHub Desktop.
Save Lawrencium77/5a21217086ca00876b369e05baa6cd43 to your computer and use it in GitHub Desktop.
Simple SmoothQuant Implementation
"""
SmoothQuant implementation. See: https://arxiv.org/pdf/2211.10438.pdf
Some details are model-specific, so the code may need tweaking.
"""
import functools
import torch
from torch import nn, Tensor
from typing import Dict, Iterable, Tuple
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#####################################################################################################
# Compute max statistics for activations and weights
def get_tensor_max(weight: Tensor) -> Tensor:
"""Compute row-wise max of a 2D tensor"""
return torch.max(weight.abs(), dim=0).values
def combine_qkv_maxs(model: nn.Module, maxs: Dict[str, torch.Tensor]) -> None:
"""Takes the max of the Q/K/V weight maxs for each layer"""
# Group Q/K/V weight maxs by layer
qkv_names = []
for i in range(0, model.n_layers):
names = []
for name in maxs.keys():
layer_num = name.split(".")[2]
if int(layer_num) == i:
if any(n in name for n in ["query", "key", "value"]):
names.append(name)
qkv_names.append(names)
# Combine Q/K/V weight maxs for each layer
for qkv_name in qkv_names:
q_name, k_name, v_name = qkv_name
q, k, v = maxs[q_name], maxs[k_name], maxs[v_name]
max_tensor = torch.max(torch.max(q, k), v)
maxs[q_name] = maxs[k_name] = maxs[v_name] = max_tensor
def get_weight_maxs(model: nn.Module) -> Dict[str, torch.Tensor]:
"""Calculate row-wise max of weights for Linear layers following a LayerNorm"""
weight_maxs = {}
module_names = ["query", "key", "value", "ff.0"]
for name, module in model.named_modules():
if any(n in name for n in module_names) and isinstance(module, nn.Linear):
weight_maxs[name] = get_tensor_max(module.weight)
combine_qkv_maxs(model, weight_maxs)
return weight_maxs
def get_actv_maxs(model: nn.Module, datastream: Iterable, num_batches: int = 10) -> Dict[str, torch.Tensor]:
"""Calculate channel-wise max activations for LayerNorm layers"""
actv_maxs = {}
def update_actv_stats(name: str, input: Tensor) -> None:
"""Update max value"""
x = torch.flatten(input, start_dim=0, end_dim=1)
maxs = torch.max(x.abs(), dim=0).values
if name in actv_maxs:
actv_maxs[name] = torch.max(actv_maxs[name], maxs)
else:
actv_maxs[name] = maxs
def actv_max_hook(model: nn.Module, input: Tuple[torch.Tensor, ...], output: Tensor, name: str) -> None:
"""Forward hook"""
if isinstance(output, tuple):
output = output[0]
update_actv_stats(name, output)
# Register hooks
hooks = []
for name, module in model.named_modules():
if isinstance(module, nn.LayerNorm):
hooks.append(module.register_forward_hook(functools.partial(actv_max_hook, name=name)))
# Collect activation stats on a few batches
for i, batch in enumerate(datastream):
data = batch[0]
model(data.to(device), chunk_size=20000)
if i >= num_batches:
break
# Remove hooks
for h in hooks:
h.remove()
return actv_maxs
#####################################################################################################
# Apply smoothing
def smoothing_from_maxs(actv_maxs, weight_maxs, alpha=0.75):
"""Calculate smoothing factors for individual weight & actv max tensor"""
return (actv_maxs ** alpha) / (weight_maxs ** (1 - alpha))
def get_smoothing_factors(actv_maxs: Tensor, weight_maxs: Tensor, alpha: float = 0.75) -> Tensor:
"""Calculate smoothing factors, given actv & weight maxs"""
smoothing_factors = {}
for actv_name, actv_max in actv_maxs.items():
actv_layer_num = actv_name.split(".")[2]
for weight_name, weight_max in weight_maxs.items():
weight_layer_num = weight_name.split(".")[2]
if actv_layer_num == weight_layer_num:
if all("attention" in name for name in [actv_name, weight_name]):
smoothing_factors[actv_name] = smoothing_from_maxs(actv_max, weight_max, alpha) ** -1
smoothing_factors[weight_name] = smoothing_from_maxs(actv_max, weight_max, alpha)
elif all("feed_forward" in name for name in [actv_name, weight_name]):
smoothing_factors[actv_name] = smoothing_from_maxs(actv_max, weight_max, alpha) ** -1
smoothing_factors[weight_name] = smoothing_from_maxs(actv_max, weight_max, alpha)
return smoothing_factors
def smooth_parameters(model: nn.Module, smoothing_factors: Dict[str, Tensor]) -> None:
"""Apply smoothing to model parameters"""
for name, module in model.named_modules():
if name in smoothing_factors.keys():
factor = smoothing_factors[name]
module.weight = nn.Parameter(module.weight * factor)
if module.bias is not None:
module.bias = nn.Parameter(module.bias * factor)
#####################################################################################################
# Wrapper function
def apply_smoothquant(model: nn.Module, datastream: Iterable, alpha: float = 0.75) -> None:
"""
Applies SmoothQuant to the given model using the provided datastream.
Parameters:
- model: PyTorch model to which SmoothQuant will be applied.
- datastream: An iterable (e.g., DataLoader) that yields batches of input data for the model.
- alpha: The balance factor between activations and weights, as described in the paper.
"""
actv_maxs = get_actv_maxs(model, datastream=datastream, num_batches=10)
weight_maxs = get_weight_maxs(model)
smoothing_factors = get_smoothing_factors(actv_maxs, weight_maxs, alpha)
smooth_parameters(model, smoothing_factors)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment