Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save PaoloLeonard/6b2fe9b6e2241a24f8a82f86a0d4eaf6 to your computer and use it in GitHub Desktop.
Save PaoloLeonard/6b2fe9b6e2241a24f8a82f86a0d4eaf6 to your computer and use it in GitHub Desktop.
Full implementation of a custom table expectation that compares the considered dataset row count to other datasets row count with the possibility of using different comparison keys.
"""
Custom table expectation which checks whether the row count is greater than the row count of other tables.
There are different ways to compare the row counts:
* With absolute values, if one row count value of the other tables is greater than the current then the validation
fails,
* With mean values, if the mean of value of the other tables row count is greater than the current row count then
the validation fails.
"""
from copy import deepcopy
from enum import Enum, auto
from statistics import mean
from typing import Dict, Tuple, Any, Optional, Callable, List
from great_expectations.core import ExpectationConfiguration
from great_expectations.core.batch_spec import PathBatchSpec
from great_expectations.execution_engine import (
SparkDFExecutionEngine,
PandasExecutionEngine,
ExecutionEngine,
)
from great_expectations.expectations.metrics.metric_provider import metric_value
from great_expectations.expectations.metrics.table_metric_provider import (
TableMetricProvider,
)
from great_expectations.expectations.expectation import TableExpectation
from great_expectations.exceptions.exceptions import InvalidExpectationKwargsError
class SupportedComparisonEnum(Enum):
"""Enum class with the currently supported comparison type."""
ABSOLUTE = auto()
MEAN = auto()
def __call__(self, *args, **kwargs):
if self.name == "ABSOLUTE":
return all(args[0] >= i * args[2] / 100 for i in args[1])
elif self.name == "MEAN":
return args[0] >= mean(args[1]) * args[2] / 100
else:
raise NotImplementedError("Comparison key is not supported.")
class OtherTableRowCount(TableMetricProvider):
"""MetricProvider class to get row count from different tables than the current."""
metric_name = "table.row_count_other"
@metric_value(engine=PandasExecutionEngine)
def _pandas(
cls,
execution_engine: "PandasExecutionEngine",
metric_domain_kwargs: Dict,
metric_value_kwargs: Dict,
metrics: Dict[Tuple, Any],
runtime_configuration: Dict,
) -> int:
other_table_filename = metric_domain_kwargs.get("table_filename")
batch_spec = PathBatchSpec(
{"path": other_table_filename, "reader_method": "read_csv"}
)
batch_data = execution_engine.get_batch_data(batch_spec=batch_spec)
df = batch_data.dataframe
return df.shape[0]
@metric_value(engine=SparkDFExecutionEngine)
def _spark(
cls,
execution_engine: "SparkDFExecutionEngine",
metric_domain_kwargs: Dict,
metric_value_kwargs: Dict,
metrics: Dict[Tuple, Any],
runtime_configuration: Dict,
) -> int:
other_table_filename = metric_domain_kwargs.get("table_filename")
batch_spec = PathBatchSpec(
{"path": other_table_filename,
"reader_method": "csv"}
)
batch_data = execution_engine.get_batch_data(batch_spec=batch_spec)
df = batch_data.dataframe
return df.count()
class ExpectTableRowCountToBeMoreThanOthers(TableExpectation):
"""TableExpectation class to compare the row count of the current dataset to other dataset(s)."""
metric_dependencies = ("table.row_count", "table.row_count_other")
success_keys = (
"other_table_filenames_list",
"comparison_key",
"lower_percentage_threshold",
)
default_kwarg_values = {
"other_table_filenames_list": None,
"comparison_key": "MEAN",
"lower_percentage_threshold": 100,
}
@staticmethod
def _validate_success_key(
param: str,
required: bool,
configuration: Optional[ExpectationConfiguration],
validation_rules: Dict[Callable, str],
) -> None:
"""Simple method to aggregate and apply validation rules to the `param`."""
if param not in configuration.kwargs:
if required:
raise InvalidExpectationKwargsError(
f"Param {param} is required but was not found in configuration."
)
return
param_value = configuration.kwargs[param]
for rule, error_message in validation_rules.items():
if not rule(param_value):
raise InvalidExpectationKwargsError(error_message)
def validate_configuration(
self, configuration: Optional[ExpectationConfiguration]
) -> bool:
super().validate_configuration(configuration=configuration)
if configuration is None:
configuration = self.configuration
self._validate_success_key(
param="other_table_filenames_list",
required=True,
configuration=configuration,
validation_rules={
lambda x: isinstance(x, str)
or isinstance(
x, List
): "other_table_filenames_list should either be a list or a string.",
lambda x: x: "other_table_filenames_list should not be empty",
},
)
self._validate_success_key(
param="comparison_key",
required=False,
configuration=configuration,
validation_rules={
lambda x: isinstance(x, str): "comparison_key should be a string.",
lambda x: x.upper()
in SupportedComparisonEnum.__members__: "Given comparison_key is not supported.",
},
)
self._validate_success_key(
param="lower_percentage_threshold",
required=False,
configuration=configuration,
validation_rules={
lambda x: isinstance(
x, int
): "lower_percentage_threshold should be an integer.",
lambda x: x
> 0: "lower_percentage_threshold should be strictly greater than 0.",
},
)
return True
def get_validation_dependencies(
self,
configuration: Optional[ExpectationConfiguration] = None,
execution_engine: Optional[ExecutionEngine] = None,
runtime_configuration: Optional[dict] = None,
) -> dict:
dependencies = super().get_validation_dependencies(
configuration, execution_engine, runtime_configuration
)
other_table_filenames_list = configuration.kwargs.get(
"other_table_filenames_list"
)
if isinstance(other_table_filenames_list, str):
other_table_filenames_list = [other_table_filenames_list]
for other_table_filename in other_table_filenames_list:
table_row_count_metric_config_other = deepcopy(
dependencies["metrics"]["table.row_count_other"]
)
table_row_count_metric_config_other.metric_domain_kwargs[
"table_filename"
] = other_table_filename
dependencies["metrics"][
f"table.row_count_other.{other_table_filename}"
] = table_row_count_metric_config_other
dependencies["metrics"]["table.row_count.self"] = dependencies["metrics"].pop(
"table.row_count"
)
dependencies["metrics"].pop("table.row_count_other")
return dependencies
def _validate(
self,
configuration: ExpectationConfiguration,
metrics: Dict,
runtime_configuration: dict = None,
execution_engine: ExecutionEngine = None,
) -> Dict:
comparison_key = self.get_success_kwargs(configuration)["comparison_key"]
other_table_filename_list = self.get_success_kwargs(configuration)[
"other_table_filenames_list"
]
lower_percentage_threshold = self.get_success_kwargs(configuration)[
"lower_percentage_threshold"
]
current_row_count = metrics["table.row_count.self"]
previous_row_count_list = []
for other_table_filename in other_table_filename_list:
previous_row_count_list.append(
metrics[f"table.row_count_other.{other_table_filename}"]
)
comparison_key_fn = SupportedComparisonEnum[comparison_key.upper()]
success_flag = comparison_key_fn(
current_row_count, previous_row_count_list, lower_percentage_threshold
)
return {
"success": success_flag,
"result": {
"self": current_row_count,
"other": previous_row_count_list,
"comparison_key": comparison_key_fn.name,
},
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment