Skip to content

Instantly share code, notes, and snippets.

@jooray
Created March 28, 2023 14:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jooray/0407c8ccc951003c4a2d63f3dd0202d2 to your computer and use it in GitHub Desktop.
Save jooray/0407c8ccc951003c4a2d63f3dd0202d2 to your computer and use it in GitHub Desktop.
Running alpacoom model on MPS (Apple Silicon) using HuggingFace Transformers and Peft
import os
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import sys
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
peft_model_id = "mrm8488/Alpacoom"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=False).to("mps")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-7b1")
model = PeftModel.from_pretrained(model, peft_model_id).to("mps")
model.eval()
# Based on the inference code by `tloen/alpaca-lora`
def generate_prompt(instruction, input=None):
if input:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:"""
def generate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
**kwargs,
):
prompt = generate_prompt(instruction, input)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("mps")
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=256,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
return output.split("### Response:")[1].strip().split("Below")[0]
instruction = sys.argv[1]
print("Instruction:", instruction)
print("Response:", generate(instruction, sys.argv[2]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment