Created
October 25, 2023 00:29
-
-
Save woshiyyya/8b649db993b9d311b9183b9161f7c3ec to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.xla_backend # noqa: F401 | |
from ray.train import ScalingConfig | |
from ray.train.torch import TorchTrainer, prepare_model | |
from ray.train.torch.xla import TorchXLAConfig | |
class Model(nn.Module): | |
def __init__(self): | |
super(Model, self).__init__() | |
self.net1 = nn.Linear(10, 128) | |
self.relu = nn.ReLU() | |
self.net2 = nn.Linear(128, 5) | |
def forward(self, x): | |
return self.net2(self.relu(self.net1(x))) | |
def train_func(): | |
device = xm.xla_device() | |
rank = xm.get_ordinal() | |
# Create the model and move to device | |
model = Model().to(device) | |
ddp_model = prepare_model( | |
model, | |
move_to_device=False, | |
parallel_strategy_kwargs={"gradient_as_bucket_view": True}, | |
) | |
loss_fn = nn.MSELoss() | |
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) | |
for step in range(5): | |
optimizer.zero_grad() | |
outputs = ddp_model(torch.randn(20, 10).to(device)) | |
labels = torch.randn(20, 5).to(device) | |
loss = loss_fn(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
xm.mark_step() | |
if rank == 0: | |
print(f"Loss after step {step}: {loss.cpu()}") | |
num_workers = 64 # assume we have 2 trn1.32xlarge instances | |
trainer = TorchTrainer( | |
train_loop_per_worker=train_func, | |
torch_config=TorchXLAConfig(), | |
scaling_config=ScalingConfig( | |
num_workers=num_workers, resources_per_worker={"neuron_cores": 1} | |
), | |
) | |
result = trainer.fit() | |
print(result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment