Skip to content

Instantly share code, notes, and snippets.

@lucataco
Last active June 30, 2023 23:55
Show Gist options
  • Save lucataco/013560e99b89f64e346ff9ed803a9699 to your computer and use it in GitHub Desktop.
Save lucataco/013560e99b89f64e346ff9ed803a9699 to your computer and use it in GitHub Desktop.
Falcon7B HF speed test
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
import time
model = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
max_length = 200
do_sample = True
top_k = 10
num_return_sequences = 1
eos_token_id = tokenizer.eos_token_id
start_time = time.time()
sequences = pipeline(
text,
max_length=max_length,
do_sample=do_sample,
top_k=top_k,
num_return_sequences=num_return_sequences,
eos_token_id=eos_token_id,
)
end_time = time.time()
for seq in sequences:
print(f"Result: {seq['generated_text']}")
num_tokens = sum(len(seq['generated_text'].split()) for seq in sequences)
duration = end_time - start_time
tokens_per_second = num_tokens / duration
print(f"Number of tokens generated per second: {tokens_per_second:.2f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment