Skip to content

Instantly share code, notes, and snippets.

@jmsquare
Last active December 20, 2022 16:10
Show Gist options
  • Save jmsquare/a89ca1561c6bc6490643af2a9e7e605d to your computer and use it in GitHub Desktop.
Save jmsquare/a89ca1561c6bc6490643af2a9e7e605d to your computer and use it in GitHub Desktop.
"""
Summary: Tests if the model prediction is invariant when the feature values are perturbed
Description: Test if the predicted classification label remains the same after
feature values perturbation.The test is passed when the percentage of unchanged
rows is higher than the threshold
Args:
df(GiskardDataset):
Dataset used to compute the test
model(GiskardModel):
Model used to compute the test
perturbation_dict(dict):
Dictionary of the perturbations. It provides the perturbed features as key
and a perturbation lambda function as value
threshold(float):
Threshold of the ratio of invariant rows
Returns:
actual_slices_size:
total number of rows of actual dataset
number_of_perturbed_rows:
Number of perturbed rows
metric:
The ratio of unchanged rows over the perturbed rows
passed:
TRUE if metric > threshold
output_df:
Dataframe of rows where the prediction changes due to perturbation
"""
import nlpaug.augmenter.word as naw
aug = naw.SynonymAug(aug_src='wordnet')
# Perturbation: all the words of the email are substituted by WordNet's synonym
perturbation = {
"Content": lambda x: aug.augment(x["Content"])
}
tests.metamorphic.test_metamorphic_invariance(
df=actual_ds,
model=model,
perturbation_dict=perturbation,
threshold=0.5
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment