Skip to content

Instantly share code, notes, and snippets.

@cemoody
Created December 16, 2022 21:00
Show Gist options
  • Save cemoody/f8457345c941da693b04675593d13a3c to your computer and use it in GitHub Desktop.
Save cemoody/f8457345c941da693b04675593d13a3c to your computer and use it in GitHub Desktop.
A multiprocess Parquet DataLoader for PyTorch. Great for loading large sequential access datasets. Easy to install, modify, and use.
import multiprocessing
import queue
from loguru import logger
import pandas as pd
def chunks(df, chunk_size=1000):
for i in range(0, len(df), chunk_size):
yield df[i : i + chunk_size]
def parquet_reader(path):
return pd.read_parquet(path)
def worker_fn(
input_queue,
output_queue,
func=None,
batch_size=128,
worker_id=0,
verbose=False,
file_reader=parquet_reader,
):
logger.info(f"starting worker {worker_id}")
output_queue.put(worker_id)
while True:
# Worker function, simply reads indices from index_queue, and adds the
# dataset element to the output_queue
try:
path = input_queue.get(timeout=0)
except queue.Empty:
continue
if path is None:
logger.info(f"worker {worker_id} got stop signal")
break
logger.info(f"worker {worker_id} reading {path}")
fh = file_reader(path)
if verbose:
logger.info(f"worker {worker_id} did read {path}")
for j, chunk in enumerate(chunks(fh, batch_size)):
if verbose:
logger.info(f"worker {worker_id} putting batch {j}")
if func:
output_queue.put(func(chunk))
else:
output_queue.put(chunk)
logger.info(f"worker {worker_id} finished {path}")
output_queue.put(worker_id)
class FileDataLoader:
finished_workers = set()
def __init__(
self,
paths,
batch_size=64,
num_workers=1,
prefetch_batches=100,
transform=None,
):
self.num_workers = num_workers
self.prefetch_batches = prefetch_batches
self.input_queue = multiprocessing.Queue()
self.output_queue = multiprocessing.Queue(maxsize=prefetch_batches)
self.transform = transform
# Start workers
self.workers = []
for i in range(num_workers):
worker = multiprocessing.Process(
target=worker_fn,
args=(self.input_queue, self.output_queue, transform, batch_size, i),
)
worker.daemon = True
worker.start()
self.workers.append(worker)
logger.debug(f"started {num_workers} workers")
# Wait for workers to start and verify they started
started_worker_ids = set()
for _ in range(num_workers):
started_worker_ids.add(self.output_queue.get(timeout=3))
assert started_worker_ids == set(range(num_workers))
logger.debug(f"verified {num_workers} workers")
# Load up input queue
self.paths = paths
for path in paths:
self.input_queue.put(path)
logger.debug(f"loaded {len(paths)} files into queue")
while self.output_queue.empty():
time.sleep(1)
logger.debug(f"discovered first batch of data")
# Load None into queue to signal workers to stop
for _ in range(num_workers):
self.input_queue.put(None)
logger.debug(f"queued stop signals to workers")
def __iter__(self):
return self
def __next__(self):
batch = self.get()
if batch is None:
raise StopIteration
return batch
def is_done(self):
# Data can be in the input, in a worker, or in the output queue
# If the input queue is empty, and all workers are done, and the output
# queue is empty, we're done
return (
(len(self.finished_workers) == self.num_workers)
and self.input_queue.empty()
and self.output_queue.empty()
)
def get(self):
timer = time.time()
while True:
try:
batch = self.output_queue.get(timeout=0)
if isinstance(batch, int):
worker_id = batch
logger.debug(f"dataloader acks worker {worker_id} finished")
self.finished_workers.add(worker_id)
continue
timer = time.time()
return batch
except queue.Empty: # output queue empty, keep trying
pass
if self.is_done():
logger.debug("Loader is done with data")
return None
if time.time() - timer > 10:
logger.debug("Loader has been waiting for data for 10 sec")
logger.debug(f"input_queue: {self.input_queue.empty()}")
logger.debug(f"output_queue: {self.output_queue.empty()}")
logger.debug(f"finished_workers: {self.finished_workers}")
time.sleep(1)
def __del__(self):
try:
for _ in self.workers:
self.input_queue.put(None)
for w in self.workers:
w.join(timeout=5.0)
self.input_queue.cancel_join_thread()
self.input_queue.close()
self.output_queue.cancel_join_thread()
self.output_queue.close()
logger.debug("closed queues")
finally:
for w in self.workers:
if w.is_alive():
w.terminate()
logger.debug("terminated workers")
if __name__ == "__main__":
import time
fns = [
"/Users/chris/Downloads/temp2/4ec3fd44-d1ec-4610-82a6-5b409a796abf.wds_img_vectors.parquet_80a55350-2927-4716-94b9-996a56606bf3.wds_img_vectors.parquet.parquet",
"/Users/chris/Downloads/temp2/3fd02503-663c-4048-830e-eab3f349ff26.wds_img_vectors.parquet_5ae4cc08-5c07-4ed2-aea3-27b2e13dd84b.wds_img_vectors.parquet.parquet",
"/Users/chris/Downloads/temp2/1fcbf028-49b2-4c84-87c1-d0c4622cdaf3.wds_img_vectors.parquet_5ec6a0c3-6771-4b11-8d32-fcde1d6af428.wds_img_vectors.parquet.parquet",
"/Users/chris/Downloads/temp2/4ec3fd44-d1ec-4610-82a6-5b409a796abf.wds_img_vectors.parquet_80a55350-2927-4716-94b9-996a56606bf3.wds_img_vectors.parquet.parquet",
# "/Users/chris/Downloads/temp2/3fd02503-663c-4048-830e-eab3f349ff26.wds_img_vectors.parquet_5ae4cc08-5c07-4ed2-aea3-27b2e13dd84b.wds_img_vectors.parquet.parquet",
# "/Users/chris/Downloads/temp2/1fcbf028-49b2-4c84-87c1-d0c4622cdaf3.wds_img_vectors.parquet_5ec6a0c3-6771-4b11-8d32-fcde1d6af428.wds_img_vectors.parquet.parquet",
# "/Users/chris/Downloads/temp2/4ec3fd44-d1ec-4610-82a6-5b409a796abf.wds_img_vectors.parquet_80a55350-2927-4716-94b9-996a56606bf3.wds_img_vectors.parquet.parquet",
# "/Users/chris/Downloads/temp2/3fd02503-663c-4048-830e-eab3f349ff26.wds_img_vectors.parquet_5ae4cc08-5c07-4ed2-aea3-27b2e13dd84b.wds_img_vectors.parquet.parquet",
# "/Users/chris/Downloads/temp2/1fcbf028-49b2-4c84-87c1-d0c4622cdaf3.wds_img_vectors.parquet_5ec6a0c3-6771-4b11-8d32-fcde1d6af428.wds_img_vectors.parquet.parquet",
]
dl = FileDataLoader(fns, batch_size=2048, num_workers=2)
log_i = 0
for i, batch in enumerate(dl):
if i % 2**log_i == 0:
logger.debug(f"batch {i}")
log_i += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment