Skip to content

Instantly share code, notes, and snippets.

@rish-16
Created May 29, 2021 06:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rish-16/73d89289743b57cd3dd7c370396deec3 to your computer and use it in GitHub Desktop.
Save rish-16/73d89289743b57cd3dd7c370396deec3 to your computer and use it in GitHub Desktop.
A guide on Colab TPU training using PyTorch XLA (Part 8)
# hlper function to get the testing accuracy at the end of the epoch
def get_test_stats(model, loader):
total_samples = 0
correct = 0
model.eval() # switch to eval mode
for (batch_idx, data) in enumerate(loader, 0):
x, y = data
logits = model(x)
preds = torch.argmax(logits, 1)
correct += torch.eq(y, preds).sum().item()
total_samples += flags['batch_size'] # more on flags later
accuracy = 100.0 * (correct / total_samples)
return accuracy
EPOCHS = 10 # feel free to change
for epoch in range(EPOCHS):
# (optional) calculate the batch-wise loss
running_loss = 0
steps = 0
model.train() # switch to train mode since we will switch to eval mode later
# get the specialised parallel train loader
para_loader_train = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
for (batch_idx, data) in enumerate(para_loader_train, 0):
steps += 1
x, y = data
output = net(x)
loss = criterion(output, y)
optimizer.zero_grad()
loss.backward()
running_loss += loss.item()
xm.optimizer_step(optimizer)
if (i % 20 == 0): # print stuff out to console
xm.master_print('{} | RunningLoss={} | Loss={}'.format(
batch_idx, running_loss / steps, loss.item()),
flush=True
)
xm.master_print("Finished training epoch {}".format(epoch))
# get the specialised parallel test loader
para_loader_test = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
val_accuracy = get_test_stats(model, para_loader_test)
xm.master_print("Validation Accuracy: {}".format(val_accuracy))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment