Skip to content

Instantly share code, notes, and snippets.

@huikang
Created October 23, 2019 16:43
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save huikang/83b7674e3fd36895f69b5c8f9b8bcc4d to your computer and use it in GitHub Desktop.
Save huikang/83b7674e3fd36895f69b5c8f9b8bcc4d to your computer and use it in GitHub Desktop.
calculate perplexity
import math
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
# Load pre-trained model (weights)
model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
model.eval()
# Load pre-trained model tokenizer (vocabulary)
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
def score(sentence):
print(sentence)
tokenize_input = tokenizer.tokenize(sentence)
tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
loss=model(tensor_input, lm_labels=tensor_input)
#loss=model(tensor_input)
print(loss)
return math.exp(loss)
#a=['there is a book on the desk',
# 'there is a plane on the desk',
# 'there is a book in the desk']
print(score('there is a book on the desk'))
#print([score(i) for i in a])
#21.31652459381952, 61.45907380241148, 26.24923942649312
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment