Skip to content

Instantly share code, notes, and snippets.

@roeeaharoni
Last active November 17, 2020 04:06
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save roeeaharoni/827f98cf899727e1b9bbb9d54e8357ed to your computer and use it in GitHub Desktop.
Save roeeaharoni/827f98cf899727e1b9bbb9d54e8357ed to your computer and use it in GitHub Desktop.
Generate the sciences of the future using BERT! (as seen on https://twitter.com/roeeaharoni/status/1089089393745371136)
import torch
from pytorch_pretrained_bert import BertForMaskedLM, BertTokenizer
import random
# Requires pytorch_pretrained_bert: https://github.com/huggingface/pytorch-pretrained-BERT
# returns the probabilities over the vocabulary for the masked words in sent
def get_preds(sent):
tokenized = bert_tokenizer.tokenize(sent)
tokenized = ['[CLS]'] + ['[MASK]' if x == 'mask' else x for x in tokenized] + ['[SEP]']
mask_idx = [ idx for idx,x in enumerate(tokenized) if x == '[MASK]']
token_ids = bert_tokenizer.convert_tokens_to_ids(tokenized)
token_ids = torch.LongTensor(token_ids).unsqueeze(0)
preds = bert_model(token_ids)
return preds[0,mask_idx]
if __name__ == '__main__':
# load model and tokenizers
model_name = "bert-large-uncased"
bert_tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertForMaskedLM.from_pretrained(model_name)
# run the model for the input
y = get_preds('i did my phd in mask mask for the last four years .')
# take the 100 most probable words for each masked position
probs_0,idx_0 = torch.topk(y[0],100)
preds_0 = bert_tokenizer.convert_ids_to_tokens(idx_0.numpy())
probs_1,idx_1 = torch.topk(y[1],100)
preds_1 = bert_tokenizer.convert_ids_to_tokens(idx_1.numpy())
# create all possible combinations and print (shuffled)
sciences = []
for w0 in preds_0:
for w1 in preds_1:
sciences.append("{} {}".format(w0,w1))
random.shuffle(sciences)
for p in sciences:
print(p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment