Skip to content

Instantly share code, notes, and snippets.

@stas00
Forked from jeffra/mp4_sharp_bug.py
Last active February 24, 2022 21:19
Show Gist options
  • Save stas00/4824504176699bcc1009bed16d2b27ca to your computer and use it in GitHub Desktop.
Save stas00/4824504176699bcc1009bed16d2b27ca to your computer and use it in GitHub Desktop.
MP4 SHARP bug (edited to support modern launcher and added some status printing to make it easier to see what's going on)
import torch
import torch.distributed as dist
import os
local_rank = int(os.environ["LOCAL_RANK"])
dist.init_process_group(backend='nccl')
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
world_size = dist.get_world_size()
model_parallel_size = 4
_DATA_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP = None
rank = dist.get_rank()
for i in range(model_parallel_size):
ranks = range(i, world_size, model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank % model_parallel_size):
_DATA_PARALLEL_GROUP = group
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size,
(i + 1) * model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
_MODEL_PARALLEL_GROUP = group
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, \
'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP
def get_model_parallel_rank():
"""Return my rank for the model parallel group."""
return torch.distributed.get_rank(group=get_model_parallel_group())
def ag_test():
src_rank = get_model_parallel_rank()
mats = []
for _ in range(dist.get_world_size(get_data_parallel_group())):
mats.append(torch.rand(1,268*1024*1024//dist.get_world_size(get_data_parallel_group()), device=device))
dist.all_gather(mats, mats[dist.get_rank(get_data_parallel_group())], group=get_data_parallel_group())
for i in range(100):
if rank == 0:
print(f"round {i}")
ag_test()
if rank == 0:
print("Done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment