Skip to content

Instantly share code, notes, and snippets.

@heiner
Created August 4, 2023 20:42
Show Gist options
  • Save heiner/7ed5802eb9b3218a418262d542ecb827 to your computer and use it in GitHub Desktop.
Save heiner/7ed5802eb9b3218a418262d542ecb827 to your computer and use it in GitHub Desktop.
"""
Repro for ownership issue with Python stores for torch.distributed.
rm -rf /tmp/store_bug; for i in {0..1}; do python -u store_bug.py $i 2 & done && wait
"""
import os
import datetime
import sys
import time
import fcntl
import pathlib
import torch
import numpy as np
DEBUG = False
POLL_SLEEP_TIME = 0.5
def _sanitize(key):
return "".join([x if x.isalnum() or x in "._-()" else "_" for x in key])
class Store(torch.distributed.Store):
def __init__(self, path, rank, timeout=5):
super().__init__()
self.path = path
self.path.mkdir(exist_ok=True)
self.rank = rank
self._timeout = timeout # self.timeout not writable here.
def set(self, key, value):
if DEBUG:
print("set with rank %i, key %s, value %s" % (self.rank, key, value))
if not isinstance(key, (str, bytes)):
raise AssertionError("Expected set to be called with string key")
if type(value) is not bytes:
raise AssertionError("Expected set to be called with bytes value")
with (self.path / _sanitize(key)).open("bw") as f:
fcntl.flock(f, fcntl.LOCK_EX)
f.truncate()
f.write(value)
def get(self, key):
path = self.path / _sanitize(key)
if DEBUG:
print("get on rank %i, key %s" % (self.rank, key))
# Busy loop.
start = time.time()
while True:
if self._timeout and time.time() - start > self._timeout:
raise RuntimeError(f"Timeout {self._timeout} hit on get({key!r})")
try:
with path.open("br") as f:
fcntl.flock(f, fcntl.LOCK_EX)
value = f.read()
if DEBUG:
print("%i get done, key %s, value %s" % (self.rank, key, repr(value)))
return value
except FileNotFoundError:
time.sleep(POLL_SLEEP_TIME)
continue
def add(self, key, value):
if DEBUG:
print("add with rank %i, key %s, value %s" % (self.rank, key, repr(value)))
# No way to open in RDWR mode and create the file if it doesn't exist ...
with os.fdopen(os.open(self.path / _sanitize(key), os.O_RDWR | os.O_CREAT), "rb+") as f:
fcntl.flock(f, fcntl.LOCK_EX)
result = np.fromfile(f, dtype=np.int64, count=1)
if not result.size:
result = np.zeros([], dtype=np.int64)
result += value
f.truncate()
f.seek(0)
result.tofile(f)
time.sleep(POLL_SLEEP_TIME)
return result.item()
def wait(self, keys, timeout=None):
paths = [self.path / _sanitize(key) for key in keys]
if timeout is not None:
timeout = timeout.total_seconds() # timedelta object.
else:
timeout = self._timeout
if DEBUG:
print("wait with rank %i, keys %s, timeout %s" % (self.rank, keys, timeout))
start = time.time()
while True:
if timeout and time.time() - start > timeout:
raise RuntimeError(f"Timeout {timeout} hit on wait({keys})")
values = []
for path in paths:
try:
with path.open("br") as f:
fcntl.flock(f, fcntl.LOCK_EX)
values.append(f.read())
except FileNotFoundError:
values.append(None)
# This assumes keys once read as non None don't get deleted while we wait.
keys = [key for key, value in zip(keys, values) if value is None]
if not keys:
return
time.sleep(POLL_SLEEP_TIME)
def init(rank, world_size, use_workaround=False, *, _cache={}):
store = Store(pathlib.Path("/tmp/store_bug"), rank)
torch.distributed.init_process_group(
backend="gloo",
world_size=world_size,
rank=rank,
timeout=datetime.timedelta(seconds=3.0),
store=store,
)
if use_workaround:
_cache["store"] = store
def main():
rank = int(sys.argv[1])
world_size = int(sys.argv[2])
use_workaround = sys.argv[3] if len(sys.argv) > 3 else False
print(f"starting rank {rank}/{world_size} with use_workaround={use_workaround}")
init(rank, world_size, use_workaround)
g = torch.distributed.new_group(timeout=datetime.timedelta(seconds=2.0))
payload = torch.tensor([3 * rank + 1], dtype=torch.int64)
torch.distributed.all_reduce(payload, group=g)
print(f"Rank {rank}, sum is {payload}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment