Skip to content

Instantly share code, notes, and snippets.

@vuiseng9
Last active September 11, 2023 05:27
Show Gist options
  • Save vuiseng9/be1f893498163b38a0bf84407c686f8a to your computer and use it in GitHub Desktop.
Save vuiseng9/be1f893498163b38a0bf84407c686f8a to your computer and use it in GitHub Desktop.

Setup

pip install transformers torch
git clone https://huggingface.co/EleutherAI/gpt-j-6b # depends on git-lfs 

Run following as python script

from transformers import AutoTokenizer, pipeline

model_dir=<path to cloned model>
tokenizer = AutoTokenizer.from_pretrained(model_dir)
generator_pipe = pipeline('text-generation', model=model_dir, tokenizer=tokenizer)
output = generator_pipe("I love the Avengers", max_length=30, num_return_sequences=1)
print(output)

For large model, the pytorch checkpoint will be split into multiple files during cloning, Running a pipeline using it will decompress them into one single pytorch_model.bin

Loading weight (python)

import torch

sd=torch.load("/path/to/pytorch.bin")

# Number of parameters by layer
for k, v in sd.items():
    print(f'{v.numel():10} | {k}')
@vuiseng9
Copy link
Author

vuiseng9 commented Jun 29, 2023

8-bit quantized LLAMA

git clone https://huggingface.co/mit-han-lab/opt-30b-smoothquant see model list here and change accordingly. , official+paper.

3/4-bit quantized LLM torch checkpoint

git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo see model list here, official+paper
The cloned files are the optimized quantization hyperparameters of respective model,.
To realize the weight quantization, follow step 3 of the usage section.

@vuiseng9
Copy link
Author

vuiseng9 commented Jul 6, 2023

Weight Distribution Plot (Histogram)

# assuming weight_tensor of torch is defined

import matplotlib.pyplot as plt

nbin=50
hist = torch.histc(weight_tensor, bins=nbin, min=weight_tensor.min(), max=weight_tensor.max())

x = range(nbin)
plt.bar(x, hist, align='center')
plt.xlabel('Bins')

do play with bins, min, max to get to the viz you like :)

@vuiseng9
Copy link
Author

vuiseng9 commented Jul 7, 2023

better way to plot

because x axis shows actual value

counts, bins= torch.histogram(weight_tensor, bins=100)
plt.hist(bins[:-1], bins, weights=counts)
# or
plt.stairs(counts, bins)

@vuiseng9
Copy link
Author

quantization function

def quantize(tensor, precision: int = 8, return_numpy = True):
    # per tensor symmetric quantization with full dynamic range
    
    if not isinstance(precision, int):
        raise ValueError("precision must be of type integer")
        
    n_step = 2 ** (precision - 1) - 1
    
    if isinstance(tensor, torch.Tensor):
        range_min = tensor.min()
        range_max = tensor.max()
        
        scale = torch.max(range_min.abs(), range_max.abs()) / n_step
        
        qtensor = (tensor/scale).to(torch.int)
    else:
        raise NotImplementedError("Unsupported tensor type")
        
    if return_numpy is True:
        return qtensor.numpy()
    return qtensor
quantized_tensor = quantize(weight_tensor)
# print(weight_tensor)
# >> tensor([0.9448, 0.9710, 0.8911,  ..., 0.8735, 0.9124, 0.9592])
# print(quantized_tensor)
# >> array([116, 120, 110, ..., 108, 112, 118], dtype=int32)

quantize(weight_tensor, precision=4)
# >> array([6, 6, 6, ..., 5, 6, 6], dtype=int32)

@vuiseng9
Copy link
Author

Setup

pip install transformers
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

Snippet to retrieve SmoothQuant-ed Weights

import torch
from collections import OrderedDict
from glob import glob

model_dir = "/data1/vchua/hf-model/opt-13b-smoothquant"

sd = dict()

# large models would have sharded binary
ptbins = sorted(glob(f"{model_dir}/pytorch_model*.bin"))
for ptbin in ptbins:
    sd.update(torch.load(ptbin, map_location=torch.device('cpu')))

for k, tensor in sd.items():
    if "weight" in k and "layer_norm" not in k:
        print(k) 
        # tensor is weight tensor of interest

SmoothQuant Models

https://huggingface.co/mit-han-lab
├── opt-13b-smoothquant
├── opt-30b-smoothquant
├── opt-66b-smoothquant
└── opt-6.7b-smoothquant

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment