Created
January 9, 2024 01:33
-
-
Save woshiyyya/f0f945a2ed7a0a04d7113328adcbad08 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# pylint: skip-file | |
import os | |
import torch | |
from torch import distributed as dist | |
from torchvision.models import resnet18 | |
from torchvision.datasets import FashionMNIST | |
from torchvision.transforms import ToTensor, Normalize, Compose | |
from torch.utils.data import DataLoader, DistributedSampler | |
import pytorch_lightning as pl | |
from ray.train.torch import TorchTrainer | |
from ray.train import ScalingConfig, RunConfig | |
import ray.train.lightning | |
# Model, Loss, Optimizer | |
class ImageClassifier(pl.LightningModule): | |
def __init__(self): | |
super(ImageClassifier, self).__init__() | |
self.model = resnet18(num_classes=10) | |
self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | |
self.criterion = torch.nn.CrossEntropyLoss() | |
def forward(self, x): | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
outputs = self.forward(x) | |
loss = self.criterion(outputs, y) | |
self.log("loss", loss, on_step=True, prog_bar=True) | |
return loss | |
def configure_optimizers(self): | |
return torch.optim.Adam(self.model.parameters(), lr=0.001) | |
nnode = 1 | |
num_gpus_per_node = 8 | |
num_cpus = 20 | |
max_epochs = 10 | |
dataset_repeat = 1000 | |
# Change these two paths yourself | |
data_root = "/mnt/local_storage" | |
storage_path="/mnt/cluster_storage" | |
def train_func(config): | |
# Data | |
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))]) | |
train_data_mnist = FashionMNIST(root=data_root, train=True, download=True, transform=transform) | |
train_data = torch.utils.data.ConcatDataset([train_data_mnist] * dataset_repeat) | |
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=num_cpus) | |
# Training | |
model = ImageClassifier() | |
# [1] Configure PyTorch Lightning Trainer. | |
trainer = pl.Trainer( | |
max_epochs=max_epochs, | |
devices="auto", | |
accelerator="auto", | |
strategy=ray.train.lightning.RayDDPStrategy(find_unused_parameters=False), | |
plugins=[ray.train.lightning.RayLightningEnvironment()], | |
logger=None, | |
benchmark=True | |
) | |
trainer = ray.train.lightning.prepare_trainer(trainer) | |
trainer.fit(model, train_dataloaders=train_dataloader) | |
# [2] Configure scaling and resource requirements. | |
scaling_config = ScalingConfig( | |
num_workers=nnode * num_gpus_per_node, | |
use_gpu=True, | |
resources_per_worker={ | |
"CPU": num_cpus, | |
"GPU": 1, | |
}, | |
placement_strategy="PACK" | |
) | |
run_config = RunConfig( | |
storage_path=storage_path, | |
name="lightning_train_example" | |
) | |
# [3] Launch distributed training job. | |
trainer = TorchTrainer(train_func, scaling_config=scaling_config, run_config=run_config) | |
result = trainer.fit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment