Skip to content

Instantly share code, notes, and snippets.

@jessfraz
Created September 1, 2021 15:24
Show Gist options
  • Save jessfraz/1edaada0aadeb25b9561b9b7bc895b2d to your computer and use it in GitHub Desktop.
Save jessfraz/1edaada0aadeb25b9561b9b7bc895b2d to your computer and use it in GitHub Desktop.
Patch for running https://github.com/karpathy/minGPT on Google's TPU.
diff --git a/mingpt/trainer.py b/mingpt/trainer.py
index 0ac491b..e246a10 100644
--- a/mingpt/trainer.py
+++ b/mingpt/trainer.py
@@ -36,6 +36,8 @@ class TrainerConfig:
for k,v in kwargs.items():
setattr(self, k, v)
+import torch_xla.core.xla_model as xm
+
class Trainer:
def __init__(self, model, train_dataset, test_dataset, config):
@@ -45,10 +47,8 @@ class Trainer:
self.config = config
# take over whatever gpus are on the system
- self.device = 'cpu'
- if torch.cuda.is_available():
- self.device = torch.cuda.current_device()
- self.model = torch.nn.DataParallel(self.model).to(self.device)
+ self.device = xm.xla_device()
+ self.model = torch.nn.DataParallel(self.model).to(self.device)
def save_checkpoint(self):
# DataParallel wrappers keep raw model object in .module attribute
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment