Last active
October 3, 2019 12:29
-
-
Save kingjr/301a1744d22fd878602a55a800c5c2bd to your computer and use it in GitHub Desktop.
charlotte
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
import os | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
sys.path.append(os.path.abspath('./sentence_embedding')) | |
from utils import load_model, build_XLM_dictionary # noqa | |
model_path = '/private/home/jeanremi/data/mous/models/' | |
model_fname = model_path + '1024_best-valid_clm_ppl.pth' | |
#model_fname = model_path + 'BERT_512_best-valid_mlm_ppl.pth' | |
model = load_model('XLM', model_fname) | |
def extract_proba(model, token_sentences): | |
""" | |
Inputs | |
#model : XLM model loaded with load_model() function | |
#token_sentences : ['The cat is on the mat', 'Please give me some bread'] with nwords total words | |
Outputs | |
#word_scores : tensor of size (nwords, vocab_size), projection of inputs onto the predlayer (self.proj(x).view(-1, self.n_words)) | |
#loss : mean cross entropy between predicted and actual distributions (F.cross_entropy(word_scores, y, reduction='mean')) | |
""" | |
# format sentences | |
positions = None | |
sentences = [(j, 'nl') for j in token_sentences] | |
sentences = [(('</s> %s </s>' % sent.strip()).split(), lang) | |
for sent, lang in sentences if len(sent)>0] | |
# build dictionnary | |
x, lengths, langs = build_XLM_dictionary(sentences, model) | |
# compute causal mask | |
alen = torch.arange(lengths.max(), dtype=torch.long, device=lengths.device) | |
pred_mask = alen[:, None] < lengths[None] - 1 | |
y = x[1:].masked_select(pred_mask[:-1]) | |
assert pred_mask.sum().item() == y.size(0) | |
# Generate embeddings | |
tensor = model.fwd(x, lengths, causal=True, langs=langs) | |
# Generate scores | |
word_scores, loss = model.predict(tensor, pred_mask, y, | |
get_scores=True) | |
# compute ranks | |
ranks = torch.argsort(word_scores, 1, True) | |
ranks = [np.where(rank == truth)[0] | |
for rank, truth in zip(ranks, x[:, 0])] | |
ranks = [r[0] if len(r) else np.nan for r in ranks] | |
# predict | |
predict = torch.argmax(word_scores, 1) | |
predict = ' '.join([model.dico[int(i)] for i in predict]) | |
return (x, word_scores, ranks, predict) | |
# 'This boy is a very happy child. He went to X' | |
np.random.seed(42) | |
token_sentences = ['Deze jongen is een heel gelukkig kind. Hij ging naar X',] | |
# Manual recursion to generate plausible sentence | |
for k in range(10): | |
x, word_scores, ranks, predict = extract_proba(model, token_sentences) | |
# make a guess on what the last word should be | |
# avoid 'unknown' | |
guess = '<unk>' | |
score = np.argsort(word_scores.detach().numpy()[-1]) | |
while guess == '<unk>': | |
# take one of the 10 best words | |
idx = -np.random.randint(10) | |
guess = model.dico[score[idx]] | |
# add guess last. X is here to comply to the mask API | |
token_sentences[0] = token_sentences[0][:-1] + guess + ' X' | |
print(token_sentences[0]) # not a very plausible sentence |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment