Skip to content

Instantly share code, notes, and snippets.

@ahoho
Last active March 4, 2024 07:27
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save ahoho/ba41c42984faf64bf4302b2b1cd7e0ce to your computer and use it in GitHub Desktop.
Save ahoho/ba41c42984faf64bf4302b2b1cd7e0ce to your computer and use it in GitHub Desktop.
Create a huggingface pipeline with a lora-trained alpaca
from typing import Optional, Any
import torch
from transformers.utils import is_accelerate_available, is_bitsandbytes_available
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
GenerationConfig,
pipeline,
)
from peft import PeftModel
ALPACA_TEMPLATE = (
"Below is an instruction that describes a task, paired with an input that provides "
"further context. Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
def load_adapted_hf_generation_pipeline(
base_model_name,
lora_model_name,
temperature: float = 0,
top_p: float = 1.,
max_tokens: int = 50,
batch_size: int = 16,
device: str = "cpu",
load_in_8bit: bool = True,
generation_kwargs: Optional[dict] = None,
):
"""
Load a huggingface model & adapt with PEFT.
Borrowed from https://github.com/tloen/alpaca-lora/blob/main/generate.py
"""
if device == "cuda":
if not is_accelerate_available():
raise ValueError("Install `accelerate`")
if load_in_8bit and not is_bitsandbytes_available():
raise ValueError("Install `bitsandbytes`")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
task = "text-generation"
if device == "cuda":
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
load_in_8bit=load_in_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
lora_model_name,
torch_dtype=torch.float16,
)
elif device == "mps":
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
lora_model_name,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model_name, device_map={"": device}, low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
lora_model_name,
device_map={"": device},
)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_in_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
generation_kwargs = generation_kwargs if generation_kwargs is not None else {}
config = GenerationConfig(
do_sample=True,
temperature=temperature,
max_new_tokens=max_tokens,
top_p=top_p,
**generation_kwargs,
)
pipe = pipeline(
task,
model=model,
tokenizer=tokenizer,
batch_size=16, # TODO: make a parameter
generation_config=config,
framework="pt",
)
return pipe
if __name__ == "__main__":
pipe = load_adapted_hf_generation_pipeline(
base_model_name="decapoda-research/llama-7b-hf",
lora_model_name="tloen/alpaca-lora-7b",
)
prompt = ALPACA_TEMPLATE.format(
instruction="Paraphrase the sentence.",
input="The quick brown fox jumped over the lazy dog."
)
print(pipe(prompt))
@AleksandrTarasov07
Copy link

Good day !
May I ask you to tell me if your script works, please ? Cause I did almost the same things, however I got the issue:

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].

@ahoho
Copy link
Author

ahoho commented Aug 9, 2023

I believe that is a just warning that you can safely ignore. For the versions of transformers & PEFT I was using (4.28.1 and 0.3.0.dev0, respectively), PeftModelForCausalLM had not been added to the text-generation pipelines list of supported models (but, as you can see, the underlying LlamaForCausalLM upon which the Peft model is added is supported--i.e., the warning is spurious)

It's possible I'm wrong here? But I did get this to work

@AleksandrTarasov07
Copy link

Thank you so much! I'm gonna test it and let you know.

@AleksandrTarasov07
Copy link

One more time thank you so much! It works well :)

@nuri428
Copy link

nuri428 commented Oct 14, 2023

Thank you so much! I

@hoffm386
Copy link

Heads up I believe you're missing a comma at the end of line 116

@ahoho
Copy link
Author

ahoho commented Dec 14, 2023

Nice catch, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment