Skip to content

Instantly share code, notes, and snippets.

@sandeepkumar-skb
Created April 19, 2023 21:20
Show Gist options
  • Save sandeepkumar-skb/2ee76e7d7cee667e505600e0694fb665 to your computer and use it in GitHub Desktop.
Save sandeepkumar-skb/2ee76e7d7cee667e505600e0694fb665 to your computer and use it in GitHub Desktop.
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
import datetime
import os
@record
def main():
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=1800))
rank = dist.get_rank()
local_rank = rank % dist.get_world_size()
torch.cuda.set_device(local_rank)
print("rank: ", torch.cuda.current_device())
var = torch.tensor(1 * rank, device='cuda')
var_list = [torch.ones_like(var) for _ in range(dist.get_world_size())]
dist.all_gather(var_list, var, )
print(var_list)
dist.destroy_process_group()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment