Skip to content

Instantly share code, notes, and snippets.

@abacaj
Last active April 8, 2024 14:11
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save abacaj/0ec2eb79dc2e94d63d10138207f3c877 to your computer and use it in GitHub Desktop.
Save abacaj/0ec2eb79dc2e94d63d10138207f3c877 to your computer and use it in GitHub Desktop.
Stream HF transformer token generation
from queue import Queue
from threading import Thread
import transformers
import torch
class TextIteratorStreamer:
def __init__(
self, tokenizer
):
self.tokenizer = tokenizer
self.queue = Queue()
self.next_tokens_are_prompt = True
def put(self, value):
if self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
self.queue.put(value)
def end(self):
self.queue.put("[STOP]")
def __iter__(self):
return self
def __next__(self):
value = self.queue.get(timeout=1) # you should adjust this for models that take longer to generate
if value == "[STOP]":
raise StopIteration
return value.item()
model_name = "model_name"
# add your tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# load your model
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, trust_remote_code=True
)
streamer = TextIteratorStreamer(
tokenizer
)
inputs = tokenizer(
"Who was the first president? The",
return_tensors="pt",
).to(model.device)
thread = Thread(
target=model.generate,
kwargs=dict(
inputs,
use_cache=True,
do_sample=True,
temperature=0.4,
top_p=0.95,
streamer=streamer,
max_new_tokens=128,
pad_token_id=tokenizer.pad_token_id,
),
)
thread.start()
# print stream of tokens
for token in streamer:
word = tokenizer.decode(token, skip_special_tokens=True)
print(word, end="", flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment