Skip to content

Instantly share code, notes, and snippets.

@itdxer
Created March 30, 2025 12:41
Show Gist options
  • Select an option

  • Save itdxer/fc96b8861422b7d504b2b7d121a440d7 to your computer and use it in GitHub Desktop.

Select an option

Save itdxer/fc96b8861422b7d504b2b7d121a440d7 to your computer and use it in GitHub Desktop.
import json
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, StoppingCriteria
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
PROMPT = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|><|eot_id|><|start_header_id|>user<|end_header_id|>
Extract the 'id', 'city', and 'time' from the following text and return them in a JSON format.
The text consists of an 'id' and a message separated by a semicolon.
{ids_and_text}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
DATA = {
90: "At midnight in Paris, the harbor lights reflect beautifully on the calm waters",
215: "The bustling streets of Amsterdam are particularly vibrant at 5:30 PM on weekdays when office hours end.",
577: "In Rome, the sun sets around 8 PM during the summer months, creating a beautiful evening atmosphere.",
664: "At noon in Paris, the aroma of freshly baked bread fills the air as locals enjoy their lunch breaks.",
}
def create_prompt(data):
return PROMPT.format(ids_and_text="\n".join(f"{identifier}: {text}" for identifier, text in data.items()))
def iter_json_parts(identifiers, desired_keys):
entries = [{"id": identifier} | dict.fromkeys(desired_keys, "PLACEHOLDER") for identifier in identifiers]
return json.dumps(entries, indent=2).split("PLACEHOLDER")
class StopOnJSONStringEnd(StoppingCriteria):
def __init__(self, tokenizer):
self.special_token_ids = [token_id for token, token_id in tokenizer.vocab.items() if '"' in token]
def __call__(self, input_ids, *args, **kwargs):
# Note: for simplicity, we handle only one sequence at a time
return input_ids[0, -1].item() in self.special_token_ids
if __name__ == "__main__":
device = torch.device("mps")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16).to(device)
stop_criteria = StopOnJSONStringEnd(tokenizer)
past_key_values = DynamicCache()
prompt = create_prompt(DATA)
token_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
start_inference_time = time.time()
for part in iter_json_parts(identifiers=DATA.keys(), desired_keys=["city", "time"]):
part_ids = tokenizer(part, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
token_ids = torch.cat([token_ids, part_ids], axis=-1)
if part.endswith('"'): # part of the JSON which starts a new string and has to be filled in by LLM
token_ids = model.generate(
token_ids,
max_new_tokens=50, # maximum length of the internal string
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=[stop_criteria],
use_cache=True,
past_key_values=past_key_values,
)
if stop_criteria(token_ids):
token_ids = token_ids[:, :-1]
print(f"Inference time: {time.time() - start_inference_time}")
print(tokenizer.decode(token_ids[0]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment