Skip to content

Instantly share code, notes, and snippets.

@daskol
Last active December 11, 2023 21:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save daskol/b088233435506a959535d0336c389cd3 to your computer and use it in GitHub Desktop.
Save daskol/b088233435506a959535d0336c389cd3 to your computer and use it in GitHub Desktop.
Benchmarking of OpenWebText parallel loader.
from argparse import ArgumentParser, Namespace
from pathlib import Path
from random import shuffle
from tqdm import tqdm
from openwebtext import OpenWebTextLoader, shuffle, take
parser = ArgumentParser()
parser.add_argument('-j', '--jobs', type=int, default=2)
parser.add_argument('data_dir', type=Path)
def main():
ns: Namespace = parser.parse_args()
subsets = [*sorted(ns.data_dir.glob('*.tar'))]
shuffle(subsets)
print(f'load {len(subsets)} parts from {ns.data_dir} '
f'with {ns.jobs} workers')
loader = OpenWebTextLoader(subsets, num_workers=ns.jobs)
samples = iter(loader)
samples = shuffle(samples, 1024)
samples = take(samples, 500_000)
for sample in tqdm(samples, total=500_000, unit='sample'):
continue
print(*sample[:2])
if __name__ == '__main__':
main()
#!/bin/bash
data_dir=subsets
num_workers_grid=(1 2 3 4 8 16 20)
for num_workers in ${num_workers_grid[@]}; do
/usr/bin/time -a -o timings.csv -f "$num_workers,%e,%C" \
python main.py -j $num_workers $data_dir
done
import multiprocessing as mp
from sys import version_info
from codecs import StreamReader, getreader
from concurrent.futures import Future, ProcessPoolExecutor, as_completed
from itertools import islice
from pathlib import Path
from tarfile import TarFile
from typing import Iterable, Optional
import numpy as np
if version_info < (3, 11):
from concurrent.futures import TimeoutError as FutureTimeoutError
else:
FutureTimeoutError = TimeoutError
__all__ = ('OpenWebTextLoader', 'shuffle', 'take')
DecodingStreamReader: StreamReader = getreader('utf-8')
# Communication primitives between producing workers and consuming parent.
queues: list[mp.JoinableQueue] = []
done_event: Optional[mp.Event] = None
sema: Optional[mp.Semaphore] = None
def iter_subset(path: Path):
with TarFile(path) as tar:
for i, info in enumerate(tar.getmembers()):
arch = tar.extractfile(info)
part_name = Path(info.name).name
with TarFile.xzopen(info.name, fileobj=arch) as part:
for j, entry in enumerate(part):
fin = part.extractfile(entry)
reader = DecodingStreamReader(fin)
yield (i, j), (part_name, entry.name), reader.read()
def pool_init(*args):
global queues, done_event
queues, done_event, *_ = args
def _pool_exec(worker_id: int, subset: Path):
global queues, done_event
queue = queues[worker_id]
# Count number of samples put to the queue during iterations.
total_samples = 0
samples = iter_subset(subset)
while not done_event.is_set():
try:
sample = next(samples)
except StopIteration:
# Push empty message to queue in order to notify consumer that this
# queue is over.
queue.put(None)
break
total_samples += 1
queue.put(sample)
# Wait until all elements of queue are get on consumer side.
queue.join()
# Return workder id (= queue index) and a result variable.
return worker_id, total_samples
def pool_exec(ix: int, subset: str):
try:
return _pool_exec(ix, subset)
except Exception as e:
print('pool_exec(): exception;', e)
class OpenWebTextLoader:
"""multiprocessing loader for OpenWebText corpus.
Args:
subsets: list of partitions of corpus on filesystem.
num_workers: degree of parallelism.
prefetch: size of prefetch buffer per worker.
random_state: random number generator for merging streams of samples.
Example:
.. code:: python
subsets = [Path('.../urlsf_subset00.tar'), ...]
num_workers = 4
for datapoint in OpenWebTextLoader(subset, num_workers):
pass
"""
def __init__(self, subsets: list[Path], num_workers=1, prefetch=2048,
random_state=None):
if prefetch <= 1:
raise ValueError('prefetch buffer should be strictly large then 1')
self.subsets = subsets
self.num_workers = num_workers
self.prefetch = prefetch
self.random_state = random_state or np.random.RandomState()
self.done_event: Optional[mp.Event] = None
self.queues: list[mp.JoinableQueue] = []
self.pool: Optional[ProcessPoolExecutor] = None
self.futures: dict[int, Future] = {}
def __del__(self):
# If there is no pool then there is nothin to close.
if self.pool is not None:
self.close()
self.drain()
self.pool.shutdown(cancel_futures=True)
def __iter__(self):
# TODO: We should split implementation of __iter__ and __next__.
if self.pool is not None:
return self
raise RuntimeError
self.done_event = mp.Event()
self.queues = [
mp.JoinableQueue(self.prefetch) for _ in range(self.num_workers)
]
self.sema = mp.Semaphore(0)
self.pool = ProcessPoolExecutor(
max_workers=self.num_workers, initializer=pool_init,
initargs=(self.queues, self.done_event, self.sema))
# Start worker processes.
for pid, (_, subset) in enumerate(zip(self.queues, self.subsets)):
self.futures[pid] = self.pool.submit(pool_exec, pid, subset)
self.subsets = self.subsets[pid + 1:]
# Prepare data structures for drawing samples from different queues.
self.pids = [*range(len(self.queues))]
self.result_queues = dict(enumerate(self.queues))
return self
def __next__(self):
# Draw samples from queues uniformly.
while self.pids:
pid = self.random_state.choice(self.pids)
if (sample := self.result_queues[pid].get()) is not None:
self.result_queues[pid].task_done()
return sample
elif self.subsets:
self.result_queues[pid].task_done()
# Wait until worker with id `pid` finished before we start
# another one background job.
future = self.futures[pid]
if (exc := future.exception()) is not None:
raise exc
# Start another one workder with the same id `pid`.
subset, self.subsets = self.subsets[0], self.subsets[1:]
self.futures[pid] = self.pool.submit(pool_exec, pid, subset)
else:
self.result_queues[pid].task_done()
self.pids.remove(pid)
else:
raise StopIteration
def close(self):
# We have send an exit event and then to drain queues (i.e. read at
# least one element from it) in order to guarantee the worker process
# will see a done event and exit.
self.done_event.set()
while self.futures:
try:
# The issue here is that we cannot to multiplex events from
# different sources (switch operator on channels in Golang).
# So, we wait for completed future with timeout and if there is
# no future completes then we drain queues.
#
# It's quite simple and native solution and it would better to
# start draining tasks in a separate thread. But it is so
# tedious in Python without alive event loop...
finished = as_completed(self.futures.values(), timeout=1e-3)
for future in finished:
if (exc := future.exception()) is not None:
raise exc
pid, _ = future.result()
self.futures.pop(pid)
except FutureTimeoutError:
for pid, _ in self.futures.items():
queue: mp.JoinableQueue = self.queues[pid]
while not queue.empty():
queue.get()
queue.task_done()
def drain(self):
for queue in self.queues:
while not queue.empty():
queue.get()
queue.task_done()
def shuffle(it: Iterable, buffer_size=1024, random_state=None):
random_state = random_state or np.random.RandomState()
# Initialize shuffling buffer.
index = [*range(buffer_size)]
buffer = [*islice(it, buffer_size)]
# In general, we draw a sample randomly and fulfill buffer until iteratle
# is alive.
for el in it:
ix = random_state.choice(index)
yield buffer[ix]
buffer[ix] = el
# Shuffle rest of buffer.
for ix in random_state.permutation(len(buffer)):
yield buffer[ix]
def take(it: Iterable, size: int):
yield from islice(it, size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment