Skip to content

Instantly share code, notes, and snippets.

@tjyuyao
Created July 19, 2023 14:05
Show Gist options
  • Save tjyuyao/6d737461a1432f643a1e7c51e169736e to your computer and use it in GitHub Desktop.
Save tjyuyao/6d737461a1432f643a1e7c51e169736e to your computer and use it in GitHub Desktop.
pytorch based parallel tqdm loader
def pqdm(func, data, n_jobs=2):
# pytorch based dataloader does not block for the whole results,
# it also accepts locally defined functions.
# These features make it favorable compared to the pqdm library or the standard multiprocessing library.
from torch.utils.data import DataLoader
from tqdm import tqdm
datalen = len(data)
class Dataset:
def __len__(self):
return datalen
def __getitem__(self, i):
return func(*data[i])
dataloader = DataLoader(Dataset(), collate_fn=lambda x: x[0], num_workers=n_jobs)
return tqdm(dataloader, total=datalen, smoothing=0.1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment