Last active
September 22, 2022 09:29
-
-
Save dvsrepo/62a15a224a7f5da8603e0048d15a1045 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import transformers | |
from datasets import load_dataset | |
from sklearn.preprocessing import MinMaxScaler | |
import shap | |
from rubrix import TextClassificationRecord, TokenAttributions | |
import rubrix as rb | |
# Transformers pipeline model | |
model = transformers.pipeline("sentiment-analysis", return_all_scores=True) | |
# Load Stanford treebank dataset | |
sst = load_dataset("sst", split="test[0:5]") | |
# Use shap's library text explainer | |
explainer = shap.Explainer(model) | |
shap_values = explainer(sst['sentence']) | |
# Instantiate the scaler | |
scaler = MinMaxScaler(feature_range=[-1, 1]) | |
predictions = model(sst["sentence"]) | |
for i in range(0, len(shap_values.values)): | |
# Scale shap values betweeen -1 and 1 (using e.g., scikit-learn MinMaxScaler | |
scaled = scaler.fit_transform(shap_values.values[i]) | |
# get prediction label idx for indexing attributions and shap_values | |
# sorts by score to get the max score prediction | |
sorted_predictions = sorted(predictions[i], key=lambda d: d["score"], reverse=True) | |
label_idx = 0 if sorted_predictions[0]["label"] == "NEGATIVE" else 1 | |
# Build token attributions | |
token_attributions = [ | |
TokenAttributions( | |
token=token, attributions={shap_values.output_names[label_idx]: score} | |
) | |
for token, score in zip(shap_values.data[i], [row[label_idx] for row in scaled]) | |
] | |
# Build Rubrix record | |
record = TextClassificationRecord( | |
inputs=sst["sentence"][i], | |
prediction=[(pred["label"], pred["score"]) for pred in predictions[i]], | |
explanation={"text": token_attributions}, | |
) | |
# Log record | |
rb.log(record, name="rubrix_shap_example") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment