Created
September 6, 2023 22:06
-
-
Save moyix/3ce3cd7f4ae30b3db838151830330a78 to your computer and use it in GitHub Desktop.
StoppingCriteria abused to print tokens to stdout as they're generated
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import torch | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
class StreamPrinter(StoppingCriteria): | |
def __init__(self): | |
StoppingCriteria.__init__(self) | |
self.pos = 0 | |
def __call__(self, input_ids, scores): | |
new_tokens = input_ids[0][self.pos:] | |
sys.stdout.write(tokenizer.decode(new_tokens)) | |
sys.stdout.flush() | |
self.pos += len(new_tokens) | |
return False | |
tokenizer = AutoTokenizer.from_pretrained('/data/research/falcon-180B') | |
model = AutoModelForCausalLM.from_pretrained( | |
'/data/research/falcon-180B', | |
torch_dtype=torch.bfloat16, | |
low_cpu_mem_usage=True, | |
device_map='auto', | |
load_in_4bit=True | |
) | |
prompt = "The UAE is known for its love of falcons." | |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") | |
output = model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.9, | |
max_new_tokens=512, | |
stopping_criteria=StoppingCriteriaList([StreamPrinter()]), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment