Skip to content

Instantly share code, notes, and snippets.

@AndrewBMartin
Last active February 28, 2019 16:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AndrewBMartin/09e3976c818d99a92b1aa42d8bbe7afb to your computer and use it in GitHub Desktop.
Save AndrewBMartin/09e3976c818d99a92b1aa42d8bbe7afb to your computer and use it in GitHub Desktop.
Example training function with tensorboard
def train(model, train_loader, device, optimizer, log_interval, epoch, globaliter):
"""
Example training function for PyTorch recording to TensorBoard.
"""
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
globaliter += 1
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
predictions = model(data)
loss = F.nll_loss(predictions, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
# This is where I'm recording to Tensorboard
with train_summary_writer.as_default():
tf.summary.scalar('loss', loss.item(), step=globaliter)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment