Skip to content

Instantly share code, notes, and snippets.

@thehesiod
Last active May 15, 2023 22:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save thehesiod/2f56f98370bea45f021d3704b21707a9 to your computer and use it in GitHub Desktop.
Save thehesiod/2f56f98370bea45f021d3704b21707a9 to your computer and use it in GitHub Desktop.
Memory Tracer
import tracemalloc
import os
import linecache
import wrapt
_TRACE_FILTERS = (
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, tracemalloc.__file__, all_frames=True), # needed because tracemalloc calls fnmatch
tracemalloc.Filter(False, linecache.__file__),
tracemalloc.Filter(False, os.path.abspath(__file__), all_frames=True), # since we call weakref
)
class MemTracer:
def __init__(self, logger: logging.Logger, period_s: float, file_path: str=None, incremental: bool=True):
"""
Will create an instance of Memory Tracer that dumps out results each `period_s`, first period is ignored to warm-up the tracer
:param period_s:
:param file_path: file path to log results to
:param incremental: set to True to be incremental snapshots
"""
self._logger = logger
self._trace_start = None
self._last_snapshot = None
self._period_s = period_s
self._file_path = file_path
self._incremental = incremental
self._num_periods = 0
self._num_ticks = 0
if file_path and os.path.exists(file_path):
os.unlink(file_path)
# gc.set_debug(gc.DEBUG_LEAK)
self.patch_thread_pool_executor()
@classmethod
def patch_thread_pool_executor(cls):
if concurrent.futures.ThreadPoolExecutor._adjust_thread_count != cls._adjust_thread_count:
wrapt.wrap_function_wrapper('concurrent.futures', 'ThreadPoolExecutor._adjust_thread_count', cls._adjust_thread_count)
@classmethod
def _adjust_thread_count(cls, wrapped, instance, args, kwargs):
num_threads = len(instance._threads)
while num_threads < instance._max_workers:
wrapped(*args, **kwargs)
num_threads = len(instance._threads)
@classmethod
def patch_and_tick(cls, obj, *args, **kwargs):
"""
Will add an instance of MemTracer to "obj" as a private property if it doesn't exist and call tick
:param obj: object to patch onto
:param args: init args to MemTracer
:param kwargs: init args to MemTracer
"""
tracer = getattr(obj, '_tracer', None)
if not tracer:
tracer = obj._tracer = cls(*args, **kwargs)
tracer.tick()
def start(self):
if not tracemalloc.is_tracing():
tracemalloc.start(40)
self._trace_start = time.time()
def capture(self, store: bool=True):
# to avoid this popping up in the traces
re._cache.clear()
gc.collect()
if self._trace_start is None:
self.start()
with log_elapsed(self._logger, "Capturing trace"):
snapshot = tracemalloc.take_snapshot()
snapshot = snapshot.filter_traces(_TRACE_FILTERS)
if store:
self._last_snapshot = snapshot
return snapshot
def tick(self):
self._num_ticks += 1
if self._trace_start is None:
self.start()
return
elapsed_s = time.time() - self._trace_start
if elapsed_s > self._period_s:
self._num_periods += 1
try:
self.dump_snapshop(elapsed_s)
# objgraph.show_most_common_types(limit=50)
finally:
self._trace_start = time.time() # want to set this at the end so we get the correct period after this dump
def dump_snapshop(self, elapsed_s=-1):
if self._last_snapshot is None:
self.capture()
return
snapshot = self.capture(False)
top_stats: List[tracemalloc.StatisticDiff] = snapshot.compare_to(self._last_snapshot, 'traceback')
total_acquired = 0
total_released = 0
max_stats = min(len(top_stats), 40)
stream = StringIO()
stream.write('===============================' + os.linesep)
stream.write(f"[Top {max_stats}/{len(top_stats)} differences elapsed: {round(elapsed_s)}] in periods: {self._num_periods} and ticks: {self._num_ticks} max RSS: {get_max_rss()} MB" + os.linesep)
num_printed = 0
for stat in sorted(top_stats, key=lambda x: x.size_diff, reverse=True):
if stat.size_diff <= 0:
total_released += -stat.size_diff
else:
total_acquired += stat.size_diff
if num_printed < max_stats and stat.size_diff > 0:
stream.write(f"{stat.count_diff} memory blocks: {stat.size_diff / 1024} KB" + os.linesep)
for line in stat.traceback.format():
stream.write('\t' + str(line) + os.linesep)
num_printed += 1
stream.write(f"total KB acquired: {total_acquired / 1024} released: {total_released / 1024}" + os.linesep)
stream.write('===============================' + os.linesep)
# stream.write(mem_top(25, 300))
if self._file_path:
with open(self._file_path, 'w+') as f:
f.write(stream.getvalue())
else:
print(stream.getvalue())
if self._incremental:
self._last_snapshot = snapshot
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment