Last active
March 8, 2024 14:19
-
-
Save tengomucho/76fb3d630ac4a99c7f1f5e654700bb60 to your computer and use it in GitHub Desktop.
🤗 transformers with google/gemma-2b on TPU test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# Mostly inspired from: | |
# https://twitter.com/reach_vb/status/1759716375033938291 | |
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache | |
import torch | |
import time | |
import datetime | |
import os | |
import platform | |
os.environ["PJRT_DEVICE"] = "TPU" | |
if "CPU" in os.environ: | |
device = "cpu" | |
else: | |
try: | |
import torch_xla.core.xla_model as xm | |
device = "xla" | |
except ModuleNotFoundError: | |
device = "mps" | |
def sample_greedy(logits): | |
next_logits = logits[:, -1] | |
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int() | |
return next_token_id | |
def decode_one_tokens(model, cur_token, input_pos, cache_position): | |
logits = model( | |
cur_token, | |
position_ids=input_pos, | |
cache_position=cache_position, | |
return_dict=False, | |
use_cache=True, | |
)[0] | |
new_token = sample_greedy(logits) | |
return new_token | |
def conditional_compile(func): | |
if "DBG_COMPILE" in os.environ: | |
if device == "mps": | |
compile_params = { | |
"backend": "aot_eager", | |
"fullgraph": True, | |
} | |
else: | |
compile_params = { | |
"backend": "openxla_eval", | |
# "mode":"reduce-overhead", | |
# "fullgraph": True | |
} | |
compiled = torch.compile(func, **compile_params) | |
return compiled | |
return func | |
prg_start = time.time() | |
model_id = "google/gemma-2b" | |
torch_dtype = torch.float16 if device == "mps" else torch.bfloat16 | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, device_map=device) | |
model = model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
prompts = ["Here's a funny thing:", "This is a good recipe for a dessert:"] | |
# 4 prompts test, for a change | |
# prompts = ["Here's a funny thing:", "This is a good recipe for a dessert:", "Give me one more chance", ""] | |
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) | |
batch_size, sequence_length = inputs["input_ids"].shape | |
max_cache_length = 1024 | |
max_new_tokens = 20 | |
with torch.no_grad(): | |
start = time.time() | |
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length) | |
end = time.time() | |
print(f"Model cache setup took {end - start} seconds.") | |
start = time.time() | |
cache_position = torch.arange(sequence_length, device=device) | |
generated_ids = torch.zeros( | |
(batch_size, sequence_length + max_new_tokens + 1), | |
dtype=torch.int, | |
device=device, | |
) | |
generated_ids[:, cache_position] = inputs["input_ids"].to(torch.int) | |
# prefill here | |
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] | |
next_token = sample_greedy(logits) | |
if device == "xla": | |
xm.mark_step() | |
generated_ids[:, sequence_length] = next_token[:, 0] | |
end = time.time() | |
print(f"Prefill took {end - start} seconds.") | |
attention_mask = inputs["attention_mask"] | |
pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) | |
pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1 | |
assert pos_ids[0, 0] != pos_ids[1, 0] | |
# decode_one_tokens = conditional_compile(decode_one_tokens) | |
model = conditional_compile(model) | |
cache_position = torch.tensor([sequence_length + 1], device=device) | |
start = time.time() | |
for i in range(max_new_tokens): | |
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position) | |
print(f"{i} - {datetime.datetime.now()}") | |
generated_ids[:, cache_position] = next_token | |
cache_position += 1 | |
pos_ids += 1 | |
if device == "xla": | |
xm.mark_step() | |
end = time.time() | |
print(f"Eval took {end - start} seconds.") | |
print(f"Getting data back {datetime.datetime.now()}") | |
decoded_texts = tokenizer.batch_decode(generated_ids) | |
for i, text in enumerate(decoded_texts): | |
print(i, text) | |
end = time.time() | |
print(f"Program run in {end - prg_start} seconds. Device: {device} System: {platform.system()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note that you will need to be authorized to use
google-gemma
(you can runhuggingface-cli login
to do that).To try with compilation on, try: