Skip to content

Instantly share code, notes, and snippets.

@wookim3
Created October 19, 2020 14:29
Show Gist options
  • Save wookim3/8c94006a51aa4cb678ef546436815918 to your computer and use it in GitHub Desktop.
Save wookim3/8c94006a51aa4cb678ef546436815918 to your computer and use it in GitHub Desktop.
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[rank],
output_device=rank
)
with model.join():
for _ in range(5):
for inp in inputs:
loss = model(inp).sum()
loss.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment