Skip to content

Instantly share code, notes, and snippets.

@nelhage
Last active January 22, 2022 01:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nelhage/c5f9b2831014963a4b35d2e0311f93c7 to your computer and use it in GitHub Desktop.
Save nelhage/c5f9b2831014963a4b35d2e0311f93c7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import os
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
INTERVAL = 1
COMM_SIZE = (10,)
def run(rank, size):
torch.cuda.set_device(rank)
pg = torch.distributed.new_group(list(range(size)), backend="nccl")
if rank == 0:
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
dist.barrier()
torch.cuda.synchronize()
outputs = torch.zeros((size, *COMM_SIZE), dtype=torch.float, device="cuda")
mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda")
# Do an allgather to warm up comms. Somehow the first
# all-gather we do isn't actually async and waits for the comm
# to complete.
pg._allgather_base(outputs, mine).wait()
with torch.cuda.stream(s1):
# Allocate a tensor whose backing block comes from stream
# `s1`.
mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda")
# When we execute the _allgather_base,
# ProcessGroupNCCL::collective calls `recordStream` to
# record `mine` as having been used on the NCCL comms
# stream.
handle = pg._allgather_base(outputs, mine)
# Now we free `mine`. This ends up in
# DeviceCachingAllocator::free, which notices that the
# block has non-empty stream uses, and queues an event on
# the NCCL comms stream.
#
# Note that the bug wouldn't show up with point-to-point
# comms, because they hold on to their input or output
# tensors in WorkNCCL::outputs_, and so the tensor would
# not actually be freed her.
mine = None
print("[0] Queued the receive.")
t = time.time()
# Now we do some concurrent work while the comms happen in the
# background.
while not handle.is_completed():
# We allocate a tensor, and then we `record_stream` to
# make the allocator record it as having stream_uses. This
# is the simplest demo for a reproducer; in real code this
# can happen in autograd, by other comms, or a handful of
# other ways.
data = torch.randn((1024,), device="cuda")
data.record_stream(s2)
# Now we free `data`. Since it has `stream_uses`, the
# allocator enqueues an event and marks the underlying
# buffer for later free.
#
# However, `process_events` will walk the event list in
# order, and stop at the first event which isn't
# ready. Since we queued and event on the NCCL comms up
# above, it will always stop there, and no memory will be
# released until the comms complete.
data = None
now = time.time()
if (now - t) > INTERVAL:
# Dump memory stats every second
t = now
print(torch.cuda.memory_summary(abbreviated=True))
handle.wait()
else:
dist.barrier()
outputs = torch.zeros((size, *COMM_SIZE), dtype=torch.float, device="cuda")
mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda")
pg._allgather_base(outputs, mine)
# On rank 1, we just sleep 10s and then do an all-gather, to
# achieve the effect of a long-running op on the NCCL stream
# in rank 0.
print("[1] Sleeping...")
time.sleep(10)
pg._allgather_base(outputs, mine)
print("[1] Sent a tensor")
def init_process(rank, size, fn, backend="nccl"):
""" Initialize the distributed environment. """
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
size = 2
processes = []
mp.set_start_method("spawn")
for rank in range(size):
p = mp.Process(target=init_process, args=(rank, size, run))
p.start()
processes.append(p)
for p in processes:
p.join()
#!/usr/bin/env python
import os
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
INTERVAL = 1
COMM_SIZE = (10,)
def run(rank, size):
torch.cuda.set_device(rank)
if rank == 0:
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
dist.barrier()
torch.cuda.synchronize()
buf = torch.empty(COMM_SIZE, dtype=torch.float, device="cuda")
# Do a warmup comms; The first comm seems to block until
# completion whether or not we do it async.
dist.irecv(buf, src=1).wait()
with torch.cuda.stream(s1):
# Allocate a tensor whose backing block comes from stream
# `s1`.
buf = torch.empty(COMM_SIZE, dtype=torch.float, device="cuda")
# Now use it in a NCCL
# comm. ProcessGroupNCCL::pointToPoint will call
# `recordStream` to record `buf` as having been used on
# the NCCL comms stream.
#
# This comm will be fast since rank 0 sends promptly.
dist.irecv(buf, src=1).wait()
# Now we start a long-running comm. Rank 0 will sleep
# before sending this tensor, so this results in a
# long-running op on the NCCL CUDA stream.
handle = dist.irecv(
torch.empty(COMM_SIZE, dtype=torch.float, device="cuda"), src=1
)
# Now we free `buf`. This eventually ends up in
# DeviceCachingAllocator::free; it notices that the block
# has non-empty stream uses, and so queues an event on the
# NCCL comms stream to make sure the tensor is actually
# done being used before it is actually released to CUDA.
buf = None
print("[0] Queued the receive.")
t = time.time()
# Now we do some concurrent work while the comms happen in the
# background.
while not handle.is_completed():
# We allocate a tensor, and then we `record_stream` to
# make the allocator record it as having stream_uses. This
# is the simplest demo for a reproducer; in real code this
# can happen in autograd, by other comms, or a handful of
# other ways.
data = torch.randn((1024,), device="cuda")
data.record_stream(s2)
# Now we free `data`. Since it has `stream_uses`, the
# allocator enqueues an event and marks the underlying
# buffer for later free.
#
# However, `process_events` will walk the event list in
# order, and stop at the first event which isn't
# ready. Since we queued an event on the NCCL comms up
# above, it will always stop there, and no memory will be
# released until the comms complete.
data = None
now = time.time()
if (now - t) > INTERVAL:
# Dump memory stats every second
t = now
print(torch.cuda.memory_summary(abbreviated=True))
handle.wait()
else:
dist.barrier()
buf = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda")
# One warm-up send
dist.isend(buf, dst=0).wait()
# One send for the first (fast) `irecv` in the rank 0
dist.isend(buf, dst=0).wait()
# Now we sleep 10 and then do a final isend, to cause the
# final `irecv` in rank 0 to be long-running.
print("[1] Sleeping...")
time.sleep(10)
dist.isend(buf, dst=0).wait()
print("[1] Sent a tensor")
def init_process(rank, size, fn, backend="nccl"):
""" Initialize the distributed environment. """
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
size = 2
processes = []
mp.set_start_method("spawn")
for rank in range(size):
p = mp.Process(target=init_process, args=(rank, size, run))
p.start()
processes.append(p)
for p in processes:
p.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment