Created
December 23, 2022 04:19
-
-
Save ImanHosseini/ad922cc9c01e05bc16c59926b6f35fd9 to your computer and use it in GitHub Desktop.
SmoothQuant changes for GPT-J
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from datasets import load_dataset | |
import functools | |
from collections import defaultdict | |
from transformers.models.opt.modeling_opt import OPTForCausalLM | |
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM | |
from functools import partial | |
import numpy as np | |
from tqdm import tqdm | |
def get_act_scales(model, tokenizer, dataset_path, num_samples=10, seq_len=512): | |
model.eval() | |
device = next(model.parameters()).device | |
act_scales = {} | |
def stat_tensor(name, tensor): | |
hidden_dim = tensor.shape[-1] | |
tensor = tensor.view(-1, hidden_dim).abs().detach() | |
comming_max = torch.max(tensor, dim=0)[0].float().cpu() | |
if name in act_scales: | |
act_scales[name] = torch.max(act_scales[name], comming_max) | |
else: | |
act_scales[name] = comming_max | |
def stat_input_hook(m, x, y, name): | |
if isinstance(x, tuple): | |
x = x[0] | |
stat_tensor(name, x) | |
hooks = [] | |
for name, m in model.named_modules(): | |
if isinstance(m, nn.Linear): | |
hooks.append( | |
m.register_forward_hook( | |
functools.partial(stat_input_hook, name=name)) | |
) | |
# dataset = load_dataset("json", data_files=dataset_path, split="train") | |
dataset = load_dataset("lambada", split='validation[:512]') | |
dataset = dataset.shuffle(seed=42) | |
for i in tqdm(range(num_samples)): | |
input_ids = tokenizer(dataset[i]["text"], return_tensors="pt", | |
max_length=seq_len, truncation=True).input_ids.to(device) | |
model(input_ids) | |
for h in hooks: | |
h.remove() | |
return act_scales | |
@torch.no_grad() | |
def get_static_decoder_layer_scales(model, | |
tokenizer, | |
dataset_path, | |
num_samples=10, | |
seq_len=512, | |
): | |
model.eval() | |
device = next(model.parameters()).device | |
act_dict = defaultdict(dict) | |
def stat_io_hook(m, x, y, name): | |
if isinstance(x, tuple): | |
x = x[0] | |
if name not in act_dict or "input" not in act_dict[name]: | |
act_dict[name]["input"] = x.detach().abs().max().item() | |
else: | |
act_dict[name]["input"] = max( | |
act_dict[name]["input"], x.detach().abs().max().item()) | |
if isinstance(y, tuple): | |
y = y[0] | |
if name not in act_dict or "output" not in act_dict[name]: | |
act_dict[name]["output"] = y.detach().abs().max().item() | |
else: | |
act_dict[name]["output"] = max( | |
act_dict[name]["output"], y.detach().abs().max().item()) | |
hooks = [] | |
for name, m in model.named_modules(): | |
if isinstance(m, torch.nn.Linear): | |
hooks.append(m.register_forward_hook( | |
partial(stat_io_hook, name=name))) | |
print("Collecting activation scales...") | |
pbar = tqdm(range(num_samples)) | |
# dataset = load_dataset('json', data_files=dataset_path, split="train") | |
dataset = load_dataset("lambada", split='validation[:512]') | |
dataset = dataset.shuffle(seed=42) | |
for i in pbar: | |
input_ids = tokenizer(dataset[i]["text"], return_tensors="pt", | |
max_length=seq_len, truncation=True).input_ids.to(device) | |
model(input_ids) | |
mean_scale = np.mean([v["input"] for v in act_dict.values()]) | |
pbar.set_description(f"Mean input scale: {mean_scale:.2f}") | |
for hook in hooks: | |
hook.remove() | |
decoder_layer_scales = [] | |
if isinstance(model, OPTForCausalLM): | |
for idx in range(model.config.num_hidden_layers): | |
scale_dict = {} | |
scale_dict["attn_input_scale"] = act_dict[ | |
f"model.decoder.layers.{idx}.self_attn.q_proj"]['input'] / 127 | |
scale_dict["q_output_scale"] = act_dict[ | |
f"model.decoder.layers.{idx}.self_attn.q_proj"]['output'] / 127 | |
scale_dict["k_output_scale"] = act_dict[ | |
f"model.decoder.layers.{idx}.self_attn.k_proj"]['output'] / 127 | |
scale_dict["v_output_scale"] = act_dict[ | |
f"model.decoder.layers.{idx}.self_attn.v_proj"]['output'] / 127 | |
scale_dict["out_input_scale"] = act_dict[ | |
f"model.decoder.layers.{idx}.self_attn.out_proj"]['input'] / 127 | |
scale_dict["fc1_input_scale"] = act_dict[ | |
f"model.decoder.layers.{idx}.fc1"]['input'] / 127 | |
scale_dict["fc2_input_scale"] = act_dict[ | |
f"model.decoder.layers.{idx}.fc2"]["input"] / 127 | |
decoder_layer_scales.append(scale_dict) | |
elif isinstance(model, GPTJForCausalLM): | |
for idx in range(len(model.transformer.h)): | |
scale_dict = {} | |
scale_dict["attn_input_scale"] = act_dict[ | |
f"transformer.h.{idx}.attn.q_proj"]['input'] / 127 | |
scale_dict["q_output_scale"] = act_dict[ | |
f"transformer.h.{idx}.attn.q_proj"]['output'] / 127 | |
scale_dict["k_output_scale"] = act_dict[ | |
f"transformer.h.{idx}.attn.k_proj"]['output'] / 127 | |
scale_dict["v_output_scale"] = act_dict[ | |
f"transformer.h.{idx}.attn.v_proj"]['output'] / 127 | |
scale_dict["out_input_scale"] = act_dict[ | |
f"transformer.h.{idx}.attn.out_proj"]['input'] / 127 | |
scale_dict["fc1_input_scale"] = act_dict[ | |
f"transformer.h.{idx}.mlp.fc_in"]['input'] / 127 | |
scale_dict["fc2_input_scale"] = act_dict[ | |
f"transformer.h.{idx}.mlp.fc_out"]["input"] / 127 | |
decoder_layer_scales.append(scale_dict) | |
return decoder_layer_scales |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import argparse | |
import os | |
import json | |
from transformers.models.opt.modeling_opt import OPTForCausalLM | |
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM | |
from transformers import AutoTokenizer | |
from smoothquant.opt import Int8OPTForCausalLM | |
from smoothquant.gptj import Int8GPTJForCausalLM | |
from smoothquant.smooth import smooth_lm | |
from smoothquant.calibration import get_static_decoder_layer_scales | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model-name", type=str, default='facebook/opt-13b') | |
parser.add_argument("--num-samples", type=int, default=10) | |
parser.add_argument("--seq-len", type=int, default=512) | |
parser.add_argument("--act-scales", type=str, | |
default='act_scales/opt-13b.pt') | |
parser.add_argument("--output-path", type=str, | |
default='int8_models/opt-13b-smoothquant') | |
parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst', | |
help='location of the calibration dataset, we use the validation set of the Pile dataset') | |
args = parser.parse_args() | |
if "opt" in args.model_name: | |
model = OPTForCausalLM.from_pretrained(args.model_name, device_map="sequential", torch_dtype=torch.float16) | |
elif "gptj" in args.model_name: | |
model = GPTJForCausalLM.from_pretrained(args.model_name, device_map="sequential", torch_dtype=torch.float16) | |
act_scales = torch.load(args.act_scales) | |
smooth_lm(model, act_scales, 0.1) | |
tkn = args.model_name | |
if tkn.startswith("moyix"): | |
tkn = tkn.replace("moyix", "Salesforce") | |
tkn = tkn.replace("-gptj","") | |
tokenizer = AutoTokenizer.from_pretrained(tkn) | |
if not os.path.exists(args.dataset_path): | |
print(f'Cannot find the dataset at {args.dataset_path}') | |
print('Please download the Pile dataset and put the validation set at the path') | |
print('You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst') | |
raise FileNotFoundError | |
decoder_layer_scales = get_static_decoder_layer_scales(model, | |
tokenizer, | |
args.dataset_path, | |
num_samples=args.num_samples, | |
seq_len=args.seq_len) | |
# dump layer scales | |
print("===:",end="") | |
print(model.transformer.h[0].attn.k_proj.weight) | |
with open('model_dec_scales.json', 'w') as fp: | |
json.dump(decoder_layer_scales, fp) | |
if "opt" in args.model_name: | |
int8_model = Int8OPTForCausalLM.from_float(model, decoder_layer_scales) | |
elif "gptj" in args.model_name: | |
int8_model = Int8GPTJForCausalLM.from_float(model, decoder_layer_scales) | |
print(int8_model.transformer.h[0].mlp.fc1.bias) | |
int8_model.save_pretrained(args.output_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import os | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
) | |
import argparse | |
from smoothquant.calibration import get_act_scales | |
def build_model_and_tokenizer(model_name): | |
tkn = model_name | |
if tkn.startswith("moyix"): | |
tkn = tkn.replace("moyix", "Salesforce") | |
tkn = tkn.replace("-gptj","") | |
tokenizer = AutoTokenizer.from_pretrained(tkn, model_max_length=512) | |
kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} | |
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) | |
return model, tokenizer | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model-name', type=str, | |
default='facebook/opt-1.3b', help='model name') | |
parser.add_argument('--output-path', type=str, default='act_scales/opt-1.3b.pt', | |
help='where to save the act scales') | |
parser.add_argument('--dataset-path', type=str, default='dataset/val.jsonl.zst', | |
help='location of the calibration dataset, we use the validation set of the Pile dataset') | |
parser.add_argument('--num-samples', type=int, default=10) | |
parser.add_argument('--seq-len', type=int, default=512) | |
args = parser.parse_args() | |
return args | |
@torch.no_grad() | |
def main(): | |
args = parse_args() | |
model, tokenizer = build_model_and_tokenizer(args.model_name) | |
if not os.path.exists(args.dataset_path): | |
print(f'Cannot find the dataset at {args.dataset_path}') | |
print('Please download the Pile dataset and put the validation set at the path') | |
print('You can download the validation dataset of the Pile at https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst') | |
raise FileNotFoundError | |
act_scales = get_act_scales(model, tokenizer, args.dataset_path, | |
args.num_samples, args.seq_len) | |
os.makedirs(os.path.dirname(args.output_path), exist_ok=True) | |
torch.save(act_scales, args.output_path) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from transformers.models.opt.modeling_opt import OPTDecoderLayer | |
from transformers.models.gptj.modeling_gptj import GPTJBlock | |
from transformers.models.bloom.modeling_bloom import BloomBlock | |
@torch.no_grad() | |
def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5): | |
if not isinstance(fcs, list): | |
fcs = [fcs] | |
assert isinstance(ln, nn.LayerNorm) | |
for fc in fcs: | |
assert isinstance(fc, nn.Linear) | |
assert ln.weight.numel() == fc.in_features == act_scales.numel() | |
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype | |
act_scales = act_scales.to(device=device, dtype=dtype) | |
weight_scales = torch.cat([fc.weight.abs().max( | |
dim=0, keepdim=True)[0] for fc in fcs], dim=0) | |
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) | |
scales = (act_scales.pow(alpha) / weight_scales.pow(1-alpha) | |
).clamp(min=1e-5).to(device).to(dtype) | |
ln.weight.div_(scales) | |
ln.bias.div_(scales) | |
for fc in fcs: | |
fc.weight.mul_(scales.view(1, -1)) | |
@torch.no_grad() | |
def smooth_lm(model, scales, alpha=0.0): | |
for name, module in model.named_modules(): | |
if isinstance(module, OPTDecoderLayer): | |
attn_ln = module.self_attn_layer_norm | |
qkv = [module.self_attn.q_proj, | |
module.self_attn.k_proj, module.self_attn.v_proj] | |
qkv_input_scales = scales[name + '.self_attn.q_proj'] | |
smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) | |
ffn_ln = module.final_layer_norm | |
fc1 = module.fc1 | |
fc1_input_scales = scales[name + '.fc1'] | |
smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) | |
elif isinstance(module, BloomBlock): | |
attn_ln = module.input_layernorm | |
qkv = module.self_attention.query_key_value | |
qkv_input_scales = scales[name + '.self_attention.query_key_value'] | |
smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) | |
ffn_ln = module.post_attention_layernorm | |
fc1 = module.mlp.dense_h_to_4h | |
fc1_input_scales = scales[name + '.mlp.dense_h_to_4h'] | |
smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) | |
elif isinstance(module, GPTJBlock): | |
print("SMOOTHING") | |
attn_ln = module.ln_1 | |
qkv = [module.attn.q_proj, | |
module.attn.k_proj, module.attn.v_proj] | |
qkv_input_scales = scales[name + '.attn.q_proj'] | |
smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) | |
# ffn_ln = module.ln_1 | |
# fc1 = module.mlp.fc_in | |
# print(scales.keys()) | |
# fc1_input_scales = scales[name + '.mlp.fc_in'] | |
# smooth_ln_fcs(None, fc1, fc1_input_scales, alpha) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment