Last active
May 7, 2024 04:28
-
-
Save ahoho/ba41c42984faf64bf4302b2b1cd7e0ce to your computer and use it in GitHub Desktop.
Create a huggingface pipeline with a lora-trained alpaca
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Any | |
import torch | |
from transformers.utils import is_accelerate_available, is_bitsandbytes_available | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
GenerationConfig, | |
pipeline, | |
) | |
from peft import PeftModel | |
ALPACA_TEMPLATE = ( | |
"Below is an instruction that describes a task, paired with an input that provides " | |
"further context. Write a response that appropriately completes the request.\n\n" | |
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" | |
) | |
def load_adapted_hf_generation_pipeline( | |
base_model_name, | |
lora_model_name, | |
temperature: float = 0, | |
top_p: float = 1., | |
max_tokens: int = 50, | |
batch_size: int = 16, | |
device: str = "cpu", | |
load_in_8bit: bool = True, | |
generation_kwargs: Optional[dict] = None, | |
): | |
""" | |
Load a huggingface model & adapt with PEFT. | |
Borrowed from https://github.com/tloen/alpaca-lora/blob/main/generate.py | |
""" | |
if device == "cuda": | |
if not is_accelerate_available(): | |
raise ValueError("Install `accelerate`") | |
if load_in_8bit and not is_bitsandbytes_available(): | |
raise ValueError("Install `bitsandbytes`") | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
task = "text-generation" | |
if device == "cuda": | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
load_in_8bit=load_in_8bit, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
model = PeftModel.from_pretrained( | |
model, | |
lora_model_name, | |
torch_dtype=torch.float16, | |
) | |
elif device == "mps": | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
device_map={"": device}, | |
torch_dtype=torch.float16, | |
) | |
model = PeftModel.from_pretrained( | |
model, | |
lora_model_name, | |
device_map={"": device}, | |
torch_dtype=torch.float16, | |
) | |
else: | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, device_map={"": device}, low_cpu_mem_usage=True | |
) | |
model = PeftModel.from_pretrained( | |
model, | |
lora_model_name, | |
device_map={"": device}, | |
) | |
# unwind broken decapoda-research config | |
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk | |
model.config.bos_token_id = 1 | |
model.config.eos_token_id = 2 | |
if not load_in_8bit: | |
model.half() # seems to fix bugs for some users. | |
model.eval() | |
generation_kwargs = generation_kwargs if generation_kwargs is not None else {} | |
config = GenerationConfig( | |
do_sample=True, | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_p=top_p, | |
**generation_kwargs, | |
) | |
pipe = pipeline( | |
task, | |
model=model, | |
tokenizer=tokenizer, | |
batch_size=16, # TODO: make a parameter | |
generation_config=config, | |
framework="pt", | |
) | |
return pipe | |
if __name__ == "__main__": | |
pipe = load_adapted_hf_generation_pipeline( | |
base_model_name="decapoda-research/llama-7b-hf", | |
lora_model_name="tloen/alpaca-lora-7b", | |
) | |
prompt = ALPACA_TEMPLATE.format( | |
instruction="Paraphrase the sentence.", | |
input="The quick brown fox jumped over the lazy dog." | |
) | |
print(pipe(prompt)) |
Thank you so much! I'm gonna test it and let you know.
One more time thank you so much! It works well :)
Thank you so much! I
Heads up I believe you're missing a comma at the end of line 116
Nice catch, thanks!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I believe that is a just warning that you can safely ignore. For the versions of transformers & PEFT I was using (4.28.1 and 0.3.0.dev0, respectively),
PeftModelForCausalLM
had not been added to thetext-generation
pipelines list of supported models (but, as you can see, the underlyingLlamaForCausalLM
upon which the Peft model is added is supported--i.e., the warning is spurious)It's possible I'm wrong here? But I did get this to work