Skip to content

Instantly share code, notes, and snippets.

@jelmervdl
Created September 29, 2022 15:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jelmervdl/c0ef9af690a3ef40db7f6135c78068c7 to your computer and use it in GitHub Desktop.
Save jelmervdl/c0ef9af690a3ef40db7f6135c78068c7 to your computer and use it in GitHub Desktop.
Batch shuffling
#!/usr/bin/env python3
import subprocess
import random
import os
import re
import sys
import numpy as np
from collections import defaultdict
from itertools import accumulate
from bisect import bisect_right
from shutil import rmtree
from multiprocessing.pool import Pool
LINES_PER_BATCH = 1000000
def text_reader(path):
with open(path, 'rb') as fh:
decompress = subprocess.Popen(['pigz', '-cd'], stdin=fh, stdout=subprocess.PIPE)
yield from decompress.stdout
if decompress.wait() != 0:
raise RuntimeError(f'gunzip of {path} exited with non-zero exit status {decompress.returncode}')
def embeddings_reader(path):
embeddings = np.memmap(path, dtype=np.float32, mode="r").reshape((-1, 3072))
yield from embeddings
def batch_reader(in_path, batch, text_columns, bin_columns):
readers = [
text_reader(os.path.join(in_path, batch['gz'][lang]))
for lang in text_columns
] + [
embeddings_reader(os.path.join(in_path, batch['bin'][lang]))
for lang in bin_columns
]
yield from zip(*readers)
class text_writer:
def __init__(self, path):
self.path = path
self.fh = open(path, 'wb')
self.compress = subprocess.Popen(['pigz', '-9c'], stdin=subprocess.PIPE, stdout=self.fh)
def write(self, value):
self.compress.stdin.write(value)
def close(self):
self.compress.stdin.close()
if self.compress.wait() != 0:
raise RuntimeError(f'gzip of {self.path} exited with non-zero exit status {self.compress.returncode}')
self.fh.close()
class embeddings_writer:
def __init__(self, path):
self.path = path
self.fh = open(path, 'wb')
def write(self, value):
self.fh.write(value.tobytes())
def close(self):
self.fh.close()
class batch_writer:
def __init__(self, out_path, index, text_columns, bin_columns):
self.writers = [
text_writer(os.path.join(out_path, f'corpus.{lang}.{index:04d}.gz'))
for lang in text_columns
] + [
embeddings_writer(os.path.join(out_path, f'corpus.{lang}.{index:04d}.bin'))
for lang in bin_columns
]
def write(self, values):
for writer, value in zip(self.writers, values):
writer.write(value)
def close(self):
for writer in self.writers:
writer.close()
def shuffle_text(in_filename, out_filename, indices):
# Read all lines in memory (small anyway)
lines = list(text_reader(in_filename))
writer = text_writer(out_filename)
for index in indices:
writer.write(lines[index])
writer.close()
def shuffle_embeddings(in_filename, out_filename, indices):
ein = np.memmap(in_filename, dtype=np.float32, mode="r").reshape((-1, 3072))
if len(indices) != ein.shape[0]:
raise RuntimeError(f'Number of lines in {in_filename} ({ein.shape[0]}) does not match number of indices ({len(indices)})')
writer = embeddings_writer(out_filename)
for index in indices:
writer.write(ein[index,:])
writer.close()
def shuffle_batch(args):
in_path, tmp_path, batch_id, batch = args
print(f"Shuffling batch {batch_id}")
indices = list(range(batch['length']))
random.shuffle(indices)
for lang, filename in batch['gz'].items():
try:
print(f"Shuffling text file {filename}")
shuffle_text(
os.path.join(in_path, filename),
os.path.join(tmp_path, filename),
indices)
except Exception as exc:
raise RuntimeError(f'Error reordering text lines in {batch_id}/{filename}') from exc
for lang, filename in batch['bin'].items():
try:
print(f"Shuffling embeddings file {filename}")
shuffle_embeddings(
os.path.join(in_path, filename),
os.path.join(tmp_path, filename),
indices)
except Exception as exc:
raise RuntimeError(f'Error reordering embeddings in {batch_id}/{filename}') from exc
return batch_id
def shuffle(in_path, out_path):
files_per_batch = defaultdict(lambda: {'gz': {}, 'bin': {}, 'length': None})
for file in os.scandir(in_path):
if not file.is_file():
continue
match = re.match(r'^corpus\.(?P<lang>[a-z]+)\.(?P<batch>\d{4})\.(?P<type>gz|bin)$', file.name)
if not match:
continue
files_per_batch[match.group('batch')][match.group('type')][match.group('lang')] = file.name
if match.group('type') == 'bin':
files_per_batch[match.group('batch')]['length'] = int(file.stat().st_size / 3072 / 4)
text_langs = set(lang
for batch in files_per_batch.values()
for lang in batch['gz'].keys())
if len(text_langs) != 2:
raise RuntimeError(f'Expected two languages in a dataset, found {text_langs!r}')
bin_langs = set(lang
for batch in files_per_batch.values()
for lang in batch['bin'].keys())
if len(bin_langs) != 1:
raise RuntimeError(f'Expected one embedded language in a dataset, found {bin_langs!r}')
# Assert all files match up
for batch_id, batch in files_per_batch.items():
if set(batch['gz']) != text_langs:
raise RuntimeError(f'Number of text files for batch {batch_id} is not 2')
if set(batch['bin']) != bin_langs:
raise RuntimeError(f'Number of embedding files for batch {batch_id} is not 1')
if batch['length'] > LINES_PER_BATCH:
raise RuntimeError(f'Number of lines in embedding file is larger than expected')
# Shuffle all files individually
tmp_path = os.path.join(out_path, f'.tmp{os.getpid()}')
os.makedirs(tmp_path)
print("Submitting shuffling jobs")
with Pool(8) as pool:
args_it = (
(in_path, tmp_path, batch_id, batch)
for batch_id, batch in files_per_batch.items()
)
for batch_id in pool.imap_unordered(shuffle_batch, args_it):
print(f"Finished shuffling {batch_id}")
# Rewrite batches by taking from random files
batches = files_per_batch.values()
columns = list(text_langs), list(bin_langs)
batch_indices = [
n
for n, batch in enumerate(batches)
for _ in range(batch['length'])
]
random.shuffle(batch_indices)
readers = [batch_reader(in_path, batch, *columns) for batch in batches]
writer = None
write_index = 0
for entry_index, read_index in enumerate(batch_indices):
if entry_index % LINES_PER_BATCH == 0:
print(f"Writing batch {write_index:04d}")
if writer:
writer.close()
writer = batch_writer(out_path, write_index, *columns)
write_index += 1
writer.write(next(readers[read_index]))
if writer:
writer.close()
print(f"Checking that all batches have been read")
for reader in readers:
if next(reader, None) is not None:
raise RuntimeError('Reader for batch still returned data after writing entire dataset')
print(f"Removing temporary files")
rmtree(tmp_path)
if __name__ == '__main__':
shuffle(*sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment