Last active
July 13, 2021 11:16
-
-
Save rmazzine/89709e966de04719d90097db85d5882a to your computer and use it in GitHub Desktop.
Simple Counterfactual Generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def simple_cf_generator(factual_oh, adapted_nn, tolerance=100): | |
# Get the current prediction result of the factual class | |
current_pred = adapted_nn.predict(np.array([factual_oh]))[0][0] | |
# The predictor will always start with a probability lower than 0.5 | |
if current_pred < 0.5: | |
predictor = adapted_nn.predict | |
else: | |
predictor = lambda x: 1 - adapted_nn(x) | |
# Create a variable to store the possible CF explanation | |
possible_cf = factual_oh.copy() | |
# In this loop, we modify every feature and check which one caused the best | |
# increasement in probability | |
for _ in range(tolerance): | |
# Create a matrix with 1 in every place we can modify | |
possible_feature_changes = np.eye(len(possible_cf)) | |
# Multiply the previous matrix by the possible counterfactual | |
is_value_activated = possible_feature_changes * np.array(possible_cf) | |
# Now, for each possible feature modification, let's find if it's 1 (activated) or 0 (unactivated) | |
activated_values_idx = np.where(is_value_activated.sum(axis=1) == 1) | |
unactivated_values_idx = np.where(abs(is_value_activated.sum(axis=1) - 1) == 1) | |
# For the features that are activated, let's unactivate them (1 to 0) | |
# for unactivate, let's activate (0 to 1) | |
possible_cfs_activated = abs(possible_feature_changes[activated_values_idx] - 1) * possible_cf | |
possible_cfs_unactivated = possible_feature_changes[unactivated_values_idx] + possible_cf | |
# Calculate the probability of the possible CF | |
preds_activated = predictor(possible_cfs_activated) | |
preds_unactivated = predictor(possible_cfs_unactivated) | |
# Now, get the modification (activation or unactivation) which caused a best increase | |
# in classification probability. Having that, update the possible CF to the best modification | |
if max(preds_activated) > max(preds_unactivated): | |
current_pred = max(preds_activated)[0] | |
idx_best = np.where(preds_activated == current_pred)[0][0] | |
possible_cf = possible_cfs_activated[idx_best] | |
else: | |
current_pred = max(preds_unactivated)[0] | |
idx_best = np.where(preds_unactivated == current_pred)[0][0] | |
possible_cf = possible_cfs_unactivated[idx_best] | |
# If it flipped the classification, stop the loop | |
if current_pred >= 0.5: | |
break | |
return possible_cf, current_pred |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment