Skip to content

Instantly share code, notes, and snippets.

@tengomucho
Created March 21, 2024 15:22
Show Gist options
  • Save tengomucho/ba7155273291200cc47d96edae479eec to your computer and use it in GitHub Desktop.
Save tengomucho/ba7155273291200cc47d96edae479eec to your computer and use it in GitHub Desktop.
Test showing an error when compiling model and using direct assignment
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache
import torch
import os
import torch_xla.core.xla_model as xm
from datetime import datetime
os.environ["PJRT_DEVICE"] = "TPU"
def dprint(*args, **kwargs):
print(datetime.now(), *args, **kwargs)
def sample_greedy(logits):
next_logits = logits[:, -1]
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int()
return next_token_id
@torch.no_grad
def test_gemma_model():
model_id = "google/gemma-2b"
dprint("Setting up device.")
device = xm.xla_device(devkind="TPU")
torch_dtype = torch.bfloat16
# Prepare inputs
dprint("Preparing inputs.")
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompts = ["", "", "", "It was a bright cold day in April, and the clocks were striking thirteen."]
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
batch_size, seq_length = input_ids.shape
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
cache_position = torch.arange(seq_length, device=device)
# Load model
dprint("Loading and preparing model.")
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, device_map=device)
model = model.eval()
max_cache_length = 1024
# Setup cache
model._setup_cache(StaticCache, batch_size, max_cache_length)
# This is probably not useful here, but I do it in the original code
xm.mark_step()
# Run model prefill
dprint("Running prefill.")
outputs = model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=True,
)
logits = outputs.logits
xm.mark_step()
next_token_id = sample_greedy(logits)
next_changing_token = next_token_id[3, 0].item()
dprint(f"Prefill_output: {next_changing_token} {tokenizer.decode(next_changing_token)}")
# Prepare next position_ids
position_ids = position_ids.max(axis=-1)[0].unsqueeze(1) + 1
cur_position = seq_length
# Compile model
model = torch.compile(model, backend="openxla")
# Run for few more steps
for i in range(3):
position_ids = torch.zeros(
[batch_size, 1],
dtype=torch.int64,
device=device,
)
position_ids[3, 0] = cur_position
cache_position = position_ids.max().unsqueeze(0)
input_ids = torch.full(
[batch_size, 1],
fill_value=tokenizer.eos_token_id,
dtype=torch.int64,
device=device,
)
# Following input only contains the third value set
# NOTE: here's where things go wrong: if I assign the value in the
# tensor, it seems the compiled call will not work as expected, as if
# input set was not taking into account.
input_ids[3, 0] = next_changing_token
# I found a workaround: using index_put_ instead of direct assignment.
# It would be good to understand why direct assignment does not work.
# Commented code below works as expected:
#
# input_ids.index_put_([torch.tensor([3])], torch.tensor(next_changing_token))
xm.mark_step()
outputs = model(
input_ids,
position_ids=position_ids,
cache_position=cache_position,
use_cache=True,
)
logits = outputs.logits
xm.mark_step()
next_token_id = sample_greedy(logits)
next_changing_token = next_token_id[3, 0].item()
cur_position += 1
dprint(f"Output {i} {next_changing_token} {tokenizer.decode(next_changing_token)}")
assert next_changing_token != tokenizer.eos_token_id
if __name__ == "__main__":
test_gemma_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment