Skip to content

Instantly share code, notes, and snippets.

@niw
Last active August 27, 2023 13:26
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save niw/b46b3a55833adf3294f9f8be8370d678 to your computer and use it in GitHub Desktop.
Save niw/b46b3a55833adf3294f9f8be8370d678 to your computer and use it in GitHub Desktop.
Use CodeLlama on macOS with MPS quickly
# Usage
# =====
#
# ## Prerequisite
#
# Prepare Python 3, for exmaple, install Homebrew and `brew install python`.
#
# ## Install dependencies
#
# $ python3 -m venv .venv
# $ .venv/bin/pip3 install --pre --index-url https://download.pytorch.org/whl/nightly/cpu torch
# $ .venv/bin/pip3 install git+https://github.com/huggingface/transformers.git
#
# Due to known issue, need to use the latest pytorch to use MPS.
#
# ## Patch Transformers
#
# Patch `lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py` to change
# `long()` to `int()` on the line `position_ids = attention_mask.long().cumsum(-1) - 1`.
# This is required to use MPS.
#
# ## Run test script on Python REPL
#
# $ .venv/bin/python3 -i test.py
# >>> gen('Write a function prints "Hello World"')
#
# First time when you run this script, it downloads large model and metadata in `~/.cache`.
# Be on the faster internet and prepare storage.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "mps"
model_name = "codellama/CodeLlama-7b-Instruct-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
)
#print(model)
model.to(device)
@torch.no_grad()
def gen(instruction, system=None):
if system is not None:
instruction = f"<<SYS>>\n{system}\n<</SYS>>\n\n{instruction}"
prompt = f"[INST] {instruction} [/INST]"
print(prompt)
inputs = tokenizer(
prompt,
return_tensors="pt",
)
inputs.to(device)
#print(inputs)
tokens = model.generate(
**inputs,
max_new_tokens=128,
temperature=0.2,
do_sample=True,
)
#print(tokens)
print(tokenizer.decode(tokens[0], skip_special_tokens=True))
# gen(
# 'Write a function that computes the set of sums of all contiguous sublists of a given list in Python.',
# system="Provide answers in JavaScript"
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment