Skip to content

Instantly share code, notes, and snippets.

@PaoloLeonard
Last active October 13, 2021 12:48
Show Gist options
  • Save PaoloLeonard/3e5aa714397147d516778660573de023 to your computer and use it in GitHub Desktop.
Save PaoloLeonard/3e5aa714397147d516778660573de023 to your computer and use it in GitHub Desktop.
Expectation implementation for the GE table expectation tutorial.
from copy import deepcopy
from typing import Dict, Tuple, Any, Optional, Callable, List
from great_expectations.core import ExpectationConfiguration
from great_expectations.execution_engine import (
ExecutionEngine
)
from great_expectations.expectations.expectation import TableExpectation
from great_expectations.exceptions.exceptions import InvalidExpectationKwargsError
class ExpectTableRowCountToBeMoreThanOthers(TableExpectation):
"""TableExpectation class to compare the row count of the current dataset to other dataset(s)."""
metric_dependencies = ("table.row_count", "table.row_count_other")
success_keys = (
"other_table_filenames_list",
"comparison_key",
"lower_percentage_threshold",
)
default_kwarg_values = {
"other_table_filenames_list": None,
"comparison_key": "MEAN",
"lower_percentage_threshold": 100,
}
@staticmethod
def _validate_success_key(
param: str,
required: bool,
configuration: Optional[ExpectationConfiguration],
validation_rules: Dict[Callable, str],
) -> None:
"""Simple method to aggregate and apply validation rules to the `param`."""
if param not in configuration.kwargs:
if required:
raise InvalidExpectationKwargsError(
f"Param {param} is required but was not found in configuration."
)
return
param_value = configuration.kwargs[param]
for rule, error_message in validation_rules.items():
if not rule(param_value):
raise InvalidExpectationKwargsError(error_message)
def validate_configuration(
self, configuration: Optional[ExpectationConfiguration]
) -> bool:
super().validate_configuration(configuration=configuration)
if configuration is None:
configuration = self.configuration
self._validate_success_key(
param="other_table_filenames_list",
required=True,
configuration=configuration,
validation_rules={
lambda x: isinstance(x, str)
or isinstance(
x, List
): "other_table_filenames_list should either be a list or a string.",
lambda x: x: "other_table_filenames_list should not be empty",
},
)
self._validate_success_key(
param="comparison_key",
required=False,
configuration=configuration,
validation_rules={
lambda x: isinstance(x, str): "comparison_key should be a string.",
lambda x: x.upper()
in SupportedComparisonEnum.__members__: "Given comparison_key is not supported.",
},
)
self._validate_success_key(
param="lower_percentage_threshold",
required=False,
configuration=configuration,
validation_rules={
lambda x: isinstance(
x, int
): "lower_percentage_threshold should be an integer.",
lambda x: x
> 0: "lower_percentage_threshold should be strictly greater than 0.",
},
)
return True
def get_validation_dependencies(
self,
configuration: Optional[ExpectationConfiguration] = None,
execution_engine: Optional[ExecutionEngine] = None,
runtime_configuration: Optional[dict] = None,
) -> dict:
dependencies = super().get_validation_dependencies(
configuration, execution_engine, runtime_configuration
)
other_table_filenames_list = configuration.kwargs.get(
"other_table_filenames_list"
)
if isinstance(other_table_filenames_list, str):
other_table_filenames_list = [other_table_filenames_list]
for other_table_filename in other_table_filenames_list:
table_row_count_metric_config_other = deepcopy(
dependencies["metrics"]["table.row_count_other"]
)
table_row_count_metric_config_other.metric_domain_kwargs[
"table_filename"
] = other_table_filename
dependencies["metrics"][
f"table.row_count_other.{other_table_filename}"
] = table_row_count_metric_config_other
dependencies["metrics"]["table.row_count.self"] = dependencies["metrics"].pop(
"table.row_count"
)
dependencies["metrics"].pop("table.row_count_other")
return dependencies
def _validate(
self,
configuration: ExpectationConfiguration,
metrics: Dict,
runtime_configuration: dict = None,
execution_engine: ExecutionEngine = None,
) -> Dict:
comparison_key = self.get_success_kwargs(configuration)["comparison_key"]
other_table_filename_list = self.get_success_kwargs(configuration)[
"other_table_filenames_list"
]
lower_percentage_threshold = self.get_success_kwargs(configuration)[
"lower_percentage_threshold"
]
current_row_count = metrics["table.row_count.self"]
previous_row_count_list = []
for other_table_filename in other_table_filename_list:
previous_row_count_list.append(
metrics[f"table.row_count_other.{other_table_filename}"]
)
comparison_key_fn = SupportedComparisonEnum[comparison_key.upper()]
success_flag = comparison_key_fn(
current_row_count, previous_row_count_list, lower_percentage_threshold
)
return {
"success": success_flag,
"result": {
"self": current_row_count,
"other": previous_row_count_list,
"comparison_key": comparison_key_fn.name,
},
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment