-
-
Save 0cc4m/a753b6a16a618cfbe747a74920dc50f6 to your computer and use it in GitHub Desktop.
8-bit Test for bitsandbytes to compare running on GPUs with igemmlt with the code used to run on GPUs without this function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from bitsandbytes.autograd._functions import MatMul8bitLt, MatmulLtState | |
MAX_NEW_TOKENS = 32 | |
model_name = "EleutherAI/pythia-410m-deduped" | |
text = """ | |
Q: On average Joe throws 25 punches per minute. A fight lasts 5 rounds of 3 minutes. | |
How many punches did he throw?\n | |
A: Let’s think step by step.\n""" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(0) | |
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) | |
max_memory = f"{free_in_GB}GB" | |
n_gpus = torch.cuda.device_count() | |
max_memory = {i: max_memory for i in range(n_gpus)} | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
load_in_8bit=True, | |
max_memory=max_memory, | |
) | |
generated_ids = model.generate(input_ids, max_length=len(input_ids[0]) + MAX_NEW_TOKENS, do_sample=True) | |
result = tokenizer.decode(generated_ids.cpu().squeeze()) | |
print(f"8bit-reg: {result}") | |
del model | |
gc.collect() | |
# Monkey patch bitsandbytes call to force disable igemmlt to simulate GPU that doesn't support it | |
old_method = MatMul8bitLt.forward | |
@staticmethod | |
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): | |
state.force_no_igemmlt = True | |
return old_method(ctx, A, B, out, bias, state) | |
MatMul8bitLt.forward = forward | |
global force_no_igemmlt | |
force_no_igemmlt = True | |
model_8bit = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
load_in_8bit=True, | |
max_memory=max_memory, | |
) | |
generated_ids_8bit = model_8bit.generate(input_ids, max_length=len(input_ids[0]) + MAX_NEW_TOKENS, do_sample=True) | |
result_8bit = tokenizer.decode(generated_ids.cpu().squeeze()) | |
print(f"8bit-old: {result_8bit}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment