Skip to content

Instantly share code, notes, and snippets.

@AndrewBMartin
Last active February 28, 2019 16:52
Show Gist options
  • Save AndrewBMartin/6bd45a6bd8fa69aee2bc712e6e67266d to your computer and use it in GitHub Desktop.
Save AndrewBMartin/6bd45a6bd8fa69aee2bc712e6e67266d to your computer and use it in GitHub Desktop.
Example training function with tensorboardcolab
def train(model, train_loader, device, optimizer, log_interval, epoch, globaliter, tb):
"""
Example training function for PyTorch recording to tensorboardcolab.
"""
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
globaliter += 1
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
predictions = model(data)
loss = F.nll_loss(predictions, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
# This is where I'm recording to Tensorboard
tb.save_value('Train Loss', 'train_loss', self.globaliter, loss.item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment