|
class SortedGcsCsvShardFileMergeReader(object): |
|
"""Merges several sorted .csv files stored on GCS. |
|
|
|
This class is both an iterator and a context manager. |
|
|
|
Let's say there are 2 .csv files stored on GCS, with contents like: |
|
|
|
/bucket/file_1.csv: |
|
[0, "Matt"], |
|
[0, "Sam"], |
|
[2, "Dude"], |
|
|
|
/bucket/file_2.csv: |
|
[0, "Bakery"], |
|
[2, "Francisco"], |
|
[3, "Matt"], |
|
|
|
These files are already sorted by our key_columns = [0]. We want to read |
|
all of the rows with the same key_column value at once, regardless of |
|
which files those rows reside in. This class does that by opening |
|
handles to all of the files and picking off the top rows from each of the |
|
files as long as they share the same key. |
|
|
|
For example: |
|
|
|
merge_reader = SortedGcsCsvShardFileMergeReader("/bucket", "file.*", [0]) |
|
with merge_reader: |
|
for row in merge_reader: |
|
# Returns rows in totally-sorted order, like: |
|
# [0, "Matt"] (from file_1.csv) |
|
# [0, "Sam"] (from file_1.csv) |
|
# [0, "Bakery"] (from file_2.csv) |
|
# [2, "Dude"] (from file_1.csv) |
|
# [2, "Francisco"] (from file_2.csv) |
|
# [3, "Matt"] (from file_2.csv) |
|
|
|
The merge columns must be comparable, so that this class can return the |
|
results in totally sorted order. |
|
|
|
NOTE: All shards must have at least one row. |
|
|
|
To do this, we build up a somewhat complex instance object to keep track |
|
of the shards and their current statuses. self.shard_files has this format: |
|
|
|
{ |
|
"shard_file_path_1": { |
|
"gcs_file": handle to the gcs file stream |
|
"csv_reader": csv_reader reading from gcs_file |
|
"head_key": the key tuple of the head |
|
"head": the most recently read from the csv_reader |
|
"rows_returned": a running count of the rows returned |
|
} |
|
... |
|
"shard_file_path_2" : {} |
|
... |
|
"shard_file_path_n" : {} |
|
} |
|
|
|
The self.shard_files object is pruned as the shards are exhausted. |
|
""" |
|
|
|
def __init__(self, input_bucket, input_pattern, merge_columns): |
|
"""Constructor. |
|
|
|
Arguments: |
|
input_bucket - The bucket to read from. |
|
input_pattern - The file pattern to read from. |
|
merge_columns - The columns used for determining row merging. |
|
""" |
|
shard_paths = get_shard_files( |
|
input_bucket, input_pattern, full_path=True) |
|
|
|
if len(shard_paths) == 0: |
|
raise ValueError("Could not find any shard files.") |
|
|
|
logging.info("Merge-reading: %s", shard_paths) |
|
|
|
self.shard_files = { |
|
shard_path: {} for shard_path in shard_paths |
|
} |
|
|
|
self.merge_columns = merge_columns |
|
self.current_key = None |
|
self.current_shard_path = None |
|
|
|
def __enter__(self): |
|
"""Open handles to all of the shard files, read the first row""" |
|
retry_params = gcs.RetryParams(urlfetch_timeout=60, |
|
max_retry_period=60 * 60.0) |
|
|
|
for shard_path in self.shard_files.keys(): |
|
gcs_file = gcs.open(shard_path, "r", retry_params=retry_params) |
|
csv_reader = csv.reader(gcs_file) |
|
head = csv_reader.next() # Assumes there is at least 1 row |
|
|
|
self.shard_files[shard_path] = { |
|
"gcs_file": gcs_file, |
|
"csv_reader": csv_reader, |
|
"head": head, |
|
"head_key": self._get_key(head), |
|
"rows_returned": 0, |
|
} |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def _get_key(self, row): |
|
return tuple([v for i, v in enumerate(row) if i in self.merge_columns]) |
|
|
|
def _advance_shard(self, shard_path): |
|
"""Update the shard's head values, return the current head.""" |
|
# Save the head, to return later |
|
metadata = self.shard_files[shard_path] |
|
row = metadata["head"] |
|
|
|
try: |
|
new_head = metadata["csv_reader"].next() |
|
metadata["head"] = new_head |
|
metadata["head_key"] = self._get_key(new_head) |
|
metadata["rows_returned"] += 1 |
|
except StopIteration: |
|
self._close_shard(shard_path) |
|
self.current_shard_path = None |
|
|
|
return row |
|
|
|
def _find_next_key(self): |
|
"""Find the next key to start merge reading. |
|
|
|
We must always choose the next "lowest" key value to be the next key |
|
to read. Not all shards have all keys, so we must do this to ensure |
|
that we do not mis-order the rows in the final output. |
|
""" |
|
lowest_key_value = None |
|
lowest_shard_path = None |
|
|
|
for path, metadata in self.shard_files.iteritems(): |
|
if (metadata["head_key"] < lowest_key_value |
|
or lowest_key_value is None): |
|
lowest_key_value = metadata["head_key"] |
|
lowest_shard_path = path |
|
|
|
return lowest_key_value, lowest_shard_path |
|
|
|
def next(self): |
|
# We've exhausted all rows from all shards |
|
if len(self.shard_files) == 0: |
|
raise StopIteration |
|
|
|
# This happens at the very beginning, or after exhausting a shard |
|
if self.current_shard_path is None: |
|
self.current_shard_path = self.shard_files.keys()[0] |
|
|
|
# This happens at the very beginning, or after exhausting a key |
|
if self.current_key is None: |
|
self.current_key, self.current_shard_path = self._find_next_key() |
|
|
|
# If the current shard has more, just return that |
|
if (self.shard_files[self.current_shard_path]["head_key"] |
|
== self.current_key): |
|
return self._advance_shard(self.current_shard_path) |
|
|
|
# Iterate over all shard_files |
|
for path, metadata in self.shard_files.iteritems(): |
|
if metadata["head_key"] == self.current_key: |
|
self.current_shard_path = path |
|
return self._advance_shard(path) |
|
|
|
# We didn't find any rows for current_key, so start over |
|
self.current_key = None |
|
return self.next() |
|
|
|
def _close_shard(self, shard_path): |
|
"""Close the shard, remove it from the shard_files collection.""" |
|
if shard_path not in self.shard_files: |
|
return |
|
|
|
metadata = self.shard_files[shard_path] |
|
logging.info( |
|
"Closing shard after reading %d rows. %d shards remain. %s", |
|
metadata["rows_returned"], len(self.shard_files) - 1, shard_path) |
|
try: |
|
metadata["gcs_file"].close() |
|
except Exception: |
|
logging.exception("Ignoring exception from %s", shard_path) |
|
|
|
del self.shard_files[shard_path] |
|
|
|
def __exit__(self, exception_type, exception_value, exception_traceback): |
|
"""Closes all shards.""" |
|
shard_paths = self.shard_files.keys() |
|
|
|
for path in shard_paths: |
|
self._close_shard(path) |
|
|
|
# TODO(mattfaus): Re-raise any exception passed here? |
|
|
|
def get_shard_files(bucket, filename_prefix, full_path=False): |
|
"""Find files in a bucket, matching a filename prefix.""" |
|
if not bucket.startswith("/"): |
|
bucket = "/%s" % bucket |
|
|
|
if bucket.endswith("/"): |
|
bucket = bucket[:-1] |
|
|
|
retry_params = gcs.RetryParams(urlfetch_timeout=60, |
|
max_retry_period=60 * 30.0) |
|
|
|
filename_prefix = bucket + "/" + filename_prefix |
|
|
|
shard_files = [] |
|
for file in gcs.listbucket(filename_prefix, retry_params=retry_params): |
|
path = file.filename |
|
if not full_path: |
|
# Remove the "/<bucket>" + "/" prefix |
|
path = path[len(bucket) + 1:] |
|
|
|
shard_files.append(path) |
|
|
|
# Sort for deterministic ordering |
|
shard_files.sort() |
|
return shard_files |
|
|