Skip to content

Instantly share code, notes, and snippets.

@Hasenpfote
Created August 6, 2018 08:57
Show Gist options
  • Save Hasenpfote/5f43fa400468ef7080820b90bd3679ac to your computer and use it in GitHub Desktop.
Save Hasenpfote/5f43fa400468ef7080820b90bd3679ac to your computer and use it in GitHub Desktop.
malloc_tracer
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
class Klass(object):
def __init__(self, value):
self._value = value
def func(self, x):
dataset1 = np.empty((100, ), dtype=np.float64)
print('x', x)
dataset1 = np.empty((1000, ), dtype=np.float64)
l = [i for i in range(100000)]
if x == 0:
dataset4a = np.empty((100000, ), dtype=np.float64)
return 0
elif x == 1:
dataset4b = np.empty((100000, ), dtype=np.float64)
return 1
dataset3 = np.empty((3000, ), dtype=np.float64)
return 2
@staticmethod
def sfunc():
dataset = np.empty((100, ), dtype=np.float64)
l = [i for i in range(100000)]
print('Hello')
return dataset
import numpy as np
from foo import Klass
from malloc_tracer import *
def function(x, y, z):
dataset1 = np.empty((100, ), dtype=np.float64)
print('x', x)
dataset1 = np.empty((1000, ), dtype=np.float64)
l = [i for i in range(100000)]
if x == 0:
dataset4a = np.empty((100000, ), dtype=np.float64)
return 0
elif x == 1:
dataset4b = np.empty((100000, ), dtype=np.float64)
return 1
dataset3 = np.empty((3000, ), dtype=np.float64)
return 2
# Test for function.
def test1():
tracer = Tracer(function)
tracer.trace(
target_args=dict(x=1, y=2, z=3),
setup='import numpy as np'
)
# test for method.
def test2():
tracer = Tracer(Klass)
tracer.trace(
init_args=dict(value=1),
target_name='func',
target_args=dict(x=1),
setup='import numpy as np'
)
# Test for static method.
def test3():
tracer = Tracer(Klass)
tracer.trace(
target_name='sfunc',
setup='import numpy as np'
)
def main():
test1()
#test2()
#test3()
if __name__ == '__main__':
main()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import inspect
import ast
import types
import math
from tracemalloc import start, take_snapshot, stop, Filter
__all__ = ['Tracer', ]
DUMMY_SRC_NAME = '<malloc_tracer-src>'
def bytes_to_hrf(size):
'''Convert bytes to human readable format.'''
units = ('B', 'KiB', 'MiB', 'GiB', 'TiB')
if size > 0:
order = min(int(math.log(size) / math.log(1024)), len(units)-1)
else:
order = 0
fmt = '6.0f' if order == 0 else '6.1f'
return '{0:{1}} {2}'.format(size/(1024**order), fmt, units[order])
class Transformer(ast.NodeTransformer):
'''Add tracemalloc functions.'''
def __init__(self, result_id):
self._result_id = result_id
def visit_FunctionDef(self, node):
# Pre-hook.
pre_hook_expr = ast.Expr(
value=ast.Call(
func=ast.Name(id='start', ctx=ast.Load()),
args=[],
keywords=[]
)
)
# Post-hook.
finalbody = [
ast.Global(names=[self._result_id]),
ast.Assign(
targets=[ast.Name(id=self._result_id, ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id='take_snapshot', ctx=ast.Load()),
args=[],
keywords=[]
)
),
ast.Expr(
value=ast.Call(
func=ast.Name(id='stop', ctx=ast.Load()),
args=[],
keywords=[]
)
)
]
body_elems = [pre_hook_expr]
body_elems.extend([elem for elem in node.body])
node.body.clear()
node.body.append(
ast.Try(
body=body_elems,
handlers=[],
orelse=[],
finalbody=finalbody
)
)
return ast.fix_missing_locations(node)
class Tracer(object):
def __init__(
self,
obj
):
if not (inspect.isfunction(obj) or inspect.isclass(obj)):
raise TypeError('The obj must be a function or a class.')
source_lines, line_no = inspect.getsourcelines(obj)
source_text = ''.join(source_lines).strip()
node = ast.parse(source_text)
node = Transformer(result_id='SNAPSHOT').visit(node)
locals_ = {}
code = compile(node, DUMMY_SRC_NAME, 'exec')
exec(code, globals(), locals_)
self._obj = locals_[obj.__name__]
self._source_lines = source_lines
self._line_no = line_no
self._filepath = inspect.getfile(obj)
def _take_snapshot(
self,
init_args=None,
target_name=None,
target_args=None,
setup='pass'
):
try:
if target_args is None:
target_args = dict()
# Add modules temporarily.
temp = {'SNAPSHOT': None}
code = compile(setup, DUMMY_SRC_NAME, 'exec')
exec(code, globals(), temp)
for key in list(temp):
if key in globals().keys():
temp.pop(key)
globals().update(temp)
global SNAPSHOT
if target_name is None:
self._obj(**target_args)
else:
if isinstance(self._obj.__dict__[target_name], staticmethod):
method = getattr(self._obj, target_name)
else:
instance = self._obj(**init_args)
method = getattr(instance, target_name)
method(**target_args)
return SNAPSHOT
finally:
# Restore.
for key in temp.keys():
globals().pop(key, None)
def trace(
self,
init_args=None,
target_name=None,
target_args=None,
setup='pass'
):
snapshot = self._take_snapshot(
init_args=init_args,
target_name=target_name,
target_args=target_args,
setup=setup
)
snapshot = snapshot.filter_traces([Filter(True, DUMMY_SRC_NAME),])
stats = snapshot.statistics('lineno')
total = 0
detected_lines = {}
for stat in stats:
frame = stat.traceback[0]
detected_lines[str(frame.lineno)] = stat.size
total += stat.size
print('File "{}"'.format(self._filepath))
print('Total {}(raw {} B)'.format(bytes_to_hrf(total), total))
print('Line # Trace Line Contents')
print('=' * (24+80))
source_text = ''.join(self._source_lines).strip()
for line_no, line in enumerate(source_text.split(sep='\n'), 1):
size = detected_lines.get(str(line_no))
trace = ' '*10 if size is None else bytes_to_hrf(size)
print('{line_no:6d} {trace:10s} {contents}'.format(line_no=self._line_no + line_no - 1, trace=trace, contents=line))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment