Skip to content

Instantly share code, notes, and snippets.

@elutins
Created March 28, 2024 14:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elutins/6b305f7a1c6cf0be7f8254da30bbb545 to your computer and use it in GitHub Desktop.
Save elutins/6b305f7a1c6cf0be7f8254da30bbb545 to your computer and use it in GitHub Desktop.
base drift parameters class - public gist
class BaseDriftParameters:
"""A base class to provide a template/common methods for all feature drift pipelines.
this class can be copy/pasted into an existing codebase and will by default run all tests set in global_tests,
numerical_tests, categorical_tests. the only variables that need modified are the high_risk_features,
low_risk_features, categorical_cols, numeric_cols
"""
response_fields = []
# setting default thresholds of failure for tests that are to be run as part of drift detection
global_n_sigmas_threshold = 3
global_missing_values_threshold = 0.5
global_share_ool_threshold = 0.25
# setting default tests that are to be run as part of drift detection
global_tests = [TestColumnShareOfMissingValues]
numerical_tests = [TestMeanInNSigmas]
categorical_tests = [TestShareOfOutListValues]
response_tests = [TestColumnDrift]
# variables needed to customize drift tests and corresponding parameters
custom_column_tests = {}
custom_n_sigmas_mapping: dict[str, float] = {}
custom_missing_vals_mapping: dict[str, float] = {}
custom_share_ool_mapping: dict[str, float] = {}
def __init__(
self,
high_risk_features: list[str],
low_risk_features: list[str],
categorical_cols: list[str],
numeric_cols: list[str],
text_cols: list[str],
global_n_sigmas_threshold: int = global_n_sigmas_threshold,
global_missing_values_threshold: float = global_missing_values_threshold,
global_share_ool_threshold: float = global_share_ool_threshold,
global_tests: list = global_tests,
numerical_tests: list = numerical_tests,
categorical_tests: list = categorical_tests,
response_tests: list = response_tests,
custom_column_tests: dict = custom_column_tests,
response_fields: list[str] = response_fields,
# custom parameter mappings
custom_missing_vals_mapping: dict[str, float] = custom_missing_vals_mapping,
custom_n_sigmas_mapping: dict[str, float] = custom_n_sigmas_mapping,
custom_share_ool_mapping: dict[str, float] = custom_share_ool_mapping,
) -> None:
"""Instantiate drift object.
high_risk_features, low_risk_features are the only attributes that need to be set in order to run drift
detection.
Args:
high_risk_features (list[str]): list of features that require immediate attention if any drift test fails
low_risk_features (list[str]): list of features that do not require immediate attention if a drift test fails
categorical_cols (list[str]): list of categorical cols in dataset
numeric_cols (list[str]): list of numeric cols in dataset
text_cols (list[str]): list of text cols in dataset
global_n_sigmas_threshold (int, optional): TestMeanInNSigmas will fail if mean of current data is >
global_n_sigmas_threshold of mean from reference data
global_missing_values_threshold (float, optional): default failure threshold for TestColumnShareOfMissingValues
global_share_ool_threshold (float, optional): default failure threshold for TestShareOfOutListValues, test
will fail if the % of unseen categorical values in the current data is > global_share_ool_threshold
global_tests (list[dict], optional): evidently tests/configs to run against all columns in dataset
numerical_tests (list[dict], optional): evidently tests/configs to run against numeric columns in dataset
categorical_tests (list[dict], optional): evidently tests/configs to run against categorical columns in dataset.
response_tests (list[dict], optional): evidently tests/configs to run against response variables in dataset.
custom_column_tests (list[dict], optional): evidently tests/configs to run against specified columns in dataset.
response_fields (list[str], optional): list of response fields in data
custom_missing_vals_mapping (dict): dict of {column: failure_threshold} pairs to set for the test
TestColumnShareOfMissingValues. if a col isnt in this dict, global_missing_values_threshold will be used
custom_n_sigmas_mapping (dict): dict of {column: failure_threshold} pairs to set for the test
TestMeanInNSigmas. if a col isnt in this dict, global_n_sigmas_threshold will be used
custom_share_ool_mapping (dict): dict of {column: failure_threshold} pairs to set for the test
TestShareOfOutListValues. if a col isnt in this dict, global_share_ool_threshold will be used
"""
self.high_risk_features = high_risk_features
self.low_risk_features = low_risk_features
self.categorical_cols = categorical_cols
self.numeric_cols = numeric_cols
self.text_cols = text_cols
self.global_n_sigmas_threshold = global_n_sigmas_threshold
self.global_missing_values_threshold = global_missing_values_threshold
self.global_share_ool_threshold = global_share_ool_threshold
self.global_tests = global_tests
self.numerical_tests = numerical_tests
self.categorical_tests = categorical_tests
self.response_tests = response_tests
self.custom_column_tests = custom_column_tests
self.response_fields = response_fields
self.custom_n_sigmas_mapping = custom_n_sigmas_mapping
self.custom_missing_vals_mapping = custom_missing_vals_mapping
self.custom_share_ool_mapping = custom_share_ool_mapping
# the two below functions are needed to run to properly build the Test Suite + Drift Report
self.drift_features = self.high_risk_features + self.low_risk_features + self.response_fields
self.set_drift_features(categorical_cols=categorical_cols, numeric_cols=numeric_cols)
self.drift_feature_mapping = self.build_feature_mapping()
def __repr__(self) -> str:
return (
f"Features that will be tested for drift:\n"
f"\thigh_risk_features: {self.high_risk_features} \n"
f"\tlow_risk_features: {self.low_risk_features}"
)
def set_drift_features(self, categorical_cols: list[str], numeric_cols: list[str]) -> None:
"""Set the risk level and dtype of numeric, categorical features for given feature set inputs."""
assert set(self.high_risk_features).isdisjoint(
set(self.low_risk_features)
), "high and low risk features overlap"
self.high_risk_cat_features = [feature for feature in categorical_cols if feature in self.high_risk_features]
self.high_risk_num_features = [feature for feature in numeric_cols if feature in self.high_risk_features]
self.low_risk_cat_features = [feature for feature in categorical_cols if feature in self.low_risk_features]
self.low_risk_num_features = [feature for feature in numeric_cols if feature in self.low_risk_features]
def build_feature_mapping(self) -> dict[str, dict]:
drift_mapping = {}
for feature in self.drift_features:
feature_config = {}
feature_config["is_critical"] = (feature in self.high_risk_features) or (feature in self.response_fields)
feature_config["missing_vals_threshold"] = self.custom_missing_vals_mapping.get(
feature, self.global_missing_values_threshold
)
if feature in self.numeric_cols:
feature_config["n_sigmas"] = self.custom_n_sigmas_mapping.get(feature, self.global_n_sigmas_threshold)
if feature in self.categorical_cols:
feature_config["lt_ool_threshold"] = self.global_share_ool_threshold
drift_mapping[feature] = feature_config
return drift_mapping
def build_test_suite(self) -> TestSuite:
"""Instantiate tests to run during drift detection; tests based on corresponding class variables values.
Will instantiate all tests set in the GLOBAL_TESTS, NUMERICAL_TESTS, CATEGORICAL_TESTS, RESPONSE_TESTS class
variables. Uses Evidentlys generate_column_tests() function to aggregate the individuals tests.
generate_column_tests() accepts an Evidently Test to run, columns to run the test against, and test parameters.
Tests are also instantiates based on their risk level. For important features can set `is_critical=True` in
parameters (default is True), or can for non-important features can set is_critical=False.
Returns:
TestSuite() object with list of all tests to run at time of drift calculation.
"""
# If you want to get a Warning instead, use the is_critical parameter and set it to False.
suite_tests = []
for column, test_config in self.custom_column_tests.items():
suite_tests.append(
generate_column_tests(test_config["test"], columns=[column], parameters=test_config["parameters"])
)
for test in self.numerical_tests:
accepted_test_args = test.__fields__.keys()
for col in self.numeric_cols:
col_config = self.drift_feature_mapping[col]
col_args = {k: v for k, v in col_config.items() if k in accepted_test_args}
suite_tests.append(generate_column_tests(test, columns=[col], parameters=col_args))
for test in self.categorical_tests:
accepted_test_args = test.__fields__.keys()
for col in self.categorical_cols:
col_config = self.drift_feature_mapping[col]
if test == TestShareOfOutListValues:
col_config["lt"] = col_config["lt_ool_threshold"]
if test == TestColumnShareOfMissingValues:
col_config["lt"] = col_config["missing_vals_threshold"]
col_args = {k: v for k, v in col_config.items() if k in accepted_test_args}
suite_tests.append(generate_column_tests(test, columns=[col], parameters=col_args))
for test in self.response_tests:
suite_tests.append(generate_column_tests(test, columns=self.response_fields))
# adding global tests
for test in self.global_tests:
accepted_test_args = test.__fields__.keys()
for col in self.drift_feature_mapping:
col_config = self.drift_feature_mapping[col]
if test == TestColumnShareOfMissingValues:
col_config["lt"] = col_config["missing_vals_threshold"]
col_args = {k: v for k, v in col_config.items() if k in accepted_test_args}
suite_tests.append(generate_column_tests(test, columns=[col], parameters=col_args))
return TestSuite(tests=suite_tests)
def extract_failed_features(self, alert_tests: list[dict], verbose: bool = True) -> list[str]:
"""Given an input test suite, extract the unique features that failed any drift test.
Args:
failed_tests (list[dict]): list of failed tests from the Evidently TestSuite object
Returns:
list[str]: unique features that failed a test
"""
unique_alert_features = set()
for test in alert_tests:
if test["parameters"] and "column_name" in test["parameters"]:
unique_alert_features.add(test["parameters"]["column_name"])
else:
# column name must be parsed from description. column name will be embedded btw ** characters,
# eg: the mean values of column **{column_name}** is x. the expected range is y to z.
failed_feature = test["description"].split("**")[1]
unique_alert_features.add(failed_feature)
# returning a dict of failed features with their proper dtype:
failed_feature_dtypes = {}
failed_feature_dtypes["numeric"] = list(unique_alert_features.intersection(set(self.numeric_cols)))
failed_feature_dtypes["categorical"] = list(unique_alert_features.intersection(set(self.categorical_cols)))
failed_feature_dtypes["text"] = list(unique_alert_features.intersection(set(self.text_cols)))
if verbose:
print(f"failed categorical features: {failed_feature_dtypes['categorical']}")
print(f"failed numeric features: {failed_feature_dtypes['numeric']}")
print(f"failed text features: {failed_feature_dtypes['text']}")
return failed_feature_dtypes, list(unique_alert_features)
def generate_drift_report(
self, categorical_features: list[str], numeric_features: list[str], text_features: list[str]
) -> Report:
"""Generates Report object based on input features and their respective dtypes.
Evidently's DataDriftTable creates useful visualizations for numeric + categorical features. After testing out
various other reports, found that ColumnSummaryMetric is best for high level look at text features.
Args:
categorical_features (list[str]): categorical features to be included in report
numeric_features (list[str]): numeric features to be included in report
text_features (list[str]): text features to be included in report
Returns:
Evidently Report
"""
report_metrics = []
if len(categorical_features + numeric_features) > 0:
report_metrics.append(DataDriftTable(columns=categorical_features + numeric_features))
if len(text_features) > 0:
report_metrics.extend([ColumnSummaryMetric(column_name=text_col) for text_col in text_features])
return Report(metrics=report_metrics)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment