Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active October 30, 2023 21:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save woshiyyya/620a8127d4787f8ae9e9dd45fccc4b67 to your computer and use it in GitHub Desktop.
Save woshiyyya/620a8127d4787f8ae9e9dd45fccc4b67 to your computer and use it in GitHub Desktop.
# This script is tested with the PR(https://github.com/ray-project/ray/pull/39130) from AWS team.
# It configures the required environment variables for Neuron XLA.
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend # noqa: F401
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer, prepare_model
from ray.train.torch.xla import TorchXLAConfig
from torch.nn.parallel import DistributedDataParallel as DDP
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.net1 = nn.Linear(10, 128)
self.relu = nn.ReLU()
self.net2 = nn.Linear(128, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def train_func():
device = xm.xla_device()
rank = xm.get_ordinal()
# Create the model and move to device
model = Model().to(device)
# ddp_model = prepare_model(
# model,
# move_to_device=False,
# parallel_strategy_kwargs={"gradient_as_bucket_view": True},
# )
ddp_model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
for step in range(50):
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10).to(device))
labels = torch.randn(20, 5).to(device)
loss = loss_fn(outputs, labels)
loss.backward()
xm.optimizer_step(optimizer)
# optimizer.step()
# xm.mark_step()
if rank == 0:
print(f"Loss after step {step}: {loss.cpu()}")
print("Finished Training")
# trn1.32xlarge -> 32 neuron_cores, 128 CPU
# 2x trn1.32xlarge
trainer = TorchTrainer(
train_loop_per_worker=train_func,
torch_config=TorchXLAConfig(),
scaling_config=ScalingConfig(
num_workers=64, resources_per_worker={"neuron_cores": 1}
),
)
result = trainer.fit()
print(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment