Skip to content

Instantly share code, notes, and snippets.

@sandeepkumar-skb
Last active April 19, 2023 04:11
Show Gist options
  • Save sandeepkumar-skb/ba2b28ad52dc7e90fba7834df0f15fc1 to your computer and use it in GitHub Desktop.
Save sandeepkumar-skb/ba2b28ad52dc7e90fba7834df0f15fc1 to your computer and use it in GitHub Desktop.
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
import datetime
import os
@record
def main():
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=1800))
rank = dist.get_rank()
local_rank = rank % dist.get_world_size()
torch.cuda.set_device(local_rank)
print("rank: ", torch.cuda.current_device())
var = torch.tensor(1 * rank, device='cuda')
dist.all_reduce(var, op=torch.distributed.ReduceOp.SUM)
print(var)
dist.destroy_process_group()
if __name__ == "__main__":
main()
@sandeepkumar-skb
Copy link
Author

torchrun --nnodes <> --nproc_per_node <> all_reduce.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment