Skip to content

Instantly share code, notes, and snippets.

@priya-dwivedi
Created September 30, 2020 18:31
Show Gist options
  • Save priya-dwivedi/b0713a2802205a18747379ee59abfa3b to your computer and use it in GitHub Desktop.
Save priya-dwivedi/b0713a2802205a18747379ee59abfa3b to your computer and use it in GitHub Desktop.
Inference using Auto LM head
from transformers import AutoTokenizer, AutoModelWithLMHead
tokenizer = AutoTokenizer.from_pretrained("deep-learning-analytics/triviaqa-t5-base")
model = AutoModelWithLMHead.from_pretrained("deep-learning-analytics/triviaqa-t5-base")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
text = "Who directed the movie Jaws?"
preprocess_text = text.strip().replace("\n","")
tokenized_text = tokenizer.encode(preprocess_text, return_tensors="pt").to(device)
outs = model.model.generate(
tokenized_text,
max_length=10,
num_beams=2,
early_stopping=True
)
dec = [tokenizer.decode(ids) for ids in outs]
print("Predicted Answer: ", dec)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment