Last active
April 5, 2021 21:58
-
-
Save Krilecy/b9242f394e5b4131744a71a7fbf6a908 to your computer and use it in GitHub Desktop.
Hugging Face Sequence Classification Example with XLNet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModelForSequenceClassification, XLNetTokenizer | |
import torch | |
tokenizer = XLNetTokenizer.from_pretrained("ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli") | |
model = AutoModelForSequenceClassification.from_pretrained("ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli") | |
classes = ["not paraphrase", "is paraphrase"] | |
sequence_0 = "The company HuggingFace is based in New York City" | |
sequence_1 = "Apples are especially bad for your health" | |
sequence_2 = "HuggingFace's headquarters are situated in Manhattan" | |
paraphrase = tokenizer(sequence_0, sequence_2, return_tensors="pt") | |
not_paraphrase = tokenizer(sequence_0, sequence_1, return_tensors="pt") | |
paraphrase_classification_logits = model(**paraphrase).logits | |
not_paraphrase_classification_logits = model(**not_paraphrase).logits | |
paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0] | |
not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0] | |
# Should be paraphrase | |
for i in range(len(classes)): | |
print(f"{classes[i]}: {int(round(paraphrase_results[i] * 100))}%") | |
# Should not be paraphrase | |
for i in range(len(classes)): | |
print(f"{classes[i]}: {int(round(not_paraphrase_results[i] * 100))}%") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment