Created
March 28, 2024 14:35
-
-
Save elutins/6b305f7a1c6cf0be7f8254da30bbb545 to your computer and use it in GitHub Desktop.
base drift parameters class - public gist
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BaseDriftParameters: | |
"""A base class to provide a template/common methods for all feature drift pipelines. | |
this class can be copy/pasted into an existing codebase and will by default run all tests set in global_tests, | |
numerical_tests, categorical_tests. the only variables that need modified are the high_risk_features, | |
low_risk_features, categorical_cols, numeric_cols | |
""" | |
response_fields = [] | |
# setting default thresholds of failure for tests that are to be run as part of drift detection | |
global_n_sigmas_threshold = 3 | |
global_missing_values_threshold = 0.5 | |
global_share_ool_threshold = 0.25 | |
# setting default tests that are to be run as part of drift detection | |
global_tests = [TestColumnShareOfMissingValues] | |
numerical_tests = [TestMeanInNSigmas] | |
categorical_tests = [TestShareOfOutListValues] | |
response_tests = [TestColumnDrift] | |
# variables needed to customize drift tests and corresponding parameters | |
custom_column_tests = {} | |
custom_n_sigmas_mapping: dict[str, float] = {} | |
custom_missing_vals_mapping: dict[str, float] = {} | |
custom_share_ool_mapping: dict[str, float] = {} | |
def __init__( | |
self, | |
high_risk_features: list[str], | |
low_risk_features: list[str], | |
categorical_cols: list[str], | |
numeric_cols: list[str], | |
text_cols: list[str], | |
global_n_sigmas_threshold: int = global_n_sigmas_threshold, | |
global_missing_values_threshold: float = global_missing_values_threshold, | |
global_share_ool_threshold: float = global_share_ool_threshold, | |
global_tests: list = global_tests, | |
numerical_tests: list = numerical_tests, | |
categorical_tests: list = categorical_tests, | |
response_tests: list = response_tests, | |
custom_column_tests: dict = custom_column_tests, | |
response_fields: list[str] = response_fields, | |
# custom parameter mappings | |
custom_missing_vals_mapping: dict[str, float] = custom_missing_vals_mapping, | |
custom_n_sigmas_mapping: dict[str, float] = custom_n_sigmas_mapping, | |
custom_share_ool_mapping: dict[str, float] = custom_share_ool_mapping, | |
) -> None: | |
"""Instantiate drift object. | |
high_risk_features, low_risk_features are the only attributes that need to be set in order to run drift | |
detection. | |
Args: | |
high_risk_features (list[str]): list of features that require immediate attention if any drift test fails | |
low_risk_features (list[str]): list of features that do not require immediate attention if a drift test fails | |
categorical_cols (list[str]): list of categorical cols in dataset | |
numeric_cols (list[str]): list of numeric cols in dataset | |
text_cols (list[str]): list of text cols in dataset | |
global_n_sigmas_threshold (int, optional): TestMeanInNSigmas will fail if mean of current data is > | |
global_n_sigmas_threshold of mean from reference data | |
global_missing_values_threshold (float, optional): default failure threshold for TestColumnShareOfMissingValues | |
global_share_ool_threshold (float, optional): default failure threshold for TestShareOfOutListValues, test | |
will fail if the % of unseen categorical values in the current data is > global_share_ool_threshold | |
global_tests (list[dict], optional): evidently tests/configs to run against all columns in dataset | |
numerical_tests (list[dict], optional): evidently tests/configs to run against numeric columns in dataset | |
categorical_tests (list[dict], optional): evidently tests/configs to run against categorical columns in dataset. | |
response_tests (list[dict], optional): evidently tests/configs to run against response variables in dataset. | |
custom_column_tests (list[dict], optional): evidently tests/configs to run against specified columns in dataset. | |
response_fields (list[str], optional): list of response fields in data | |
custom_missing_vals_mapping (dict): dict of {column: failure_threshold} pairs to set for the test | |
TestColumnShareOfMissingValues. if a col isnt in this dict, global_missing_values_threshold will be used | |
custom_n_sigmas_mapping (dict): dict of {column: failure_threshold} pairs to set for the test | |
TestMeanInNSigmas. if a col isnt in this dict, global_n_sigmas_threshold will be used | |
custom_share_ool_mapping (dict): dict of {column: failure_threshold} pairs to set for the test | |
TestShareOfOutListValues. if a col isnt in this dict, global_share_ool_threshold will be used | |
""" | |
self.high_risk_features = high_risk_features | |
self.low_risk_features = low_risk_features | |
self.categorical_cols = categorical_cols | |
self.numeric_cols = numeric_cols | |
self.text_cols = text_cols | |
self.global_n_sigmas_threshold = global_n_sigmas_threshold | |
self.global_missing_values_threshold = global_missing_values_threshold | |
self.global_share_ool_threshold = global_share_ool_threshold | |
self.global_tests = global_tests | |
self.numerical_tests = numerical_tests | |
self.categorical_tests = categorical_tests | |
self.response_tests = response_tests | |
self.custom_column_tests = custom_column_tests | |
self.response_fields = response_fields | |
self.custom_n_sigmas_mapping = custom_n_sigmas_mapping | |
self.custom_missing_vals_mapping = custom_missing_vals_mapping | |
self.custom_share_ool_mapping = custom_share_ool_mapping | |
# the two below functions are needed to run to properly build the Test Suite + Drift Report | |
self.drift_features = self.high_risk_features + self.low_risk_features + self.response_fields | |
self.set_drift_features(categorical_cols=categorical_cols, numeric_cols=numeric_cols) | |
self.drift_feature_mapping = self.build_feature_mapping() | |
def __repr__(self) -> str: | |
return ( | |
f"Features that will be tested for drift:\n" | |
f"\thigh_risk_features: {self.high_risk_features} \n" | |
f"\tlow_risk_features: {self.low_risk_features}" | |
) | |
def set_drift_features(self, categorical_cols: list[str], numeric_cols: list[str]) -> None: | |
"""Set the risk level and dtype of numeric, categorical features for given feature set inputs.""" | |
assert set(self.high_risk_features).isdisjoint( | |
set(self.low_risk_features) | |
), "high and low risk features overlap" | |
self.high_risk_cat_features = [feature for feature in categorical_cols if feature in self.high_risk_features] | |
self.high_risk_num_features = [feature for feature in numeric_cols if feature in self.high_risk_features] | |
self.low_risk_cat_features = [feature for feature in categorical_cols if feature in self.low_risk_features] | |
self.low_risk_num_features = [feature for feature in numeric_cols if feature in self.low_risk_features] | |
def build_feature_mapping(self) -> dict[str, dict]: | |
drift_mapping = {} | |
for feature in self.drift_features: | |
feature_config = {} | |
feature_config["is_critical"] = (feature in self.high_risk_features) or (feature in self.response_fields) | |
feature_config["missing_vals_threshold"] = self.custom_missing_vals_mapping.get( | |
feature, self.global_missing_values_threshold | |
) | |
if feature in self.numeric_cols: | |
feature_config["n_sigmas"] = self.custom_n_sigmas_mapping.get(feature, self.global_n_sigmas_threshold) | |
if feature in self.categorical_cols: | |
feature_config["lt_ool_threshold"] = self.global_share_ool_threshold | |
drift_mapping[feature] = feature_config | |
return drift_mapping | |
def build_test_suite(self) -> TestSuite: | |
"""Instantiate tests to run during drift detection; tests based on corresponding class variables values. | |
Will instantiate all tests set in the GLOBAL_TESTS, NUMERICAL_TESTS, CATEGORICAL_TESTS, RESPONSE_TESTS class | |
variables. Uses Evidentlys generate_column_tests() function to aggregate the individuals tests. | |
generate_column_tests() accepts an Evidently Test to run, columns to run the test against, and test parameters. | |
Tests are also instantiates based on their risk level. For important features can set `is_critical=True` in | |
parameters (default is True), or can for non-important features can set is_critical=False. | |
Returns: | |
TestSuite() object with list of all tests to run at time of drift calculation. | |
""" | |
# If you want to get a Warning instead, use the is_critical parameter and set it to False. | |
suite_tests = [] | |
for column, test_config in self.custom_column_tests.items(): | |
suite_tests.append( | |
generate_column_tests(test_config["test"], columns=[column], parameters=test_config["parameters"]) | |
) | |
for test in self.numerical_tests: | |
accepted_test_args = test.__fields__.keys() | |
for col in self.numeric_cols: | |
col_config = self.drift_feature_mapping[col] | |
col_args = {k: v for k, v in col_config.items() if k in accepted_test_args} | |
suite_tests.append(generate_column_tests(test, columns=[col], parameters=col_args)) | |
for test in self.categorical_tests: | |
accepted_test_args = test.__fields__.keys() | |
for col in self.categorical_cols: | |
col_config = self.drift_feature_mapping[col] | |
if test == TestShareOfOutListValues: | |
col_config["lt"] = col_config["lt_ool_threshold"] | |
if test == TestColumnShareOfMissingValues: | |
col_config["lt"] = col_config["missing_vals_threshold"] | |
col_args = {k: v for k, v in col_config.items() if k in accepted_test_args} | |
suite_tests.append(generate_column_tests(test, columns=[col], parameters=col_args)) | |
for test in self.response_tests: | |
suite_tests.append(generate_column_tests(test, columns=self.response_fields)) | |
# adding global tests | |
for test in self.global_tests: | |
accepted_test_args = test.__fields__.keys() | |
for col in self.drift_feature_mapping: | |
col_config = self.drift_feature_mapping[col] | |
if test == TestColumnShareOfMissingValues: | |
col_config["lt"] = col_config["missing_vals_threshold"] | |
col_args = {k: v for k, v in col_config.items() if k in accepted_test_args} | |
suite_tests.append(generate_column_tests(test, columns=[col], parameters=col_args)) | |
return TestSuite(tests=suite_tests) | |
def extract_failed_features(self, alert_tests: list[dict], verbose: bool = True) -> list[str]: | |
"""Given an input test suite, extract the unique features that failed any drift test. | |
Args: | |
failed_tests (list[dict]): list of failed tests from the Evidently TestSuite object | |
Returns: | |
list[str]: unique features that failed a test | |
""" | |
unique_alert_features = set() | |
for test in alert_tests: | |
if test["parameters"] and "column_name" in test["parameters"]: | |
unique_alert_features.add(test["parameters"]["column_name"]) | |
else: | |
# column name must be parsed from description. column name will be embedded btw ** characters, | |
# eg: the mean values of column **{column_name}** is x. the expected range is y to z. | |
failed_feature = test["description"].split("**")[1] | |
unique_alert_features.add(failed_feature) | |
# returning a dict of failed features with their proper dtype: | |
failed_feature_dtypes = {} | |
failed_feature_dtypes["numeric"] = list(unique_alert_features.intersection(set(self.numeric_cols))) | |
failed_feature_dtypes["categorical"] = list(unique_alert_features.intersection(set(self.categorical_cols))) | |
failed_feature_dtypes["text"] = list(unique_alert_features.intersection(set(self.text_cols))) | |
if verbose: | |
print(f"failed categorical features: {failed_feature_dtypes['categorical']}") | |
print(f"failed numeric features: {failed_feature_dtypes['numeric']}") | |
print(f"failed text features: {failed_feature_dtypes['text']}") | |
return failed_feature_dtypes, list(unique_alert_features) | |
def generate_drift_report( | |
self, categorical_features: list[str], numeric_features: list[str], text_features: list[str] | |
) -> Report: | |
"""Generates Report object based on input features and their respective dtypes. | |
Evidently's DataDriftTable creates useful visualizations for numeric + categorical features. After testing out | |
various other reports, found that ColumnSummaryMetric is best for high level look at text features. | |
Args: | |
categorical_features (list[str]): categorical features to be included in report | |
numeric_features (list[str]): numeric features to be included in report | |
text_features (list[str]): text features to be included in report | |
Returns: | |
Evidently Report | |
""" | |
report_metrics = [] | |
if len(categorical_features + numeric_features) > 0: | |
report_metrics.append(DataDriftTable(columns=categorical_features + numeric_features)) | |
if len(text_features) > 0: | |
report_metrics.extend([ColumnSummaryMetric(column_name=text_col) for text_col in text_features]) | |
return Report(metrics=report_metrics) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment