Skip to content

Instantly share code, notes, and snippets.

@cedrickchee
Last active March 22, 2023 19:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cedrickchee/3be95c8ab9f38132382737783bd5d55e to your computer and use it in GitHub Desktop.
Save cedrickchee/3be95c8ab9f38132382737783bd5d55e to your computer and use it in GitHub Desktop.
HuggingFace Transformers inference for Stanford Alpaca (fine-tuned LLaMA)
# Based on: Original Alpaca Model/Dataset/Inference Code by Tatsu-lab
import time, torch
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
tokenizer = LlamaTokenizer.from_pretrained("./checkpoint-1200/")
def generate_prompt(instruction, input=None):
if input:
return f"""The following is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:"""
else:
return f"""The following is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:"""
model = LlamaForCausalLM.from_pretrained(
"checkpoint-1200",
load_in_8bit=False,
torch_dtype=torch.float16,
device_map="auto"
)
while True:
text = generate_prompt(input("User: "))
time.sleep(1)
input_ids = tokenizer(text, return_tensors="pt").input_ids.to("cuda")
generated_ids = model.generate(
input_ids,
max_new_tokens=250,
do_sample=True,
repetition_penalty=1.0,
temperature=0.8,
top_p=0.75,
top_k=40
)
print(tokenizer.decode(generated_ids[0]))

Stanford Alpaca is a model fine-tuned from the LLaMA-7B.

The inference code is using Alpaca Native model, which was fine-tuned using the original tatsu-lab/stanford_alpaca repository. The fine-tuning process does not use LoRA, unlike tloen/alpaca-lora.

Hardware and software requirements

For the Alpaca-7B:

  • Linux, MacOS

  • 1x GPU 24GB in fp16 or 1x GPU 12GB in int8

  • PyTorch with CUDA (not the CPU version)

  • HuggingFace Transformers library

    pip install git+https://github.com/huggingface/transformers.git

    Currently, the Transformers library only has support for LLaMA through the latest GitHub repository, and not through Python package.

  • If run in 8-bit (quantized model), install Bitsandbytes and set load_in_8bit=true

@cedrickchee
Copy link
Author

cedrickchee commented Mar 18, 2023

How to use.

  1. Download model weights from https://huggingface.co/chavinlo/alpaca-native

  2. Change ./checkpoint-1200/ to the directory of your HuggingFace format model files directory.

@cedrickchee
Copy link
Author

FAQ:

  1. What if I want to fine-tune Stanford Alpaca myself?

    The Replicate team have repeated the training process and published a tutorial about how they did it. It cost less than $100.

@cedrickchee
Copy link
Author

I've written a simpler tutorial: Creating a chatbot using Alpaca native and LangChain

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment