Skip to content

Instantly share code, notes, and snippets.

@nalzok
Last active January 27, 2024 19:41
Show Gist options
  • Save nalzok/b88b192d1c8ed65d66bd603d1c5444b5 to your computer and use it in GitHub Desktop.
Save nalzok/b88b192d1c8ed65d66bd603d1c5444b5 to your computer and use it in GitHub Desktop.
from time import perf_counter
import torch
from flash_attn_time import benchmark_one
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
def benchmark():
for llama in (1, 2):
sizes = (7, 13, 30, 65) if llama == 1 else (7, 13, 70)
for size in sizes:
for rate, method in ((2, "E8P"), (4, "E8PRVQ")):
if method is None:
publisher = "relaxml" if llama == 1 else "meta-llama"
repo_id = f"Llama-{llama}-{size}b-hf"
else:
publisher = "relaxml"
repo_id = f"Llama-{llama}-{size}b-{method}-{rate}Bit"
print(">", repo_id)
if (llama, size, rate) not in {(2, 7, 2),
(2, 7, 4),
(2, 13, 2),
(2, 13, 4),
(2, 70, 2),
(1, 30, 2),
(1, 30, 4),
(1, 65, 2)}:
print("Skip")
continue
try:
snapshot_path = snapshot_download(f"{publisher}/{repo_id}")
except RepositoryNotFoundError:
print("404")
continue
model_name = f"meta-llama/Llama-2-{size}b-hf" if llama == 2 else f"relaxml/Llama-1-{size}b-hf"
try:
start_time = perf_counter()
benchmark_one(model_name, snapshot_path, method is not None)
end_time = perf_counter()
except torch.cuda.OutOfMemoryError:
print("OOM")
continue
print("Elapsed", end_time - start_time)
if __name__ == "__main__":
benchmark()
import torch
import torch.nn as nn
from lib.linear.quantized_linear import QuantizedLinear
from lib.linear.fused_quantized_linear import FusedQuantizedLinear
from lib import codebook
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from flash_attn.modules.embedding import GPT2Embeddings
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.modules.mlp import GatedMlp
from flash_attn.models.gpt import GPTLMHeadModel, GPTModel, Block
from flash_attn.models.llama import llama_config_to_gpt2_config, remap_state_dict_hf_llama
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.benchmark import pytorch_profiler
from flash_attn.ops.triton.layer_norm import RMSNorm
import os
import json
device = "cuda"
dtype = torch.float16
def benchmark_one(model_name, quip_hf, quantized):
llama_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config = llama_config_to_gpt2_config(llama_config)
config.use_flash_attn = True
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
# pretrained_state_dict = remap_state_dict_hf_llama(state_dict_from_pretrained(model_name), config)
def remap_state_dict(state_dict):
frozen_keys = list(state_dict.keys())
for key in frozen_keys:
new_key = key.replace(
'model', 'transformer'
).replace(
'self_attn.qkv_proj', 'mixer.Wqkv'
).replace(
'self_attn.o_proj', 'mixer.out_proj'
).replace(
'mlp.upgate_proj', 'mlp.fc1'
).replace(
'mlp.down_proj', 'mlp.fc2'
).replace(
'input_layernorm', 'norm1'
).replace(
'post_attention_layernorm', 'norm2'
).replace(
'embed_tokens', 'embeddings.word_embeddings'
)
if new_key.endswith('Wqkv.fuse_scales'):
head_dim = llama_config.hidden_size // llama_config.num_attention_heads
device = state_dict[key].device
dtype = state_dict[key].dtype
fuse_scales = torch.concat([
state_dict[key][0] * torch.ones(llama_config.num_attention_heads * head_dim, device=device, dtype=dtype),
state_dict[key][1] * torch.ones(llama_config.num_key_value_heads * head_dim, device=device, dtype=dtype),
state_dict[key][1] * torch.ones(llama_config.num_key_value_heads * head_dim, device=device, dtype=dtype),
], dim=0)
state_dict[new_key] = fuse_scales
elif new_key.endswith('fc1.fuse_scales'):
# upgate
device = state_dict[key].device
fuse_scales = torch.concat([
state_dict[key][0] * torch.ones(llama_config.intermediate_size, device=device, dtype=dtype),
state_dict[key][1] * torch.ones(llama_config.intermediate_size, device=device, dtype=dtype),
], dim=0)
state_dict[new_key] = fuse_scales
else:
if new_key == 'transformer.norm.weight':
new_key = 'transformer.ln_f.weight'
state_dict[new_key] = state_dict[key]
if new_key != key:
del(state_dict[key])
return state_dict
from lib.utils.unsafe_import import model_from_hf_path
m = model_from_hf_path(quip_hf, use_cuda_graph=False)[0].state_dict()
model = GPTLMHeadModel(config, device='meta', dtype=dtype)
if quantized:
m = remap_state_dict(m)
quip_params = json.load(open(os.path.join(quip_hf, 'config.json')))['quip_params']
def replace_linear(module):
for name, child in module.named_children():
if isinstance(child, nn.Linear) and name != 'lm_head':
if name.endswith('Wqkv') or name.endswith('fc1'):
ql = FusedQuantizedLinear(
-1, (child.out_features,), True,
child.in_features,
child.out_features,
quip_params['codesz'],
quip_params.get('packsz', 1),
quip_params.get('pack_out', False),
quip_params['idx_dtype'],
quip_params.get('codebook_version', 0),
rank=quip_params['lora_rank'],
rescale_WH=quip_params['rescale_WH'],
resid_scale_override=quip_params.get('resid_scale_override', -1)
)
else:
ql = QuantizedLinear(
child.in_features,
child.out_features,
quip_params['codesz'],
quip_params.get('packsz', 1),
quip_params.get('pack_out', False),
quip_params['idx_dtype'],
quip_params.get('codebook_version', 0),
rank=quip_params['lora_rank'],
rescale_WH=quip_params['rescale_WH'],
resid_scale_override=quip_params.get('resid_scale_override', -1)
)
ql.codebook_id.copy_(codebook.get_id(quip_params['codebook']))
setattr(module, name, ql)
else:
replace_linear(child)
replace_linear(model)
model.load_state_dict(m, strict=False, assign=True)
#model.load_state_dict(pretrained_state_dict)
def replace_meta(module):
for name, child in module.named_children():
if isinstance(child, GPTModel) \
or isinstance(child, GPT2Embeddings) \
or isinstance(child, nn.ModuleList) \
or isinstance(child, Block) \
or isinstance(child, GatedMlp):
replace_meta(child)
elif isinstance(child, nn.Embedding):
child.weight = nn.Parameter(torch.randn_like(child.weight, device="cuda"))
elif isinstance(child, RotaryEmbedding):
child.inv_freq = torch.randn_like(child.inv_freq, device="cuda")
elif isinstance(child, RMSNorm):
child.weight = nn.Parameter(torch.randn_like(child.weight, device="cuda"))
elif isinstance(child, nn.Linear):
child.weight = nn.Parameter(torch.randn_like(child.weight, device="cuda"))
replace_meta(model)
model = model.cuda()
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
input_ids = tokenizer.encode("a", return_tensors="pt").to(device)
max_length = input_ids.shape[-1] + 1000
out = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
if __name__ == "__main__":
model_name = "relaxml/Llama-1-30b-hf"
quip_hf = '/share/desa/nfs01/qs234/huggingface/hub/models--relaxml--Llama-1-30b-E8P-2Bit/snapshots/42807d6d30647886bfc77072871a960a89919f46/'
# torch.cuda.memory._record_memory_history(enabled='all')
benchmark_one(model_name, quip_hf, True)
# from pickle import dump
# s = torch.cuda.memory._snapshot()
# with open(f"snapshot.pickle", "wb") as f:
# dump(s, f)
#
# torch.cuda.memory._record_memory_history(enabled=None)
absl-py==2.1.0
accelerate==0.26.1
aiohttp==3.9.1
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.2.0
attrs==23.2.0
cachetools==5.3.2
certifi==2023.11.17
chardet==5.2.0
charset-normalizer==3.3.2
click==8.1.7
colorama==0.4.6
DataProperty==1.0.1
datasets==2.16.1
dill==0.3.7
distro==1.9.0
einops==0.7.0
evaluate==0.4.1
fast-hadamard-transform==1.0.1
filelock==3.13.1
flash-attn==2.4.2
frozenlist==1.4.1
fsspec==2023.10.0
fused-dense-lib==0.0.0
glog==0.3.1
h11==0.14.0
httpcore==1.0.2
httpx==0.26.0
huggingface-hub==0.20.3
icdiff==2.0.7
idna==3.6
Jinja2==3.1.3
joblib==1.3.2
jsonlines==4.0.0
lm-eval==0.3.0
lxml==5.1.0
MarkupSafe==2.1.4
mbstrdecoder==1.1.3
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
networkx==3.2.1
ninja==1.11.1.1
nltk==3.8.1
numexpr==2.8.8
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-ml-py==12.535.133
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
nvitop==1.3.2
openai==1.9.0
packaging==23.2
pandas==2.2.0
pathvalidate==3.2.0
peft==0.7.1
pillow==10.2.0
portalocker==2.8.2
primefac==2.0.12
psutil==5.9.8
pyarrow==15.0.0
pyarrow-hotfix==0.6
pybind11==2.11.1
pycountry==23.12.11
pydantic==2.5.3
pydantic_core==2.14.6
pytablewriter==1.2.0
python-dateutil==2.8.2
python-gflags==3.1.2
pytz==2023.3.post1
PyYAML==6.0.1
quiptools-cuda==0.0.0
regex==2023.12.25
requests==2.31.0
responses==0.18.0
rouge-score==0.1.2
sacrebleu==1.5.0
safetensors==0.4.2
scikit-learn==1.4.0
scipy==1.12.0
sentencepiece==0.1.99
six==1.16.0
sniffio==1.3.0
sqlitedict==2.1.0
sympy==1.12
tabledata==1.3.3
tabulate==0.9.0
tcolorpy==0.1.4
termcolor==2.4.0
threadpoolctl==3.2.0
tokenizers==0.15.1
torch==2.1.2
torchvision==0.16.2
tqdm==4.66.1
tqdm-multiprocess==0.0.11
transformers==4.37.1
triton==2.1.0
typepy==1.3.2
typing_extensions==4.9.0
tzdata==2023.4
urllib3==2.1.0
xxhash==3.4.1
yarl==1.9.4
zstandard==0.22.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment