-
-
Save wangkuiyi/e8693c2ee145b1ec2881ec28f33659ad to your computer and use it in GitHub Desktop.
The simplest distributed PyTorch example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# To run this program, type the following commands in a terminal. | |
# | |
# MASTER_ADDR=localhost MASTER_PORT=12306 RANK=0 WORLD_SIZE=2 python3 s.py & | |
# MASTER_ADDR=localhost MASTER_PORT=12306 RANK=1 WORLD_SIZE=2 python3 s.py & | |
import os | |
import sys | |
import time | |
import logging | |
import torch | |
import torch.distributed as dist | |
log = logging.getLogger(__name__) | |
log.setLevel(logging.INFO) | |
log.addHandler(logging.StreamHandler(stream=sys.stderr)) | |
def compute_distributed_sum(): | |
rank = int(os.environ["RANK"]) | |
world_size = int(os.environ["WORLD_SIZE"]) | |
log.info( | |
f"Executing distributed compute sum with rank: {rank}, world_size: {world_size}" | |
) | |
dist.init_process_group(backend="gloo") | |
tensor = torch.tensor(rank + 10) | |
dist.all_reduce(tensor, op=dist.ReduceOp.SUM) | |
res = tensor.item() | |
expected_sum = sum(range(0, world_size)) | |
log.info( | |
f"Distributed compute sum with rank: {rank}, world_size: {world_size}, got result: {res}" | |
) | |
return res | |
def main(): | |
compute_distributed_sum() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment