Last active
October 30, 2023 21:06
-
-
Save woshiyyya/620a8127d4787f8ae9e9dd45fccc4b67 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This script is tested with the PR(https://github.com/ray-project/ray/pull/39130) from AWS team. | |
# It configures the required environment variables for Neuron XLA. | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.xla_backend # noqa: F401 | |
from ray.train import ScalingConfig | |
from ray.train.torch import TorchTrainer, prepare_model | |
from ray.train.torch.xla import TorchXLAConfig | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
class Model(nn.Module): | |
def __init__(self): | |
super(Model, self).__init__() | |
self.net1 = nn.Linear(10, 128) | |
self.relu = nn.ReLU() | |
self.net2 = nn.Linear(128, 5) | |
def forward(self, x): | |
return self.net2(self.relu(self.net1(x))) | |
def train_func(): | |
device = xm.xla_device() | |
rank = xm.get_ordinal() | |
# Create the model and move to device | |
model = Model().to(device) | |
# ddp_model = prepare_model( | |
# model, | |
# move_to_device=False, | |
# parallel_strategy_kwargs={"gradient_as_bucket_view": True}, | |
# ) | |
ddp_model = DDP(model, gradient_as_bucket_view=True) | |
loss_fn = nn.MSELoss() | |
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) | |
for step in range(50): | |
optimizer.zero_grad() | |
outputs = ddp_model(torch.randn(20, 10).to(device)) | |
labels = torch.randn(20, 5).to(device) | |
loss = loss_fn(outputs, labels) | |
loss.backward() | |
xm.optimizer_step(optimizer) | |
# optimizer.step() | |
# xm.mark_step() | |
if rank == 0: | |
print(f"Loss after step {step}: {loss.cpu()}") | |
print("Finished Training") | |
# trn1.32xlarge -> 32 neuron_cores, 128 CPU | |
# 2x trn1.32xlarge | |
trainer = TorchTrainer( | |
train_loop_per_worker=train_func, | |
torch_config=TorchXLAConfig(), | |
scaling_config=ScalingConfig( | |
num_workers=64, resources_per_worker={"neuron_cores": 1} | |
), | |
) | |
result = trainer.fit() | |
print(result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment