Create a gist now

Instantly share code, notes, and snippets.

Writes CSV data into multiple output shards, grouping rows by keys. Output shards are written to Google Cloud Storage.
class BatchedGcsCsvShardFileWriter(object):
"""Writes CSV data into multiple output shards, grouping rows by keys.
This class is a context manager, which closes all shards upon exit.
Say you are writing a lot of CSV data, like:
[0, "Bakery"],
[2, "Francisco"],
[3, "Matt"],
[0, "Matt"],
[0, "Sam"],
[2, "Dude"],
And you want to write this to 2 output shards, but you need to have all
rows that share the same value in the first column to the same shard, like:
/bucket/shard_0.csv:
[0, "Matt"],
[0, "Sam"],
[0, "Bakery"],
[3, "Matt"],
/bucket/shard_1.csv:
[2, "Dude"],
[2, "Francisco"],
You can do so pretty easily with this class. For example:
writer = BatchedGcsCsvShardFileWriter("bucket/shard_%(num)s.csv",
key_columns=[0], num_shards=2)
with writer:
writer.writerow([0, "Bakery"])
writer.writerow([2, "Francisco"])
writer.writerow([3, "Matt"])
writer.writerow([0, "Matt"])
writer.writerow([0, "Sam"])
writer.writerow([2, "Dude"])
NOTE: That the output sharding may not be uniform if the distribution of
values in the key_columns is not.
"""
def __init__(self, shard_path_pattern, key_columns, num_shards):
"""Constructor.
Arguments:
shard_path_pattern - The naming pattern for the output shards.
Must include "%(num)s" to indicate the shard.
key_columns - The columns that are used to determine the shard to
write to.
num_shards - The number of output shards to write.
"""
self.shard_path_pattern = shard_path_pattern
self.key_columns = key_columns
self.num_shards = num_shards
# An array of objects. Initialized in __enter__().
self.shard_files = [None] * self.num_shards
def _get_shard(self, row):
"""Returns an integer in [0,self.num_shards]"""
key = tuple([v for i, v in enumerate(row) if i in self.key_columns])
# We must use hashlib, because hash() is unreliable
# http://stackoverflow.com/questions/793761/built-in-python-hash-function
big_hash = int(hashlib.md5(repr(key)).hexdigest(), 16)
return big_hash % self.num_shards
def writerow(self, row):
shard = self.shard_files[self._get_shard(row)]
shard['rows_written'] += 1
if shard['rows_written'] % 100000 == 0:
logging.info("Written %d rows to shard %s",
shard['rows_written'], shard['path'])
return shard['csv_writer'].writerow(row)
def __enter__(self):
retry_params = gcs.RetryParams(urlfetch_timeout=60,
max_retry_period=60 * 5.0)
for idx in range(self.num_shards):
path = self.shard_path_pattern % {'num': idx}
gcs_file = gcs.open(path, "w", retry_params=retry_params)
csv_writer = csv.writer(gcs_file)
self.shard_files[idx] = {
'path': path,
'gcs_file': gcs_file,
'csv_writer': csv_writer,
'rows_written': 0,
}
def __exit__(self, exception_type, exception_value, exception_traceback):
for metadata in self.shard_files:
try:
logging.info("Closing shard %s", metadata['path'])
metadata['gcs_file'].close()
except:
logging.exception("Ignoring exception closing GCS file")
# TODO(mattfaus): Re-raise exceptions?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment