Skip to content

Instantly share code, notes, and snippets.

@manuthu
Last active July 4, 2024 11:08
Show Gist options
  • Save manuthu/6fbbe348fa59bf2089a5c4bfd664446c to your computer and use it in GitHub Desktop.
Save manuthu/6fbbe348fa59bf2089a5c4bfd664446c to your computer and use it in GitHub Desktop.
RDS Export to S3 dumps the files as parquet files. This script assumes that the parquet files have already been downloaded to the local onprem server.
import os
import pyarrow.parquet as pq
import pandas as pd
import mysql.connector
import concurrent.futures
import logging
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('DB')
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler('db.log')
fh.setFormatter(formatter)
logger.addHandler(fh)
# Define your MySQL configuration
MYSQL_HOST = "localhost"
MYSQL_USER = "xxxx"
MYSQL_PASSWORD = "xxxxxxxx#"
DATABASE = "xxxx"
# Define the path to the Parquet datasets
DB_PATH = "/local/db/db-dump-26-10-23/staging"
# Define the number of CPU cores to use
NUM_CORES = 8
# Create a connection pool
connection_pool = mysql.connector.pooling.MySQLConnectionPool(
pool_name="conn_pool",
pool_size=NUM_CORES,
pool_reset_session=True,
host=MYSQL_HOST,
user=MYSQL_USER,
password=MYSQL_PASSWORD,
database=DATABASE
)
def process_one_parquet_dataset(parquet_dir):
parquet_data = []
for filename in os.listdir(parquet_dir):
if filename.endswith(".parquet"):
file_path = os.path.join(parquet_dir, filename)
table = pq.read_table(file_path)
df = table.to_pandas()
parquet_data.append(df)
return pd.concat(parquet_data, ignore_index=True)
def insert_df_to_db(cursor, table_name, result_df):
for _, row in result_df.iterrows():
insert_query = f"""
INSERT INTO {table_name} ({', '.join(result_df.columns)})
VALUES ({', '.join(['%s'] * len(result_df.columns))})
"""
data_tuple = tuple(row)
cursor.execute(insert_query, data_tuple)
def process_parquet_dataset(dataset):
table_name = dataset.split("/")[0].split(".")[1]
parquet_dir = f"{DB_PATH}/{dataset}"
connection = connection_pool.get_connection()
cursor = connection.cursor()
try:
# Process a parquet file
start_time = time.time()
logger.info(f"Start processing PARQUET {parquet_dir}")
result_df = process_one_parquet_dataset(parquet_dir)
end_time = time.time()
logger.info(f"Done processing PARQUET {parquet_dir} in {end_time - start_time} seconds")
# Insert the data to MySQL tables
start_time = time.time()
logger.info(f"Start inserting data to: {table_name}")
cursor.execute("START TRANSACTION")
insert_df_to_db(cursor, table_name, result_df)
cursor.execute("COMMIT")
end_time = time.time()
logger.info(f"Done inserting data to: {table_name} in {end_time - start_time} seconds")
except Exception as ex:
logger.error(f"Error Processing PARQUET {parquet_dir}. Rolling back transaction.")
connection.rollback()
finally:
cursor.close()
connection_pool.release_connection(connection)
def main():
parquet_datasets = sorted([f"{folder}/1" for folder in os.listdir(DB_PATH)])
logger.info("Start processing Data.")
for i, dataset in enumerate(parquet_datasets):
logger.info(f"Processing dataset {i}")
process_parquet_dataset(dataset)
logger.info("Done processing Data.")
def multi_process_main():
parquet_datasets = sorted([f"{folder}/1" for folder in os.listdir(DB_PATH)])
logger.info("Start concurrent execution of processes.")
with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_CORES) as executor:
executor.map(process_parquet_dataset, parquet_datasets)
logger.info("Done concurrent execution of processes.")
if __name__ == '__main__':
# No concurrency
#main()
multi_process_main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment