Created
September 29, 2022 15:53
-
-
Save jelmervdl/c0ef9af690a3ef40db7f6135c78068c7 to your computer and use it in GitHub Desktop.
Batch shuffling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import subprocess | |
import random | |
import os | |
import re | |
import sys | |
import numpy as np | |
from collections import defaultdict | |
from itertools import accumulate | |
from bisect import bisect_right | |
from shutil import rmtree | |
from multiprocessing.pool import Pool | |
LINES_PER_BATCH = 1000000 | |
def text_reader(path): | |
with open(path, 'rb') as fh: | |
decompress = subprocess.Popen(['pigz', '-cd'], stdin=fh, stdout=subprocess.PIPE) | |
yield from decompress.stdout | |
if decompress.wait() != 0: | |
raise RuntimeError(f'gunzip of {path} exited with non-zero exit status {decompress.returncode}') | |
def embeddings_reader(path): | |
embeddings = np.memmap(path, dtype=np.float32, mode="r").reshape((-1, 3072)) | |
yield from embeddings | |
def batch_reader(in_path, batch, text_columns, bin_columns): | |
readers = [ | |
text_reader(os.path.join(in_path, batch['gz'][lang])) | |
for lang in text_columns | |
] + [ | |
embeddings_reader(os.path.join(in_path, batch['bin'][lang])) | |
for lang in bin_columns | |
] | |
yield from zip(*readers) | |
class text_writer: | |
def __init__(self, path): | |
self.path = path | |
self.fh = open(path, 'wb') | |
self.compress = subprocess.Popen(['pigz', '-9c'], stdin=subprocess.PIPE, stdout=self.fh) | |
def write(self, value): | |
self.compress.stdin.write(value) | |
def close(self): | |
self.compress.stdin.close() | |
if self.compress.wait() != 0: | |
raise RuntimeError(f'gzip of {self.path} exited with non-zero exit status {self.compress.returncode}') | |
self.fh.close() | |
class embeddings_writer: | |
def __init__(self, path): | |
self.path = path | |
self.fh = open(path, 'wb') | |
def write(self, value): | |
self.fh.write(value.tobytes()) | |
def close(self): | |
self.fh.close() | |
class batch_writer: | |
def __init__(self, out_path, index, text_columns, bin_columns): | |
self.writers = [ | |
text_writer(os.path.join(out_path, f'corpus.{lang}.{index:04d}.gz')) | |
for lang in text_columns | |
] + [ | |
embeddings_writer(os.path.join(out_path, f'corpus.{lang}.{index:04d}.bin')) | |
for lang in bin_columns | |
] | |
def write(self, values): | |
for writer, value in zip(self.writers, values): | |
writer.write(value) | |
def close(self): | |
for writer in self.writers: | |
writer.close() | |
def shuffle_text(in_filename, out_filename, indices): | |
# Read all lines in memory (small anyway) | |
lines = list(text_reader(in_filename)) | |
writer = text_writer(out_filename) | |
for index in indices: | |
writer.write(lines[index]) | |
writer.close() | |
def shuffle_embeddings(in_filename, out_filename, indices): | |
ein = np.memmap(in_filename, dtype=np.float32, mode="r").reshape((-1, 3072)) | |
if len(indices) != ein.shape[0]: | |
raise RuntimeError(f'Number of lines in {in_filename} ({ein.shape[0]}) does not match number of indices ({len(indices)})') | |
writer = embeddings_writer(out_filename) | |
for index in indices: | |
writer.write(ein[index,:]) | |
writer.close() | |
def shuffle_batch(args): | |
in_path, tmp_path, batch_id, batch = args | |
print(f"Shuffling batch {batch_id}") | |
indices = list(range(batch['length'])) | |
random.shuffle(indices) | |
for lang, filename in batch['gz'].items(): | |
try: | |
print(f"Shuffling text file {filename}") | |
shuffle_text( | |
os.path.join(in_path, filename), | |
os.path.join(tmp_path, filename), | |
indices) | |
except Exception as exc: | |
raise RuntimeError(f'Error reordering text lines in {batch_id}/{filename}') from exc | |
for lang, filename in batch['bin'].items(): | |
try: | |
print(f"Shuffling embeddings file {filename}") | |
shuffle_embeddings( | |
os.path.join(in_path, filename), | |
os.path.join(tmp_path, filename), | |
indices) | |
except Exception as exc: | |
raise RuntimeError(f'Error reordering embeddings in {batch_id}/{filename}') from exc | |
return batch_id | |
def shuffle(in_path, out_path): | |
files_per_batch = defaultdict(lambda: {'gz': {}, 'bin': {}, 'length': None}) | |
for file in os.scandir(in_path): | |
if not file.is_file(): | |
continue | |
match = re.match(r'^corpus\.(?P<lang>[a-z]+)\.(?P<batch>\d{4})\.(?P<type>gz|bin)$', file.name) | |
if not match: | |
continue | |
files_per_batch[match.group('batch')][match.group('type')][match.group('lang')] = file.name | |
if match.group('type') == 'bin': | |
files_per_batch[match.group('batch')]['length'] = int(file.stat().st_size / 3072 / 4) | |
text_langs = set(lang | |
for batch in files_per_batch.values() | |
for lang in batch['gz'].keys()) | |
if len(text_langs) != 2: | |
raise RuntimeError(f'Expected two languages in a dataset, found {text_langs!r}') | |
bin_langs = set(lang | |
for batch in files_per_batch.values() | |
for lang in batch['bin'].keys()) | |
if len(bin_langs) != 1: | |
raise RuntimeError(f'Expected one embedded language in a dataset, found {bin_langs!r}') | |
# Assert all files match up | |
for batch_id, batch in files_per_batch.items(): | |
if set(batch['gz']) != text_langs: | |
raise RuntimeError(f'Number of text files for batch {batch_id} is not 2') | |
if set(batch['bin']) != bin_langs: | |
raise RuntimeError(f'Number of embedding files for batch {batch_id} is not 1') | |
if batch['length'] > LINES_PER_BATCH: | |
raise RuntimeError(f'Number of lines in embedding file is larger than expected') | |
# Shuffle all files individually | |
tmp_path = os.path.join(out_path, f'.tmp{os.getpid()}') | |
os.makedirs(tmp_path) | |
print("Submitting shuffling jobs") | |
with Pool(8) as pool: | |
args_it = ( | |
(in_path, tmp_path, batch_id, batch) | |
for batch_id, batch in files_per_batch.items() | |
) | |
for batch_id in pool.imap_unordered(shuffle_batch, args_it): | |
print(f"Finished shuffling {batch_id}") | |
# Rewrite batches by taking from random files | |
batches = files_per_batch.values() | |
columns = list(text_langs), list(bin_langs) | |
batch_indices = [ | |
n | |
for n, batch in enumerate(batches) | |
for _ in range(batch['length']) | |
] | |
random.shuffle(batch_indices) | |
readers = [batch_reader(in_path, batch, *columns) for batch in batches] | |
writer = None | |
write_index = 0 | |
for entry_index, read_index in enumerate(batch_indices): | |
if entry_index % LINES_PER_BATCH == 0: | |
print(f"Writing batch {write_index:04d}") | |
if writer: | |
writer.close() | |
writer = batch_writer(out_path, write_index, *columns) | |
write_index += 1 | |
writer.write(next(readers[read_index])) | |
if writer: | |
writer.close() | |
print(f"Checking that all batches have been read") | |
for reader in readers: | |
if next(reader, None) is not None: | |
raise RuntimeError('Reader for batch still returned data after writing entire dataset') | |
print(f"Removing temporary files") | |
rmtree(tmp_path) | |
if __name__ == '__main__': | |
shuffle(*sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment