Skip to content

Instantly share code, notes, and snippets.

@kingjr
Last active October 3, 2019 12:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kingjr/301a1744d22fd878602a55a800c5c2bd to your computer and use it in GitHub Desktop.
Save kingjr/301a1744d22fd878602a55a800c5c2bd to your computer and use it in GitHub Desktop.
charlotte
import sys
import os
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath('./sentence_embedding'))
from utils import load_model, build_XLM_dictionary # noqa
model_path = '/private/home/jeanremi/data/mous/models/'
model_fname = model_path + '1024_best-valid_clm_ppl.pth'
#model_fname = model_path + 'BERT_512_best-valid_mlm_ppl.pth'
model = load_model('XLM', model_fname)
def extract_proba(model, token_sentences):
"""
Inputs
#model : XLM model loaded with load_model() function
#token_sentences : ['The cat is on the mat', 'Please give me some bread'] with nwords total words
Outputs
#word_scores : tensor of size (nwords, vocab_size), projection of inputs onto the predlayer (self.proj(x).view(-1, self.n_words))
#loss : mean cross entropy between predicted and actual distributions (F.cross_entropy(word_scores, y, reduction='mean'))
"""
# format sentences
positions = None
sentences = [(j, 'nl') for j in token_sentences]
sentences = [(('</s> %s </s>' % sent.strip()).split(), lang)
for sent, lang in sentences if len(sent)>0]
# build dictionnary
x, lengths, langs = build_XLM_dictionary(sentences, model)
# compute causal mask
alen = torch.arange(lengths.max(), dtype=torch.long, device=lengths.device)
pred_mask = alen[:, None] < lengths[None] - 1
y = x[1:].masked_select(pred_mask[:-1])
assert pred_mask.sum().item() == y.size(0)
# Generate embeddings
tensor = model.fwd(x, lengths, causal=True, langs=langs)
# Generate scores
word_scores, loss = model.predict(tensor, pred_mask, y,
get_scores=True)
# compute ranks
ranks = torch.argsort(word_scores, 1, True)
ranks = [np.where(rank == truth)[0]
for rank, truth in zip(ranks, x[:, 0])]
ranks = [r[0] if len(r) else np.nan for r in ranks]
# predict
predict = torch.argmax(word_scores, 1)
predict = ' '.join([model.dico[int(i)] for i in predict])
return (x, word_scores, ranks, predict)
# 'This boy is a very happy child. He went to X'
np.random.seed(42)
token_sentences = ['Deze jongen is een heel gelukkig kind. Hij ging naar X',]
# Manual recursion to generate plausible sentence
for k in range(10):
x, word_scores, ranks, predict = extract_proba(model, token_sentences)
# make a guess on what the last word should be
# avoid 'unknown'
guess = '<unk>'
score = np.argsort(word_scores.detach().numpy()[-1])
while guess == '<unk>':
# take one of the 10 best words
idx = -np.random.randint(10)
guess = model.dico[score[idx]]
# add guess last. X is here to comply to the mask API
token_sentences[0] = token_sentences[0][:-1] + guess + ' X'
print(token_sentences[0]) # not a very plausible sentence
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment