Last active
February 15, 2024 10:38
-
-
Save filipamiralopes/5093889464a52092cbf53c29840a5196 to your computer and use it in GitHub Desktop.
Dynamically trim incoming data in Redshift table
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyspark | |
from pyspark.sql import SparkSession | |
import pyspark.sql.functions as F | |
from pyspark.sql.types import StringType | |
import requests | |
import redshift_connector | |
import logging as log | |
from typing import List | |
from redshift_connector.core import Connection | |
from table_meta.py import TableMeta # refer to table_meta.py (https://gist.github.com/filipamiralopes/06e5375198e04bfaccc64421031adf81) | |
spark = SparkSession.builder.master("local[1]") \ | |
.appName('my_spark_app') \ | |
.getOrCreate() | |
def _get_truncate_cols_cmd_list( | |
spark: SparkSession, | |
s3_path: str, | |
table_name: str, | |
db_table: str, | |
table_meta: TableMeta, | |
iam_role: str | |
): | |
"""Generates set of commands to handle varchar length error in temporary table.""" | |
command_list = [] | |
off_cols = [] | |
len_df = None | |
log.info("Read data from S3") | |
df = spark.read.parquet(s3_path, pathGlobFilter="*.parquet") | |
df_strings = df.select([F.length(col.name).alias(col.name) for col in df.schema.fields if isinstance(col.dataType, StringType)]) | |
max_len = df.groupby().max().na.fill(value=0).first().asDict() | |
# Create empty temporary table from production table | |
cmd_1 = f"CREATE TEMPORARY TABLE temp_{table_name} AS SELECT * FROM {db_table} LIMIT 0;" | |
command_list.append(cmd_1) | |
# Iterate through columns in YAML schema to compare with incoming data | |
for col in table_meta.columns: | |
if 'varchar' in col.type.lower(): | |
if col.type.lower() == 'varchar': | |
len_yaml = 256 | |
elif col.type.lower() == 'varchar(max)': | |
len_yaml = 65535 | |
else: | |
# extract numeric value | |
len_yaml = int(col.type.lower().replace('varchar(', '').replace(')', '')) | |
try: | |
len_df = max_len[f"max({col.name})"] | |
except Exception as e: | |
log.info(f'{e}: {col.name} does not exist. Process will continue without it.') | |
if len_df > len_yaml: | |
new_len = str(len_df * 2) | |
# Increase varchar length of offensive column in temp table | |
cmd_2 = f"ALTER TABLE temp_{table_name} ALTER COLUMN {col.name} TYPE VARCHAR({new_len});" | |
command_list.append(cmd_2) | |
# Collect offensive columns to later substring it and build notification | |
off_cols.append({"column_name": col.name, "current_length": len_yaml, "incoming_length": len_df}) | |
# Copy data from s3 into temp table | |
cmd_3 = f"COPY temp_{table_name} FROM '{s3_path}' IAM_ROLE '{iam_role}' FORMAT AS PARQUET SERIALIZETOJSON;" | |
command_list.append(cmd_3) | |
# Trim data in temporary table. Account for special characters, load 90% of current VARCHAR length | |
for d in off_cols: | |
cmd_4 = (f"UPDATE temp_{table_name} {d['column_name']} SET {d['column_name']} = substring({d['column_name']},1,{int(0.9*d['current_length'])})") | |
command_list.append(cmd_4) | |
return command_list, off_cols | |
def _truncate_cols( | |
redshift_conn: Connection, | |
command_list: List, | |
table_name: str, | |
db_table: str, | |
off_cols: List | |
) -> str: | |
""" | |
Applies set of commands to handle varchar length error in temporary table. | |
Outputs notification text (to send e.g. to Slack). | |
""" | |
text = None | |
if command_list: | |
log.info("Running the following commands:") | |
with redshift_conn.cursor() as cursor: | |
for cmd in command_list: | |
try: | |
redshift_conn.rollback() | |
redshift_conn.autocommit = True | |
log.info(f"{cmd}") | |
cursor.execute(cmd) | |
redshift_conn.autocommit = False | |
except Exception as e: | |
if "target column size should be different" in str(e): | |
log.info("Column has already been altered, please adjust yaml file") | |
else: | |
log.info(e) | |
pre_command = f"TRUNCATE {db_table}" | |
command = f"INSERT INTO {db_table} SELECT * FROM temp_{table_name}" | |
log.info(f"{pre_command}") | |
cursor.execute(pre_command) | |
log.info(f"{command}") | |
cursor.execute(command) | |
redshift_conn.commit() | |
# Prepare text to send Slack notification | |
which_cols = "" | |
for d in off_cols: | |
prettified = (f"Column name: `{d['column_name']}`\nCurrent data length: {d['current_length']}\n" | |
f"Incoming data length: {d['incoming_length']}\n\n") | |
which_cols = which_cols + prettified | |
text = (f"Hello team!\n" | |
f":large_yellow_circle: Some column(s) were truncated in table `{db_table}`.\n" | |
f"Consider increasing the length in schema for below columns to disable this notification:\n\n" | |
f"{which_cols}" | |
) | |
else: | |
log.info("All columns match lengths in dataframe") | |
return text | |
def _send_slack_message(text: str) -> None: | |
""" | |
Set up your Slack App before using this function. | |
Utility function to send messages in a channel | |
Args: | |
:param text: Text so be sent to Slack Channel | |
:return: Status response code from curl command | |
""" | |
if text: | |
# your channel webhook url | |
url = "https://hooks.slack.com/services/.../.../..." | |
headers = CaseInsensitiveDict() | |
headers["Content-Type"] = "application/json" | |
data = '{"text":"' + text + '"}' | |
log.info(data) | |
resp = requests.post(url, headers=headers, data=data.encode('utf-8')) | |
log.info(resp.status_code) | |
def load_data_with_length_check( | |
target_db: str, | |
table_name: str, | |
s3_path: str, | |
iam_role: str, | |
redshift_conn: Connection, | |
spark: SparkSession, | |
table_meta: TableMeta | |
) -> None: | |
log.info("Truncate production table and write data to Redshift") | |
db_table = f"{target_db}.{table_name}" | |
pre_command = f"TRUNCATE {db_table}" | |
command = f"COPY {db_table} FROM '{s3_path}' IAM_ROLE '{iam_role}' FORMAT AS PARQUET SERIALIZETOJSON" | |
with redshift_conn.cursor() as cursor: | |
try: | |
cursor.execute(pre_command) | |
cursor.execute(command) | |
redshift_conn.commit() | |
except Exception as e: | |
if "The length of the data column" in str(e): | |
log.info(f"String length error: {e}") | |
# Get SQL command list to handle error | |
command_list, off_cols = _get_truncate_cols_cmd_list(spark, table_meta, db_table, table_name, s3_path, iam_role) | |
# Apply SQL commands and generate notification text | |
text = _truncate_cols(redshift_conn, command_list, table_name, db_table, off_cols) | |
_send_slack_message(text) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment