Skip to content

Instantly share code, notes, and snippets.

@Krilecy
Last active April 5, 2021 21:58
Show Gist options
  • Save Krilecy/b9242f394e5b4131744a71a7fbf6a908 to your computer and use it in GitHub Desktop.
Save Krilecy/b9242f394e5b4131744a71a7fbf6a908 to your computer and use it in GitHub Desktop.
Hugging Face Sequence Classification Example with XLNet
from transformers import AutoTokenizer, AutoModelForSequenceClassification, XLNetTokenizer
import torch
tokenizer = XLNetTokenizer.from_pretrained("ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli")
model = AutoModelForSequenceClassification.from_pretrained("ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli")
classes = ["not paraphrase", "is paraphrase"]
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"
paraphrase = tokenizer(sequence_0, sequence_2, return_tensors="pt")
not_paraphrase = tokenizer(sequence_0, sequence_1, return_tensors="pt")
paraphrase_classification_logits = model(**paraphrase).logits
not_paraphrase_classification_logits = model(**not_paraphrase).logits
paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0]
# Should be paraphrase
for i in range(len(classes)):
print(f"{classes[i]}: {int(round(paraphrase_results[i] * 100))}%")
# Should not be paraphrase
for i in range(len(classes)):
print(f"{classes[i]}: {int(round(not_paraphrase_results[i] * 100))}%")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment