Created
June 6, 2023 19:05
-
-
Save shaunagm/1dbd785ca98447b3df69adb3856f1f5c to your computer and use it in GitHub Desktop.
"accumulator" - internal helper script from TMC that may be usefully moved into Parsons
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import petl | |
from parsons import Redshift, Table | |
from parsons.utilities.files import close_temp_file | |
from canalespy import logger | |
VALID_STRATEGIES = ['copy', 'local_file', 'upsert'] | |
def get_key(dest, strategy): | |
return f'{dest}__{strategy}' | |
class Accumulator: | |
""" | |
Class for accumulating Parsons Table data to push to a database. | |
`Args:` | |
db_client: object | |
A Parsons DB Client (anything with a copy and upsert method) | |
chunk_size: int | |
The number of rows in a Table before the table should be written to the destination | |
concat_threshold: int | |
The number of times the accumulator will concat tables together before flushing the | |
data for a destination. | |
""" | |
def __init__(self, db_client=None, chunk_size=20_000, concat_threshold=100): | |
self.db_client = db_client | |
self.chunk_size = chunk_size | |
self.accumulated = {} | |
self.concat_threshold = concat_threshold | |
def __enter__(self): | |
""" | |
Enter magic method for starting the context manager. | |
No-op. | |
""" | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
""" | |
Exit magic method for leaving the context manager. | |
Flushes all tables currently in the ``Accumulator``. | |
""" | |
if not exc_value: | |
self.flush_all() | |
else: | |
logger.warning('Exception raised, accumulator will not flush data') | |
def flush_all(self): | |
""" | |
Flush all of our accumulated data to its destinations. | |
""" | |
logger.debug('Flushing all data in accumulator') | |
for data in self.accumulated.values(): | |
logger.debug('Flushing data for %s', data['dest']) | |
if data['current_rows'] > 0: | |
logger.debug('Flushing %s rows for %s', data['current_rows'], data['dest']) | |
self.flush(data['dest'], data['strategy']) | |
def flush(self, dest, strategy=None): | |
""" | |
Flush accumulated data for a specific destination. | |
`Args:` | |
dest: str | |
The name of the destination (eg a DB table) to flush data for | |
strategy: str | |
`Optional`; The strategy to flush for (eg ``copy``); if not provided, flush | |
all strategies for a given destination | |
""" | |
# If no strategy is provided, do all of them | |
if not strategy: | |
found_something = False | |
logger.debug('No strategy provided to flush method, flushing all data for destination ' | |
'%s', dest) | |
for strategy in VALID_STRATEGIES: | |
try: | |
self.flush(dest, strategy) | |
found_something = True | |
except KeyError: | |
pass | |
# If there's no data in here for the table, then raise an error | |
if not found_something: | |
raise KeyError(f'No data for DB table {dest}') | |
return | |
key = get_key(dest, strategy) | |
if key not in self.accumulated: | |
raise KeyError(f'No data for DB table {dest} and strategy {strategy}') | |
data = self.accumulated[key] | |
# If the Parsons Table is empty, just return | |
if data['current_rows'] == 0: | |
return | |
if data['strategy'] == 'copy': | |
logger.info(f'Copying {data["current_rows"]} rows to {data["dest"]}') | |
self.db_client.copy(data['tbl'], data['dest'], **data['kwargs']) | |
elif data['strategy'] == 'upsert': | |
logger.debug(f'Deduplicating {data["current_rows"]} before upsert to {data["dest"]}') | |
tbl = data['tbl'] | |
primary_key = data['kwargs']['primary_key'] | |
tbl.table = petl.transform.dedup.distinct(tbl.table, key=primary_key) | |
logger.info(f'Upserting {data["current_rows"]} rows to {data["dest"]}') | |
self.db_client.upsert(tbl, data['dest'], **data['kwargs']) | |
# Commenting out until while we dig in on fixes | |
# if type(self.db_client) == Redshift: | |
# logger.info(f'Running vacuum sort on {data["dest"]}') | |
# self.db_client.query(f'vacuum sort only {data["dest"]}') | |
elif data['strategy'] == 'local_file': | |
logger.info(f'Writing {data["current_rows"]} rows to {data["dest"]}') | |
data['tbl'].to_csv(data['dest']) | |
else: | |
raise ValueError(f'Invalid strategy: {data["strategy"]}') | |
data['written_rows'] += data['current_rows'] | |
# Reset our Parsons Table | |
data['tbl'] = Table() | |
data['concat_count'] = 0 | |
data['current_rows'] = 0 | |
logger.debug(f'Closing {len(data["files"])} files') | |
for file in data['files']: | |
close_temp_file(file) | |
# Reset our temp files | |
data['files'] = [] | |
def get_destinations(self): | |
""" | |
Get all destinations that we have accumulated data for. | |
""" | |
return [data['dest'] for data in self.accumulated.values()] | |
def get_total_rows(self, destination, strategy=None): | |
""" | |
Get the total number of rows accumulated for a destination. | |
`Args:` | |
dest: str | |
The name of the destination (eg a DB table) to get counts for | |
strategy: str | |
`Optional`; The strategy to get counts for (eg ``copy``); if not provided, get | |
counts for all strategies for a given destination | |
`Returns:` | |
Number of rows. | |
""" | |
if not strategy: | |
total = 0 | |
for strat in VALID_STRATEGIES: | |
total += self.get_written_rows(destination, strategy) | |
return total | |
key = get_key(destination, strategy) | |
if key not in self.accumulated: | |
return 0 | |
return self.accumulated[key]['total_rows'] | |
def get_written_rows(self, destination, strategy=None): | |
""" | |
Get the number of rows written to a destination. | |
`Args:` | |
dest: str | |
The name of the destination (eg a DB table) to get counts for | |
strategy: str | |
`Optional`; The strategy to get counts for (eg ``copy``); if not provided, get | |
counts for all strategies for a given destination | |
`Returns:` | |
Number of rows. | |
""" | |
if not strategy: | |
total = 0 | |
for strat in VALID_STRATEGIES: | |
total += self.get_written_rows(destination, strat) | |
return total | |
key = get_key(destination, strategy) | |
if key not in self.accumulated: | |
return 0 | |
return self.accumulated[key]['written_rows'] | |
def upsert(self, parsons_tbl, db_tbl, primary_key, underlying_file=None, **kwargs): | |
""" | |
Accumulate data for upserting to a database table. | |
""" | |
if not self.db_client: | |
raise ValueError('Cannot upsert without a database client') | |
strategy = 'upsert' | |
kwargs['primary_key'] = primary_key | |
return self.accumulate(parsons_tbl, db_tbl, strategy, underlying_file, **kwargs) | |
def copy(self, parsons_tbl, db_tbl, underlying_file=None, **kwargs): | |
""" | |
Accumulate data for copying to a database table. | |
""" | |
if not self.db_client: | |
raise ValueError('Cannot copy without a database client') | |
strategy = 'copy' | |
return self.accumulate(parsons_tbl, db_tbl, strategy, underlying_file, **kwargs) | |
def to_csv(self, parsons_tbl, local_file, underlying_file=None): | |
""" | |
Accumulate data for writing to a local disk as a CSV file. | |
""" | |
return self.accumulate(parsons_tbl, local_file, 'local_file', underlying_file) | |
def accumulate(self, parsons_tbl, dest, strategy='copy', underlying_file=None, **kwargs): | |
""" | |
Accumulate data. | |
""" | |
if strategy not in VALID_STRATEGIES: | |
raise ValueError(f'Invalid strategy: {strategy}') | |
key = get_key(dest, strategy) | |
if key not in self.accumulated: | |
total_rows = parsons_tbl.num_rows | |
self.accumulated[key] = { | |
'tbl': parsons_tbl, | |
'strategy': strategy, | |
'dest': dest, | |
'files': [], | |
'kwargs': kwargs, | |
'total_rows': total_rows, | |
'concat_count': 0, | |
'written_rows': 0, | |
'current_rows': total_rows, | |
} | |
else: | |
tbl_rows = parsons_tbl.num_rows | |
self.accumulated[key]['tbl'].concat(parsons_tbl) | |
self.accumulated[key]['total_rows'] += tbl_rows | |
self.accumulated[key]['current_rows'] += tbl_rows | |
self.accumulated[key]['concat_count'] += 1 | |
if underlying_file: | |
self.accumulated[key]['files'].append(underlying_file) | |
if (self.accumulated[key]['current_rows'] > self.chunk_size | |
or self.accumulated[key]['concat_count'] > self.concat_threshold): | |
logger.info(f'Table for destination {self.accumulated[key]["dest"]} ' | |
f'has {self.accumulated[key]["current_rows"]} rows, which is greater ' | |
f'than our threshold of {self.chunk_size}; flushing') | |
self.flush(dest, strategy) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment