Created
October 31, 2018 17:00
-
-
Save EnisBerk/a3e53d4d45c04086d2cf1784ee2d2012 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""run.py:""" | |
#!/usr/bin/env python | |
import os | |
import torch | |
import torch.distributed as dist | |
from torch.multiprocessing import Process | |
def run(rank, size): | |
# tensor = torch.zeros(1) | |
print("cuda:{}".format(rank)) | |
device = torch.device("cuda:{}".format(rank)) | |
mn = torch.ones(10).to(device) | |
dist.all_reduce(mn, op=dist.reduce_op.SUM) | |
# if rank == 0: | |
# tensor += 5 | |
# # Send the tensor to process 1 | |
# dist.send(tensor=tensor, dst=1) | |
# else: | |
# # Receive tensor from process 0 | |
# dist.recv(tensor=tensor, src=0) | |
# print('Rank ', rank, ' has data ', tensor[0]) | |
mn= mn.to("cpu") | |
print(mn[0]) | |
print("done",rank) | |
def init_processes(rank, size, fn, backend='gloo'): | |
""" Initialize the distributed environment. """ | |
print("started init at{}".format(rank)) | |
# we need those so process can talk to each other including over a network | |
os.environ['MASTER_ADDR'] = '127.0.0.1' | |
os.environ['MASTER_PORT'] = '29500' | |
os.environ['WORLD_SIZE'] = str(size) | |
os.environ['RANK'] = str(rank) | |
dist.init_process_group(backend, rank=rank, world_size=size) | |
# dist.init_process_group(backend,init_method='file:///home/ebc327/test/sharedfile',rank=rank,world_size=size) | |
print("end init at{}".format(rank)) | |
# this our function that a process will run after initilisation | |
fn(rank, size) | |
print("fn called at{}".format(rank)) | |
size = 2 | |
processes = [] | |
for rank in range(size): | |
#here we create a process, | |
#target is the function that our new process will call | |
p = Process(target=init_processes, args=(rank, size, run)) | |
p.start() | |
processes.append(p) | |
for p in processes: | |
p.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment