Skip to content

Instantly share code, notes, and snippets.

@giacaglia
Last active December 22, 2019 00:30
Show Gist options
  • Save giacaglia/f726bd62dd6d27d8d5cc2629d4ef712b to your computer and use it in GitHub Desktop.
Save giacaglia/f726bd62dd6d27d8d5cc2629d4ef712b to your computer and use it in GitHub Desktop.
def train(gpu, args):
torch.manual_seed(0)
model = ConvNet()
model = nn.DataParallel(model)
torch.cuda.set_device(gpu)
model.cuda(gpu)
batch_size = 100
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(gpu)
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
# Data loading code
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
pin_memory=True)
start = datetime.now()
total_step = len(train_loader)
for epoch in range(args.epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 100 == 0 and gpu == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
epoch + 1,
args.epochs,
i + 1,
total_step,
loss.item())
)
if gpu == 0:
print("Training complete in: " + str(datetime.now() - start))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment