Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ebsmothers/ce0aeb128115866af42e692d678e5193 to your computer and use it in GitHub Desktop.
Save ebsmothers/ce0aeb128115866af42e692d678e5193 to your computer and use it in GitHub Desktop.
from torch import nn
from torchtune.utils import get_memory_stats, get_device
from torchao.dtypes.nf4tensor import to_nf4
from bitsandbytes.functional import quantize_nf4
def main():
device = get_device('cuda')
# Size of Llama3-8B output projection weight
big_linear = nn.Linear(in_features=4096, out_features=128256, bias=False, device=device)
memory_stats = get_memory_stats(device=device)
print(f"before quantize: {memory_stats}")
# Quantize with ao
ao_quant = to_nf4(big_linear.weight)
memory_stats = get_memory_stats(device=device)
print(f"after ao quant: {memory_stats}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment