Skip to content

Instantly share code, notes, and snippets.

@filipamiralopes
Last active February 15, 2024 10:38
Show Gist options
  • Save filipamiralopes/5093889464a52092cbf53c29840a5196 to your computer and use it in GitHub Desktop.
Save filipamiralopes/5093889464a52092cbf53c29840a5196 to your computer and use it in GitHub Desktop.
Dynamically trim incoming data in Redshift table
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import StringType
import requests
import redshift_connector
import logging as log
from typing import List
from redshift_connector.core import Connection
from table_meta.py import TableMeta # refer to table_meta.py (https://gist.github.com/filipamiralopes/06e5375198e04bfaccc64421031adf81)
spark = SparkSession.builder.master("local[1]") \
.appName('my_spark_app') \
.getOrCreate()
def _get_truncate_cols_cmd_list(
spark: SparkSession,
s3_path: str,
table_name: str,
db_table: str,
table_meta: TableMeta,
iam_role: str
):
"""Generates set of commands to handle varchar length error in temporary table."""
command_list = []
off_cols = []
len_df = None
log.info("Read data from S3")
df = spark.read.parquet(s3_path, pathGlobFilter="*.parquet")
df_strings = df.select([F.length(col.name).alias(col.name) for col in df.schema.fields if isinstance(col.dataType, StringType)])
max_len = df.groupby().max().na.fill(value=0).first().asDict()
# Create empty temporary table from production table
cmd_1 = f"CREATE TEMPORARY TABLE temp_{table_name} AS SELECT * FROM {db_table} LIMIT 0;"
command_list.append(cmd_1)
# Iterate through columns in YAML schema to compare with incoming data
for col in table_meta.columns:
if 'varchar' in col.type.lower():
if col.type.lower() == 'varchar':
len_yaml = 256
elif col.type.lower() == 'varchar(max)':
len_yaml = 65535
else:
# extract numeric value
len_yaml = int(col.type.lower().replace('varchar(', '').replace(')', ''))
try:
len_df = max_len[f"max({col.name})"]
except Exception as e:
log.info(f'{e}: {col.name} does not exist. Process will continue without it.')
if len_df > len_yaml:
new_len = str(len_df * 2)
# Increase varchar length of offensive column in temp table
cmd_2 = f"ALTER TABLE temp_{table_name} ALTER COLUMN {col.name} TYPE VARCHAR({new_len});"
command_list.append(cmd_2)
# Collect offensive columns to later substring it and build notification
off_cols.append({"column_name": col.name, "current_length": len_yaml, "incoming_length": len_df})
# Copy data from s3 into temp table
cmd_3 = f"COPY temp_{table_name} FROM '{s3_path}' IAM_ROLE '{iam_role}' FORMAT AS PARQUET SERIALIZETOJSON;"
command_list.append(cmd_3)
# Trim data in temporary table. Account for special characters, load 90% of current VARCHAR length
for d in off_cols:
cmd_4 = (f"UPDATE temp_{table_name} {d['column_name']} SET {d['column_name']} = substring({d['column_name']},1,{int(0.9*d['current_length'])})")
command_list.append(cmd_4)
return command_list, off_cols
def _truncate_cols(
redshift_conn: Connection,
command_list: List,
table_name: str,
db_table: str,
off_cols: List
) -> str:
"""
Applies set of commands to handle varchar length error in temporary table.
Outputs notification text (to send e.g. to Slack).
"""
text = None
if command_list:
log.info("Running the following commands:")
with redshift_conn.cursor() as cursor:
for cmd in command_list:
try:
redshift_conn.rollback()
redshift_conn.autocommit = True
log.info(f"{cmd}")
cursor.execute(cmd)
redshift_conn.autocommit = False
except Exception as e:
if "target column size should be different" in str(e):
log.info("Column has already been altered, please adjust yaml file")
else:
log.info(e)
pre_command = f"TRUNCATE {db_table}"
command = f"INSERT INTO {db_table} SELECT * FROM temp_{table_name}"
log.info(f"{pre_command}")
cursor.execute(pre_command)
log.info(f"{command}")
cursor.execute(command)
redshift_conn.commit()
# Prepare text to send Slack notification
which_cols = ""
for d in off_cols:
prettified = (f"Column name: `{d['column_name']}`\nCurrent data length: {d['current_length']}\n"
f"Incoming data length: {d['incoming_length']}\n\n")
which_cols = which_cols + prettified
text = (f"Hello team!\n"
f":large_yellow_circle: Some column(s) were truncated in table `{db_table}`.\n"
f"Consider increasing the length in schema for below columns to disable this notification:\n\n"
f"{which_cols}"
)
else:
log.info("All columns match lengths in dataframe")
return text
def _send_slack_message(text: str) -> None:
"""
Set up your Slack App before using this function.
Utility function to send messages in a channel
Args:
:param text: Text so be sent to Slack Channel
:return: Status response code from curl command
"""
if text:
# your channel webhook url
url = "https://hooks.slack.com/services/.../.../..."
headers = CaseInsensitiveDict()
headers["Content-Type"] = "application/json"
data = '{"text":"' + text + '"}'
log.info(data)
resp = requests.post(url, headers=headers, data=data.encode('utf-8'))
log.info(resp.status_code)
def load_data_with_length_check(
target_db: str,
table_name: str,
s3_path: str,
iam_role: str,
redshift_conn: Connection,
spark: SparkSession,
table_meta: TableMeta
) -> None:
log.info("Truncate production table and write data to Redshift")
db_table = f"{target_db}.{table_name}"
pre_command = f"TRUNCATE {db_table}"
command = f"COPY {db_table} FROM '{s3_path}' IAM_ROLE '{iam_role}' FORMAT AS PARQUET SERIALIZETOJSON"
with redshift_conn.cursor() as cursor:
try:
cursor.execute(pre_command)
cursor.execute(command)
redshift_conn.commit()
except Exception as e:
if "The length of the data column" in str(e):
log.info(f"String length error: {e}")
# Get SQL command list to handle error
command_list, off_cols = _get_truncate_cols_cmd_list(spark, table_meta, db_table, table_name, s3_path, iam_role)
# Apply SQL commands and generate notification text
text = _truncate_cols(redshift_conn, command_list, table_name, db_table, off_cols)
_send_slack_message(text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment