Skip to content

Instantly share code, notes, and snippets.

@jeffra
Created February 8, 2021 18:28
Show Gist options
  • Save jeffra/a47c47519afbdbe9c7b2758ed873df88 to your computer and use it in GitHub Desktop.
Save jeffra/a47c47519afbdbe9c7b2758ed873df88 to your computer and use it in GitHub Desktop.
MP4 SHARP bug
import torch
import torch.distributed as dist
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
args = parser.parse_args()
dist.init_process_group(backend='nccl')
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
world_size = dist.get_world_size()
model_parallel_size = 4
_DATA_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP = None
rank = dist.get_rank()
for i in range(model_parallel_size):
ranks = range(i, world_size, model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank % model_parallel_size):
_DATA_PARALLEL_GROUP = group
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size,
(i + 1) * model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
_MODEL_PARALLEL_GROUP = group
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, \
'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP
def get_model_parallel_rank():
"""Return my rank for the model parallel group."""
return torch.distributed.get_rank(group=get_model_parallel_group())
def ag_test():
src_rank = get_model_parallel_rank()
mats = []
for _ in range(dist.get_world_size(get_data_parallel_group())):
mats.append(torch.rand(1,268*1024*1024//dist.get_world_size(get_data_parallel_group()), device=device))
dist.all_gather(mats, mats[dist.get_rank(get_data_parallel_group())], group=get_data_parallel_group())
for _ in range(100):
ag_test()
@jeffra
Copy link
Author

jeffra commented Feb 8, 2021

mp4_sharp_bug.py

On A100, 4 nodes, 32 gpus this causes one of the machines to always reboot.

Reproduced this:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment