Last active
August 22, 2018 16:57
-
-
Save suzusuzu/dfb370aef34e96c93a93b0b7be74c978 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import chainer | |
import math | |
import argparse | |
def ngram_sentences(sentences, n=1): | |
parse = lambda words, n : list(zip(*(words[i:] for i in range(n)))) | |
arr = [] | |
for sentence in sentences: | |
arr.extend(parse(sentence, n)) | |
return arr | |
def freq(ngram): | |
freq_map = {} | |
for o in ngram: | |
if o not in freq_map: | |
freq_map[o]=1 | |
else: | |
freq_map[o]+=1 | |
return freq_map | |
def prob(words, ngram_freq, n_1gram_freq, v, alpha): | |
n = len(words) | |
a = ngram_freq.get(tuple(words), 0) + alpha | |
b = n_1gram_freq.get(tuple(words[0:n-1]), 0) + v * alpha | |
return a/b | |
def perplexity(n, sentences, ngram_freq, n_1gram_freq, v, alpha=0.01): | |
ppl = 0.0 | |
N = 0 | |
for sentence in sentences: | |
for i in range(len(sentence)-n+1): | |
p_ = prob(sentence[i:i+n], ngram_freq, n_1gram_freq, v, alpha) | |
ppl -= math.log(p_) | |
N+=1 | |
return math.exp(ppl/N) | |
def sentences(text, eos): | |
sentences = [] | |
sentence = [] | |
for o in text: | |
sentence.append(o) | |
if o == eos: | |
sentences.append(sentence) | |
sentence = [] | |
return sentences | |
if __name__ == '__main__': | |
# arg parser | |
parser = argparse.ArgumentParser(description='n-gram perplexity in ptb') | |
parser.add_argument('-n', type=int, default=2, help='n in n-gram (default: 2)') | |
parser.add_argument('-a', type=int, default=0.01, help='alpha in additive smoothing (default: 0.01)') | |
args = parser.parse_args() | |
n = args.n | |
alpha = args.a | |
# data | |
train, val, test = chainer.datasets.get_ptb_words() | |
ptb_dict = chainer.datasets.get_ptb_words_vocabulary() | |
V = len(ptb_dict) | |
eosid = ptb_dict['<eos>'] | |
trains = sentences(train, eosid) | |
tests = sentences(test, eosid) | |
# ngram parse | |
ngram = ngram_sentences(trains, n) | |
n_1gram = ngram_sentences(trains, n-1) | |
ngram_freq = freq(ngram) | |
n_1gram_freq = freq(n_1gram) | |
# train perplexity | |
train_ppl = perplexity(n, trains, ngram_freq, n_1gram_freq, V, alpha) | |
print(str(n) + '-gram perplexity in train:', train_ppl) | |
# test perplexity | |
test_ppl = perplexity(n, tests, ngram_freq, n_1gram_freq, V, alpha) | |
print(str(n) + '-gram perplexity in test:', test_ppl) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment