Created
March 21, 2024 15:22
-
-
Save tengomucho/ba7155273291200cc47d96edae479eec to your computer and use it in GitHub Desktop.
Test showing an error when compiling model and using direct assignment
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache | |
import torch | |
import os | |
import torch_xla.core.xla_model as xm | |
from datetime import datetime | |
os.environ["PJRT_DEVICE"] = "TPU" | |
def dprint(*args, **kwargs): | |
print(datetime.now(), *args, **kwargs) | |
def sample_greedy(logits): | |
next_logits = logits[:, -1] | |
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int() | |
return next_token_id | |
@torch.no_grad | |
def test_gemma_model(): | |
model_id = "google/gemma-2b" | |
dprint("Setting up device.") | |
device = xm.xla_device(devkind="TPU") | |
torch_dtype = torch.bfloat16 | |
# Prepare inputs | |
dprint("Preparing inputs.") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
prompts = ["", "", "", "It was a bright cold day in April, and the clocks were striking thirteen."] | |
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) | |
input_ids = inputs["input_ids"] | |
attention_mask = inputs["attention_mask"] | |
batch_size, seq_length = input_ids.shape | |
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) | |
cache_position = torch.arange(seq_length, device=device) | |
# Load model | |
dprint("Loading and preparing model.") | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, device_map=device) | |
model = model.eval() | |
max_cache_length = 1024 | |
# Setup cache | |
model._setup_cache(StaticCache, batch_size, max_cache_length) | |
# This is probably not useful here, but I do it in the original code | |
xm.mark_step() | |
# Run model prefill | |
dprint("Running prefill.") | |
outputs = model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
cache_position=cache_position, | |
use_cache=True, | |
) | |
logits = outputs.logits | |
xm.mark_step() | |
next_token_id = sample_greedy(logits) | |
next_changing_token = next_token_id[3, 0].item() | |
dprint(f"Prefill_output: {next_changing_token} {tokenizer.decode(next_changing_token)}") | |
# Prepare next position_ids | |
position_ids = position_ids.max(axis=-1)[0].unsqueeze(1) + 1 | |
cur_position = seq_length | |
# Compile model | |
model = torch.compile(model, backend="openxla") | |
# Run for few more steps | |
for i in range(3): | |
position_ids = torch.zeros( | |
[batch_size, 1], | |
dtype=torch.int64, | |
device=device, | |
) | |
position_ids[3, 0] = cur_position | |
cache_position = position_ids.max().unsqueeze(0) | |
input_ids = torch.full( | |
[batch_size, 1], | |
fill_value=tokenizer.eos_token_id, | |
dtype=torch.int64, | |
device=device, | |
) | |
# Following input only contains the third value set | |
# NOTE: here's where things go wrong: if I assign the value in the | |
# tensor, it seems the compiled call will not work as expected, as if | |
# input set was not taking into account. | |
input_ids[3, 0] = next_changing_token | |
# I found a workaround: using index_put_ instead of direct assignment. | |
# It would be good to understand why direct assignment does not work. | |
# Commented code below works as expected: | |
# | |
# input_ids.index_put_([torch.tensor([3])], torch.tensor(next_changing_token)) | |
xm.mark_step() | |
outputs = model( | |
input_ids, | |
position_ids=position_ids, | |
cache_position=cache_position, | |
use_cache=True, | |
) | |
logits = outputs.logits | |
xm.mark_step() | |
next_token_id = sample_greedy(logits) | |
next_changing_token = next_token_id[3, 0].item() | |
cur_position += 1 | |
dprint(f"Output {i} {next_changing_token} {tokenizer.decode(next_changing_token)}") | |
assert next_changing_token != tokenizer.eos_token_id | |
if __name__ == "__main__": | |
test_gemma_model() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment