Skip to content

Instantly share code, notes, and snippets.

@ArtemisDicoTiar
Created August 20, 2021 14:58
Show Gist options
  • Save ArtemisDicoTiar/a024f6b50677e10c51b866e4d33d8027 to your computer and use it in GitHub Desktop.
Save ArtemisDicoTiar/a024f6b50677e10c51b866e4d33d8027 to your computer and use it in GitHub Desktop.
bert: next sentence prediction
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.nn import functional as F
BATCH = [
("I understand Tesla's vision.", "Haha, that's a nice [MASK]."), # pun?
("the man went to [MASK] store", "he bought a gallon [MASK] milk"), # is next
("the man [MASK] to the store", "penguin [MASK] are flight ##less birds") # not next
]
BERT_MODEL = "bert-base-uncased"
def main():
global BATCH, BERT_MODEL
nsp = BertForNextSentencePrediction.from_pretrained(BERT_MODEL)
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)
print(nsp.config)
encoded = tokenizer(BATCH,
add_special_tokens=True,
return_tensors="pt",
truncation=True,
padding=True)
# mlm houses a pretrained bert_ucl model
outputs = nsp(**encoded)
# output's shape: [4, 2]
# sentence count 4
# is next?, not next? 2
probs = F.softmax(outputs.logits)
preds = list(map(
lambda b: f"isNext: {b[1][0]}, isNotNext: {b[1][1]} "
f"\n\t=> {b[0]}",
zip(BATCH, probs)
))
print(*preds, sep='\n')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment