Skip to content

Instantly share code, notes, and snippets.

@geoHeil
Last active December 28, 2022 20:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save geoHeil/12ce1e1403e474b44a84fd267323acb4 to your computer and use it in GitHub Desktop.
Save geoHeil/12ce1e1403e474b44a84fd267323acb4 to your computer and use it in GitHub Desktop.
stateful partitions (scd2)
class PartitionedParquetIOManagerScd2Stateful(IOManager):
"""
Spark dataframe stored in a partitioned Deltalake table as well as stateful (UPSERT MERGE INTO) SCD2 representation.
That latter one is also copied over into Postgres.
- step 1: store partitioned table
- step 2: SCD2 UPSERT using MERGE INTO
a) if not exists insert initial
b) in case it exists UPSERT
- local: check local file system
- S3: check bucket file system in object store if it exists
- step 3: push over into postgres (if tagged)
- step 4: optimize (ZORDER) and delete small files
- step 5: delete old history (0)
"""
def __init__(
self,
base_path,
username,
password,
jdbc_connection,
s3_client=None,
):
self._base_path = base_path
self._username = username
self._password = password
self._jdbc_connection = jdbc_connection
self._s3 = s3_client
def handle_output(
self,
context: OutputContext,
obj: Union[pandas.DataFrame, pyspark.sql.DataFrame],
):
cleaned_path = self._get_path("cleaned", context)
scd2_path = self._get_path("v2", context)
context.log.info(
f"partitions: {context.has_asset_partitions}, type: {context.dagster_type.typing_type}"
)
if "://" not in self._base_path:
os.makedirs(os.path.dirname(cleaned_path), exist_ok=True)
os.makedirs(os.path.dirname(scd2_path), exist_ok=True)
if isinstance(obj, pyspark.sql.DataFrame):
cleaned_path = cleaned_path.replace("s3", "s3a")
scd2_path = scd2_path.replace("s3", "s3a")
# step 1: store partitioned table
# deliberately partitioned to keep nice history
context.log.debug(f"Storing cleaned to: {cleaned_path}")
obj.write.partitionBy("valid_from").mode("append").format("delta").save(
cleaned_path
)
# step 2: SCD2 UPSERT using MERGE INTO
# - local: check local file system
# - S3: check bucket file system in object store if it exists
if "://" not in self._base_path:
# handle local
context.log.info(f"checking for path exits: {scd2_path}")
path_exists = Path(scd2_path).exists()
else:
# handle s3
bucket_postfix = self._base_path.removeprefix("s3://")
prefix = scd2_path.removeprefix("s3a://").removeprefix(
f"{bucket_postfix}/"
)
context.log.info(
f"checking for S3 bucket({bucket_postfix}) /prefix({prefix}) exits"
)
r = self._s3.list_objects_v2(Bucket=bucket_postfix, Prefix=prefix)
path_exists = r["KeyCount"] > 0
context.log.info(f"SCD2 path exits: {path_exists}")
spark = context.resources.pyspark.spark_session
# .op_def.output_defs[0]
meta_config = context.metadata
context.log.debug(f"asset metadata: {meta_config}")
if not path_exists:
# a) if not exists insert initial
# > Performance optimization: partition by closing valid_to TS and only process/join/handle data which is still open
# skipping for now as not needed.
# Trial on half a years of data shows: Our dataset is WAY TOO SMALL and not changing often enough! Better to NOT partition and having a smaller number of files
obj.write.mode("append").format("delta").save(scd2_path)
else:
# b) in case it exsits UPSERT
context.log.debug(
f"Soft delete no longer active records to SCD2 for path: {scd2_path}"
)
self.handle_deletions_for_merge(
scd2_path,
obj,
context.partition_key,
spark,
keys=meta_config["keys"]["list"],
)
self.handle_merge(
scd2_path,
obj,
spark,
keys=meta_config["keys"]["list"],
value_fields=meta_config["value_fields"]["list"],
insert_column_map=meta_config["insert_column_map"],
)
# to ensure updates to existing ones are handled as well (i.e. new values inserted and not only old ones invalidated)
# https://github.com/delta-io/delta/issues/1364
self.handle_merge(
scd2_path,
obj,
spark,
keys=meta_config["keys"]["list"],
value_fields=meta_config["value_fields"]["list"],
insert_column_map=meta_config["insert_column_map"],
)
# step 4: optimize (ZORDER) and delete small files
cleaned = DeltaTable.forPath(spark, cleaned_path)
scd2 = DeltaTable.forPath(spark, scd2_path)
# as a performance improvment this could also be run only once per week or so
part_to_optimize = f"valid_from=to_date('{context.partition_key}')"
context.log.debug(
f"Optimize cleaned only for partition of: {part_to_optimize}"
)
cleaned.optimize().where(part_to_optimize).executeZOrderBy(
meta_config["keys"]["list"][0]
)
# if pd.to_datetime(context.partition_key).weekday() == 1:
context.log.debug("Optimize Deltalake with ZORDER")
scd2.optimize().executeZOrderBy(meta_config["keys"]["list"][0])
# step 5: delete old history (0) (but keep last 7 days)
context.log.debug("Deleting history")
spark.conf.set(
"spark.databricks.delta.retentionDurationCheck.enabled", "false"
)
spark.conf.set(
"spark.databricks.delta.vacuum.parallelDelete.enabled", "true"
)
spark.conf.set(
"spark.sql.sources.parallelPartitionDiscovery.parallelism", "50"
)
context.log.debug(cleaned.vacuum(0))
context.log.debug(scd2.vacuum(0))
# step 6: compute some metadata
df_materialized = spark.read.parquet(scd2_path)
docstring_schema = spark_columns_to_markdown(df_materialized.schema)
context.add_output_metadata({"schema": MetadataValue.md(docstring_schema)})
row_count = df_materialized.count()
# step 3: push over into postgres (if tagged)
if "is_leaf_node_asset" in meta_config:
# publish additionally over to Postgres!
key = context.asset_key.path[-1] # type: ignore
schema = context.asset_key.path[-2]
context.log.debug(
f"2nd materialization for leaf node data to: {schema}.{key}"
)
if "fixup_special_json_column" in meta_config:
fixup_col: str = str(meta_config["fixup_special_json_column"])
context.log.debug(
f"Special post-write to DB type fixup to json for column: {fixup_col}"
)
mz_obj_jsonified = df_materialized.withColumn(
fixup_col, F.to_json(F.col(fixup_col))
)
# more generic solution: first SCD2 from spark is NEVER used directly in DBT graph E2E for building a view.
# The first DBT interaction ALWAYS builds a table - and subsequent ones then as view/table. This eases the operations a lot.
# Otherwise multiple drop cascade would be required everywhere.
# self._handle_drop_cascade(schema=schema, table=key)
self._write_to_db(mz_obj_jsonified, schema=schema, table=key)
self._handle_db_type_fixup(
column_to_fix=fixup_col, schema=schema, table=key
)
else:
self._write_to_db(df_materialized, schema=schema, table=key)
else:
raise Exception(f"Outputs of type {type(obj)} not supported.")
yield MetadataEntry.int(value=row_count, label="row_count")
yield MetadataEntry.path(path=scd2_path, label="path_scd2")
yield MetadataEntry.path(path=cleaned_path, label="path_cleaned")
def load_input(self, context) -> Union[pyspark.sql.DataFrame, str]:
context.log.info(
f"partitions: {context.has_asset_partitions}, type: {context.dagster_type.typing_type}"
)
path = self._get_path("v2", context.upstream_output)
if context.dagster_type.typing_type == pyspark.sql.DataFrame:
# return pyspark dataframe
path = path.replace("s3", "s3a")
return context.resources.pyspark.spark_session.read.format("delta").load(
path
)
return check.failed(
f"Inputs of type {context.dagster_type.typing_type} not supported. Please specify a valid type "
"for this input either on the argument of the @asset-decorated function."
)
def _get_path(self, path_type: str, context: OutputContext):
key = context.asset_key.path[-1] # type: ignore
# we are always partitioned.
# delta is NOT requiring full path (which is randomly changing based on calling optimize)
return os.path.join(self._base_path, f"{key}_{path_type}.parquet")
def _write_to_db(self, df: pyspark.sql.DataFrame, schema: str, table: str):
self._handle_drop_cascade(schema=schema, table=table)
df.write.mode("overwrite").format("jdbc").option(
"url", self._jdbc_connection
).option("dbtable", f"{schema}.{table}").option("batchsize", "10000").option(
"numPartitions", "4"
).option(
"user", self._username
).option(
"password", self._password
).option(
"driver", "org.postgresql.Driver"
).save()
def _handle_drop_cascade(self, schema: str, table: str):
"""Drop cascade table prior to write.
> WARNING this cascadingly deletes the data! Ensure to have the DBT pipelines run directly afterwards so any steps are recreated.
"""
db_con = self._jdbc_connection.replace(
"jdbc:postgresql://",
f"postgresql+psycopg2://{self._username}:{self._password}@",
)
engine = create_engine(db_con)
with engine.connect() as con:
con.execute(text(f"""DROP TABLE IF EXISTS {schema}.{table} CASCADE;"""))
def _handle_db_type_fixup(self, column_to_fix: str, schema: str, table: str):
"""Fixup specific column to be of type jsonb and not text"""
db_con = self._jdbc_connection.replace(
"jdbc:postgresql://",
f"postgresql+psycopg2://{self._username}:{self._password}@",
)
engine = create_engine(db_con)
with engine.connect() as con:
con.execute(
text(
f"""
ALTER TABLE {schema}.{table}
ALTER COLUMN {column_to_fix} TYPE jsonb USING {column_to_fix}::jsonb;
"""
)
)
def handle_deletions(
self,
updates_df,
keys,
current_dt,
columns_to_drop=[],
is_current_col="is_current",
valid_to_col="valid_to",
):
def _(df_current_state):
original_columns = df_current_state.columns
mapping = dict(
[
(c, f"{c}_c")
for c in updates_df.drop(*keys).drop(*columns_to_drop).columns
]
)
value_cols = updates_df.drop(*keys).drop(*columns_to_drop).columns
r = df_current_state.join(
perform_rename(updates_df, mapping), on=keys, how="left"
)
# it is good enough to look only at a single value (all wil be NULL anyways)
v = f"{value_cols[0]}_c"
pdate = pd.to_datetime(current_dt).date() - pd.to_timedelta(1, unit="d")
r = r.withColumn(
is_current_col,
F.when(F.col(v).isNull(), F.lit(False)).otherwise(
F.col(is_current_col)
),
).withColumn(
valid_to_col,
F.when(
(F.col(is_current_col) == F.lit(False))
& F.col(valid_to_col).isNull(),
F.lit(pdate),
).otherwise(F.col(valid_to_col)),
)
return r.select(original_columns)
return _
def handle_deletions_for_merge(self, path, updates, current_dt, spark, keys):
delta_state = spark.read.format("delta").load(path)
deletions_handled = delta_state.transform(
self.handle_deletions(updates, keys, current_dt)
)
deletions_handled.write.mode("overwrite").format("delta").save(path)
def construct_join_predicate(self, keys, comparator, merge_predicate="AND"):
r = []
for item in keys:
result = f"s.{item} {comparator} c.{item}"
r.append(result)
r = f" {merge_predicate} ".join(r)
get_dagster_logger().debug(f"Join predicate of: {r}")
return r
def handle_merge(self, path, updates, spark, keys, value_fields, insert_column_map):
"""
WARNING: must be called TWICE! To ensure UPSERTS work
https://github.com/delta-io/delta/issues/1364
WARNING: multiple joins are used inside - this is not terribly efficient.
"""
delta_state = DeltaTable.forPath(spark, path)
delta_state.alias("s").merge(
updates.alias("c"),
self.construct_join_predicate(keys, comparator="=")
+ " AND s.is_current = true",
).whenMatchedUpdate(
condition=f"s.is_current = true AND {self.construct_join_predicate(value_fields, merge_predicate='OR', comparator='<>')}",
set={ # Set current to false and endDate to source's effective date.
"is_current": "false",
"valid_to": F.date_sub(F.col("c.valid_from"), 1),
},
).whenNotMatchedInsert(
values=insert_column_map
).execute()
@io_manager(
config_schema={
"base_path": Field(str, is_required=False),
"username": StringSource,
"password": StringSource,
"jdbc_connection": str,
},
required_resource_keys={"pyspark"},
)
def local_stateful_scd2(init_context):
return PartitionedParquetIOManagerScd2Stateful(
base_path=init_context.resource_config.get(
"base_path", get_system_temp_directory()
),
username=init_context.resource_config.get("username"),
password=init_context.resource_config.get("password"),
jdbc_connection=init_context.resource_config.get("jdbc_connection"),
)
@io_manager(
required_resource_keys={"s3_bucket", "s3", "pyspark"},
config_schema={
"username": StringSource,
"password": StringSource,
"jdbc_connection": str,
},
)
def s3_stateful_scd2(init_context):
s3 = init_context.resources.s3
return PartitionedParquetIOManagerScd2Stateful(
base_path="s3://" + init_context.resources.s3_bucket,
username=init_context.resource_config.get("username"),
password=init_context.resource_config.get("password"),
jdbc_connection=init_context.resource_config.get("jdbc_connection"),
s3_client=s3,
)
@geoHeil
Copy link
Author

geoHeil commented Dec 28, 2022

The IO manager

The IO manager stores a partitioned Deltalake table (of the raw cleaned) data (for backfills in case of business logic changes for cleaning logic) as well as stateful (UPSERT MERGE INTO) SCD2 representation.
That latter one is also copied over into Postgres (potentially simply using external tables instead would be faster in the future, but Postgres can only talk to parquet, not deltalake).

- step 1: store partitioned table
- step 2: SCD2 UPSERT using MERGE INTO
    a) if not exists insert initial
    b) in case it exists UPSERT
        - local: check local file system
        - S3: check bucket file system in object store if it exists
- step 3: push over into postgres (if tagged) and potentially fixup data types
- step 4: optimize (ZORDER) and delete small files
- step 5: delete old delta history (0)

usage

Usage from the asset:

@asset(
    io_manager_key="parquet_io_scd2_stateful",
    compute_kind="spark_scd2",
    required_resource_keys={"pyspark"},
    metadata={
        "keys": {"list": ["id"]},
        "value_fields": {"list": ["value1", "value2"]},
        "insert_column_map": {
            "id": "c.id",
            "value1": "c.value1",
            "value2": "c.value2",
            # Set current to true along with the new address and its effective date.
            "valid_from": "c.valid_from",
            "is_current": "true",
            "valid_to": "null",
            "__run_id": "c.__run_id",
        },
    },
)
def a_scd2(context, a: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:

As you can see the incoming partition (a) of the daily dataset is MERGED INTO the SCD2 asset using the IO manager.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment