Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Created January 9, 2024 01:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save woshiyyya/f0f945a2ed7a0a04d7113328adcbad08 to your computer and use it in GitHub Desktop.
Save woshiyyya/f0f945a2ed7a0a04d7113328adcbad08 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# pylint: skip-file
import os
import torch
from torch import distributed as dist
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader, DistributedSampler
import pytorch_lightning as pl
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, RunConfig
import ray.train.lightning
# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
def __init__(self):
super(ImageClassifier, self).__init__()
self.model = resnet18(num_classes=10)
self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.criterion = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
outputs = self.forward(x)
loss = self.criterion(outputs, y)
self.log("loss", loss, on_step=True, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.001)
nnode = 1
num_gpus_per_node = 8
num_cpus = 20
max_epochs = 10
dataset_repeat = 1000
# Change these two paths yourself
data_root = "/mnt/local_storage"
storage_path="/mnt/cluster_storage"
def train_func(config):
# Data
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
train_data_mnist = FashionMNIST(root=data_root, train=True, download=True, transform=transform)
train_data = torch.utils.data.ConcatDataset([train_data_mnist] * dataset_repeat)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=num_cpus)
# Training
model = ImageClassifier()
# [1] Configure PyTorch Lightning Trainer.
trainer = pl.Trainer(
max_epochs=max_epochs,
devices="auto",
accelerator="auto",
strategy=ray.train.lightning.RayDDPStrategy(find_unused_parameters=False),
plugins=[ray.train.lightning.RayLightningEnvironment()],
logger=None,
benchmark=True
)
trainer = ray.train.lightning.prepare_trainer(trainer)
trainer.fit(model, train_dataloaders=train_dataloader)
# [2] Configure scaling and resource requirements.
scaling_config = ScalingConfig(
num_workers=nnode * num_gpus_per_node,
use_gpu=True,
resources_per_worker={
"CPU": num_cpus,
"GPU": 1,
},
placement_strategy="PACK"
)
run_config = RunConfig(
storage_path=storage_path,
name="lightning_train_example"
)
# [3] Launch distributed training job.
trainer = TorchTrainer(train_func, scaling_config=scaling_config, run_config=run_config)
result = trainer.fit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment