Skip to content

Instantly share code, notes, and snippets.

@MichaelCurrie
Last active September 7, 2023 23:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MichaelCurrie/b5ab978c0c0c1860bb5e75676775b43b to your computer and use it in GitHub Desktop.
Save MichaelCurrie/b5ab978c0c0c1860bb5e75676775b43b to your computer and use it in GitHub Desktop.
Fast pandas DataFrame read/write to mariadb
#!/usr/bin/env python3
"""
Drop-in replacement for pandas.DataFrame methods for reading and writing to database:
pandas.DataFrame.to_sql
pandas.read_sql_table
For some reason the current (Jan 2023) implementation in Pandas
using sqlalchemy is really slow. These methods are ~300x faster.
NOTE: Works only with mariadb's connector for now (pip3 install mariadb)
https://mariadb.com/resources/blog/how-to-connect-python-programs-to-mariadb/
"""
import pandas as pd
import warnings
class DataFrameFast(pd.DataFrame):
def to_sql(self, name, con, if_exists='append', index=False, *args, **kwargs):
if not if_exists in ['replace', 'append']:
raise AssertionError(
"not if_exists in ['replace', 'append'] is not yet impemented")
# Truncate database table
# NOTE: Users may have to perform to_sql in the correct
# sequence to avoid causing foreign key errors with this step
if if_exists == 'replace':
with con.cursor() as cursor:
r = cursor.execute(f"TRUNCATE TABLE {name}")
# Prepare an INSERT which will populate the real mariadb table with df's data
# INSERT INTO table(c1,c2,...) VALUES (v11,v12,...), ... (vnn,vn2,...);
# If index, then we also want the index inserted
cols = [self.index.name] * index + list(self.columns)
cmd = (f"INSERT INTO {name} ({', '.join(cols)})"
f" VALUES ({', '.join(['?']*len(cols))})")
table_data = list(self.itertuples(index=index))
# Replace nan with None for SQL to accept it.
table_data = [
[None if pd.isnull(value) else value for value in sublist]
for sublist in table_data]
if len(table_data) == 0:
pass
else:
with con.cursor() as cursor:
cursor.executemany(cmd, table_data)
def column_info(self):
""" Returns the column information.
Parameters:
table_name: string. If None, returns column info for ALL tables.
"""
clauses = [f"TABLE_SCHEMA = '{self.database}'"]
if not table_name is None:
clauses.append(f"TABLE_NAME = '{table_name}'")
with con.cursor() as cursor:
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS "
f"WHERE {'AND '.join(clauses)};")
records0 = custor.fetchall()
return pd.DataFrame(r)
def get_data(con, query):
with con.cursor() as cursor:
cursor.execute(query)
records0 = cursor.fetchall()
# Get the field names
fields = cursor.description
# Get a list of dicts with proper field names
# (i.e. records in the pandas sense)
return [
{fields[i][0]:field_value for i, field_value in enumerate(v)}
for v in records0]
def read_sql_table(name, con, *args, **kwargs):
""" A drop-in replacement for pd.read_sql_table
"""
records = get_data(con, f"SELECT * FROM {name};")
if len(records) > 0:
df = DataFrameFast.from_records(records)
else:
if name.count('.') == 1:
table_schema, table_name = name.split('.')
query = (f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS "
f"WHERE TABLE_SCHEMA = '{table_schema}' AND "
f"TABLE_NAME = '{table_name}';")
else:
warnings.warn("Note: this assumes the table name "
f"{name} is unique across all databases")
query = (f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS "
f"WHERE TABLE_NAME = '{name}';")
# Make an empty dataframe with the right column names
column_info = get_data(con, query)
columns = list(pd.DataFrame(column_info)['COLUMN_NAME'])
df = DataFrameFast(columns=columns)
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment