Skip to content

Instantly share code, notes, and snippets.

@rmazzine
Last active July 13, 2021 11:16
Show Gist options
  • Save rmazzine/89709e966de04719d90097db85d5882a to your computer and use it in GitHub Desktop.
Save rmazzine/89709e966de04719d90097db85d5882a to your computer and use it in GitHub Desktop.
Simple Counterfactual Generator
import numpy as np
def simple_cf_generator(factual_oh, adapted_nn, tolerance=100):
# Get the current prediction result of the factual class
current_pred = adapted_nn.predict(np.array([factual_oh]))[0][0]
# The predictor will always start with a probability lower than 0.5
if current_pred < 0.5:
predictor = adapted_nn.predict
else:
predictor = lambda x: 1 - adapted_nn(x)
# Create a variable to store the possible CF explanation
possible_cf = factual_oh.copy()
# In this loop, we modify every feature and check which one caused the best
# increasement in probability
for _ in range(tolerance):
# Create a matrix with 1 in every place we can modify
possible_feature_changes = np.eye(len(possible_cf))
# Multiply the previous matrix by the possible counterfactual
is_value_activated = possible_feature_changes * np.array(possible_cf)
# Now, for each possible feature modification, let's find if it's 1 (activated) or 0 (unactivated)
activated_values_idx = np.where(is_value_activated.sum(axis=1) == 1)
unactivated_values_idx = np.where(abs(is_value_activated.sum(axis=1) - 1) == 1)
# For the features that are activated, let's unactivate them (1 to 0)
# for unactivate, let's activate (0 to 1)
possible_cfs_activated = abs(possible_feature_changes[activated_values_idx] - 1) * possible_cf
possible_cfs_unactivated = possible_feature_changes[unactivated_values_idx] + possible_cf
# Calculate the probability of the possible CF
preds_activated = predictor(possible_cfs_activated)
preds_unactivated = predictor(possible_cfs_unactivated)
# Now, get the modification (activation or unactivation) which caused a best increase
# in classification probability. Having that, update the possible CF to the best modification
if max(preds_activated) > max(preds_unactivated):
current_pred = max(preds_activated)[0]
idx_best = np.where(preds_activated == current_pred)[0][0]
possible_cf = possible_cfs_activated[idx_best]
else:
current_pred = max(preds_unactivated)[0]
idx_best = np.where(preds_unactivated == current_pred)[0][0]
possible_cf = possible_cfs_unactivated[idx_best]
# If it flipped the classification, stop the loop
if current_pred >= 0.5:
break
return possible_cf, current_pred
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment