Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active October 23, 2023 21:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save woshiyyya/01c17fb76a52a9cea8284d69c98e06f4 to your computer and use it in GitHub Desktop.
Save woshiyyya/01c17fb76a52a9cea8284d69c98e06f4 to your computer and use it in GitHub Desktop.
Torch_DDP_Example
import os
import tempfile
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel
import ray
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
# If using GPUs, set this to True.
use_gpu = True
# Number of processes to run training on.
num_workers = 8
hidden_dim = 32000
# Define your network structure.
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.layer1 = nn.Linear(1, hidden_dim)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(hidden_dim, 1)
def forward(self, input):
return self.layer2(self.relu(self.layer1(input)))
# Training loop.
def train_loop_per_worker(config):
# Read configurations.
lr = config["lr"]
batch_size = config["batch_size"]
num_epochs = config["num_epochs"]
# Fetch training dataset.
train_dataset_shard = ray.train.get_dataset_shard("train")
# Instantiate and prepare model for training.
model = NeuralNetwork()
model = ray.train.torch.prepare_model(model)
# Define loss and optimizer.
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# Create data loader.
dataloader = train_dataset_shard.iter_torch_batches(
batch_size=batch_size, dtypes=torch.float
)
# Train multiple epochs.
for epoch in range(num_epochs):
# Train epoch.
for batch in dataloader:
output = model(batch["input"])
loss = loss_fn(output, batch["label"])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Create checkpoint.
base_model = (model.module
if isinstance(model, DistributedDataParallel) else model)
checkpoint_dir = tempfile.mkdtemp()
torch.save(
{"model_state_dict": base_model.state_dict()},
os.path.join(checkpoint_dir, "model.pt"),
)
checkpoint = Checkpoint.from_directory(checkpoint_dir)
# Report metrics and checkpoint.
ray.train.report({"loss": loss.item()}, checkpoint=checkpoint)
# $ANYSCALE_ARTIFACT_STORAGE
storage_path = "s3://anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/org_7c1Kalm9WcX2bNIjW53GUT/cld_kvedZWag2qA8i5BjxUevf5i7/artifact_storage/test"
# Define configurations.
train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1), storage_path=storage_path)
# Define datasets.
train_dataset = ray.data.from_items(
[{"input": [x], "label": [2 * x + 1]} for x in range(2000)]
)
datasets = {"train": train_dataset}
# Initialize the Trainer.
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
scaling_config=scaling_config,
run_config=run_config,
datasets=datasets
)
# Train the model.
result = trainer.fit()
# Inspect the results.
final_loss = result.metrics["loss"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment