Skip to content

Instantly share code, notes, and snippets.

@andyljones
Last active October 28, 2022 09:28
Show Gist options
  • Save andyljones/8886f903d9ae82b108992dc47978ef20 to your computer and use it in GitHub Desktop.
Save andyljones/8886f903d9ae82b108992dc47978ef20 to your computer and use it in GitHub Desktop.
Demonstration of the MPS hang bug
"""This script should be run on a machine with at least 2 GPUs and an MPS server running. You can launch an MPS daemon with
```
nvidia-cuda-mps-control -d
```
The script first uses `test_cuda` to verify a CUDA context can be created on each GPU. It then spawns two workers; a
'good' worker and a 'bad' worker. The workers collaborate through Pytorch's DataDistributedParallel module to calculate
the gradient for a trivial computation. The 'good' worker carries out both the forward and backward pass, while the
bad worker carries out the forward pass and then exits. This seems to lock up the MPS server, and any subsequent
attempts to create CUDA contexts fail by hanging eternally.
```
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1
OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: Could not collect
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
Nvidia driver version: 440.26
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.18.1
[pip] torch==1.4.0
[pip] torchfile==0.1.0
[pip] torchvision==0.5.0
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.0.15 py37ha843d7b_0
[conda] mkl_random 1.1.0 py37hd6b4f25_0
[conda] pytorch 1.4.0 py3.7_cuda10.1.243_cudnn7.6.3_0 pytorch
[conda] torchfile 0.1.0 pypi_0 pypi
[conda] torchvision 0.5.0 py37_cu101 pytorch
```
"""
import os
import torch
import time
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import distributed as dist
from torch import multiprocessing as mp
from torch import nn
def _test_cuda(device):
import torch
torch.tensor([0]).to(device)
def test_cuda():
"""Tests whether creating a CUDA context hangs on any device by creating a subprocess for each
device and having them call the minimal `_test_cuda` function. If a subprocess doesn't
terminate in 10s, the corresponding device has hung."""
procs = []
for device in range(torch.cuda.device_count()):
proc = mp.Process(target=_test_cuda, args=(device,))
proc.start()
procs.append(proc)
for _ in range(10):
status = ['checking' if p.is_alive() else 'ready' for p in procs]
if all(s == 'ready' for s in status):
print(f"CUDA ready")
return
time.sleep(1)
else:
raise ValueError(f"One of the CUDA checks overrun: {', '.join(status)}")
def _setup(rank):
"""Create a trivial network and batch on each device, preparing to go forward/backward with
them via DDP"""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
torch.cuda.set_device(rank)
dist.init_process_group("nccl", rank=rank, world_size=2)
net = DDP(nn.Linear(1, 1).cuda(), device_ids=[rank])
batch = torch.zeros(1).cuda()
return net, batch
def bad_worker():
"""This worker goes forward and then exits, leaving the good worker hanging"""
print('Bad worker starting')
net, batch = _setup(0)
print('Bad worker set up')
loss = net(batch).sum()
print('Bad worker finished')
def good_worker():
"""This worker goes forward and then tries to go backwards. It'll be left hanging by the bad worker"""
print('Good worker starting')
net, batch = _setup(1)
print('Good worker set up')
loss = net(batch).sum()
loss.backward()
print('Good worker finished')
def run():
test_cuda()
bad = mp.Process(target=bad_worker)
good = mp.Process(target=good_worker)
bad.start()
good.start()
bad.join()
print('Bad joined')
good.join()
print('Good joined')
test_cuda()
if __name__ == '__main__':
mp.set_start_method('forkserver')
run()
@dajianderichang
Copy link

Dude, is there a solution?

@andyljones
Copy link
Author

Issue thread is here

@dajianderichang
Copy link

问题线程在这里

thanks~~~~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment