Skip to content

Instantly share code, notes, and snippets.

@jayelm
Created October 22, 2021 20:27
Show Gist options
  • Save jayelm/b82790acad87b4de87b003cac49e6357 to your computer and use it in GitHub Desktop.
Save jayelm/b82790acad87b4de87b003cac49e6357 to your computer and use it in GitHub Desktop.
Testing multi-process read/writes to shared memory tensor
import torch
from torch import multiprocessing as mp
import numpy as np
import time
ctx = mp.get_context("fork")
shared_arr = torch.zeros(10, dtype=torch.float32).share_memory_()
procs = []
def write(proc, arr):
np.random.seed(proc)
while True:
i = np.random.randint(10)
val = np.random.random()
arr[i] = val
now = time.time()
try:
for i in range(16):
proc = ctx.Process(target=write, args=(i, shared_arr))
proc.start()
procs.append(proc)
while True:
i = np.random.randint(10)
if not (0 <= shared_arr[i] <= 1):
elapsed = time.time() - now
raise RuntimeError(f"Got value {shared_arr[i]} at position {i} after {elapsed:f}s. Full arr: {shared_arr}")
except KeyboardInterrupt:
pass
finally:
for proc in procs:
proc.join(timeout=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment