Writes CSV data into multiple output shards, grouping rows by keys. Output shards are written to Google Cloud Storage.
class BatchedGcsCsvShardFileWriter(object): | |
"""Writes CSV data into multiple output shards, grouping rows by keys. | |
This class is a context manager, which closes all shards upon exit. | |
Say you are writing a lot of CSV data, like: | |
[0, "Bakery"], | |
[2, "Francisco"], | |
[3, "Matt"], | |
[0, "Matt"], | |
[0, "Sam"], | |
[2, "Dude"], | |
And you want to write this to 2 output shards, but you need to have all | |
rows that share the same value in the first column to the same shard, like: | |
/bucket/shard_0.csv: | |
[0, "Matt"], | |
[0, "Sam"], | |
[0, "Bakery"], | |
[3, "Matt"], | |
/bucket/shard_1.csv: | |
[2, "Dude"], | |
[2, "Francisco"], | |
You can do so pretty easily with this class. For example: | |
writer = BatchedGcsCsvShardFileWriter("bucket/shard_%(num)s.csv", | |
key_columns=[0], num_shards=2) | |
with writer: | |
writer.writerow([0, "Bakery"]) | |
writer.writerow([2, "Francisco"]) | |
writer.writerow([3, "Matt"]) | |
writer.writerow([0, "Matt"]) | |
writer.writerow([0, "Sam"]) | |
writer.writerow([2, "Dude"]) | |
NOTE: That the output sharding may not be uniform if the distribution of | |
values in the key_columns is not. | |
""" | |
def __init__(self, shard_path_pattern, key_columns, num_shards): | |
"""Constructor. | |
Arguments: | |
shard_path_pattern - The naming pattern for the output shards. | |
Must include "%(num)s" to indicate the shard. | |
key_columns - The columns that are used to determine the shard to | |
write to. | |
num_shards - The number of output shards to write. | |
""" | |
self.shard_path_pattern = shard_path_pattern | |
self.key_columns = key_columns | |
self.num_shards = num_shards | |
# An array of objects. Initialized in __enter__(). | |
self.shard_files = [None] * self.num_shards | |
def _get_shard(self, row): | |
"""Returns an integer in [0,self.num_shards]""" | |
key = tuple([v for i, v in enumerate(row) if i in self.key_columns]) | |
# We must use hashlib, because hash() is unreliable | |
# http://stackoverflow.com/questions/793761/built-in-python-hash-function | |
big_hash = int(hashlib.md5(repr(key)).hexdigest(), 16) | |
return big_hash % self.num_shards | |
def writerow(self, row): | |
shard = self.shard_files[self._get_shard(row)] | |
shard['rows_written'] += 1 | |
if shard['rows_written'] % 100000 == 0: | |
logging.info("Written %d rows to shard %s", | |
shard['rows_written'], shard['path']) | |
return shard['csv_writer'].writerow(row) | |
def __enter__(self): | |
retry_params = gcs.RetryParams(urlfetch_timeout=60, | |
max_retry_period=60 * 5.0) | |
for idx in range(self.num_shards): | |
path = self.shard_path_pattern % {'num': idx} | |
gcs_file = gcs.open(path, "w", retry_params=retry_params) | |
csv_writer = csv.writer(gcs_file) | |
self.shard_files[idx] = { | |
'path': path, | |
'gcs_file': gcs_file, | |
'csv_writer': csv_writer, | |
'rows_written': 0, | |
} | |
def __exit__(self, exception_type, exception_value, exception_traceback): | |
for metadata in self.shard_files: | |
try: | |
logging.info("Closing shard %s", metadata['path']) | |
metadata['gcs_file'].close() | |
except: | |
logging.exception("Ignoring exception closing GCS file") | |
# TODO(mattfaus): Re-raise exceptions? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment