Skip to content

Instantly share code, notes, and snippets.

@ImanHosseini
Created December 23, 2022 04:19
Show Gist options
  • Save ImanHosseini/ad922cc9c01e05bc16c59926b6f35fd9 to your computer and use it in GitHub Desktop.
Save ImanHosseini/ad922cc9c01e05bc16c59926b6f35fd9 to your computer and use it in GitHub Desktop.
SmoothQuant changes for GPT-J
import torch
import torch.nn as nn
from datasets import load_dataset
import functools
from collections import defaultdict
from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
from functools import partial
import numpy as np
from tqdm import tqdm
def get_act_scales(model, tokenizer, dataset_path, num_samples=10, seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = {}
def stat_tensor(name, tensor):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
if name in act_scales:
act_scales[name] = torch.max(act_scales[name], comming_max)
else:
act_scales[name] = comming_max
def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x)
hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear):
hooks.append(
m.register_forward_hook(
functools.partial(stat_input_hook, name=name))
)
# dataset = load_dataset("json", data_files=dataset_path, split="train")
dataset = load_dataset("lambada", split='validation[:512]')
dataset = dataset.shuffle(seed=42)
for i in tqdm(range(num_samples)):
input_ids = tokenizer(dataset[i]["text"], return_tensors="pt",
max_length=seq_len, truncation=True).input_ids.to(device)
model(input_ids)
for h in hooks:
h.remove()
return act_scales
@torch.no_grad()
def get_static_decoder_layer_scales(model,
tokenizer,
dataset_path,
num_samples=10,
seq_len=512,
):
model.eval()
device = next(model.parameters()).device
act_dict = defaultdict(dict)
def stat_io_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
if name not in act_dict or "input" not in act_dict[name]:
act_dict[name]["input"] = x.detach().abs().max().item()
else:
act_dict[name]["input"] = max(
act_dict[name]["input"], x.detach().abs().max().item())
if isinstance(y, tuple):
y = y[0]
if name not in act_dict or "output" not in act_dict[name]:
act_dict[name]["output"] = y.detach().abs().max().item()
else:
act_dict[name]["output"] = max(
act_dict[name]["output"], y.detach().abs().max().item())
hooks = []
for name, m in model.named_modules():
if isinstance(m, torch.nn.Linear):
hooks.append(m.register_forward_hook(
partial(stat_io_hook, name=name)))
print("Collecting activation scales...")
pbar = tqdm(range(num_samples))
# dataset = load_dataset('json', data_files=dataset_path, split="train")
dataset = load_dataset("lambada", split='validation[:512]')
dataset = dataset.shuffle(seed=42)
for i in pbar:
input_ids = tokenizer(dataset[i]["text"], return_tensors="pt",
max_length=seq_len, truncation=True).input_ids.to(device)
model(input_ids)
mean_scale = np.mean([v["input"] for v in act_dict.values()])
pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
for hook in hooks:
hook.remove()
decoder_layer_scales = []
if isinstance(model, OPTForCausalLM):
for idx in range(model.config.num_hidden_layers):
scale_dict = {}
scale_dict["attn_input_scale"] = act_dict[
f"model.decoder.layers.{idx}.self_attn.q_proj"]['input'] / 127
scale_dict["q_output_scale"] = act_dict[
f"model.decoder.layers.{idx}.self_attn.q_proj"]['output'] / 127
scale_dict["k_output_scale"] = act_dict[
f"model.decoder.layers.{idx}.self_attn.k_proj"]['output'] / 127
scale_dict["v_output_scale"] = act_dict[
f"model.decoder.layers.{idx}.self_attn.v_proj"]['output'] / 127
scale_dict["out_input_scale"] = act_dict[
f"model.decoder.layers.{idx}.self_attn.out_proj"]['input'] / 127
scale_dict["fc1_input_scale"] = act_dict[
f"model.decoder.layers.{idx}.fc1"]['input'] / 127
scale_dict["fc2_input_scale"] = act_dict[
f"model.decoder.layers.{idx}.fc2"]["input"] / 127
decoder_layer_scales.append(scale_dict)
elif isinstance(model, GPTJForCausalLM):
for idx in range(len(model.transformer.h)):
scale_dict = {}
scale_dict["attn_input_scale"] = act_dict[
f"transformer.h.{idx}.attn.q_proj"]['input'] / 127
scale_dict["q_output_scale"] = act_dict[
f"transformer.h.{idx}.attn.q_proj"]['output'] / 127
scale_dict["k_output_scale"] = act_dict[
f"transformer.h.{idx}.attn.k_proj"]['output'] / 127
scale_dict["v_output_scale"] = act_dict[
f"transformer.h.{idx}.attn.v_proj"]['output'] / 127
scale_dict["out_input_scale"] = act_dict[
f"transformer.h.{idx}.attn.out_proj"]['input'] / 127
scale_dict["fc1_input_scale"] = act_dict[
f"transformer.h.{idx}.mlp.fc_in"]['input'] / 127
scale_dict["fc2_input_scale"] = act_dict[
f"transformer.h.{idx}.mlp.fc_out"]["input"] / 127
decoder_layer_scales.append(scale_dict)
return decoder_layer_scales
import torch
import argparse
import os
import json
from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
from transformers import AutoTokenizer
from smoothquant.opt import Int8OPTForCausalLM
from smoothquant.gptj import Int8GPTJForCausalLM
from smoothquant.smooth import smooth_lm
from smoothquant.calibration import get_static_decoder_layer_scales
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default='facebook/opt-13b')
parser.add_argument("--num-samples", type=int, default=10)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--act-scales", type=str,
default='act_scales/opt-13b.pt')
parser.add_argument("--output-path", type=str,
default='int8_models/opt-13b-smoothquant')
parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst',
help='location of the calibration dataset, we use the validation set of the Pile dataset')
args = parser.parse_args()
if "opt" in args.model_name:
model = OPTForCausalLM.from_pretrained(args.model_name, device_map="sequential", torch_dtype=torch.float16)
elif "gptj" in args.model_name:
model = GPTJForCausalLM.from_pretrained(args.model_name, device_map="sequential", torch_dtype=torch.float16)
act_scales = torch.load(args.act_scales)
smooth_lm(model, act_scales, 0.1)
tkn = args.model_name
if tkn.startswith("moyix"):
tkn = tkn.replace("moyix", "Salesforce")
tkn = tkn.replace("-gptj","")
tokenizer = AutoTokenizer.from_pretrained(tkn)
if not os.path.exists(args.dataset_path):
print(f'Cannot find the dataset at {args.dataset_path}')
print('Please download the Pile dataset and put the validation set at the path')
print('You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst')
raise FileNotFoundError
decoder_layer_scales = get_static_decoder_layer_scales(model,
tokenizer,
args.dataset_path,
num_samples=args.num_samples,
seq_len=args.seq_len)
# dump layer scales
print("===:",end="")
print(model.transformer.h[0].attn.k_proj.weight)
with open('model_dec_scales.json', 'w') as fp:
json.dump(decoder_layer_scales, fp)
if "opt" in args.model_name:
int8_model = Int8OPTForCausalLM.from_float(model, decoder_layer_scales)
elif "gptj" in args.model_name:
int8_model = Int8GPTJForCausalLM.from_float(model, decoder_layer_scales)
print(int8_model.transformer.h[0].mlp.fc1.bias)
int8_model.save_pretrained(args.output_path)
import torch
import os
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
import argparse
from smoothquant.calibration import get_act_scales
def build_model_and_tokenizer(model_name):
tkn = model_name
if tkn.startswith("moyix"):
tkn = tkn.replace("moyix", "Salesforce")
tkn = tkn.replace("-gptj","")
tokenizer = AutoTokenizer.from_pretrained(tkn, model_max_length=512)
kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"}
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
return model, tokenizer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model-name', type=str,
default='facebook/opt-1.3b', help='model name')
parser.add_argument('--output-path', type=str, default='act_scales/opt-1.3b.pt',
help='where to save the act scales')
parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst',
help='location of the calibration dataset, we use the validation set of the Pile dataset')
parser.add_argument('--num-samples', type=int, default=10)
parser.add_argument('--seq-len', type=int, default=512)
args = parser.parse_args()
return args
@torch.no_grad()
def main():
args = parse_args()
model, tokenizer = build_model_and_tokenizer(args.model_name)
if not os.path.exists(args.dataset_path):
print(f'Cannot find the dataset at {args.dataset_path}')
print('Please download the Pile dataset and put the validation set at the path')
print('You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst')
raise FileNotFoundError
act_scales = get_act_scales(model, tokenizer, args.dataset_path,
args.num_samples, args.seq_len)
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
torch.save(act_scales, args.output_path)
if __name__ == '__main__':
main()
import torch
import torch.nn as nn
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.gptj.modeling_gptj import GPTJBlock
from transformers.models.bloom.modeling_bloom import BloomBlock
@torch.no_grad()
def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
if not isinstance(fcs, list):
fcs = [fcs]
assert isinstance(ln, nn.LayerNorm)
for fc in fcs:
assert isinstance(fc, nn.Linear)
assert ln.weight.numel() == fc.in_features == act_scales.numel()
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
act_scales = act_scales.to(device=device, dtype=dtype)
weight_scales = torch.cat([fc.weight.abs().max(
dim=0, keepdim=True)[0] for fc in fcs], dim=0)
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
scales = (act_scales.pow(alpha) / weight_scales.pow(1-alpha)
).clamp(min=1e-5).to(device).to(dtype)
ln.weight.div_(scales)
ln.bias.div_(scales)
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
@torch.no_grad()
def smooth_lm(model, scales, alpha=0.0):
for name, module in model.named_modules():
if isinstance(module, OPTDecoderLayer):
attn_ln = module.self_attn_layer_norm
qkv = [module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj]
qkv_input_scales = scales[name + '.self_attn.q_proj']
smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
ffn_ln = module.final_layer_norm
fc1 = module.fc1
fc1_input_scales = scales[name + '.fc1']
smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)
elif isinstance(module, BloomBlock):
attn_ln = module.input_layernorm
qkv = module.self_attention.query_key_value
qkv_input_scales = scales[name + '.self_attention.query_key_value']
smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
ffn_ln = module.post_attention_layernorm
fc1 = module.mlp.dense_h_to_4h
fc1_input_scales = scales[name + '.mlp.dense_h_to_4h']
smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)
elif isinstance(module, GPTJBlock):
print("SMOOTHING")
attn_ln = module.ln_1
qkv = [module.attn.q_proj,
module.attn.k_proj, module.attn.v_proj]
qkv_input_scales = scales[name + '.attn.q_proj']
smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
# ffn_ln = module.ln_1
# fc1 = module.mlp.fc_in
# print(scales.keys())
# fc1_input_scales = scales[name + '.mlp.fc_in']
# smooth_ln_fcs(None, fc1, fc1_input_scales, alpha)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment