Skip to content

Instantly share code, notes, and snippets.

@suzusuzu
Last active August 22, 2018 16:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save suzusuzu/dfb370aef34e96c93a93b0b7be74c978 to your computer and use it in GitHub Desktop.
Save suzusuzu/dfb370aef34e96c93a93b0b7be74c978 to your computer and use it in GitHub Desktop.
import numpy as np
import chainer
import math
import argparse
def ngram_sentences(sentences, n=1):
parse = lambda words, n : list(zip(*(words[i:] for i in range(n))))
arr = []
for sentence in sentences:
arr.extend(parse(sentence, n))
return arr
def freq(ngram):
freq_map = {}
for o in ngram:
if o not in freq_map:
freq_map[o]=1
else:
freq_map[o]+=1
return freq_map
def prob(words, ngram_freq, n_1gram_freq, v, alpha):
n = len(words)
a = ngram_freq.get(tuple(words), 0) + alpha
b = n_1gram_freq.get(tuple(words[0:n-1]), 0) + v * alpha
return a/b
def perplexity(n, sentences, ngram_freq, n_1gram_freq, v, alpha=0.01):
ppl = 0.0
N = 0
for sentence in sentences:
for i in range(len(sentence)-n+1):
p_ = prob(sentence[i:i+n], ngram_freq, n_1gram_freq, v, alpha)
ppl -= math.log(p_)
N+=1
return math.exp(ppl/N)
def sentences(text, eos):
sentences = []
sentence = []
for o in text:
sentence.append(o)
if o == eos:
sentences.append(sentence)
sentence = []
return sentences
if __name__ == '__main__':
# arg parser
parser = argparse.ArgumentParser(description='n-gram perplexity in ptb')
parser.add_argument('-n', type=int, default=2, help='n in n-gram (default: 2)')
parser.add_argument('-a', type=int, default=0.01, help='alpha in additive smoothing (default: 0.01)')
args = parser.parse_args()
n = args.n
alpha = args.a
# data
train, val, test = chainer.datasets.get_ptb_words()
ptb_dict = chainer.datasets.get_ptb_words_vocabulary()
V = len(ptb_dict)
eosid = ptb_dict['<eos>']
trains = sentences(train, eosid)
tests = sentences(test, eosid)
# ngram parse
ngram = ngram_sentences(trains, n)
n_1gram = ngram_sentences(trains, n-1)
ngram_freq = freq(ngram)
n_1gram_freq = freq(n_1gram)
# train perplexity
train_ppl = perplexity(n, trains, ngram_freq, n_1gram_freq, V, alpha)
print(str(n) + '-gram perplexity in train:', train_ppl)
# test perplexity
test_ppl = perplexity(n, tests, ngram_freq, n_1gram_freq, V, alpha)
print(str(n) + '-gram perplexity in test:', test_ppl)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment