Skip to content

Instantly share code, notes, and snippets.

@rish-16
Created May 29, 2021 06:20
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save rish-16/0fb48446dc9dd29f533901ec15d48528 to your computer and use it in GitHub Desktop.
Save rish-16/0fb48446dc9dd29f533901ec15d48528 to your computer and use it in GitHub Desktop.
A guide on Colab TPU training using PyTorch XLA (Part 6)
'''
num_replicas is the total number of times we'll replicate
the batch samples for all cores.
'''
train_sampler = torch.utils.data.distributed.DistributedSampler(
im_train,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True
)
test_sampler = torch.utils.data.distributed.DistributedSampler(
im_test,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False
)
# ignore batch_size and num_workers for now
train_loader = torch.utils.data.DataLoader(
im_train,
batch_size=flags['batch_size'],
sampler=train_sampler,
num_workers=flags['num_workers'],
drop_last=True
)
test_loader = torch.utils.data.DataLoader(
im_test,
batch_size=flags['batch_size'],
sampler=test_sampler,
num_workers=flags['num_workers'],
drop_last=True
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment