-
-
Save itdxer/942e61cb2eb254c9ef2a472076103793 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import json | |
import time | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, StoppingCriteria | |
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct" | |
PROMPT_TEMPLATE_A = """ | |
<|begin_of_text|><|start_header_id|>system<|end_header_id|><|eot_id|><|start_header_id|>user<|end_header_id|> | |
Extract the 'id', 'city', and 'time' from the following text and return them in a JSON format. | |
The text consists of an 'id' and a message separated by a semicolon. Do not generate anything other then JSON. | |
{ids_and_text} | |
<|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
""" | |
PROMPT_TEMPLATE_B = """ | |
<|begin_of_text|><|start_header_id|>system<|end_header_id|><|eot_id|><|start_header_id|>user<|end_header_id|> | |
Create a JSON with keys: | |
id: Unique identifier | |
text: Original text content. | |
city: City name from the text. | |
time: Time or date from the text. | |
<|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
""" | |
DATA = { | |
90: "At midnight in Paris, the harbor lights reflect beautifully on the calm waters", | |
215: "The bustling streets of Amsterdam are particularly vibrant at 5:30 PM on weekdays when office hours end.", | |
577: "In Rome, the sun sets around 8 PM during the summer months, creating a beautiful evening atmosphere.", | |
664: "At noon in Paris, the aroma of freshly baked bread fills the air as locals enjoy their lunch breaks.", | |
} | |
END_OF_STRING_REGEXP = re.compile(r"(?<!\\)\"") | |
def is_end_of_string(token): | |
return END_OF_STRING_REGEXP.search(token) is not None | |
def create_prompt(data): | |
return PROMPT_TEMPLATE_A.format(ids_and_text="\n".join(f"{identifier}: {text}" for identifier, text in data.items())) | |
def iter_json_parts_template_A(identifiers, desired_keys): | |
entries = [{"id": identifier} | dict.fromkeys(desired_keys, "PLACEHOLDER") for identifier in identifiers] | |
return json.dumps(entries, indent=2).split("PLACEHOLDER") | |
def iter_json_parts_template_B(data, desired_keys): | |
entries = [{"id": identifier, "text": text} | dict.fromkeys(desired_keys, "PLACEHOLDER") for identifier, text in data.items()] | |
return json.dumps(entries, indent=2).split("PLACEHOLDER") | |
class StopOnJSONStringEnd(StoppingCriteria): | |
def __init__(self, llm): | |
self.llm = llm | |
self.special_token_ids = [token_id for token, token_id in llm.tokenizer.vocab.items() if is_end_of_string(token)] | |
def __call__(self, input_ids, *args, **kwargs): | |
last_token_id = self.get_last_token_id(input_ids) | |
return last_token_id in self.special_token_ids | |
def get_last_token_id(self, input_ids): | |
# Note: for simplicity, we handle only one sequence at a time | |
assert len(input_ids) == 1, "Expected input with exactly one sequence in it" | |
return input_ids[0, -1].item() | |
def clean_last_token(self, input_ids): | |
last_token_id = self.get_last_token_id(input_ids) | |
last_token = self.llm.tokenizer.decode(last_token_id) | |
last_token_clean = END_OF_STRING_REGEXP.split(last_token, maxsplit=1)[0] | |
sequence = input_ids[:, :-1] | |
if last_token_clean: # making sure that it's not an empty string | |
sequence = self.llm.append_tokens(sequence, text_to_append=last_token_clean) | |
return sequence | |
class LLM: | |
def __init__(self, model_path, device): | |
self.device = device | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16).to(device) | |
self.stop_criteria = StopOnJSONStringEnd(self) | |
self.reset_cache() | |
def reset_cache(self): | |
self.past_key_values = DynamicCache() | |
def tokenize(self, text): | |
return self.tokenizer(text, return_tensors="pt").input_ids.to(self.device) | |
def append_tokens(self, token_ids, text_to_append): | |
to_append_ids = self.tokenizer(text_to_append, return_tensors="pt", add_special_tokens=False).input_ids.to(self.device) | |
return torch.cat([token_ids, to_append_ids], axis=-1) | |
def generate(self, token_ids): | |
token_ids = self.model.generate( | |
token_ids, | |
max_new_tokens=50, # maximum length of the internal string | |
pad_token_id=self.tokenizer.eos_token_id, | |
stopping_criteria=[self.stop_criteria], | |
use_cache=True, | |
past_key_values=self.past_key_values, | |
) | |
if self.stop_criteria(token_ids): | |
token_ids = self.stop_criteria.clean_last_token(token_ids) | |
return token_ids | |
def run_first_testcase(llm): | |
print("Running first test case") | |
prompt = create_prompt(DATA) | |
token_ids = llm.tokenize(prompt) | |
start_inference_time = time.time() | |
token_ids = llm.model.generate(token_ids, max_new_tokens=500) | |
print(f"Inference time: {time.time() - start_inference_time:.3f} sec") | |
print(llm.tokenizer.decode(token_ids[0])) | |
def run_second_testcase(llm): | |
print("Running second test case") | |
llm.reset_cache() | |
prompt = create_prompt(DATA) | |
token_ids = llm.tokenize(prompt) | |
start_inference_time = time.time() | |
for json_part in iter_json_parts_template_A(identifiers=DATA.keys(), desired_keys=["city", "time"]): | |
token_ids = llm.append_tokens(token_ids, text_to_append=json_part) | |
if json_part.endswith('"'): # part of the JSON which starts a new string and has to be filled in by LLM | |
token_ids = llm.generate(token_ids) | |
print(f"Inference time: {time.time() - start_inference_time:.3f} sec") | |
print(llm.tokenizer.decode(token_ids[0])) | |
def run_third_testcase(llm): | |
print("Running third test case") | |
llm.reset_cache() | |
token_ids = llm.tokenize(PROMPT_TEMPLATE_B) | |
start_inference_time = time.time() | |
for json_part in iter_json_parts_template_B(DATA, desired_keys=["city", "time"]): | |
token_ids = llm.append_tokens(token_ids, text_to_append=json_part) | |
if json_part.endswith('"'): # part of the JSON which starts a new string and has to be filled in by LLM | |
token_ids = llm.generate(token_ids) | |
print(f"Inference time: {time.time() - start_inference_time:.3f} sec") | |
print(llm.tokenizer.decode(token_ids[0])) | |
if __name__ == "__main__": | |
llm = LLM(MODEL_PATH, device=torch.device("mps")) | |
run_first_testcase(llm) | |
run_second_testcase(llm) | |
run_third_testcase(llm) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment