Skip to content

Instantly share code, notes, and snippets.

@dvsrepo
Last active October 23, 2021 11:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dvsrepo/30c1413ae70e90c1069a3b8714b938b0 to your computer and use it in GitHub Desktop.
Save dvsrepo/30c1413ae70e90c1069a3b8714b938b0 to your computer and use it in GitHub Desktop.
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer
from datasets import load_dataset
import rubrix as rb
from rubrix import TokenAttributions
# Load Stanford sentiment treebank test set
dataset = load_dataset("sst", "default", split="test")
# Let's use a sentiment classifier fine-tuned on sst
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Define the explainer using transformers_interpret
cls_explainer = SequenceClassificationExplainer(model, tokenizer)
records = []
for example in dataset:
# Build Token attributions objects
word_attributions = cls_explainer(example["sentence"])
token_attributions = [
TokenAttributions(
token=token,
attributions={cls_explainer.predicted_class_name: score}
) # ignore first (CLS) and last (SEP) tokens
for token, score in word_attributions[1:-1]
]
# Build Text classification records
record = rb.TextClassificationRecord(
inputs=example["sentence"],
prediction=[(cls_explainer.predicted_class_name, cls_explainer.pred_probs)],
explanation={"text": token_attributions},
)
records.append(record)
# Build Rubrix dataset with interpretations for each record
rb.log(records, name="transformers_interpret_example")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment