Skip to content

Instantly share code, notes, and snippets.

@Delaunay
Last active March 1, 2024 20:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Delaunay/83adde64adeb50a847931e53ad8f2864 to your computer and use it in GitHub Desktop.
Save Delaunay/83adde64adeb50a847931e53ad8f2864 to your computer and use it in GitHub Desktop.
import multiprocessing
from multiprocessing.managers import SharedMemoryManager
import time
import json
SHM_INDEX_IN = -1
SHM_INDEX_OUT = -2
SHM_ON = -3
SHM_SIZE = -4
SHM_MAX = 4
class Observer:
def __init__(self) -> None:
pass
def __call__(self, key, value):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def _worker(buffer, observer_cls, in_lock, out_lock, on_lock, index=0):
size = buffer[SHM_SIZE] # len(buffer) == 84 but size == 20
with on_lock:
buffer[SHM_ON] = 1
with observer_cls() as obserer:
while buffer[SHM_ON] > 0:
# Get the in lock to make sure we a reading a full value
with in_lock:
index_in = buffer[SHM_INDEX_IN]
while buffer[SHM_INDEX_OUT] < index_in:
# read only, no need for lock
counter = buffer[SHM_INDEX_OUT]
idx = (counter % size) * 2
key = buffer[idx]
value = buffer[idx + 1]
obserer(key, value)
# finished reading
# make sure we do an atomic write
with out_lock:
buffer[SHM_INDEX_OUT] = counter + 1
def _preallocate_buffer(size, key_size=256):
template = []
for i in range(size * 2):
template.append(" " * 256) # Key
template.append(int(0)) # Value
template.append(size) # SHM_SIZE -4
template.append(0) # SHM_ON -3
template.append(0) # SHM_INDEX_OUT -2
template.append(0) # SHM_INDEX_IN -1
return template
class NotInitialized(Exception):
pass
class Backpressure(Exception):
pass
class AsyncMetric:
def __init__(self, observer_cls, size=20):
self.smm = SharedMemoryManager()
self.ringbuffer = []
self.worker = None
self.key_size = 256
self.size = size
self.observer_cls = observer_cls
self.out_lock = None
self.in_lock = None
self.worker = None
def __enter__(self):
self.smm.start()
self.ringbuffer = self.smm.ShareableList(
_preallocate_buffer(self.size, self.key_size)
)
self.out_lock = multiprocessing.Lock()
self.in_lock = multiprocessing.Lock()
self._init_worker()
return self
def _init_worker(self):
on_lock = multiprocessing.Lock()
self.worker = multiprocessing.Process(
target=_worker,
args=(self.ringbuffer, self.observer_cls, self.in_lock, self.out_lock, on_lock),
)
self.worker.start()
self._wait_worker_init(on_lock)
def _wait_worker_init(self, on_lock):
while True:
with on_lock:
is_ready = self.ringbuffer[SHM_ON]
if is_ready:
break
def wait(self):
# Wait for worker to catch up
with self.in_lock:
in_pos = self.ringbuffer[SHM_INDEX_IN]
while True:
with self.out_lock:
out_pos = self.ringbuffer[SHM_INDEX_OUT]
if out_pos == in_pos:
break
def __exit__(self, *args):
self.wait()
self.ringbuffer[SHM_ON] = 0
self.worker.join()
return self.smm.__exit__(*args)
def _push_unsafe(self, key, value, counter):
# no need to lock, we are the only one writing to it
idx = (counter % self.size) * 2
self.ringbuffer[idx] = key
self.ringbuffer[idx + 1] = value
# finished writing
# "SHM_INDEX_IN" is read by the worker
# and need to be locked to avoid partial reads
with self.in_lock:
self.ringbuffer[SHM_INDEX_IN] = counter + 1
def push_unsafe(self, key, value):
self._push_unsafe(key, value, self.ringbuffer[SHM_INDEX_IN])
def push(self, key, value):
if self.ringbuffer is None:
raise NotInitialized("Shared memory is not initialized")
# worker could be writing to it
# get the out lock to make sure writing is finished
in_index = self.ringbuffer[SHM_INDEX_IN]
with self.out_lock:
out_index = self.ringbuffer[SHM_INDEX_OUT]
if (out_index + self.size + 1 <= in_index ):
raise Backpressure("Worker is not able to process all those events")
if len(key) > self.key_size:
raise ValueError("Key is bigger than storage")
self._push_unsafe(key, value, in_index)
my_perf_counter = time.perf_counter_ns
class MyObserver(Observer):
def __init__(self) -> None:
self.fp = None
self.acc = {}
self.time_diff = 0
self.count = 0
self.start_time = None
self.end_time = None
def __enter__(self):
self.fp = open("results.txt", "w")
return self
def __exit__(self, *args):
return self.fp.__exit__(*args)
def __call__(self, key, value):
if key == "start_time":
self.acc = {}
# self.time_diff += my_perf_counter() - value
# self.count += 1
self.start_time = value
return
if key == "end_time":
self.end_time = value
# self.time_diff += my_perf_counter() - value
# self.count += 1
total_time_s = (self.end_time - self.start_time) * 1e-9
self.acc["elasped_s"] = total_time_s
# self.acc["latency_s"] = (self.time_diff / self.count) * 1e-9
self.time_diff = 0
self.count = 0
data = json.dumps(self.acc) + "\n"
self.fp.write(data)
self.acc = {}
print(data, end="")
return
self.acc[key] = value
import time
from contextlib import nullcontext
from taranis.core.perfcounter import AsyncMetric, MyObserver, my_perf_counter
import torch
class Nothing():
def __enter__(self):
return self
def __exit__(self, *args):
return
def push(self, *args):
pass
def push_unsafe(self, *args):
pass
ENALBED = False
def stuff():
if not ENALBED:
return Nothing()
return AsyncMetric(MyObserver, 200)
def main():
s = 2048
a = torch.randn((s, s)).cuda()
b = torch.randn((s, s)).cuda()
c = torch.randn((s, s)).cuda()
import timeit
with stuff() as metric:
# 1000000 => 20
# [20.2530603, 20.3786452, 20.1735227, 19.87201780000001, 19.49251009999999]
# print(timeit.timeit('metric.push("start_time", 0)', globals=locals())) # 20
# 10.921498800000002
# print(timeit.timeit('metric.push_unsafe("start_time", 0)', globals=locals())) # 20
for i in range(30):
start = my_perf_counter()
metric.push_unsafe("start_time", start)
metric.push_unsafe("batch_size", s)
for j in range(20):
torch.matmul(a, b, out=c)
torch.cuda.synchronize()
end = my_perf_counter()
metric.push_unsafe("end_time", end)
if not ENALBED:
print((end - start) * 1e-9)
print("DOne")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment