-
-
Save geoHeil/12ce1e1403e474b44a84fd267323acb4 to your computer and use it in GitHub Desktop.
stateful partitions (scd2)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PartitionedParquetIOManagerScd2Stateful(IOManager): | |
""" | |
Spark dataframe stored in a partitioned Deltalake table as well as stateful (UPSERT MERGE INTO) SCD2 representation. | |
That latter one is also copied over into Postgres. | |
- step 1: store partitioned table | |
- step 2: SCD2 UPSERT using MERGE INTO | |
a) if not exists insert initial | |
b) in case it exists UPSERT | |
- local: check local file system | |
- S3: check bucket file system in object store if it exists | |
- step 3: push over into postgres (if tagged) | |
- step 4: optimize (ZORDER) and delete small files | |
- step 5: delete old history (0) | |
""" | |
def __init__( | |
self, | |
base_path, | |
username, | |
password, | |
jdbc_connection, | |
s3_client=None, | |
): | |
self._base_path = base_path | |
self._username = username | |
self._password = password | |
self._jdbc_connection = jdbc_connection | |
self._s3 = s3_client | |
def handle_output( | |
self, | |
context: OutputContext, | |
obj: Union[pandas.DataFrame, pyspark.sql.DataFrame], | |
): | |
cleaned_path = self._get_path("cleaned", context) | |
scd2_path = self._get_path("v2", context) | |
context.log.info( | |
f"partitions: {context.has_asset_partitions}, type: {context.dagster_type.typing_type}" | |
) | |
if "://" not in self._base_path: | |
os.makedirs(os.path.dirname(cleaned_path), exist_ok=True) | |
os.makedirs(os.path.dirname(scd2_path), exist_ok=True) | |
if isinstance(obj, pyspark.sql.DataFrame): | |
cleaned_path = cleaned_path.replace("s3", "s3a") | |
scd2_path = scd2_path.replace("s3", "s3a") | |
# step 1: store partitioned table | |
# deliberately partitioned to keep nice history | |
context.log.debug(f"Storing cleaned to: {cleaned_path}") | |
obj.write.partitionBy("valid_from").mode("append").format("delta").save( | |
cleaned_path | |
) | |
# step 2: SCD2 UPSERT using MERGE INTO | |
# - local: check local file system | |
# - S3: check bucket file system in object store if it exists | |
if "://" not in self._base_path: | |
# handle local | |
context.log.info(f"checking for path exits: {scd2_path}") | |
path_exists = Path(scd2_path).exists() | |
else: | |
# handle s3 | |
bucket_postfix = self._base_path.removeprefix("s3://") | |
prefix = scd2_path.removeprefix("s3a://").removeprefix( | |
f"{bucket_postfix}/" | |
) | |
context.log.info( | |
f"checking for S3 bucket({bucket_postfix}) /prefix({prefix}) exits" | |
) | |
r = self._s3.list_objects_v2(Bucket=bucket_postfix, Prefix=prefix) | |
path_exists = r["KeyCount"] > 0 | |
context.log.info(f"SCD2 path exits: {path_exists}") | |
spark = context.resources.pyspark.spark_session | |
# .op_def.output_defs[0] | |
meta_config = context.metadata | |
context.log.debug(f"asset metadata: {meta_config}") | |
if not path_exists: | |
# a) if not exists insert initial | |
# > Performance optimization: partition by closing valid_to TS and only process/join/handle data which is still open | |
# skipping for now as not needed. | |
# Trial on half a years of data shows: Our dataset is WAY TOO SMALL and not changing often enough! Better to NOT partition and having a smaller number of files | |
obj.write.mode("append").format("delta").save(scd2_path) | |
else: | |
# b) in case it exsits UPSERT | |
context.log.debug( | |
f"Soft delete no longer active records to SCD2 for path: {scd2_path}" | |
) | |
self.handle_deletions_for_merge( | |
scd2_path, | |
obj, | |
context.partition_key, | |
spark, | |
keys=meta_config["keys"]["list"], | |
) | |
self.handle_merge( | |
scd2_path, | |
obj, | |
spark, | |
keys=meta_config["keys"]["list"], | |
value_fields=meta_config["value_fields"]["list"], | |
insert_column_map=meta_config["insert_column_map"], | |
) | |
# to ensure updates to existing ones are handled as well (i.e. new values inserted and not only old ones invalidated) | |
# https://github.com/delta-io/delta/issues/1364 | |
self.handle_merge( | |
scd2_path, | |
obj, | |
spark, | |
keys=meta_config["keys"]["list"], | |
value_fields=meta_config["value_fields"]["list"], | |
insert_column_map=meta_config["insert_column_map"], | |
) | |
# step 4: optimize (ZORDER) and delete small files | |
cleaned = DeltaTable.forPath(spark, cleaned_path) | |
scd2 = DeltaTable.forPath(spark, scd2_path) | |
# as a performance improvment this could also be run only once per week or so | |
part_to_optimize = f"valid_from=to_date('{context.partition_key}')" | |
context.log.debug( | |
f"Optimize cleaned only for partition of: {part_to_optimize}" | |
) | |
cleaned.optimize().where(part_to_optimize).executeZOrderBy( | |
meta_config["keys"]["list"][0] | |
) | |
# if pd.to_datetime(context.partition_key).weekday() == 1: | |
context.log.debug("Optimize Deltalake with ZORDER") | |
scd2.optimize().executeZOrderBy(meta_config["keys"]["list"][0]) | |
# step 5: delete old history (0) (but keep last 7 days) | |
context.log.debug("Deleting history") | |
spark.conf.set( | |
"spark.databricks.delta.retentionDurationCheck.enabled", "false" | |
) | |
spark.conf.set( | |
"spark.databricks.delta.vacuum.parallelDelete.enabled", "true" | |
) | |
spark.conf.set( | |
"spark.sql.sources.parallelPartitionDiscovery.parallelism", "50" | |
) | |
context.log.debug(cleaned.vacuum(0)) | |
context.log.debug(scd2.vacuum(0)) | |
# step 6: compute some metadata | |
df_materialized = spark.read.parquet(scd2_path) | |
docstring_schema = spark_columns_to_markdown(df_materialized.schema) | |
context.add_output_metadata({"schema": MetadataValue.md(docstring_schema)}) | |
row_count = df_materialized.count() | |
# step 3: push over into postgres (if tagged) | |
if "is_leaf_node_asset" in meta_config: | |
# publish additionally over to Postgres! | |
key = context.asset_key.path[-1] # type: ignore | |
schema = context.asset_key.path[-2] | |
context.log.debug( | |
f"2nd materialization for leaf node data to: {schema}.{key}" | |
) | |
if "fixup_special_json_column" in meta_config: | |
fixup_col: str = str(meta_config["fixup_special_json_column"]) | |
context.log.debug( | |
f"Special post-write to DB type fixup to json for column: {fixup_col}" | |
) | |
mz_obj_jsonified = df_materialized.withColumn( | |
fixup_col, F.to_json(F.col(fixup_col)) | |
) | |
# more generic solution: first SCD2 from spark is NEVER used directly in DBT graph E2E for building a view. | |
# The first DBT interaction ALWAYS builds a table - and subsequent ones then as view/table. This eases the operations a lot. | |
# Otherwise multiple drop cascade would be required everywhere. | |
# self._handle_drop_cascade(schema=schema, table=key) | |
self._write_to_db(mz_obj_jsonified, schema=schema, table=key) | |
self._handle_db_type_fixup( | |
column_to_fix=fixup_col, schema=schema, table=key | |
) | |
else: | |
self._write_to_db(df_materialized, schema=schema, table=key) | |
else: | |
raise Exception(f"Outputs of type {type(obj)} not supported.") | |
yield MetadataEntry.int(value=row_count, label="row_count") | |
yield MetadataEntry.path(path=scd2_path, label="path_scd2") | |
yield MetadataEntry.path(path=cleaned_path, label="path_cleaned") | |
def load_input(self, context) -> Union[pyspark.sql.DataFrame, str]: | |
context.log.info( | |
f"partitions: {context.has_asset_partitions}, type: {context.dagster_type.typing_type}" | |
) | |
path = self._get_path("v2", context.upstream_output) | |
if context.dagster_type.typing_type == pyspark.sql.DataFrame: | |
# return pyspark dataframe | |
path = path.replace("s3", "s3a") | |
return context.resources.pyspark.spark_session.read.format("delta").load( | |
path | |
) | |
return check.failed( | |
f"Inputs of type {context.dagster_type.typing_type} not supported. Please specify a valid type " | |
"for this input either on the argument of the @asset-decorated function." | |
) | |
def _get_path(self, path_type: str, context: OutputContext): | |
key = context.asset_key.path[-1] # type: ignore | |
# we are always partitioned. | |
# delta is NOT requiring full path (which is randomly changing based on calling optimize) | |
return os.path.join(self._base_path, f"{key}_{path_type}.parquet") | |
def _write_to_db(self, df: pyspark.sql.DataFrame, schema: str, table: str): | |
self._handle_drop_cascade(schema=schema, table=table) | |
df.write.mode("overwrite").format("jdbc").option( | |
"url", self._jdbc_connection | |
).option("dbtable", f"{schema}.{table}").option("batchsize", "10000").option( | |
"numPartitions", "4" | |
).option( | |
"user", self._username | |
).option( | |
"password", self._password | |
).option( | |
"driver", "org.postgresql.Driver" | |
).save() | |
def _handle_drop_cascade(self, schema: str, table: str): | |
"""Drop cascade table prior to write. | |
> WARNING this cascadingly deletes the data! Ensure to have the DBT pipelines run directly afterwards so any steps are recreated. | |
""" | |
db_con = self._jdbc_connection.replace( | |
"jdbc:postgresql://", | |
f"postgresql+psycopg2://{self._username}:{self._password}@", | |
) | |
engine = create_engine(db_con) | |
with engine.connect() as con: | |
con.execute(text(f"""DROP TABLE IF EXISTS {schema}.{table} CASCADE;""")) | |
def _handle_db_type_fixup(self, column_to_fix: str, schema: str, table: str): | |
"""Fixup specific column to be of type jsonb and not text""" | |
db_con = self._jdbc_connection.replace( | |
"jdbc:postgresql://", | |
f"postgresql+psycopg2://{self._username}:{self._password}@", | |
) | |
engine = create_engine(db_con) | |
with engine.connect() as con: | |
con.execute( | |
text( | |
f""" | |
ALTER TABLE {schema}.{table} | |
ALTER COLUMN {column_to_fix} TYPE jsonb USING {column_to_fix}::jsonb; | |
""" | |
) | |
) | |
def handle_deletions( | |
self, | |
updates_df, | |
keys, | |
current_dt, | |
columns_to_drop=[], | |
is_current_col="is_current", | |
valid_to_col="valid_to", | |
): | |
def _(df_current_state): | |
original_columns = df_current_state.columns | |
mapping = dict( | |
[ | |
(c, f"{c}_c") | |
for c in updates_df.drop(*keys).drop(*columns_to_drop).columns | |
] | |
) | |
value_cols = updates_df.drop(*keys).drop(*columns_to_drop).columns | |
r = df_current_state.join( | |
perform_rename(updates_df, mapping), on=keys, how="left" | |
) | |
# it is good enough to look only at a single value (all wil be NULL anyways) | |
v = f"{value_cols[0]}_c" | |
pdate = pd.to_datetime(current_dt).date() - pd.to_timedelta(1, unit="d") | |
r = r.withColumn( | |
is_current_col, | |
F.when(F.col(v).isNull(), F.lit(False)).otherwise( | |
F.col(is_current_col) | |
), | |
).withColumn( | |
valid_to_col, | |
F.when( | |
(F.col(is_current_col) == F.lit(False)) | |
& F.col(valid_to_col).isNull(), | |
F.lit(pdate), | |
).otherwise(F.col(valid_to_col)), | |
) | |
return r.select(original_columns) | |
return _ | |
def handle_deletions_for_merge(self, path, updates, current_dt, spark, keys): | |
delta_state = spark.read.format("delta").load(path) | |
deletions_handled = delta_state.transform( | |
self.handle_deletions(updates, keys, current_dt) | |
) | |
deletions_handled.write.mode("overwrite").format("delta").save(path) | |
def construct_join_predicate(self, keys, comparator, merge_predicate="AND"): | |
r = [] | |
for item in keys: | |
result = f"s.{item} {comparator} c.{item}" | |
r.append(result) | |
r = f" {merge_predicate} ".join(r) | |
get_dagster_logger().debug(f"Join predicate of: {r}") | |
return r | |
def handle_merge(self, path, updates, spark, keys, value_fields, insert_column_map): | |
""" | |
WARNING: must be called TWICE! To ensure UPSERTS work | |
https://github.com/delta-io/delta/issues/1364 | |
WARNING: multiple joins are used inside - this is not terribly efficient. | |
""" | |
delta_state = DeltaTable.forPath(spark, path) | |
delta_state.alias("s").merge( | |
updates.alias("c"), | |
self.construct_join_predicate(keys, comparator="=") | |
+ " AND s.is_current = true", | |
).whenMatchedUpdate( | |
condition=f"s.is_current = true AND {self.construct_join_predicate(value_fields, merge_predicate='OR', comparator='<>')}", | |
set={ # Set current to false and endDate to source's effective date. | |
"is_current": "false", | |
"valid_to": F.date_sub(F.col("c.valid_from"), 1), | |
}, | |
).whenNotMatchedInsert( | |
values=insert_column_map | |
).execute() | |
@io_manager( | |
config_schema={ | |
"base_path": Field(str, is_required=False), | |
"username": StringSource, | |
"password": StringSource, | |
"jdbc_connection": str, | |
}, | |
required_resource_keys={"pyspark"}, | |
) | |
def local_stateful_scd2(init_context): | |
return PartitionedParquetIOManagerScd2Stateful( | |
base_path=init_context.resource_config.get( | |
"base_path", get_system_temp_directory() | |
), | |
username=init_context.resource_config.get("username"), | |
password=init_context.resource_config.get("password"), | |
jdbc_connection=init_context.resource_config.get("jdbc_connection"), | |
) | |
@io_manager( | |
required_resource_keys={"s3_bucket", "s3", "pyspark"}, | |
config_schema={ | |
"username": StringSource, | |
"password": StringSource, | |
"jdbc_connection": str, | |
}, | |
) | |
def s3_stateful_scd2(init_context): | |
s3 = init_context.resources.s3 | |
return PartitionedParquetIOManagerScd2Stateful( | |
base_path="s3://" + init_context.resources.s3_bucket, | |
username=init_context.resource_config.get("username"), | |
password=init_context.resource_config.get("password"), | |
jdbc_connection=init_context.resource_config.get("jdbc_connection"), | |
s3_client=s3, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The IO manager
The IO manager stores a partitioned Deltalake table (of the raw cleaned) data (for backfills in case of business logic changes for cleaning logic) as well as stateful (UPSERT MERGE INTO) SCD2 representation.
That latter one is also copied over into Postgres (potentially simply using external tables instead would be faster in the future, but Postgres can only talk to parquet, not deltalake).
usage
Usage from the asset:
As you can see the incoming partition (
a
) of the daily dataset isMERGED INTO
the SCD2 asset using the IO manager.