Skip to content

Instantly share code, notes, and snippets.

@mirfan899
Forked from huikang/perplexity.py
Created March 30, 2021 05:03
Show Gist options
  • Save mirfan899/da9ea0af1ef6077e02e2d249555c3904 to your computer and use it in GitHub Desktop.
Save mirfan899/da9ea0af1ef6077e02e2d249555c3904 to your computer and use it in GitHub Desktop.
calculate perplexity
import math
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
# Load pre-trained model (weights)
model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
model.eval()
# Load pre-trained model tokenizer (vocabulary)
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
def score(sentence):
print(sentence)
tokenize_input = tokenizer.tokenize(sentence)
tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
loss=model(tensor_input, lm_labels=tensor_input)
#loss=model(tensor_input)
print(loss)
return math.exp(loss)
#a=['there is a book on the desk',
# 'there is a plane on the desk',
# 'there is a book in the desk']
print(score('there is a book on the desk'))
#print([score(i) for i in a])
#21.31652459381952, 61.45907380241148, 26.24923942649312
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment