Skip to content

Instantly share code, notes, and snippets.

@harpone
Created January 15, 2021 12:05
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save harpone/3b6003c22295a50cbd3d2cfc566dc115 to your computer and use it in GitHub Desktop.
Save harpone/3b6003c22295a50cbd3d2cfc566dc115 to your computer and use it in GitHub Desktop.
Test Webdataset with torch-xla multiprocessing distributed setting
from itertools import islice
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import torch_xla.distributed.parallel_loader as pl
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import webdataset as wds
"""
This is a toy implementation of webdataset for torch-xla.
I was suspecting that the pl.MpDeviceLoader was somehow not splitting examples across cores and workers, but
it seems to be fine (test by setting `split_per_core` and `split_per_worker` to True/False).
"""
batch_size = 5
num_iters = 64
dataset_length = 8 * 100
tpuip = '10.44.70.138' # set this to your TPU's IP
num_cores = 8
num_workers = 4
split_per_core = True
split_per_worker = True
# This is just the webdataset default OpenImages dataset:
urls = "http://storage.googleapis.com/nvdata-openimages/openimages-train-{000000..000554}.tar"
urls = f"pipe:curl -L -s {urls} || true"
# Setting up:
os.environ['OMP_NUM_THREADS'] = '1' # good to have for webdataset
os.environ["XRT_TPU_CONFIG"] = f"tpu_worker;0;{tpuip}:8470"
def identity(x):
return x
def main(device_idx):
print(f'Starting process {device_idx}.')
device = xm.xla_device()
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
preproc = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
dataset = (
wds.Dataset(urls, length=dataset_length)
.decode("pil")
.to_tuple("jpg;png", "json")
.map_tuple(preproc, identity)
.batched(batch_size)
)
def shard_selection(urls_):
"""Split urls correctly per accelerator.
:param urls_:
:return: slice of urls_
"""
urls_this = urls_[device_idx::num_cores]
return urls_this
def shard_shuffle(urls_): # not really a *shuffle*...
"""Split urls correctly per worker.
:param urls_:
:return: slice of urls_
"""
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
return urls_
else:
num_workers_ = worker_info.num_workers
worker_id = worker_info.id
urls_ = urls_[worker_id::num_workers_]
return urls_
if split_per_core:
dataset.shard_selection = shard_selection
if split_per_worker:
dataset.shard_shuffle = shard_shuffle
dataloader = DataLoader(dataset,
num_workers=num_workers,
batch_size=None,
collate_fn=None) # batching, collate done in dataset hence the None:s
device_dataloader = pl.MpDeviceLoader(dataloader, device)
xm.rendezvous('init') # wait until all processes started
ids = list()
for i, sample in enumerate(islice(device_dataloader, 0, num_iters)):
print(f'device={device_idx} :: iter={i}')
ids_this = list()
for example in sample[1]:
try:
id = example[0]['ImageID']
except KeyError:
id = 'None'
ids_this.append(id)
ids += ids_this
print(f'ids={ids_this}')
xm.rendezvous('step')
xm.rendezvous('all_collected')
print(f'Process {device_idx} done.')
# Count unique image ids:
ids_int = list()
for id in ids:
id_int = np.frombuffer(bytes(id, encoding='utf-8'), np.uint8).astype(np.int32) # to int because I want to gather from all devices
ids_int.append(torch.tensor(id_int))
ids_int = torch.stack(ids_int).to(device) # shape [batch_size * num_iters, 16]
# Check that all ids are unique:
xm.rendezvous('before_gather')
ids_int_all = xm.all_gather(ids_int, 0)
xm.rendezvous('after_gather')
ids_unique = torch.unique(ids_int_all, dim=0)
if len(ids_unique) == len(ids_int_all):
xm.master_print('All examples unique.')
else:
xm.master_print('Non-unique examples detected:')
xm.master_print(f'Num examples={len(ids_int_all)}')
xm.master_print(f'Num uniques={len(ids_unique)}')
raise KeyboardInterrupt
if __name__ == '__main__':
xmp.spawn(main, args=(), nprocs=num_cores, start_method='fork')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment