Skip to content

Instantly share code, notes, and snippets.

@tomdean
Last active September 9, 2023 23:56
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save tomdean/a58f5e7bb1ad87a5af50678d41773669 to your computer and use it in GitHub Desktop.
Save tomdean/a58f5e7bb1ad87a5af50678d41773669 to your computer and use it in GitHub Desktop.
from collections import OrderedDict
import datetime
from typing import Iterator, List, Sized, Union
import numpy as np
import pandas as pd
from psycopg2.extensions import QuotedString
from sqlalchemy import and_, exists, MetaData, Table, Column as SAColumn
import logging
log = logging.getLogger()
#: Number of rows to insert per batch transaction
BATCH_SIZE = 50000
def to_python_type(column):
if str(column.type) == 'UUID':
return str
return column.type.python_type
def to_str(val):
if isinstance(val, bytes):
val = val.decode('utf-8')
return QuotedString((str(val) or "").encode('utf-8')).getquoted().decode('utf-8')
def to_date(value):
dt = pd.to_datetime(value)
return to_str(dt.date().isoformat() if dt else None)
def to_datetime(value):
dt = pd.to_datetime(value)
return to_str(dt.to_pydatetime().isoformat() if dt else None)
class Column:
def __init__(self, name: str, python_type: type):
"""Wrapper to cast Python values for use in ad-hoc SQL.
Example::
columns = [Column('id', int), Column('amount', float)]
:param name: Name of the column.
:param python_type: Python type e.g. int, str, float.
"""
self.name = name
self.python_type = python_type
def escape(self, value) -> str:
"""Escape a value for use in a Postgres ad-hoc SQL statement."""
if pd.isnull(value):
return 'NULL'
func = self.python_type
if isinstance(value, (datetime.datetime, np.datetime64, pd.Timestamp)) or \
func in (datetime.date, datetime.datetime):
func = to_datetime
elif isinstance(value, datetime.date):
print(self.name)
func = to_date
elif issubclass(self.python_type, str):
func = to_str
return func(value)
def __eq__(self, b):
return self.name == b.name and self.python_type == b.python_type
def __repr__(self):
return '{}<name={}, type={}>'.format(
self.__class__.__name__, self.name, self.python_type.__name__)
class ColumnCollection(OrderedDict):
def __init__(self, columns: list):
super().__init__([(c.name, c) for c in columns])
class BulkInsertFromIterator:
def __init__(self, table, data: Iterator, columns: list,
batch_size: int=BATCH_SIZE, header: bool=False):
"""Bulk insert into Postgres from an iterator in fixed-size batches.
Example::
bulk = BulkInsertFromIterator(
'table.name',
iter([[1, 'Python'], [2, 'PyPy', 3]]),
[Column('id', int), Column('name', str)]
)
bulk.execute(db.engine.raw_connection)
:param table: Name of the table.
:param data: Iterable containing the data to insert.
:param columns: List of :class:`Column` objects.
:param batch_size: Rows to insert per batch.
:param header: True if the first row is a header.
"""
self.table = table
self.data = data
self.columns = columns
self.batch_size = batch_size
self.header = header
if isinstance(self.data, list):
self.data = iter(self.data)
if not isinstance(self.data, Iterator):
raise TypeError('Expected Iterator, got {}'.format(
self.data.__class__))
if not self.columns:
raise ValueError('Columns cannot be empty')
if isinstance(self.columns[0], tuple):
self.columns = [Column(*c) for c in self.columns]
def batch_execute(self, conn):
"""Insert data in batches of `batch_size`.
:param conn: A DB API 2.0 connection object
"""
def batches(data, batch_size):
"""Return batches of length `batch_size` from any object that
supports iteration without knowing length."""
rv = []
for idx, line in enumerate(data):
if idx != 0 and idx % batch_size == 0:
yield rv
rv = []
rv.append(line)
yield rv
columns = ColumnCollection(self.columns)
if self.header:
self.columns = [columns.get(h) for h in next(self.data)]
columns = ColumnCollection(self.columns)
total = 0
query = BulkInsertQuery(self.table, columns)
for batch in batches(self.data, self.batch_size):
total += query.execute(conn, batch) or 0
yield total
def execute(self, conn):
"""Execute all batches."""
return max(list(self.batch_execute(conn)))
class BulkInsertQuery:
def __init__(self, table: str, columns):
"""Execute a multi-row INSERT statement.
This does not take advantage of parameterized queries, but escapes
string values manually in :class:`Column`.
:param table: Name of the table being inserted into.
:param columns: Columns required for type coercion.
"""
self.table = table
self.columns = columns
self.query = 'INSERT INTO {} ({}) VALUES '.format(
table, ', '.join([c for c in columns]))
def execute(self, conn, rows: list) -> int:
"""Execute a single multi-row INSERT for `rows`.
:param conn: Function that returns a database connection
:param rows: List of tuples in the same order as :attr:`columns`.
"""
if not len(rows):
raise ValueError('No data provided')
if len(self.columns) != len(rows[0]):
raise ValueError('Expecting {} columns, found {}'.format(
len(self.columns), len(rows[0])))
# Clone the data
rows = list(rows)
conn = conn()
cursor = conn.cursor()
try:
cursor.execute(self.query + ', '.join(self.escape_rows(rows)))
conn.commit()
finally:
cursor.close()
conn.close()
return len(rows)
def escape_rows(self, rows: list):
"""Escape values for use in non-parameterized SQL queries.
:param rows: List of values to escape.
"""
def to_tuple(values):
rv = []
for column in self.columns:
rv.append(self.columns.get(column).escape(values[column]))
return tuple(rv)
for idx, row in enumerate(rows):
data = to_tuple(row)
rows[idx] = '({})'.format(', '.join(map(str, data)))
return rows
def as_columns(columns) -> List[Column]:
rv = []
for column in columns:
if isinstance(column, Column):
rv.append(column)
if isinstance(column, tuple):
rv.append(Column(*column))
if isinstance(column, str):
rv.append(Column(column, str))
if isinstance(column, SAColumn):
rv.append(Column(column.name, to_python_type(column)))
return rv
def from_sqlalchemy_table(table: Table, data: Iterator, columns: List[str],
batch_size: int=BATCH_SIZE) -> BulkInsertFromIterator:
"""Return a :class:`BulkInsertFromIterator` based on the metadata
of a SQLAlchemy table.
Example::
batch = from_sqlalchemy_table(
Rating.__table__,
data,
['rating_id', 'repo_id', 'login_id', 'rating']
)
:param table: A :class:`sqlalchemy.Table` instance.
:param data: An iterator.
:param columns: List of column names to use.
:param batch_size: Number of rows to insert per SQL statement
"""
if not isinstance(table, Table):
raise TypeError('Expected sqlalchemy.Table, got {}'.format(table))
wrapped = []
for name in columns:
column = table.columns.get(name)
wrapped.append(Column(str(column.name), to_python_type(column)))
return BulkInsertFromIterator(table, data, wrapped, batch_size, False)
def create_staging_table(engine, table: Table) -> Table:
"""Create a copy of the table to store intermediary results.
Primary keys and other unique constraints are removed.
:param engine: SQLAlchemy engine
:param table: SQLAlchemy table to clone schema from
"""
table = table.tometadata(MetaData(), schema="staging")
# Remove constraints to prevent errors
for column in table.columns:
if column.primary_key:
column.primary_key = False
table.indexes = []
table.constraints = []
table.primary_key = None
log.info('Creating staging table {}.{}'.format(table.schema, table.name))
table.drop(engine, checkfirst=True)
table.create(engine)
return table
def stage_and_merge(engine, target: Table, rows: Union[Iterator, Sized]):
"""Write data to an intermediary staging table before adding to `target`.
:param engine: A instance of :class:`sqlalchemy.engine.Engine`
:param target: Table to write the results
:param rows: Data to write
"""
if isinstance(rows, Sized) and len(rows) > 0:
log.info('Staging Rows: {}'.format(len(rows)))
# Drop & recreate the staging table
source = create_staging_table(engine, target)
# Insert data into a temporary staging table prior to copying to the target
try:
bulk = BulkInsertFromIterator(source, rows, as_columns(source.columns))
bulk.execute(engine.raw_connection)
keys = filter(lambda c: c.primary_key, target.columns)
where = map(lambda c: source.c[c.name] == target.c[c.name], keys)
# Only insert rows that do not exist in the target table
query = source.select().distinct().where(~exists().where(and_(*where)))
result = engine.execute(target.insert().from_select(source.c, query))
log.info('Updated Row Count: {}'.format(result.rowcount))
finally:
source.drop(engine)
def stage_and_replace(engine, target: Table, rows: Union[Iterator, Sized]):
if isinstance(rows, Sized) and len(rows) > 0:
log.info('Staging Rows: {}'.format(len(rows)))
# Drop & recreate the staging table
source = create_staging_table(engine, target)
# Insert data into a temporary staging table prior to copying to the target
try:
bulk = BulkInsertFromIterator(source, rows, as_columns(source.columns))
bulk.execute(engine.raw_connection)
# Re-create the target table prior to inserting
if target.exists(engine):
target.drop(engine)
target.create(engine)
query = source.select().distinct()
result = engine.execute(target.insert().from_select(source.c, query))
log.info('Updated Row Count: {}'.format(result.rowcount))
finally:
source.drop(engine)
def determine_columns(table: Table, rows):
columns = as_columns(table.columns)
if not isinstance(rows[0], dict):
return columns
keys = rows[0].keys()
return list(filter(lambda c: c.name in keys, columns))
def stage_and_update(engine, target: Table, rows: Union[Iterator, Sized]):
"""Write data to an intermediary staging table before adding to `target`.
:param engine: A instance of :class:`sqlalchemy.engine.Engine`
:param target: Table to write the results
:param rows: Data to write
"""
if isinstance(rows, Sized) and len(rows) > 0:
log.info('Staging Rows: {}'.format(len(rows)))
# Drop & recreate the staging table
source = create_staging_table(engine, target)
columns = determine_columns(source, rows)
# Insert data into a temporary staging table prior to copying to the target
try:
bulk = BulkInsertFromIterator(source, rows, columns)
bulk.execute(engine.raw_connection)
keys = filter(lambda c: c.primary_key, target.columns)
where = map(lambda c: source.c[c.name] == target.c[c.name], keys)
# Delete from target table before appending
delete = target.delete().where(exists().where(and_(*where)))
engine.execute(delete)
# Copy rows from staging table to target
insert = source.select()
result = engine.execute(target.insert().from_select(source.c, insert))
log.info('Updated Row Count: {}'.format(result.rowcount))
finally:
source.drop(engine)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment