Skip to content

Instantly share code, notes, and snippets.

@wookim3
Created October 19, 2020 14:31
Show Gist options
  • Save wookim3/82097d9d0f9b014998fe3c6d739db4e8 to your computer and use it in GitHub Desktop.
Save wookim3/82097d9d0f9b014998fe3c6d739db4e8 to your computer and use it in GitHub Desktop.
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
rref = rpc.remote("worker1", torch.add,
args=(t1, t2))
ddp_model = DDP(my_model)
# Setup optimizer
optimizer_params = [rref]
for param in ddp_model.parameters():
optimizer_params.append(RRef(param))
dist_optim = DistributedOptimizer(
optim.SGD,
optimizer_params,
lr=0.05,
)
with dist_autograd.context() as context_id:
pred = ddp_model(rref.to_here())
loss = loss_func(pred, loss)
dist_autograd.backward(context_id, loss)
dist_optim.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment