Created
March 19, 2024 23:59
-
-
Save philerooski/3fd9e8c355fd1117af1b006640e09c04 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A script which uploads validation results and a data validation | |
report to S3 for the FitbitSleepLogs data type. This was run in | |
Glue 4.0 while specifying --additional-python-modules great_expectations==0.18.11,boto3==1.24.70 | |
""" | |
import json | |
import logging | |
import os | |
import subprocess | |
import sys | |
from urllib.parse import urlparse | |
import boto3 | |
import great_expectations as gx | |
import yaml | |
from awsglue.context import GlueContext | |
from great_expectations.core.batch import RuntimeBatchRequest | |
from great_expectations.core.expectation_configuration import ExpectationConfiguration | |
from pyspark.context import SparkContext | |
from pyspark.sql import SparkSession | |
from pyspark.sql.types import StructType | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
handler = logging.StreamHandler(sys.stdout) | |
handler.setLevel(logging.DEBUG) | |
formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s") | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
NAMESPACE="main" | |
TABLE_NAME="FitbitSleepLogs" | |
PARQUET_BUCKET="recover-dev-processed-data" | |
CLOUDFORMATION_BUCKET="recover-dev-cloudformation" | |
EXPECTATION_SUITE_NAME="my_expectation_suite" | |
def get_spark_df(glue_context, parquet_bucket, namespace, datatype): | |
s3_parquet_path = \ | |
f"s3://{parquet_bucket}/{namespace}/parquet/dataset_{datatype}/" | |
dynamic_frame = glue_context.create_dynamic_frame_from_options( | |
connection_type="s3", | |
connection_options={ | |
"paths": [s3_parquet_path] | |
}, | |
format="parquet" | |
) | |
spark_df = dynamic_frame.toDF() | |
return spark_df | |
def get_table_schema(table_name: str) -> dict[str,list]: | |
""" | |
Get a table schema from table_columns.yaml | |
Args: | |
table_name (str): The name of the table | |
Returns: | |
(dict[str,list]) formatted as: | |
{ | |
"columns": | |
[ | |
{"Name": "str", "Type": "str"}, | |
... | |
], | |
"partition_keys": | |
[ | |
{"Name": "str", "Type": "str"}, | |
... | |
], | |
} | |
""" | |
if not os.path.exists("table_columns.yaml"): | |
s3_client = boto3.client("s3") | |
s3_object = s3_client.get_object( | |
Bucket=CLOUDFORMATION_BUCKET, | |
Key=f"{NAMESPACE}/src/glue/resources/table_columns.yaml" | |
) | |
schemas = yaml.safe_load(s3_object["Body"].read()) | |
else: | |
with open("table_columns.yaml", "r") as yaml_file: | |
schemas = yaml.safe_load(yaml_file) | |
table_schema = schemas["tables"][table_name] | |
return table_schema | |
def create_expectations_from_schema(table_schema: dict[str,list], actual_schema: StructType) -> list[ExpectationConfiguration]: | |
expectations = [] | |
present_fields = [field.name for field in actual_schema] | |
type_mapping = { | |
"string": "StringType", | |
"boolean": "BooleanType" | |
} | |
for field in table_schema["columns"] + table_schema["partition_keys"]: | |
if field["Name"] in present_fields and not field["Type"].startswith("array"): | |
expectation = ExpectationConfiguration( | |
expectation_type="expect_column_values_to_be_of_type", | |
kwargs={ | |
"column": field["Name"], | |
"type_": type_mapping[field["Type"]] | |
} | |
) | |
expectations.append(expectation) | |
return expectations | |
def create_range_expectations(): | |
expectations = [] | |
range_expectations = [ | |
{"column": "Duration", "min_value": 0, "max_value": 86400000}, | |
{"column": "Efficiency", "min_value": 0, "max_value": 100} | |
] | |
for expectation in range_expectations: | |
range_expectation = ExpectationConfiguration( | |
expectation_type="expect_column_values_to_be_between", | |
kwargs=expectation | |
) | |
expectations.append(range_expectation) | |
return expectations | |
def create_record_expectation(spark_df): | |
all_log_ids = spark_df.select("LogId").collect() | |
all_log_id_values = [row.asDict()["LogId"] for row in all_log_ids if row.asDict()["LogId"] is not None] | |
all_log_id_values.append("IAmAMissingRecord") | |
record_expectation = ExpectationConfiguration( | |
expectation_type="expect_column_distinct_values_to_equal_set", | |
kwargs={ | |
"column": "LogId", | |
"value_set": all_log_id_values, | |
}, | |
) | |
return record_expectation | |
def add_column_for_log_id_ref(spark_df): | |
all_log_ids = spark_df.select("LogId").collect() | |
all_log_id_values = [row.asDict()["LogId"] for row in all_log_ids] | |
all_log_id_values.append("IAmAMissingRecord") | |
spark = SparkSession.getActiveSession() | |
new_column_df = spark.createDataFrame([(value,) for value in all_log_id_values], ["LogIdRef"]) | |
spark_df = spark_df.withColumn("id", spark_df["LogId"]) | |
new_column_df = new_column_df.withColumn("id", new_column_df["LogIdRef"]) | |
joined_df = spark_df.join(new_column_df, "id", "fullouter").drop("id") | |
return joined_df | |
def create_smart_record_expectation(spark_df): | |
all_log_ids = spark_df.select("LogId").collect() | |
all_log_id_values = [row.asDict()["LogId"] for row in all_log_ids if row.asDict()["LogId"] is not None] | |
record_expectation = ExpectationConfiguration( | |
expectation_type="expect_column_values_to_be_in_set", | |
kwargs={ | |
"column": "LogIdRef", | |
"value_set": all_log_id_values, | |
}, | |
) | |
return record_expectation | |
def get_batch_request(spark_dataset): | |
batch_request = RuntimeBatchRequest( | |
datasource_name="my_spark_datasource", | |
data_connector_name="my_runtime_data_connector", | |
data_asset_name="my-parquet-data-asset", | |
runtime_parameters={"batch_data": spark_dataset}, | |
batch_identifiers={"my_batch_identifier": "okaybatchidentifier"} | |
) | |
return batch_request | |
def upload_results(context): | |
s3_client = boto3.client("s3") | |
validation_results = context.validations_store.get_all()[0].to_json_dict() | |
s3_client.put_object( | |
Body=json.dumps(validation_results), | |
Bucket=PARQUET_BUCKET, | |
Key=f"{NAMESPACE}/great_expectations/validation_results.json" | |
) | |
data_docs = context.build_data_docs() | |
data_docs_path = urlparse(data_docs["local_site"]).path | |
data_docs_dir = os.path.split(data_docs_path)[0] | |
subprocess.run(["aws", "s3", "sync", data_docs_dir, f"s3://{PARQUET_BUCKET}/{NAMESPACE}/great_expectations/data_docs/"]) | |
def main(): | |
context = gx.get_context() | |
glue_context = GlueContext(SparkContext.getOrCreate()) | |
logger.info("get_spark_df") | |
spark_df = get_spark_df( | |
glue_context=glue_context, | |
parquet_bucket=PARQUET_BUCKET, | |
namespace=NAMESPACE, | |
datatype=TABLE_NAME.lower() | |
) | |
expectation_suite = context.add_expectation_suite(EXPECTATION_SUITE_NAME) | |
table_schema = get_table_schema(table_name=TABLE_NAME) | |
schema_expectations = create_expectations_from_schema( | |
table_schema=table_schema, | |
actual_schema=spark_df.schema | |
) | |
range_expectations = create_range_expectations() | |
spark_df = add_column_for_log_id_ref(spark_df) | |
record_expectation = create_record_expectation(spark_df) | |
smart_record_expectation = create_smart_record_expectation(spark_df) | |
expectation_suite.add_expectation_configurations( | |
schema_expectations + range_expectations + [record_expectation, smart_record_expectation] | |
) | |
context.save_expectation_suite(expectation_suite) | |
spark_datasource = context.sources.add_spark("my_spark_datasource") | |
spark_data_asset = spark_datasource.add_dataframe_asset( | |
name="my_spark_data_asset", | |
dataframe=spark_df | |
) | |
batch_request = spark_data_asset.build_batch_request() | |
checkpoint = context.add_checkpoint( | |
name="my_checkpoint", | |
validations=[ | |
{"batch_request": batch_request, "expectation_suite_name": EXPECTATION_SUITE_NAME} | |
], | |
runtime_configuration={ | |
"result_format": { | |
"result_format": "SUMMARY", | |
"partial_unexpected_count": 100, | |
} | |
} | |
) | |
checkpoint_results = checkpoint.run() | |
upload_results(context) | |
return context | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment