Create a gist now

Instantly share code, notes, and snippets.

Merge-reads several sorted .csv files stored on Google Cloud Storage.
class SortedGcsCsvShardFileMergeReader(object):
"""Merges several sorted .csv files stored on GCS.
This class is both an iterator and a context manager.
Let's say there are 2 .csv files stored on GCS, with contents like:
/bucket/file_1.csv:
[0, "Matt"],
[0, "Sam"],
[2, "Dude"],
/bucket/file_2.csv:
[0, "Bakery"],
[2, "Francisco"],
[3, "Matt"],
These files are already sorted by our key_columns = [0]. We want to read
all of the rows with the same key_column value at once, regardless of
which files those rows reside in. This class does that by opening
handles to all of the files and picking off the top rows from each of the
files as long as they share the same key.
For example:
merge_reader = SortedGcsCsvShardFileMergeReader("/bucket", "file.*", [0])
with merge_reader:
for row in merge_reader:
# Returns rows in totally-sorted order, like:
# [0, "Matt"] (from file_1.csv)
# [0, "Sam"] (from file_1.csv)
# [0, "Bakery"] (from file_2.csv)
# [2, "Dude"] (from file_1.csv)
# [2, "Francisco"] (from file_2.csv)
# [3, "Matt"] (from file_2.csv)
The merge columns must be comparable, so that this class can return the
results in totally sorted order.
NOTE: All shards must have at least one row.
To do this, we build up a somewhat complex instance object to keep track
of the shards and their current statuses. self.shard_files has this format:
{
"shard_file_path_1": {
"gcs_file": handle to the gcs file stream
"csv_reader": csv_reader reading from gcs_file
"head_key": the key tuple of the head
"head": the most recently read from the csv_reader
"rows_returned": a running count of the rows returned
}
...
"shard_file_path_2" : {}
...
"shard_file_path_n" : {}
}
The self.shard_files object is pruned as the shards are exhausted.
"""
def __init__(self, input_bucket, input_pattern, merge_columns):
"""Constructor.
Arguments:
input_bucket - The bucket to read from.
input_pattern - The file pattern to read from.
merge_columns - The columns used for determining row merging.
"""
shard_paths = get_shard_files(
input_bucket, input_pattern, full_path=True)
if len(shard_paths) == 0:
raise ValueError("Could not find any shard files.")
logging.info("Merge-reading: %s", shard_paths)
self.shard_files = {
shard_path: {} for shard_path in shard_paths
}
self.merge_columns = merge_columns
self.current_key = None
self.current_shard_path = None
def __enter__(self):
"""Open handles to all of the shard files, read the first row"""
retry_params = gcs.RetryParams(urlfetch_timeout=60,
max_retry_period=60 * 60.0)
for shard_path in self.shard_files.keys():
gcs_file = gcs.open(shard_path, "r", retry_params=retry_params)
csv_reader = csv.reader(gcs_file)
head = csv_reader.next() # Assumes there is at least 1 row
self.shard_files[shard_path] = {
"gcs_file": gcs_file,
"csv_reader": csv_reader,
"head": head,
"head_key": self._get_key(head),
"rows_returned": 0,
}
def __iter__(self):
return self
def _get_key(self, row):
return tuple([v for i, v in enumerate(row) if i in self.merge_columns])
def _advance_shard(self, shard_path):
"""Update the shard's head values, return the current head."""
# Save the head, to return later
metadata = self.shard_files[shard_path]
row = metadata["head"]
try:
new_head = metadata["csv_reader"].next()
metadata["head"] = new_head
metadata["head_key"] = self._get_key(new_head)
metadata["rows_returned"] += 1
except StopIteration:
self._close_shard(shard_path)
self.current_shard_path = None
return row
def _find_next_key(self):
"""Find the next key to start merge reading.
We must always choose the next "lowest" key value to be the next key
to read. Not all shards have all keys, so we must do this to ensure
that we do not mis-order the rows in the final output.
"""
lowest_key_value = None
lowest_shard_path = None
for path, metadata in self.shard_files.iteritems():
if (metadata["head_key"] < lowest_key_value
or lowest_key_value is None):
lowest_key_value = metadata["head_key"]
lowest_shard_path = path
return lowest_key_value, lowest_shard_path
def next(self):
# We've exhausted all rows from all shards
if len(self.shard_files) == 0:
raise StopIteration
# This happens at the very beginning, or after exhausting a shard
if self.current_shard_path is None:
self.current_shard_path = self.shard_files.keys()[0]
# This happens at the very beginning, or after exhausting a key
if self.current_key is None:
self.current_key, self.current_shard_path = self._find_next_key()
# If the current shard has more, just return that
if (self.shard_files[self.current_shard_path]["head_key"]
== self.current_key):
return self._advance_shard(self.current_shard_path)
# Iterate over all shard_files
for path, metadata in self.shard_files.iteritems():
if metadata["head_key"] == self.current_key:
self.current_shard_path = path
return self._advance_shard(path)
# We didn't find any rows for current_key, so start over
self.current_key = None
return self.next()
def _close_shard(self, shard_path):
"""Close the shard, remove it from the shard_files collection."""
if shard_path not in self.shard_files:
return
metadata = self.shard_files[shard_path]
logging.info(
"Closing shard after reading %d rows. %d shards remain. %s",
metadata["rows_returned"], len(self.shard_files) - 1, shard_path)
try:
metadata["gcs_file"].close()
except Exception:
logging.exception("Ignoring exception from %s", shard_path)
del self.shard_files[shard_path]
def __exit__(self, exception_type, exception_value, exception_traceback):
"""Closes all shards."""
shard_paths = self.shard_files.keys()
for path in shard_paths:
self._close_shard(path)
# TODO(mattfaus): Re-raise any exception passed here?
def get_shard_files(bucket, filename_prefix, full_path=False):
"""Find files in a bucket, matching a filename prefix."""
if not bucket.startswith("/"):
bucket = "/%s" % bucket
if bucket.endswith("/"):
bucket = bucket[:-1]
retry_params = gcs.RetryParams(urlfetch_timeout=60,
max_retry_period=60 * 30.0)
filename_prefix = bucket + "/" + filename_prefix
shard_files = []
for file in gcs.listbucket(filename_prefix, retry_params=retry_params):
path = file.filename
if not full_path:
# Remove the "/<bucket>" + "/" prefix
path = path[len(bucket) + 1:]
shard_files.append(path)
# Sort for deterministic ordering
shard_files.sort()
return shard_files
class GcsTest(gae_model.GAEModelTestCase):
def setUp(self):
super(GcsTest, self).setUp()
# Initialize stubs necessary for testing the GCS client
self.testbed.init_app_identity_stub()
self.testbed.init_urlfetch_stub()
self.testbed.init_blobstore_stub()
class MergeReaderTest(GcsTest):
"""Verifies SortedGcsCsvShardFileMergeReader functionality."""
data = [
[0, "Matt"],
[0, "Ben"],
[1, "Sam"],
[2, "Bam"],
[2, "Bam"],
[3, "Matt"],
]
bucket = "/bucket"
path_pattern = "/bucket/csv_shard_%d"
filename_prefix = "csv_shard_"
def _write_shards(self, num_shards, num_rows=None):
"""Writes test shards to GCS.
Arguments:
num_shards - The number of shards to create.
num_rows - An array of the number of rows that each shard should
contain. If None, all shards will contain all rows. Max value
is len(self.data). You can use this to make some shards have
less data than others.
"""
if num_rows is None:
num_rows = [len(self.data)] * num_shards
for shard in range(num_shards):
path = self.path_pattern % shard
with gcs.open(path, 'w') as gcs_file:
csv_writer = csv.writer(gcs_file)
for row in self.data[:num_rows[shard]]:
csv_writer.writerow(row)
def _get_key(self, row, key_columns):
return tuple([v for i, v in enumerate(row) if i in key_columns])
def _verify_merge_reader(self, merge_reader, key_columns):
seen_keys = set()
prev_key = None
with merge_reader:
for row in merge_reader:
cur_key = self._get_key(row, key_columns)
if cur_key != prev_key:
self.assertNotIn(cur_key, seen_keys)
seen_keys.add(cur_key)
prev_key = cur_key
def test_same_length_shards(self):
self._write_shards(3)
key_columns = [0]
merge_reader = pipelines_util.SortedGcsCsvShardFileMergeReader(
self.bucket, self.filename_prefix, key_columns)
self._verify_merge_reader(merge_reader, key_columns)
def test_different_length_shards(self):
self._write_shards(3, num_rows=[
len(self.data),
len(self.data) - 1,
len(self.data) - 2])
key_columns = [0]
merge_reader = pipelines_util.SortedGcsCsvShardFileMergeReader(
self.bucket, self.filename_prefix, key_columns)
self._verify_merge_reader(merge_reader, key_columns)
def test_complex_key(self):
self._write_shards(3, num_rows=[
len(self.data),
len(self.data) - 1,
len(self.data) - 2])
key_columns = [0, 1]
merge_reader = pipelines_util.SortedGcsCsvShardFileMergeReader(
self.bucket, self.filename_prefix, key_columns)
self._verify_merge_reader(merge_reader, key_columns)
def test_skip_sorting(self):
# Files being sorted internally does not mean that each file contains
# all keys. The merge reader should be robust to this and skip keys
# until it finds one that matches.
shards = [
[
[5, 'hi'],
[5, 'hi'],
[6, 'hi'],
],
[
[0, 'bye'],
[2, 'bye'],
],
[
[2, 'fie'],
[3, 'fie'],
],
[
[6, 'fie'],
[7, 'fie'],
],
[
[1, 'fie'],
[7, 'fie'],
]
]
for idx, shard in enumerate(shards):
path = self.path_pattern % idx
with gcs.open(path, 'w') as gcs_file:
csv_writer = csv.writer(gcs_file)
for row in shard:
csv_writer.writerow(row)
key_columns = [0]
merge_reader = pipelines_util.SortedGcsCsvShardFileMergeReader(
self.bucket, self.filename_prefix, key_columns)
self._verify_merge_reader(merge_reader, key_columns)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment