A Pipeline job which launches a map-only job to sort .csv files in memory. Each .csv file is read from Google Cloud Storage into memory, sorted by the specified key, and then written back out to Google Cloud Storage. The machine running the sorting process must have roughly 10x the amount of memory as the size of each .csv file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ParallelInMemorySortGcsCsvShardFiles(pipeline.Pipeline): | |
def run(self, input_bucket, input_pattern, sort_columns, | |
model_type, output_bucket, output_pattern): | |
"""Sorts each input file in-memory, then writes it to an output file. | |
Arguments: | |
input_bucket - The GCS bucket which contains the unsorted .csv | |
files. | |
input_pattern - A regular expression used to find files in the | |
input_bucket to map over. (e.g. "raw_problem_logs/.*"). | |
sort_columns - An array of column indexes to sort by (e.g. [0,1]) | |
model_type - The model type being trained. Used in building the | |
output filename for organization purposes. | |
output_bucket - The GCS bucket to write sorted .csv files to. | |
output_pattern - The pattern to use to name output files. You | |
*must* include "%(num)s" in this pattern or the output files | |
will overwrite each other. "%(num)s" will be replaced with the | |
index in the sorted listed of filenames found by input_pattern. | |
shards - The number of shards used to process the input files. | |
""" | |
shard_files = get_shard_files(input_bucket, input_pattern) | |
logging.info("ParallelInMemorySortGcsCsvShardFiles processing: %s", | |
shard_files) | |
count_shards = len(shard_files) | |
if count_shards == 0: | |
raise ValueError("No shards found to sort.") | |
yield mapreduce_pipeline.MapperPipeline( | |
job_name="ParallelInMemorySortGcsCsvShardFiles", | |
handler_spec='prediction.pipelines_util' | |
'.in_memory_sort_gcs_csv_shard_files_map', | |
input_reader_spec=('third_party.mapreduce.input_readers' | |
'._GoogleCloudStorageInputReader'), | |
params={ | |
'input_reader': { | |
'bucket_name': input_bucket[1:], # Remove / prefix | |
'objects': shard_files, | |
}, | |
'sort_columns': sort_columns, | |
'model_type': model_type, | |
'input_bucket': input_bucket, | |
'input_pattern': input_pattern, | |
'output_bucket': output_bucket, | |
'output_pattern': output_pattern, | |
'root_pipeline_id': self.root_pipeline_id, | |
}, | |
# These shards are usually 170MB, and the machines we are using can | |
# handle processing at least 6 at a time, due to memory constraints | |
# Although, the math is kind of fuzzy because 1) each machine may | |
# be processing >1 shard at a time, due to multi-threading support, | |
# and 2) the auto-scaling mechanic is cpu-utilization, not memory. | |
# Which reduces this calculation to little more than a heuristic | |
# that I've found to work. Feel free to tweak as necessary. | |
shards=max(1, (count_shards / 6)) | |
) | |
def in_memory_sort_gcs_csv_shard_files_map(gcs_csv_file): | |
my_context = context.get() | |
mapper_params = my_context.mapreduce_spec.mapper.params | |
sort_columns = mapper_params['sort_columns'] | |
model_type = mapper_params['model_type'] | |
input_bucket = mapper_params['input_bucket'] | |
input_pattern = mapper_params['input_pattern'] | |
output_bucket = mapper_params['output_bucket'] | |
output_pattern = mapper_params['output_pattern'] | |
root_pipeline_id = mapper_params['root_pipeline_id'] | |
# Read rows into memory | |
all_rows = [] | |
csv_reader = csv.reader(gcs_csv_file) | |
for row in csv_reader: | |
all_rows.append(row) | |
# Sort them | |
def sort_key(csv_row): | |
values = [v for i, v in enumerate(csv_row) if i in sort_columns] | |
return tuple(values) | |
all_rows.sort(key=sort_key) | |
# Upload to GCS | |
shard_files = get_shard_files( | |
input_bucket, input_pattern, full_path=True) | |
gcs_output_path = output_pattern % { | |
'num': shard_files.index(gcs_csv_file.name) | |
} | |
gcs_output_path = build_full_gcs_path(output_bucket, model_type, | |
gcs_output_path, root_pipeline_id) | |
logging.info("Uploading sorted file to %s", gcs_output_path) | |
retry_params = gcs.RetryParams(urlfetch_timeout=60, | |
max_retry_period=60 * 60.0) | |
with gcs.open(gcs_output_path, "w", retry_params=retry_params | |
) as gcs_output_file: | |
csv_writer = csv.writer(gcs_output_file) | |
for i, row in enumerate(all_rows): | |
csv_writer.writerow(row) | |
if (i + 1) % 100000 == 0: | |
logging.info("Wrote %d rows.", i + 1) | |
def get_shard_files(bucket, pattern, full_path=False): | |
"""Find files in a bucket, matching a pattern.""" | |
if not bucket.startswith("/"): | |
bucket = "/%s" % bucket | |
if bucket.endswith("/"): | |
bucket = bucket[:-1] | |
retry_params = gcs.RetryParams(urlfetch_timeout=60, | |
max_retry_period=60 * 30.0) | |
pattern = bucket + "/" + pattern | |
shard_files = [] | |
for file in gcs.listbucket(pattern, retry_params=retry_params): | |
path = file.filename | |
if not full_path: | |
# Remove the "/<bucket>" + "/" prefix | |
path = path[len(bucket) + 1:] | |
shard_files.append(path) | |
return shard_files | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment