Last active
December 11, 2023 21:18
-
-
Save daskol/b088233435506a959535d0336c389cd3 to your computer and use it in GitHub Desktop.
Benchmarking of OpenWebText parallel loader.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from argparse import ArgumentParser, Namespace | |
from pathlib import Path | |
from random import shuffle | |
from tqdm import tqdm | |
from openwebtext import OpenWebTextLoader, shuffle, take | |
parser = ArgumentParser() | |
parser.add_argument('-j', '--jobs', type=int, default=2) | |
parser.add_argument('data_dir', type=Path) | |
def main(): | |
ns: Namespace = parser.parse_args() | |
subsets = [*sorted(ns.data_dir.glob('*.tar'))] | |
shuffle(subsets) | |
print(f'load {len(subsets)} parts from {ns.data_dir} ' | |
f'with {ns.jobs} workers') | |
loader = OpenWebTextLoader(subsets, num_workers=ns.jobs) | |
samples = iter(loader) | |
samples = shuffle(samples, 1024) | |
samples = take(samples, 500_000) | |
for sample in tqdm(samples, total=500_000, unit='sample'): | |
continue | |
print(*sample[:2]) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
data_dir=subsets | |
num_workers_grid=(1 2 3 4 8 16 20) | |
for num_workers in ${num_workers_grid[@]}; do | |
/usr/bin/time -a -o timings.csv -f "$num_workers,%e,%C" \ | |
python main.py -j $num_workers $data_dir | |
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing as mp | |
from sys import version_info | |
from codecs import StreamReader, getreader | |
from concurrent.futures import Future, ProcessPoolExecutor, as_completed | |
from itertools import islice | |
from pathlib import Path | |
from tarfile import TarFile | |
from typing import Iterable, Optional | |
import numpy as np | |
if version_info < (3, 11): | |
from concurrent.futures import TimeoutError as FutureTimeoutError | |
else: | |
FutureTimeoutError = TimeoutError | |
__all__ = ('OpenWebTextLoader', 'shuffle', 'take') | |
DecodingStreamReader: StreamReader = getreader('utf-8') | |
# Communication primitives between producing workers and consuming parent. | |
queues: list[mp.JoinableQueue] = [] | |
done_event: Optional[mp.Event] = None | |
sema: Optional[mp.Semaphore] = None | |
def iter_subset(path: Path): | |
with TarFile(path) as tar: | |
for i, info in enumerate(tar.getmembers()): | |
arch = tar.extractfile(info) | |
part_name = Path(info.name).name | |
with TarFile.xzopen(info.name, fileobj=arch) as part: | |
for j, entry in enumerate(part): | |
fin = part.extractfile(entry) | |
reader = DecodingStreamReader(fin) | |
yield (i, j), (part_name, entry.name), reader.read() | |
def pool_init(*args): | |
global queues, done_event | |
queues, done_event, *_ = args | |
def _pool_exec(worker_id: int, subset: Path): | |
global queues, done_event | |
queue = queues[worker_id] | |
# Count number of samples put to the queue during iterations. | |
total_samples = 0 | |
samples = iter_subset(subset) | |
while not done_event.is_set(): | |
try: | |
sample = next(samples) | |
except StopIteration: | |
# Push empty message to queue in order to notify consumer that this | |
# queue is over. | |
queue.put(None) | |
break | |
total_samples += 1 | |
queue.put(sample) | |
# Wait until all elements of queue are get on consumer side. | |
queue.join() | |
# Return workder id (= queue index) and a result variable. | |
return worker_id, total_samples | |
def pool_exec(ix: int, subset: str): | |
try: | |
return _pool_exec(ix, subset) | |
except Exception as e: | |
print('pool_exec(): exception;', e) | |
class OpenWebTextLoader: | |
"""multiprocessing loader for OpenWebText corpus. | |
Args: | |
subsets: list of partitions of corpus on filesystem. | |
num_workers: degree of parallelism. | |
prefetch: size of prefetch buffer per worker. | |
random_state: random number generator for merging streams of samples. | |
Example: | |
.. code:: python | |
subsets = [Path('.../urlsf_subset00.tar'), ...] | |
num_workers = 4 | |
for datapoint in OpenWebTextLoader(subset, num_workers): | |
pass | |
""" | |
def __init__(self, subsets: list[Path], num_workers=1, prefetch=2048, | |
random_state=None): | |
if prefetch <= 1: | |
raise ValueError('prefetch buffer should be strictly large then 1') | |
self.subsets = subsets | |
self.num_workers = num_workers | |
self.prefetch = prefetch | |
self.random_state = random_state or np.random.RandomState() | |
self.done_event: Optional[mp.Event] = None | |
self.queues: list[mp.JoinableQueue] = [] | |
self.pool: Optional[ProcessPoolExecutor] = None | |
self.futures: dict[int, Future] = {} | |
def __del__(self): | |
# If there is no pool then there is nothin to close. | |
if self.pool is not None: | |
self.close() | |
self.drain() | |
self.pool.shutdown(cancel_futures=True) | |
def __iter__(self): | |
# TODO: We should split implementation of __iter__ and __next__. | |
if self.pool is not None: | |
return self | |
raise RuntimeError | |
self.done_event = mp.Event() | |
self.queues = [ | |
mp.JoinableQueue(self.prefetch) for _ in range(self.num_workers) | |
] | |
self.sema = mp.Semaphore(0) | |
self.pool = ProcessPoolExecutor( | |
max_workers=self.num_workers, initializer=pool_init, | |
initargs=(self.queues, self.done_event, self.sema)) | |
# Start worker processes. | |
for pid, (_, subset) in enumerate(zip(self.queues, self.subsets)): | |
self.futures[pid] = self.pool.submit(pool_exec, pid, subset) | |
self.subsets = self.subsets[pid + 1:] | |
# Prepare data structures for drawing samples from different queues. | |
self.pids = [*range(len(self.queues))] | |
self.result_queues = dict(enumerate(self.queues)) | |
return self | |
def __next__(self): | |
# Draw samples from queues uniformly. | |
while self.pids: | |
pid = self.random_state.choice(self.pids) | |
if (sample := self.result_queues[pid].get()) is not None: | |
self.result_queues[pid].task_done() | |
return sample | |
elif self.subsets: | |
self.result_queues[pid].task_done() | |
# Wait until worker with id `pid` finished before we start | |
# another one background job. | |
future = self.futures[pid] | |
if (exc := future.exception()) is not None: | |
raise exc | |
# Start another one workder with the same id `pid`. | |
subset, self.subsets = self.subsets[0], self.subsets[1:] | |
self.futures[pid] = self.pool.submit(pool_exec, pid, subset) | |
else: | |
self.result_queues[pid].task_done() | |
self.pids.remove(pid) | |
else: | |
raise StopIteration | |
def close(self): | |
# We have send an exit event and then to drain queues (i.e. read at | |
# least one element from it) in order to guarantee the worker process | |
# will see a done event and exit. | |
self.done_event.set() | |
while self.futures: | |
try: | |
# The issue here is that we cannot to multiplex events from | |
# different sources (switch operator on channels in Golang). | |
# So, we wait for completed future with timeout and if there is | |
# no future completes then we drain queues. | |
# | |
# It's quite simple and native solution and it would better to | |
# start draining tasks in a separate thread. But it is so | |
# tedious in Python without alive event loop... | |
finished = as_completed(self.futures.values(), timeout=1e-3) | |
for future in finished: | |
if (exc := future.exception()) is not None: | |
raise exc | |
pid, _ = future.result() | |
self.futures.pop(pid) | |
except FutureTimeoutError: | |
for pid, _ in self.futures.items(): | |
queue: mp.JoinableQueue = self.queues[pid] | |
while not queue.empty(): | |
queue.get() | |
queue.task_done() | |
def drain(self): | |
for queue in self.queues: | |
while not queue.empty(): | |
queue.get() | |
queue.task_done() | |
def shuffle(it: Iterable, buffer_size=1024, random_state=None): | |
random_state = random_state or np.random.RandomState() | |
# Initialize shuffling buffer. | |
index = [*range(buffer_size)] | |
buffer = [*islice(it, buffer_size)] | |
# In general, we draw a sample randomly and fulfill buffer until iteratle | |
# is alive. | |
for el in it: | |
ix = random_state.choice(index) | |
yield buffer[ix] | |
buffer[ix] = el | |
# Shuffle rest of buffer. | |
for ix in random_state.permutation(len(buffer)): | |
yield buffer[ix] | |
def take(it: Iterable, size: int): | |
yield from islice(it, size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment