Skip to content

Instantly share code, notes, and snippets.

@rish-16
Created May 29, 2021 06:31
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save rish-16/5d26c3756fd38fddc1982acec310454b to your computer and use it in GitHub Desktop.
Save rish-16/5d26c3756fd38fddc1982acec310454b to your computer and use it in GitHub Desktop.
A guide on Colab TPU training using PyTorch XLA (Part 7)
device = xm.xla_device()
# define some hyper-params you'd feed into your model
in_channels = ...
random_param = ...
# create model using appropriate hyper-params
net = MyCustomNet(...)
# seat it atop the TPU worker device and switch it to train mode
net = net.to(device).train()
# get the loss function and optimizer – use anything
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=..., betas=(...))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment