Skip to content

Instantly share code, notes, and snippets.

@ayerofieiev-tt
Last active February 9, 2025 05:13
Show Gist options
  • Save ayerofieiev-tt/c0fa2814483401205b7eb4929fa207ab to your computer and use it in GitHub Desktop.
Save ayerofieiev-tt/c0fa2814483401205b7eb4929fa207ab to your computer and use it in GitHub Desktop.
BertPy via PyTorch 2.0 TT-NN Backend
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch_ttnn
import ttnn
model_name = "phiyodr/bert-large-finetuned-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", torch_dtype=torch.bfloat16)
model = AutoModelForQuestionAnswering.from_pretrained(model_name, torch_dtype=torch.bfloat16)
# Prepare a sample context and question.
context = (
"Johann Joachim Winckelmann was a German art historian and archaeologist. "
"He was a pioneering Hellenist who first articulated the difference between Greek, "
"Greco-Roman and Roman art. 'The prophet and founding hero of modern archaeology', "
"Winckelmann was one of the founders of scientific archaeology and first applied the "
"categories of style on a large, systematic basis to the history of art."
)
question = "What discipline did Winkelmann create?"
inputs = tokenizer.encode_plus(
question,
context,
add_special_tokens=True,
return_tensors="pt",
max_length=256,
padding="max_length",
truncation=True,
)
device = ttnn.open_device(device_id=0)
ttnn.SetDefaultDevice(device)
option = torch_ttnn.TorchTtnnOption(
device=device,
gen_graphviz=False,
run_mem_analysis=False,
metrics_path=model_name,
verbose=True,
gen_op_accuracy_tests=False,
)
model.eval()
with torch.no_grad():
compiled_model = torch.compile(model, backend=torch_ttnn.backend, options=option)
outputs = compiled_model(**inputs)
start_index = torch.argmax(outputs.start_logits)
end_index = torch.argmax(outputs.end_logits) + 1
answer_tokens = inputs["input_ids"][0, start_index:end_index]
answer = tokenizer.decode(answer_tokens)
print("Question:", question)
print("Answer:", answer)
ttnn.synchronize_device(device)
ttnn.close_device(device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment