Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Created May 16, 2021 20:14
Show Gist options
  • Save MLWhiz/041c6bd49a32a84fac1eb58b679b58d8 to your computer and use it in GitHub Desktop.
Save MLWhiz/041c6bd49a32a84fac1eb58b679b58d8 to your computer and use it in GitHub Desktop.
model = AutoModelForQuestionAnswering.from_pretrained("test-squad-trained")
text = r"""
🤗 Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose
architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural
Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
TensorFlow 2.0 and PyTorch
"""
questions = [
"How many pretrained models are available in Transformers?",
"What does Transformers provide?",
"Transformers provides interoperability between which frameworks?",
]
for question in questions:
inputs = tokenizer.encode_plus(question, text, add_special_tokens=True, return_tensors="pt")
input_ids = inputs["input_ids"].tolist()[0]
text_tokens = tokenizer.convert_ids_to_tokens(input_ids)
pred = model(**inputs)
answer_start_scores, answer_end_scores = pred['start_logits'][0] ,pred['end_logits'][0]
answer_start = torch.argmax(
answer_start_scores
) # Get the most likely beginning of answer with the argmax of the score
answer_end = torch.argmax(answer_end_scores) + 1 # Get the most likely end of answer with the argmax of the score
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
print(f"Question: {question}")
print(f"Answer: {answer}\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment