Create a gist now

Instantly share code, notes, and snippets.

A Pipeline job which launches a map-only job to sort .csv files in memory. Each .csv file is read from Google Cloud Storage into memory, sorted by the specified key, and then written back out to Google Cloud Storage. The machine running the sorting process must have roughly 10x the amount of memory as the size of each .csv file.
class ParallelInMemorySortGcsCsvShardFiles(pipeline.Pipeline):
def run(self, input_bucket, input_pattern, sort_columns,
model_type, output_bucket, output_pattern):
"""Sorts each input file in-memory, then writes it to an output file.
Arguments:
input_bucket - The GCS bucket which contains the unsorted .csv
files.
input_pattern - A regular expression used to find files in the
input_bucket to map over. (e.g. "raw_problem_logs/.*").
sort_columns - An array of column indexes to sort by (e.g. [0,1])
model_type - The model type being trained. Used in building the
output filename for organization purposes.
output_bucket - The GCS bucket to write sorted .csv files to.
output_pattern - The pattern to use to name output files. You
*must* include "%(num)s" in this pattern or the output files
will overwrite each other. "%(num)s" will be replaced with the
index in the sorted listed of filenames found by input_pattern.
shards - The number of shards used to process the input files.
"""
shard_files = get_shard_files(input_bucket, input_pattern)
logging.info("ParallelInMemorySortGcsCsvShardFiles processing: %s",
shard_files)
count_shards = len(shard_files)
if count_shards == 0:
raise ValueError("No shards found to sort.")
yield mapreduce_pipeline.MapperPipeline(
job_name="ParallelInMemorySortGcsCsvShardFiles",
handler_spec='prediction.pipelines_util'
'.in_memory_sort_gcs_csv_shard_files_map',
input_reader_spec=('third_party.mapreduce.input_readers'
'._GoogleCloudStorageInputReader'),
params={
'input_reader': {
'bucket_name': input_bucket[1:], # Remove / prefix
'objects': shard_files,
},
'sort_columns': sort_columns,
'model_type': model_type,
'input_bucket': input_bucket,
'input_pattern': input_pattern,
'output_bucket': output_bucket,
'output_pattern': output_pattern,
'root_pipeline_id': self.root_pipeline_id,
},
# These shards are usually 170MB, and the machines we are using can
# handle processing at least 6 at a time, due to memory constraints
# Although, the math is kind of fuzzy because 1) each machine may
# be processing >1 shard at a time, due to multi-threading support,
# and 2) the auto-scaling mechanic is cpu-utilization, not memory.
# Which reduces this calculation to little more than a heuristic
# that I've found to work. Feel free to tweak as necessary.
shards=max(1, (count_shards / 6))
)
def in_memory_sort_gcs_csv_shard_files_map(gcs_csv_file):
my_context = context.get()
mapper_params = my_context.mapreduce_spec.mapper.params
sort_columns = mapper_params['sort_columns']
model_type = mapper_params['model_type']
input_bucket = mapper_params['input_bucket']
input_pattern = mapper_params['input_pattern']
output_bucket = mapper_params['output_bucket']
output_pattern = mapper_params['output_pattern']
root_pipeline_id = mapper_params['root_pipeline_id']
# Read rows into memory
all_rows = []
csv_reader = csv.reader(gcs_csv_file)
for row in csv_reader:
all_rows.append(row)
# Sort them
def sort_key(csv_row):
values = [v for i, v in enumerate(csv_row) if i in sort_columns]
return tuple(values)
all_rows.sort(key=sort_key)
# Upload to GCS
shard_files = get_shard_files(
input_bucket, input_pattern, full_path=True)
gcs_output_path = output_pattern % {
'num': shard_files.index(gcs_csv_file.name)
}
gcs_output_path = build_full_gcs_path(output_bucket, model_type,
gcs_output_path, root_pipeline_id)
logging.info("Uploading sorted file to %s", gcs_output_path)
retry_params = gcs.RetryParams(urlfetch_timeout=60,
max_retry_period=60 * 60.0)
with gcs.open(gcs_output_path, "w", retry_params=retry_params
) as gcs_output_file:
csv_writer = csv.writer(gcs_output_file)
for i, row in enumerate(all_rows):
csv_writer.writerow(row)
if (i + 1) % 100000 == 0:
logging.info("Wrote %d rows.", i + 1)
def get_shard_files(bucket, pattern, full_path=False):
"""Find files in a bucket, matching a pattern."""
if not bucket.startswith("/"):
bucket = "/%s" % bucket
if bucket.endswith("/"):
bucket = bucket[:-1]
retry_params = gcs.RetryParams(urlfetch_timeout=60,
max_retry_period=60 * 30.0)
pattern = bucket + "/" + pattern
shard_files = []
for file in gcs.listbucket(pattern, retry_params=retry_params):
path = file.filename
if not full_path:
# Remove the "/<bucket>" + "/" prefix
path = path[len(bucket) + 1:]
shard_files.append(path)
return shard_files
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment