Skip to content

Instantly share code, notes, and snippets.

@tchaton
Created June 13, 2024 19:53
Show Gist options
  • Save tchaton/b738efc773a38c9857ccc050fe1b6b6e to your computer and use it in GitHub Desktop.
Save tchaton/b738efc773a38c9857ccc050fe1b6b6e to your computer and use it in GitHub Desktop.
import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F
import lightning as L
from pathlib import Path
def find_latest_checkpoint(artifacts_dir):
# Logic holds only for default TensorBoardLogger
versions_dir = Path(artifacts_dir, "lightning_logs")
if not versions_dir.is_dir():
return None
latest_versions = sorted(versions_dir.glob("version_*"))
if not latest_versions:
return None
latest_version = latest_versions[-1]
ckpt_path = Path(latest_version, "checkpoints")
if not ckpt_path.is_dir():
return None
ckpt_paths = sorted(ckpt_path.glob("*.ckpt"))
ckpt_path = ckpt_paths[-1] if ckpt_paths else None
print("Resuming training from", ckpt_path)
return ckpt_path
class LitAutoEncoder(L.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def forward(self, x):
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def on_train_epoch_end(self):
if "NUM_RESTARTS" not in os.environ:
return
# Simulate a failure to trigger automatic fault-tolerance recovery in MMT
node_rank = self.trainer.node_rank
# If running for the first time, fail on NODE 0 after 2 epochs
fail1 = os.environ["NUM_RESTARTS"] == "0" and node_rank == 0 and self.current_epoch > 2
# If running the second time, fail on NODE 1 after 4 epochs
fail2 = os.environ["NUM_RESTARTS"] == "1" and node_rank == 1 and self.current_epoch > 4
# If running the third time, fail on NODE 0 after 5 epochs
fail3 = os.environ["NUM_RESTARTS"] == "2" and node_rank == 0 and self.current_epoch > 5
if fail1 or fail2 or fail3:
assert 0, "Intentional error to simulate failure and recovery"
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def run():
# This is a shared directory all nodes can access
# We should write checkpoints here in case we need to auto-resume
artifacts_dir = os.getenv("LIGHTNING_ARTIFACTS_DIR")
# Find the last saved checkpoint (if any)
ckpt_path = find_latest_checkpoint(artifacts_dir)
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])
autoencoder = LitAutoEncoder()
trainer = L.Trainer(default_root_dir=artifacts_dir)
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val), ckpt_path=ckpt_path)
if __name__ == "__main__":
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment