Last active
March 1, 2024 20:21
-
-
Save Delaunay/83adde64adeb50a847931e53ad8f2864 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing | |
from multiprocessing.managers import SharedMemoryManager | |
import time | |
import json | |
SHM_INDEX_IN = -1 | |
SHM_INDEX_OUT = -2 | |
SHM_ON = -3 | |
SHM_SIZE = -4 | |
SHM_MAX = 4 | |
class Observer: | |
def __init__(self) -> None: | |
pass | |
def __call__(self, key, value): | |
pass | |
def __enter__(self): | |
return self | |
def __exit__(self, *args): | |
pass | |
def _worker(buffer, observer_cls, in_lock, out_lock, on_lock, index=0): | |
size = buffer[SHM_SIZE] # len(buffer) == 84 but size == 20 | |
with on_lock: | |
buffer[SHM_ON] = 1 | |
with observer_cls() as obserer: | |
while buffer[SHM_ON] > 0: | |
# Get the in lock to make sure we a reading a full value | |
with in_lock: | |
index_in = buffer[SHM_INDEX_IN] | |
while buffer[SHM_INDEX_OUT] < index_in: | |
# read only, no need for lock | |
counter = buffer[SHM_INDEX_OUT] | |
idx = (counter % size) * 2 | |
key = buffer[idx] | |
value = buffer[idx + 1] | |
obserer(key, value) | |
# finished reading | |
# make sure we do an atomic write | |
with out_lock: | |
buffer[SHM_INDEX_OUT] = counter + 1 | |
def _preallocate_buffer(size, key_size=256): | |
template = [] | |
for i in range(size * 2): | |
template.append(" " * 256) # Key | |
template.append(int(0)) # Value | |
template.append(size) # SHM_SIZE -4 | |
template.append(0) # SHM_ON -3 | |
template.append(0) # SHM_INDEX_OUT -2 | |
template.append(0) # SHM_INDEX_IN -1 | |
return template | |
class NotInitialized(Exception): | |
pass | |
class Backpressure(Exception): | |
pass | |
class AsyncMetric: | |
def __init__(self, observer_cls, size=20): | |
self.smm = SharedMemoryManager() | |
self.ringbuffer = [] | |
self.worker = None | |
self.key_size = 256 | |
self.size = size | |
self.observer_cls = observer_cls | |
self.out_lock = None | |
self.in_lock = None | |
self.worker = None | |
def __enter__(self): | |
self.smm.start() | |
self.ringbuffer = self.smm.ShareableList( | |
_preallocate_buffer(self.size, self.key_size) | |
) | |
self.out_lock = multiprocessing.Lock() | |
self.in_lock = multiprocessing.Lock() | |
self._init_worker() | |
return self | |
def _init_worker(self): | |
on_lock = multiprocessing.Lock() | |
self.worker = multiprocessing.Process( | |
target=_worker, | |
args=(self.ringbuffer, self.observer_cls, self.in_lock, self.out_lock, on_lock), | |
) | |
self.worker.start() | |
self._wait_worker_init(on_lock) | |
def _wait_worker_init(self, on_lock): | |
while True: | |
with on_lock: | |
is_ready = self.ringbuffer[SHM_ON] | |
if is_ready: | |
break | |
def wait(self): | |
# Wait for worker to catch up | |
with self.in_lock: | |
in_pos = self.ringbuffer[SHM_INDEX_IN] | |
while True: | |
with self.out_lock: | |
out_pos = self.ringbuffer[SHM_INDEX_OUT] | |
if out_pos == in_pos: | |
break | |
def __exit__(self, *args): | |
self.wait() | |
self.ringbuffer[SHM_ON] = 0 | |
self.worker.join() | |
return self.smm.__exit__(*args) | |
def _push_unsafe(self, key, value, counter): | |
# no need to lock, we are the only one writing to it | |
idx = (counter % self.size) * 2 | |
self.ringbuffer[idx] = key | |
self.ringbuffer[idx + 1] = value | |
# finished writing | |
# "SHM_INDEX_IN" is read by the worker | |
# and need to be locked to avoid partial reads | |
with self.in_lock: | |
self.ringbuffer[SHM_INDEX_IN] = counter + 1 | |
def push_unsafe(self, key, value): | |
self._push_unsafe(key, value, self.ringbuffer[SHM_INDEX_IN]) | |
def push(self, key, value): | |
if self.ringbuffer is None: | |
raise NotInitialized("Shared memory is not initialized") | |
# worker could be writing to it | |
# get the out lock to make sure writing is finished | |
in_index = self.ringbuffer[SHM_INDEX_IN] | |
with self.out_lock: | |
out_index = self.ringbuffer[SHM_INDEX_OUT] | |
if (out_index + self.size + 1 <= in_index ): | |
raise Backpressure("Worker is not able to process all those events") | |
if len(key) > self.key_size: | |
raise ValueError("Key is bigger than storage") | |
self._push_unsafe(key, value, in_index) | |
my_perf_counter = time.perf_counter_ns | |
class MyObserver(Observer): | |
def __init__(self) -> None: | |
self.fp = None | |
self.acc = {} | |
self.time_diff = 0 | |
self.count = 0 | |
self.start_time = None | |
self.end_time = None | |
def __enter__(self): | |
self.fp = open("results.txt", "w") | |
return self | |
def __exit__(self, *args): | |
return self.fp.__exit__(*args) | |
def __call__(self, key, value): | |
if key == "start_time": | |
self.acc = {} | |
# self.time_diff += my_perf_counter() - value | |
# self.count += 1 | |
self.start_time = value | |
return | |
if key == "end_time": | |
self.end_time = value | |
# self.time_diff += my_perf_counter() - value | |
# self.count += 1 | |
total_time_s = (self.end_time - self.start_time) * 1e-9 | |
self.acc["elasped_s"] = total_time_s | |
# self.acc["latency_s"] = (self.time_diff / self.count) * 1e-9 | |
self.time_diff = 0 | |
self.count = 0 | |
data = json.dumps(self.acc) + "\n" | |
self.fp.write(data) | |
self.acc = {} | |
print(data, end="") | |
return | |
self.acc[key] = value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from contextlib import nullcontext | |
from taranis.core.perfcounter import AsyncMetric, MyObserver, my_perf_counter | |
import torch | |
class Nothing(): | |
def __enter__(self): | |
return self | |
def __exit__(self, *args): | |
return | |
def push(self, *args): | |
pass | |
def push_unsafe(self, *args): | |
pass | |
ENALBED = False | |
def stuff(): | |
if not ENALBED: | |
return Nothing() | |
return AsyncMetric(MyObserver, 200) | |
def main(): | |
s = 2048 | |
a = torch.randn((s, s)).cuda() | |
b = torch.randn((s, s)).cuda() | |
c = torch.randn((s, s)).cuda() | |
import timeit | |
with stuff() as metric: | |
# 1000000 => 20 | |
# [20.2530603, 20.3786452, 20.1735227, 19.87201780000001, 19.49251009999999] | |
# print(timeit.timeit('metric.push("start_time", 0)', globals=locals())) # 20 | |
# 10.921498800000002 | |
# print(timeit.timeit('metric.push_unsafe("start_time", 0)', globals=locals())) # 20 | |
for i in range(30): | |
start = my_perf_counter() | |
metric.push_unsafe("start_time", start) | |
metric.push_unsafe("batch_size", s) | |
for j in range(20): | |
torch.matmul(a, b, out=c) | |
torch.cuda.synchronize() | |
end = my_perf_counter() | |
metric.push_unsafe("end_time", end) | |
if not ENALBED: | |
print((end - start) * 1e-9) | |
print("DOne") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment