Skip to content

Instantly share code, notes, and snippets.

@0cc4m
Last active February 25, 2023 09:45
Show Gist options
  • Save 0cc4m/a753b6a16a618cfbe747a74920dc50f6 to your computer and use it in GitHub Desktop.
Save 0cc4m/a753b6a16a618cfbe747a74920dc50f6 to your computer and use it in GitHub Desktop.
8-bit Test for bitsandbytes to compare running on GPUs with igemmlt with the code used to run on GPUs without this function
import gc
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from bitsandbytes.autograd._functions import MatMul8bitLt, MatmulLtState
MAX_NEW_TOKENS = 32
model_name = "EleutherAI/pythia-410m-deduped"
text = """
Q: On average Joe throws 25 punches per minute. A fight lasts 5 rounds of 3 minutes.
How many punches did he throw?\n
A: Let’s think step by step.\n"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(0)
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f"{free_in_GB}GB"
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
load_in_8bit=True,
max_memory=max_memory,
)
generated_ids = model.generate(input_ids, max_length=len(input_ids[0]) + MAX_NEW_TOKENS, do_sample=True)
result = tokenizer.decode(generated_ids.cpu().squeeze())
print(f"8bit-reg: {result}")
del model
gc.collect()
# Monkey patch bitsandbytes call to force disable igemmlt to simulate GPU that doesn't support it
old_method = MatMul8bitLt.forward
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
state.force_no_igemmlt = True
return old_method(ctx, A, B, out, bias, state)
MatMul8bitLt.forward = forward
global force_no_igemmlt
force_no_igemmlt = True
model_8bit = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
load_in_8bit=True,
max_memory=max_memory,
)
generated_ids_8bit = model_8bit.generate(input_ids, max_length=len(input_ids[0]) + MAX_NEW_TOKENS, do_sample=True)
result_8bit = tokenizer.decode(generated_ids.cpu().squeeze())
print(f"8bit-old: {result_8bit}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment