Skip to content

Instantly share code, notes, and snippets.

@TimCargan
Last active November 20, 2021 15:52
Show Gist options
  • Save TimCargan/fe21ee51fb620fe646b05a706a9fd42f to your computer and use it in GitHub Desktop.
Save TimCargan/fe21ee51fb620fe646b05a706a9fd42f to your computer and use it in GitHub Desktop.
Spark delta table merger
@dataclass
class Database:
"""
Data representation of a database.
This builds up all the paths needed to access data stored in DL
Args:
Attributes:
db_name:str - Database name, derived from layer
dl_path:str - datalake path, derived from layer and security
"""
db_name:str
dl_path:str = field(init=False)
dl_base:str = "/mnt/adsl/raw_data"
def __post_init__(self):
self.dl_path = f"{self.dl_base}/{self.db_name}"
def __str__(self):
return self.db_name
def table_dl_path(self, t:Table) -> str:
"""
Get the datalake path for a table
Args:
t:Table - table to build path for
"""
return f"{self.dl_path}/{t.table_name}"
def read_table(self, t:Table) -> DataFrame:
"""
Read a table out of the database
Args:
t:Table - table to read
"""
return spark.read.table(t.fq_name)
# Once Runtime > 8 we can use this code and remove the if table check
# target_delta_table = (DeltaTable
# .createIfNotExists(spark)
# .tableName(target_table.fq_name)
# .location(target_table.dl_path)
# .addColumns(df.schema)
# .partitionedBy(target_table.partition_cols)
# ).execute()
def mergeTableIntoDatabase(source_table:Table, target_database:Database, truncate_partition:bool=False, **kwargs) -> None:
"""
Merge table into the specified table, creating it if it doesn’t exist.
If the table exists and all data is merged using the spesified PKs or appended if no PKs are given
args:
soure_table:Table - the soruce table to read from
target_database:Database - the target database to merge data into
truncate_partition:bool - If true, all rows from target_table are removed where a matching partiton key value is found in source_table
"""
# Create database if needed
spark.sql(f"CREATE DATABASE IF NOT EXISTS `{target_database}`")
# Read source table
target_table = replace(source_table, database=target_database)
df = source_table.read_df()
if spark.sql(f"SHOW TABLES IN `{target_database}` LIKE '{target_table.table_name}'").count() == 0:
# If table doesnt exist create it as source table (Once we move to runtime > 8 this can be removed and we can use createIfNotExists)
df_out = df.write.format("delta").option("mergeSchema", "true").option("path", target_table.dl_path)
df_out = df_out.partitionBy(target_table.partition_cols) if target_table.partition_cols else df_out
df_out.mode("overwrite").saveAsTable(target_table.fq_name)
else:
target_delta_table = DeltaTable.forName(spark, target_table.fq_name)
if truncate_partition:
# Remove data with matching partiton columns (if present otherwise remove all data from table)
pcs_dict = {c : [r[c] for r in df.select(c).distinct().collect()] for c in source_table.partition_cols}
cond = reduce(lambda a, b: a & b, [col(c).isin(pcs_dict[c]) for c in source_table.partition_cols], lit(True)) # Base of lit(true) to handle no partition_cols
print(f"Deleting from {target_table.fq_name} based on: {cond}")
target_delta_table.delete(cond)
if target_table.cdc_cols:
# If CDC columns exist on table, filter source based on values in target
last_update, = target_delta_table.toDF().select(max(coalesce(*[col(c) for c in target_table.cdc_cols]))).first()
last_update = last_update if last_update else "1900-01-01" # Handle nulls in the table
date_filter = (coalesce(*[col(c) for c in target_table.cdc_cols]) > lit(last_update))
print(f"CDC Filter {source_table.fq_name} on: {date_filter}")
df = df.filter(date_filter)
# Merge Source into target using primary keys. If no PKs are given just append the data by matching no rows
pk_match = reduce(lambda a, b: a & b, [col(f"t.{c}") == col(f"n.{c}") for c in target_table.primary_keys]) if target_table.primary_keys else lit(False)
print(f"Merging {source_table.fq_name} into {target_table.fq_name} on: {pk_match}")
target_delta_table.alias("t").merge(df.alias("n"), pk_match).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()
@dataclass
class Table:
"""
Data represnetation of a table
Args:
table_name:str - Name of the table
datbase:Database - (Optional) database location of table
df:DataFrame - (Optional) data for table
primary_keys:Tuple[str,...] - (Optional) Tables Primary Keys, it is assunmed all PK predicates will hold
partition_cols:Tuple[str,...] - (Optional) Columns used to partiton the data
cdc_cols:Tuple[str,...] - (Optional) Colmns used for change data capture, they are checked in the order given
Note:
Either a database or df must be provied, if both are provied, the
Attributes:
fq_name:str - Fully qualifed table name, derived using the database name
dl_path:str - Datalake path to data
"""
table_name:str
database:Database = None
primary_keys:Tuple[str,...] = field(default_factory=tuple)
partition_cols:Tuple[str,...] = field(default_factory=tuple)
cdc_cols:Tuple[str,...]= field(default_factory=tuple)
df:DataFrame = None
fq_name:str = field(init=False)
dl_path:str = field(init=False)
def __post_init__(self):
# To work with databaseless tables
if self.database is None:
assert self.df is not None, "Must proved a database or Dataframe to create a table"
self.database = Database(db_name="") # Create a null database
self.fq_name = f"{self.database.db_name}.{self.table_name}"
self.dl_path = self.database.table_dl_path(self)
def read_df(self) -> DataFrame:
if not self.database.db_name:
return self.database.read_table(self)
return self.df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment