Skip to content

Instantly share code, notes, and snippets.

@dvsrepo
Last active September 22, 2022 09:29
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dvsrepo/62a15a224a7f5da8603e0048d15a1045 to your computer and use it in GitHub Desktop.
Save dvsrepo/62a15a224a7f5da8603e0048d15a1045 to your computer and use it in GitHub Desktop.
import transformers
from datasets import load_dataset
from sklearn.preprocessing import MinMaxScaler
import shap
from rubrix import TextClassificationRecord, TokenAttributions
import rubrix as rb
# Transformers pipeline model
model = transformers.pipeline("sentiment-analysis", return_all_scores=True)
# Load Stanford treebank dataset
sst = load_dataset("sst", split="test[0:5]")
# Use shap's library text explainer
explainer = shap.Explainer(model)
shap_values = explainer(sst['sentence'])
# Instantiate the scaler
scaler = MinMaxScaler(feature_range=[-1, 1])
predictions = model(sst["sentence"])
for i in range(0, len(shap_values.values)):
# Scale shap values betweeen -1 and 1 (using e.g., scikit-learn MinMaxScaler
scaled = scaler.fit_transform(shap_values.values[i])
# get prediction label idx for indexing attributions and shap_values
# sorts by score to get the max score prediction
sorted_predictions = sorted(predictions[i], key=lambda d: d["score"], reverse=True)
label_idx = 0 if sorted_predictions[0]["label"] == "NEGATIVE" else 1
# Build token attributions
token_attributions = [
TokenAttributions(
token=token, attributions={shap_values.output_names[label_idx]: score}
)
for token, score in zip(shap_values.data[i], [row[label_idx] for row in scaled])
]
# Build Rubrix record
record = TextClassificationRecord(
inputs=sst["sentence"][i],
prediction=[(pred["label"], pred["score"]) for pred in predictions[i]],
explanation={"text": token_attributions},
)
# Log record
rb.log(record, name="rubrix_shap_example")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment