Skip to content

Instantly share code, notes, and snippets.

@ahoho
Last active May 7, 2024 04:28
Show Gist options
  • Save ahoho/ba41c42984faf64bf4302b2b1cd7e0ce to your computer and use it in GitHub Desktop.
Save ahoho/ba41c42984faf64bf4302b2b1cd7e0ce to your computer and use it in GitHub Desktop.
Create a huggingface pipeline with a lora-trained alpaca
from typing import Optional, Any
import torch
from transformers.utils import is_accelerate_available, is_bitsandbytes_available
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
GenerationConfig,
pipeline,
)
from peft import PeftModel
ALPACA_TEMPLATE = (
"Below is an instruction that describes a task, paired with an input that provides "
"further context. Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
def load_adapted_hf_generation_pipeline(
base_model_name,
lora_model_name,
temperature: float = 0,
top_p: float = 1.,
max_tokens: int = 50,
batch_size: int = 16,
device: str = "cpu",
load_in_8bit: bool = True,
generation_kwargs: Optional[dict] = None,
):
"""
Load a huggingface model & adapt with PEFT.
Borrowed from https://github.com/tloen/alpaca-lora/blob/main/generate.py
"""
if device == "cuda":
if not is_accelerate_available():
raise ValueError("Install `accelerate`")
if load_in_8bit and not is_bitsandbytes_available():
raise ValueError("Install `bitsandbytes`")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
task = "text-generation"
if device == "cuda":
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
load_in_8bit=load_in_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
lora_model_name,
torch_dtype=torch.float16,
)
elif device == "mps":
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
lora_model_name,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model_name, device_map={"": device}, low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
lora_model_name,
device_map={"": device},
)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_in_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
generation_kwargs = generation_kwargs if generation_kwargs is not None else {}
config = GenerationConfig(
do_sample=True,
temperature=temperature,
max_new_tokens=max_tokens,
top_p=top_p,
**generation_kwargs,
)
pipe = pipeline(
task,
model=model,
tokenizer=tokenizer,
batch_size=16, # TODO: make a parameter
generation_config=config,
framework="pt",
)
return pipe
if __name__ == "__main__":
pipe = load_adapted_hf_generation_pipeline(
base_model_name="decapoda-research/llama-7b-hf",
lora_model_name="tloen/alpaca-lora-7b",
)
prompt = ALPACA_TEMPLATE.format(
instruction="Paraphrase the sentence.",
input="The quick brown fox jumped over the lazy dog."
)
print(pipe(prompt))
@AleksandrTarasov07
Copy link

Thank you so much! I'm gonna test it and let you know.

@AleksandrTarasov07
Copy link

One more time thank you so much! It works well :)

@nuri428
Copy link

nuri428 commented Oct 14, 2023

Thank you so much! I

@hoffm386
Copy link

Heads up I believe you're missing a comma at the end of line 116

@ahoho
Copy link
Author

ahoho commented Dec 14, 2023

Nice catch, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment