Skip to content

Instantly share code, notes, and snippets.

@mohdsanadzakirizvi
Created July 18, 2019 10:00
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mohdsanadzakirizvi/cd6c94286ff563e5a143224c63d29187 to your computer and use it in GitHub Desktop.
Save mohdsanadzakirizvi/cd6c94286ff563e5a143224c63d29187 to your computer and use it in GitHub Desktop.
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()
# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda')
# Predict all tokens
with torch.no_grad():
outputs = model(tokens_tensor, token_type_ids=segments_tensors)
predictions = outputs[0]
# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'
print('Predicted token is:',predicted_token)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment