Skip to content

Instantly share code, notes, and snippets.

@moyix
Created September 6, 2023 22:06
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save moyix/3ce3cd7f4ae30b3db838151830330a78 to your computer and use it in GitHub Desktop.
Save moyix/3ce3cd7f4ae30b3db838151830330a78 to your computer and use it in GitHub Desktop.
StoppingCriteria abused to print tokens to stdout as they're generated
import sys
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import AutoTokenizer, AutoModelForCausalLM
class StreamPrinter(StoppingCriteria):
def __init__(self):
StoppingCriteria.__init__(self)
self.pos = 0
def __call__(self, input_ids, scores):
new_tokens = input_ids[0][self.pos:]
sys.stdout.write(tokenizer.decode(new_tokens))
sys.stdout.flush()
self.pos += len(new_tokens)
return False
tokenizer = AutoTokenizer.from_pretrained('/data/research/falcon-180B')
model = AutoModelForCausalLM.from_pretrained(
'/data/research/falcon-180B',
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map='auto',
load_in_4bit=True
)
prompt = "The UAE is known for its love of falcons."
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
output = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
do_sample=True,
temperature=0.6,
top_p=0.9,
max_new_tokens=512,
stopping_criteria=StoppingCriteriaList([StreamPrinter()]),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment