Skip to content

Instantly share code, notes, and snippets.

@sekstini
Created July 31, 2023 22:04
Show Gist options
  • Save sekstini/8ec95cbe34eb40a4094caf715a10157a to your computer and use it in GitHub Desktop.
Save sekstini/8ec95cbe34eb40a4094caf715a10157a to your computer and use it in GitHub Desktop.
# %%
import torch
# %%
# == Load 4bit weights and original weights ==
lin_4bit_dump = torch.load("lin_4bit_dump.pt", map_location="cuda")
lin_orig_weight = torch.load("lin_orig_weight.pt", map_location="cuda")
tmp = torch.load("input_and_outputs.pt", map_location="cuda")
x, output_4bit, output_orig = tmp["input"], tmp["output_4bit"], tmp["output_orig"]
# %%
def dequantize_bnb_4bit(
weight: torch.Tensor, # packed 4bit weights
absmax: torch.Tensor, # groupwise absmax, fp32 if not double quant
shape: torch.Size,
blocksize: int,
compressed_stats, # Only set for double quant
code: torch.Tensor, # 4bit (ie. length 16) lookup table, fp32
) -> torch.Tensor:
assert compressed_stats is None, "Double quantization not implemented"
m, n = shape
w = weight.view(-1).to(torch.int32)
out = torch.empty((m, n), dtype=torch.float32, device=w.device)
out[:, 0::2] = code[w >> 4].view((m, n//2))
out[:, 1::2] = code[w & 0xF].view((m, n//2))
out.view(-1, blocksize).mul_(absmax.view(-1, 1))
return out.half()
keys = ["weight", "bias", "absmax", "shape", "blocksize", "compressed_stats", "code"]
weight, bias, absmax, shape, blocksize, compressed_stats, code = [lin_4bit_dump[k] for k in keys]
lin_4bit_weight = dequantize_bnb_4bit(weight, absmax, shape, blocksize, compressed_stats, code)
# %%
((x @ lin_orig_weight.T) - output_orig).max()
# %%
((x @ lin_4bit_weight.T) - output_4bit).max()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment