Last active
October 23, 2021 11:11
-
-
Save dvsrepo/30c1413ae70e90c1069a3b8714b938b0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from transformers_interpret import SequenceClassificationExplainer | |
from datasets import load_dataset | |
import rubrix as rb | |
from rubrix import TokenAttributions | |
# Load Stanford sentiment treebank test set | |
dataset = load_dataset("sst", "default", split="test") | |
# Let's use a sentiment classifier fine-tuned on sst | |
model_name = "distilbert-base-uncased-finetuned-sst-2-english" | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Define the explainer using transformers_interpret | |
cls_explainer = SequenceClassificationExplainer(model, tokenizer) | |
records = [] | |
for example in dataset: | |
# Build Token attributions objects | |
word_attributions = cls_explainer(example["sentence"]) | |
token_attributions = [ | |
TokenAttributions( | |
token=token, | |
attributions={cls_explainer.predicted_class_name: score} | |
) # ignore first (CLS) and last (SEP) tokens | |
for token, score in word_attributions[1:-1] | |
] | |
# Build Text classification records | |
record = rb.TextClassificationRecord( | |
inputs=example["sentence"], | |
prediction=[(cls_explainer.predicted_class_name, cls_explainer.pred_probs)], | |
explanation={"text": token_attributions}, | |
) | |
records.append(record) | |
# Build Rubrix dataset with interpretations for each record | |
rb.log(records, name="transformers_interpret_example") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment