Skip to content

Instantly share code, notes, and snippets.

@sathyanarays
Created May 17, 2024 11:38
Show Gist options
  • Save sathyanarays/c1cad159bc5f438b17431b6d75019de8 to your computer and use it in GitHub Desktop.
Save sathyanarays/c1cad159bc5f438b17431b6d75019de8 to your computer and use it in GitHub Desktop.
Kuberay Resilient Training Script
apiVersion: ray.io/v1
kind: RayJob
metadata:
name: rayjob-sample
spec:
#submitterConfig:
# backoffLimit: 20
# submissionMode specifies how RayJob submits the Ray job to the RayCluster.
# The default value is "K8sJobMode", meaning RayJob will submit the Ray job via a submitter Kubernetes Job.
# The alternative value is "HTTPMode", indicating that KubeRay will submit the Ray job by sending an HTTP request to the RayCluster.
# submissionMode: "K8sJobMode"
entrypoint: python /home/ray/samples/sample_code.py
# shutdownAfterJobFinishes specifies whether the RayCluster should be deleted after the RayJob finishes. Default is false.
# shutdownAfterJobFinishes: false
# ttlSecondsAfterFinished specifies the number of seconds after which the RayCluster will be deleted after the RayJob finishes.
# ttlSecondsAfterFinished: 10
# activeDeadlineSeconds is the duration in seconds that the RayJob may be active before
# KubeRay actively tries to terminate the RayJob; value must be positive integer.
# activeDeadlineSeconds: 120
# RuntimeEnvYAML represents the runtime environment configuration provided as a multi-line YAML string.
# See https://docs.ray.io/en/latest/ray-core/handling-dependencies.html for details.
# (New in KubeRay version 1.0.)
runtimeEnvYAML: |
pip:
- requests==2.26.0
- pendulum==2.1.2
- torch
- torchvision
env_vars:
counter_name: "test_counter"
# Suspend specifies whether the RayJob controller should create a RayCluster instance.
# If a job is applied with the suspend field set to true, the RayCluster will not be created and we will wait for the transition to false.
# If the RayCluster is already created, it will be deleted. In the case of transition to false, a new RayCluste rwill be created.
# suspend: false
# rayClusterSpec specifies the RayCluster instance to be created by the RayJob controller.
rayClusterSpec:
rayVersion: '2.9.0' # should match the Ray version in the image of the containers
# Ray head pod template
headGroupSpec:
# The `rayStartParams` are used to configure the `ray start` command.
# See https://github.com/ray-project/kuberay/blob/master/docs/guidance/rayStartParams.md for the default settings of `rayStartParams` in KubeRay.
# See https://docs.ray.io/en/latest/cluster/cli.html#ray-start for all available options in `rayStartParams`.
rayStartParams:
dashboard-host: '0.0.0.0'
num-cpus: "0"
#pod template
template:
spec:
containers:
- name: ray-head
image: rayproject/ray:2.9.0
securityContext:
privileged: true
ports:
- containerPort: 6379
name: gcs-server
- containerPort: 8265 # Ray dashboard
name: dashboard
- containerPort: 10001
name: client
resources:
limits:
cpu: "2"
requests:
cpu: "2"
volumeMounts:
- mountPath: /home/ray/samples
name: code-sample
- mountPath: /home/ray/ray_results
name: ray-results
volumes:
# You set volumes at the Pod level, then mount them into containers inside that Pod
- name: code-sample
configMap:
# Provide the name of the ConfigMap you want to mount.
name: ray-job-code-sample
# An array of keys from the ConfigMap to create as files
items:
- key: sample_code.py
path: sample_code.py
- name: ray-results
hostPath:
path: /ray_results
workerGroupSpecs:
# the pod replicas in this group typed worker
- replicas: 1
minReplicas: 1
maxReplicas: 5
# logical group name, for this called small-group, also can be functional
groupName: small-group
# The `rayStartParams` are used to configure the `ray start` command.
# See https://github.com/ray-project/kuberay/blob/master/docs/guidance/rayStartParams.md for the default settings of `rayStartParams` in KubeRay.
# See https://docs.ray.io/en/latest/cluster/cli.html#ray-start for all available options in `rayStartParams`.
rayStartParams: {}
#pod template
template:
spec:
volumes:
- name: ray-results
hostPath:
path: /ray_results
containers:
- name: ray-worker # must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc'
image: rayproject/ray:2.9.0
securityContext:
privileged: true
lifecycle:
preStop:
exec:
command: [ "/bin/sh","-c","ray stop" ]
volumeMounts:
- mountPath: /home/ray/ray_results
name: ray-results
resources:
limits:
cpu: "4"
requests:
cpu: "4"
# SubmitterPodTemplate is the template for the pod that will run the `ray job submit` command against the RayCluster.
# If SubmitterPodTemplate is specified, the first container is assumed to be the submitter container.
# submitterPodTemplate:
# spec:
# restartPolicy: Never
# containers:
# - name: my-custom-rayjob-submitter-pod
# image: rayproject/ray:2.9.0
# # If Command is not specified, the correct command will be supplied at runtime using the RayJob spec `entrypoint` field.
# # Specifying Command is not recommended.
# # command: ["sh", "-c", "ray job submit --address=http://$RAY_DASHBOARD_ADDRESS --submission-id=$RAY_JOB_SUBMISSION_ID -- echo hello world"]
######################Ray code sample#################################
# this sample is from https://docs.ray.io/en/latest/cluster/job-submission.html#quick-start-example
# it is mounted into the container and executed to show the Ray job at work
---
apiVersion: v1
kind: ConfigMap
metadata:
name: ray-job-code-sample
data:
sample_code.py: |
import os
import tempfile
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
import ray.train.torch
def findLatestCheckpoint(dirName, runName):
path = dirName + "/" + runName
if not os.path.exists(path):
return None
items = os.listdir(path)
session_dirs = {}
for item in items:
tpath = path + "/" + item
if os.path.isdir(tpath):
session_dirs[tpath[-19:]] = tpath
keys = list(session_dirs.keys())
keys.sort()
keys.reverse()
for key in keys:
files = os.listdir(session_dirs[key])
checkpointdirs = []
for file in files:
if file.startswith("checkpoint_"):
checkpointdirs.append(file)
if len(checkpointdirs) > 0:
checkpointdirs.sort()
checkpointdirs.reverse()
return session_dirs[key]+"/"+checkpointdirs[0]
return None
def train_func():
# Model, Loss, Optimizer
model = resnet18(num_classes=10)
model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
# [1] Prepare model.
model = ray.train.torch.prepare_model(model)
# model.to("cuda") # This is done by `prepare_model`
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
# Data
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
data_dir = os.path.join(tempfile.gettempdir(), "data")
train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
# [2] Prepare dataloader.
train_loader = ray.train.torch.prepare_data_loader(train_loader)
start_epoch = 0
checkpoint = ray.train.get_checkpoint()
if checkpoint:
print("### Found checkpoint")
with checkpoint.as_directory() as checkpoint_dir:
model_state_dict = torch.load(
os.path.join(checkpoint_dir, "model.pt"),
# map_location=..., # Load onto a different device if needed.
)
model.module.load_state_dict(model_state_dict)
optimizer.load_state_dict(
torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))
)
start_epoch = (
torch.load(os.path.join(checkpoint_dir, "epoch.pt"))["epoch"] + 1
)
print("#### Starting with epoch", start_epoch)
# Training
for epoch in range(start_epoch,10):
if ray.train.get_context().get_world_size() > 1:
train_loader.sampler.set_epoch(epoch)
for images, labels in train_loader:
# This is done by `prepare_data_loader`!
# images, labels = images.to("cuda"), labels.to("cuda")
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# [3] Report metrics and checkpoint.
metrics = {"loss": loss.item(), "epoch": epoch}
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
torch.save(
model.module.state_dict(),
os.path.join(temp_checkpoint_dir, "model.pt")
)
torch.save(
optimizer.state_dict(),
os.path.join(temp_checkpoint_dir, "optimizer.pt"),
)
torch.save(
{"epoch": epoch},
os.path.join(temp_checkpoint_dir, "epoch.pt"),
)
ray.train.report(
metrics,
checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
)
if ray.train.get_context().get_world_rank() == 0:
print(metrics)
# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=False)
checkpt = findLatestCheckpoint("/home/ray/ray_results", "test")
print("#####", checkpt)
if checkpt is not None:
# checkpt = ray.train.Checkpoint("/home/ray/ray_results/TorchTrainer_2024-05-06_10-19-49/TorchTrainer_2b79a_00000_0_2024-05-06_10-19-49/checkpoint_000000/")
# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
train_func,
scaling_config=scaling_config,
# [5a] If running in a multi-node cluster, this is where you
# should configure the run's persistent storage that is accessible
# across all worker nodes.
# run_config=ray.train.RunConfig(storage_path="s3://..."),
resume_from_checkpoint=ray.train.Checkpoint(checkpt),
run_config=ray.train.RunConfig(name="test"),
)
else:
trainer = ray.train.torch.TorchTrainer(
train_func,
scaling_config=scaling_config,
# [5a] If running in a multi-node cluster, this is where you
# should configure the run's persistent storage that is accessible
# across all worker nodes.
# run_config=ray.train.RunConfig(storage_path="s3://..."),
run_config=ray.train.RunConfig(name="test", storage_path="/home/ray/ray_results"),
)
result = trainer.fit()
# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
model = resnet18(num_classes=10)
model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
model.load_state_dict(model_state_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment