Skip to content

Instantly share code, notes, and snippets.

@itdxer
Last active March 30, 2025 14:58
Show Gist options
  • Save itdxer/942e61cb2eb254c9ef2a472076103793 to your computer and use it in GitHub Desktop.
Save itdxer/942e61cb2eb254c9ef2a472076103793 to your computer and use it in GitHub Desktop.
import re
import json
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, StoppingCriteria
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
PROMPT_TEMPLATE_A = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|><|eot_id|><|start_header_id|>user<|end_header_id|>
Extract the 'id', 'city', and 'time' from the following text and return them in a JSON format.
The text consists of an 'id' and a message separated by a semicolon. Do not generate anything other then JSON.
{ids_and_text}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
PROMPT_TEMPLATE_B = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|><|eot_id|><|start_header_id|>user<|end_header_id|>
Create a JSON with keys:
id: Unique identifier
text: Original text content.
city: City name from the text.
time: Time or date from the text.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
DATA = {
90: "At midnight in Paris, the harbor lights reflect beautifully on the calm waters",
215: "The bustling streets of Amsterdam are particularly vibrant at 5:30 PM on weekdays when office hours end.",
577: "In Rome, the sun sets around 8 PM during the summer months, creating a beautiful evening atmosphere.",
664: "At noon in Paris, the aroma of freshly baked bread fills the air as locals enjoy their lunch breaks.",
}
END_OF_STRING_REGEXP = re.compile(r"(?<!\\)\"")
def is_end_of_string(token):
return END_OF_STRING_REGEXP.search(token) is not None
def create_prompt(data):
return PROMPT_TEMPLATE_A.format(ids_and_text="\n".join(f"{identifier}: {text}" for identifier, text in data.items()))
def iter_json_parts_template_A(identifiers, desired_keys):
entries = [{"id": identifier} | dict.fromkeys(desired_keys, "PLACEHOLDER") for identifier in identifiers]
return json.dumps(entries, indent=2).split("PLACEHOLDER")
def iter_json_parts_template_B(data, desired_keys):
entries = [{"id": identifier, "text": text} | dict.fromkeys(desired_keys, "PLACEHOLDER") for identifier, text in data.items()]
return json.dumps(entries, indent=2).split("PLACEHOLDER")
class StopOnJSONStringEnd(StoppingCriteria):
def __init__(self, llm):
self.llm = llm
self.special_token_ids = [token_id for token, token_id in llm.tokenizer.vocab.items() if is_end_of_string(token)]
def __call__(self, input_ids, *args, **kwargs):
last_token_id = self.get_last_token_id(input_ids)
return last_token_id in self.special_token_ids
def get_last_token_id(self, input_ids):
# Note: for simplicity, we handle only one sequence at a time
assert len(input_ids) == 1, "Expected input with exactly one sequence in it"
return input_ids[0, -1].item()
def clean_last_token(self, input_ids):
last_token_id = self.get_last_token_id(input_ids)
last_token = self.llm.tokenizer.decode(last_token_id)
last_token_clean = END_OF_STRING_REGEXP.split(last_token, maxsplit=1)[0]
sequence = input_ids[:, :-1]
if last_token_clean: # making sure that it's not an empty string
sequence = self.llm.append_tokens(sequence, text_to_append=last_token_clean)
return sequence
class LLM:
def __init__(self, model_path, device):
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16).to(device)
self.stop_criteria = StopOnJSONStringEnd(self)
self.reset_cache()
def reset_cache(self):
self.past_key_values = DynamicCache()
def tokenize(self, text):
return self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)
def append_tokens(self, token_ids, text_to_append):
to_append_ids = self.tokenizer(text_to_append, return_tensors="pt", add_special_tokens=False).input_ids.to(self.device)
return torch.cat([token_ids, to_append_ids], axis=-1)
def generate(self, token_ids):
token_ids = self.model.generate(
token_ids,
max_new_tokens=50, # maximum length of the internal string
pad_token_id=self.tokenizer.eos_token_id,
stopping_criteria=[self.stop_criteria],
use_cache=True,
past_key_values=self.past_key_values,
)
if self.stop_criteria(token_ids):
token_ids = self.stop_criteria.clean_last_token(token_ids)
return token_ids
def run_first_testcase(llm):
print("Running first test case")
prompt = create_prompt(DATA)
token_ids = llm.tokenize(prompt)
start_inference_time = time.time()
token_ids = llm.model.generate(token_ids, max_new_tokens=500)
print(f"Inference time: {time.time() - start_inference_time:.3f} sec")
print(llm.tokenizer.decode(token_ids[0]))
def run_second_testcase(llm):
print("Running second test case")
llm.reset_cache()
prompt = create_prompt(DATA)
token_ids = llm.tokenize(prompt)
start_inference_time = time.time()
for json_part in iter_json_parts_template_A(identifiers=DATA.keys(), desired_keys=["city", "time"]):
token_ids = llm.append_tokens(token_ids, text_to_append=json_part)
if json_part.endswith('"'): # part of the JSON which starts a new string and has to be filled in by LLM
token_ids = llm.generate(token_ids)
print(f"Inference time: {time.time() - start_inference_time:.3f} sec")
print(llm.tokenizer.decode(token_ids[0]))
def run_third_testcase(llm):
print("Running third test case")
llm.reset_cache()
token_ids = llm.tokenize(PROMPT_TEMPLATE_B)
start_inference_time = time.time()
for json_part in iter_json_parts_template_B(DATA, desired_keys=["city", "time"]):
token_ids = llm.append_tokens(token_ids, text_to_append=json_part)
if json_part.endswith('"'): # part of the JSON which starts a new string and has to be filled in by LLM
token_ids = llm.generate(token_ids)
print(f"Inference time: {time.time() - start_inference_time:.3f} sec")
print(llm.tokenizer.decode(token_ids[0]))
if __name__ == "__main__":
llm = LLM(MODEL_PATH, device=torch.device("mps"))
run_first_testcase(llm)
run_second_testcase(llm)
run_third_testcase(llm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment