Created
June 13, 2024 19:53
-
-
Save tchaton/b738efc773a38c9857ccc050fe1b6b6e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F | |
import lightning as L | |
from pathlib import Path | |
def find_latest_checkpoint(artifacts_dir): | |
# Logic holds only for default TensorBoardLogger | |
versions_dir = Path(artifacts_dir, "lightning_logs") | |
if not versions_dir.is_dir(): | |
return None | |
latest_versions = sorted(versions_dir.glob("version_*")) | |
if not latest_versions: | |
return None | |
latest_version = latest_versions[-1] | |
ckpt_path = Path(latest_version, "checkpoints") | |
if not ckpt_path.is_dir(): | |
return None | |
ckpt_paths = sorted(ckpt_path.glob("*.ckpt")) | |
ckpt_path = ckpt_paths[-1] if ckpt_paths else None | |
print("Resuming training from", ckpt_path) | |
return ckpt_path | |
class LitAutoEncoder(L.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) | |
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) | |
def forward(self, x): | |
embedding = self.encoder(x) | |
return embedding | |
def training_step(self, batch, batch_idx): | |
x, _ = batch | |
x = x.view(x.size(0), -1) | |
z = self.encoder(x) | |
x_hat = self.decoder(z) | |
loss = F.mse_loss(x_hat, x) | |
self.log("train_loss", loss) | |
return loss | |
def on_train_epoch_end(self): | |
if "NUM_RESTARTS" not in os.environ: | |
return | |
# Simulate a failure to trigger automatic fault-tolerance recovery in MMT | |
node_rank = self.trainer.node_rank | |
# If running for the first time, fail on NODE 0 after 2 epochs | |
fail1 = os.environ["NUM_RESTARTS"] == "0" and node_rank == 0 and self.current_epoch > 2 | |
# If running the second time, fail on NODE 1 after 4 epochs | |
fail2 = os.environ["NUM_RESTARTS"] == "1" and node_rank == 1 and self.current_epoch > 4 | |
# If running the third time, fail on NODE 0 after 5 epochs | |
fail3 = os.environ["NUM_RESTARTS"] == "2" and node_rank == 0 and self.current_epoch > 5 | |
if fail1 or fail2 or fail3: | |
assert 0, "Intentional error to simulate failure and recovery" | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |
return optimizer | |
def run(): | |
# This is a shared directory all nodes can access | |
# We should write checkpoints here in case we need to auto-resume | |
artifacts_dir = os.getenv("LIGHTNING_ARTIFACTS_DIR") | |
# Find the last saved checkpoint (if any) | |
ckpt_path = find_latest_checkpoint(artifacts_dir) | |
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) | |
train, val = data.random_split(dataset, [55000, 5000]) | |
autoencoder = LitAutoEncoder() | |
trainer = L.Trainer(default_root_dir=artifacts_dir) | |
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val), ckpt_path=ckpt_path) | |
if __name__ == "__main__": | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment