Skip to content

Instantly share code, notes, and snippets.

@giacaglia
Created December 9, 2019 01:15
Show Gist options
  • Save giacaglia/c6154f422ff756f1a1ebd6a3e68fe149 to your computer and use it in GitHub Desktop.
Save giacaglia/c6154f422ff756f1a1ebd6a3e68fe149 to your computer and use it in GitHub Desktop.
def train(gpu, args):
############################################################
rank = args.nr * args.gpus + gpu
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=args.world_size,
rank=rank
)
############################################################
torch.manual_seed(0)
model = ConvNet()
torch.cuda.set_device(gpu)
model.cuda(gpu)
batch_size = 100
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(gpu)
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
###############################################################
# Wrap the model
model = nn.parallel.DistributedDataParallel(model,
device_ids=[gpu])
###############################################################
# Data loading code
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
transform=transforms.ToTensor(),
download=True
)
################################################################
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=args.world_size,
rank=rank
)
################################################################
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
##############################
shuffle=False, #
##############################
num_workers=0,
pin_memory=True,
#############################
sampler=train_sampler) #
#############################
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment