Last active
November 10, 2021 19:06
-
-
Save Nipsuli/891148cf1d1fdf1402626d12721b39cd to your computer and use it in GitHub Desktop.
Two pass shuffle implementation for algorithm described: in https://blog.janestreet.com/how-to-shuffle-a-big-dataset/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
import tempfile | |
import random | |
def two_pass_shuffle(input_files, output_files, temp_file_count, header_lines=0): | |
""" | |
two_pass_shuffle | |
Suffle data larger that can be shuffled in memory. | |
Implementation based on: | |
https://blog.janestreet.com/how-to-shuffle-a-big-dataset/ | |
-- First pass | |
create empty piles p[0], ..., p[M - 1] | |
for i = 0, ..., n - 1 do | |
j := uniform random draw from {0, ..., M - 1} | |
append x[i] to pile p[j] | |
-- Second pass (perhaps done lazily) | |
for j = 0, ..., M - 1 do | |
shuffle p[j] in RAM with Fisher-Yates or whatever is convenient | |
append p[j] to output file | |
One can control the memory requirement by adjusting the amount of temp files to | |
be used. | |
Can skip header rows in data, e.g. in case of csv files. If there are header rows | |
those are automatically added to output files | |
In addition it allows batching the shufled data into multple same sized output | |
files. | |
Parameters: | |
input_files (List[str]): List of input file names | |
output_files (List[str]): List of output file names | |
temp_file_count (int): number of temp files to use | |
header_lines (int): amount of header lines | |
""" | |
lines = 0 | |
header_rows = None | |
with contextlib.ExitStack() as temp_files_ctx: | |
temp_files = [ | |
temp_files_ctx.enter_context(tempfile.TemporaryFile(mode="w+t")) | |
for _ in range(temp_file_count) | |
] | |
with contextlib.ExitStack() as i_files_ctx: | |
""" | |
First shuffle pass: stream data from input to | |
randomly chosen temp file | |
""" | |
i_files = [i_files_ctx.enter_context(open(fname)) for fname in input_files] | |
for i_file in i_files: | |
_header_rows = [i_file.readline() for _ in range(header_lines)] | |
if header_rows is None: | |
header_rows = _header_rows | |
else: | |
assert ( | |
header_rows == _header_rows | |
), "Files need to have matching header rows" | |
for row in i_file.readlines(): | |
lines += 1 | |
temp_file = random.choice(temp_files) | |
temp_file.writelines([row]) | |
_ = [f.seek(0) for f in temp_files] | |
lines_per_o_file = round(lines / len(output_files), 0) | |
def row_streamer(): | |
for temp_file in temp_files: | |
rows = temp_file.readlines() | |
random.shuffle(rows) | |
for row in rows: | |
yield row | |
with contextlib.ExitStack() as o_files_ctx: | |
""" | |
Second shuffle pass: shufle temp files one by | |
one and stream to output files | |
""" | |
o_files = iter( | |
[o_files_ctx.enter_context(open(fname, "w")) for fname in output_files] | |
) | |
def next_file(previous_file): | |
try: | |
current_file = next(o_files) | |
except StopIteration: | |
return previous_file | |
current_file.writelines(header_rows) | |
return current_file | |
write_batch = 0 | |
current_file = next_file(None) | |
for row in row_streamer(): | |
if write_batch == lines_per_o_file: | |
write_batch = 0 | |
current_file = next_file(current_file) | |
current_file.writelines([row]) | |
write_batch += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment