Skip to content

Instantly share code, notes, and snippets.

@sekstini
Last active July 31, 2023 22:01
Show Gist options
  • Save sekstini/1e49390bb07f50bf5a349efff9b6d436 to your computer and use it in GitHub Desktop.
Save sekstini/1e49390bb07f50bf5a349efff9b6d436 to your computer and use it in GitHub Desktop.
# %%
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
# %%
import bitsandbytes as bnb # type: ignore
def dump_4bit_weight(m: bnb.nn.modules.Linear4bit, path: str):
assert isinstance(m, bnb.nn.modules.Linear4bit), "Only Linear4bit is supported"
obj = {
"weight": m.weight,
"bias": m.bias,
"absmax": m.weight.quant_state[0],
"shape": m.weight.quant_state[1],
"blocksize": m.weight.quant_state[3],
"compressed_stats": m.weight.quant_state[4],
"code": m.weight.quant_state[6],
}
torch.save(obj, path)
# %%
model_4bit = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,
)
)
# %%
model_orig = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, device_map={"": 1},
)
# %%
lin_4bit = model_4bit.model.layers[0].mlp.gate_proj
lin_orig = model_orig.model.layers[0].mlp.gate_proj
# %%
gen = torch.cuda.manual_seed(42)
x = torch.randn((1, lin_4bit.in_features), generator=gen, device="cuda", dtype=torch.float16)
# %%
torch.save({
"input": x,
"output_4bit": lin_4bit(x.to(model_4bit.device)).detach().cpu(),
"output_orig": lin_orig(x.to(model_orig.device)).detach().cpu(),
}, "input_and_outputs.pt")
# %%
dump_4bit_weight(lin_4bit, "lin_4bit_dump.pt")
torch.save(lin_orig.weight.detach().cpu(), "lin_orig_weight.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment