Skip to content

Instantly share code, notes, and snippets.

@Nipsuli
Last active November 10, 2021 19:06
Show Gist options
  • Save Nipsuli/891148cf1d1fdf1402626d12721b39cd to your computer and use it in GitHub Desktop.
Save Nipsuli/891148cf1d1fdf1402626d12721b39cd to your computer and use it in GitHub Desktop.
Two pass shuffle implementation for algorithm described: in https://blog.janestreet.com/how-to-shuffle-a-big-dataset/
import contextlib
import tempfile
import random
def two_pass_shuffle(input_files, output_files, temp_file_count, header_lines=0):
"""
two_pass_shuffle
Suffle data larger that can be shuffled in memory.
Implementation based on:
https://blog.janestreet.com/how-to-shuffle-a-big-dataset/
-- First pass
create empty piles p[0], ..., p[M - 1]
for i = 0, ..., n - 1 do
j := uniform random draw from {0, ..., M - 1}
append x[i] to pile p[j]
-- Second pass (perhaps done lazily)
for j = 0, ..., M - 1 do
shuffle p[j] in RAM with Fisher-Yates or whatever is convenient
append p[j] to output file
One can control the memory requirement by adjusting the amount of temp files to
be used.
Can skip header rows in data, e.g. in case of csv files. If there are header rows
those are automatically added to output files
In addition it allows batching the shufled data into multple same sized output
files.
Parameters:
input_files (List[str]): List of input file names
output_files (List[str]): List of output file names
temp_file_count (int): number of temp files to use
header_lines (int): amount of header lines
"""
lines = 0
header_rows = None
with contextlib.ExitStack() as temp_files_ctx:
temp_files = [
temp_files_ctx.enter_context(tempfile.TemporaryFile(mode="w+t"))
for _ in range(temp_file_count)
]
with contextlib.ExitStack() as i_files_ctx:
"""
First shuffle pass: stream data from input to
randomly chosen temp file
"""
i_files = [i_files_ctx.enter_context(open(fname)) for fname in input_files]
for i_file in i_files:
_header_rows = [i_file.readline() for _ in range(header_lines)]
if header_rows is None:
header_rows = _header_rows
else:
assert (
header_rows == _header_rows
), "Files need to have matching header rows"
for row in i_file.readlines():
lines += 1
temp_file = random.choice(temp_files)
temp_file.writelines([row])
_ = [f.seek(0) for f in temp_files]
lines_per_o_file = round(lines / len(output_files), 0)
def row_streamer():
for temp_file in temp_files:
rows = temp_file.readlines()
random.shuffle(rows)
for row in rows:
yield row
with contextlib.ExitStack() as o_files_ctx:
"""
Second shuffle pass: shufle temp files one by
one and stream to output files
"""
o_files = iter(
[o_files_ctx.enter_context(open(fname, "w")) for fname in output_files]
)
def next_file(previous_file):
try:
current_file = next(o_files)
except StopIteration:
return previous_file
current_file.writelines(header_rows)
return current_file
write_batch = 0
current_file = next_file(None)
for row in row_streamer():
if write_batch == lines_per_o_file:
write_batch = 0
current_file = next_file(current_file)
current_file.writelines([row])
write_batch += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment