Skip to content

Instantly share code, notes, and snippets.

@shaunagm
Created June 6, 2023 19:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shaunagm/1dbd785ca98447b3df69adb3856f1f5c to your computer and use it in GitHub Desktop.
Save shaunagm/1dbd785ca98447b3df69adb3856f1f5c to your computer and use it in GitHub Desktop.
"accumulator" - internal helper script from TMC that may be usefully moved into Parsons
import petl
from parsons import Redshift, Table
from parsons.utilities.files import close_temp_file
from canalespy import logger
VALID_STRATEGIES = ['copy', 'local_file', 'upsert']
def get_key(dest, strategy):
return f'{dest}__{strategy}'
class Accumulator:
"""
Class for accumulating Parsons Table data to push to a database.
`Args:`
db_client: object
A Parsons DB Client (anything with a copy and upsert method)
chunk_size: int
The number of rows in a Table before the table should be written to the destination
concat_threshold: int
The number of times the accumulator will concat tables together before flushing the
data for a destination.
"""
def __init__(self, db_client=None, chunk_size=20_000, concat_threshold=100):
self.db_client = db_client
self.chunk_size = chunk_size
self.accumulated = {}
self.concat_threshold = concat_threshold
def __enter__(self):
"""
Enter magic method for starting the context manager.
No-op.
"""
return self
def __exit__(self, exc_type, exc_value, traceback):
"""
Exit magic method for leaving the context manager.
Flushes all tables currently in the ``Accumulator``.
"""
if not exc_value:
self.flush_all()
else:
logger.warning('Exception raised, accumulator will not flush data')
def flush_all(self):
"""
Flush all of our accumulated data to its destinations.
"""
logger.debug('Flushing all data in accumulator')
for data in self.accumulated.values():
logger.debug('Flushing data for %s', data['dest'])
if data['current_rows'] > 0:
logger.debug('Flushing %s rows for %s', data['current_rows'], data['dest'])
self.flush(data['dest'], data['strategy'])
def flush(self, dest, strategy=None):
"""
Flush accumulated data for a specific destination.
`Args:`
dest: str
The name of the destination (eg a DB table) to flush data for
strategy: str
`Optional`; The strategy to flush for (eg ``copy``); if not provided, flush
all strategies for a given destination
"""
# If no strategy is provided, do all of them
if not strategy:
found_something = False
logger.debug('No strategy provided to flush method, flushing all data for destination '
'%s', dest)
for strategy in VALID_STRATEGIES:
try:
self.flush(dest, strategy)
found_something = True
except KeyError:
pass
# If there's no data in here for the table, then raise an error
if not found_something:
raise KeyError(f'No data for DB table {dest}')
return
key = get_key(dest, strategy)
if key not in self.accumulated:
raise KeyError(f'No data for DB table {dest} and strategy {strategy}')
data = self.accumulated[key]
# If the Parsons Table is empty, just return
if data['current_rows'] == 0:
return
if data['strategy'] == 'copy':
logger.info(f'Copying {data["current_rows"]} rows to {data["dest"]}')
self.db_client.copy(data['tbl'], data['dest'], **data['kwargs'])
elif data['strategy'] == 'upsert':
logger.debug(f'Deduplicating {data["current_rows"]} before upsert to {data["dest"]}')
tbl = data['tbl']
primary_key = data['kwargs']['primary_key']
tbl.table = petl.transform.dedup.distinct(tbl.table, key=primary_key)
logger.info(f'Upserting {data["current_rows"]} rows to {data["dest"]}')
self.db_client.upsert(tbl, data['dest'], **data['kwargs'])
# Commenting out until while we dig in on fixes
# if type(self.db_client) == Redshift:
# logger.info(f'Running vacuum sort on {data["dest"]}')
# self.db_client.query(f'vacuum sort only {data["dest"]}')
elif data['strategy'] == 'local_file':
logger.info(f'Writing {data["current_rows"]} rows to {data["dest"]}')
data['tbl'].to_csv(data['dest'])
else:
raise ValueError(f'Invalid strategy: {data["strategy"]}')
data['written_rows'] += data['current_rows']
# Reset our Parsons Table
data['tbl'] = Table()
data['concat_count'] = 0
data['current_rows'] = 0
logger.debug(f'Closing {len(data["files"])} files')
for file in data['files']:
close_temp_file(file)
# Reset our temp files
data['files'] = []
def get_destinations(self):
"""
Get all destinations that we have accumulated data for.
"""
return [data['dest'] for data in self.accumulated.values()]
def get_total_rows(self, destination, strategy=None):
"""
Get the total number of rows accumulated for a destination.
`Args:`
dest: str
The name of the destination (eg a DB table) to get counts for
strategy: str
`Optional`; The strategy to get counts for (eg ``copy``); if not provided, get
counts for all strategies for a given destination
`Returns:`
Number of rows.
"""
if not strategy:
total = 0
for strat in VALID_STRATEGIES:
total += self.get_written_rows(destination, strategy)
return total
key = get_key(destination, strategy)
if key not in self.accumulated:
return 0
return self.accumulated[key]['total_rows']
def get_written_rows(self, destination, strategy=None):
"""
Get the number of rows written to a destination.
`Args:`
dest: str
The name of the destination (eg a DB table) to get counts for
strategy: str
`Optional`; The strategy to get counts for (eg ``copy``); if not provided, get
counts for all strategies for a given destination
`Returns:`
Number of rows.
"""
if not strategy:
total = 0
for strat in VALID_STRATEGIES:
total += self.get_written_rows(destination, strat)
return total
key = get_key(destination, strategy)
if key not in self.accumulated:
return 0
return self.accumulated[key]['written_rows']
def upsert(self, parsons_tbl, db_tbl, primary_key, underlying_file=None, **kwargs):
"""
Accumulate data for upserting to a database table.
"""
if not self.db_client:
raise ValueError('Cannot upsert without a database client')
strategy = 'upsert'
kwargs['primary_key'] = primary_key
return self.accumulate(parsons_tbl, db_tbl, strategy, underlying_file, **kwargs)
def copy(self, parsons_tbl, db_tbl, underlying_file=None, **kwargs):
"""
Accumulate data for copying to a database table.
"""
if not self.db_client:
raise ValueError('Cannot copy without a database client')
strategy = 'copy'
return self.accumulate(parsons_tbl, db_tbl, strategy, underlying_file, **kwargs)
def to_csv(self, parsons_tbl, local_file, underlying_file=None):
"""
Accumulate data for writing to a local disk as a CSV file.
"""
return self.accumulate(parsons_tbl, local_file, 'local_file', underlying_file)
def accumulate(self, parsons_tbl, dest, strategy='copy', underlying_file=None, **kwargs):
"""
Accumulate data.
"""
if strategy not in VALID_STRATEGIES:
raise ValueError(f'Invalid strategy: {strategy}')
key = get_key(dest, strategy)
if key not in self.accumulated:
total_rows = parsons_tbl.num_rows
self.accumulated[key] = {
'tbl': parsons_tbl,
'strategy': strategy,
'dest': dest,
'files': [],
'kwargs': kwargs,
'total_rows': total_rows,
'concat_count': 0,
'written_rows': 0,
'current_rows': total_rows,
}
else:
tbl_rows = parsons_tbl.num_rows
self.accumulated[key]['tbl'].concat(parsons_tbl)
self.accumulated[key]['total_rows'] += tbl_rows
self.accumulated[key]['current_rows'] += tbl_rows
self.accumulated[key]['concat_count'] += 1
if underlying_file:
self.accumulated[key]['files'].append(underlying_file)
if (self.accumulated[key]['current_rows'] > self.chunk_size
or self.accumulated[key]['concat_count'] > self.concat_threshold):
logger.info(f'Table for destination {self.accumulated[key]["dest"]} '
f'has {self.accumulated[key]["current_rows"]} rows, which is greater '
f'than our threshold of {self.chunk_size}; flushing')
self.flush(dest, strategy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment