Skip to content

Instantly share code, notes, and snippets.

@Hasenpfote
Last active August 3, 2018 07:38
Show Gist options
  • Save Hasenpfote/32ef187f27b0a5008783839ca1935751 to your computer and use it in GitHub Desktop.
Save Hasenpfote/32ef187f27b0a5008783839ca1935751 to your computer and use it in GitHub Desktop.
Take a snapshot in a function.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import os
import inspect
import ast
import types
import linecache
import math
from tracemalloc import start, take_snapshot, Filter
DUMMY_SRC_NAME = '<tracemalloc_utils-src>'
class Transformer(ast.NodeTransformer):
'''Add tracemalloc functions.'''
def __init__(self, result_id):
self._result_id = result_id
def visit_FunctionDef(self, node):
pre_hook_expr = ast.Expr(
value=ast.Call(
func=ast.Name(id='start', ctx=ast.Load()),
args=[],
keywords=[]
)
)
finalbody = [
ast.Global(names=[self._result_id]),
ast.Assign(
targets=[ast.Name(id=self._result_id, ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id='take_snapshot', ctx=ast.Load()),
args=[],
keywords=[]
)
)
]
body_elems = [pre_hook_expr]
body_elems.extend([elem for elem in node.body])
node.body.clear()
node.body.append(
ast.Try(
body=body_elems,
handlers=[],
orelse=[],
finalbody=finalbody
)
)
return ast.fix_missing_locations(node)
def take_inner_snapshot(
function,
function_args,
setup='pass'
):
'''Take a inner snapshot.'''
if not isinstance(function, types.FunctionType):
raise TypeError
if not isinstance(function_args, dict):
raise TypeError
if not isinstance(setup, str):
raise TypeError
# Add modules temporarily.
temp = {'SNAPSHOT': None}
code = compile(setup, DUMMY_SRC_NAME, 'exec')
exec(code, globals(), temp)
globals().update(temp)
global SNAPSHOT
# Modify the function.
source = inspect.getsource(function).strip()
node = ast.parse(source)
node = Transformer(result_id='SNAPSHOT').visit(node)
# Take the snapshot.
_locals = {}
code = compile(node, DUMMY_SRC_NAME, 'exec')
exec(code, globals(), _locals)
dst_function = _locals[function.__name__]
dst_function(**function_args)
snapshot = SNAPSHOT
# Restore.
for key in temp.keys():
globals().pop(key, None)
return snapshot, source
def display_top(
snapshot,
source=None,
key_type='lineno',
limit=10
):
if source is not None:
source_lines = source.split(sep='\n')
snapshot = snapshot.filter_traces([Filter(True, DUMMY_SRC_NAME),])
top_stats = snapshot.statistics(key_type)
print('Top {} lines'.format(limit))
for index, stat in enumerate(top_stats[:limit], 1):
frame = stat.traceback[0]
# replace "/path/to/module/file.py" with "module/file.py"
filename = os.sep.join(frame.filename.split(os.sep)[-2:])
fmt = '#{index}: {filename}:{lineno}: {size:.3f} KiB'.format(
index=index,
filename=filename,
lineno=frame.lineno,
size=stat.size / 1024
)
print(fmt)
line = ''
if frame.filename == DUMMY_SRC_NAME:
if source is not None:
line = source_lines[frame.lineno-1].strip()
else:
line = linecache.getline(frame.filename, frame.lineno).strip()
if line:
print(' %s' % line)
other = top_stats[limit:]
if other:
size = sum(stat.size for stat in other)
fmt = '{length} other: {size:.3f} KiB'.format(
length=len(other),
size=size / 1024
)
print(fmt)
total = sum(stat.size for stat in top_stats)
fmt = 'Total allocated size: {size:.3f} KiB'.format(size=total / 1024)
print(fmt)
def bytes_to_hrf(size):
'''Convert bytes to human readable format.'''
units = ('B', 'KiB', 'MiB', 'GiB', 'TiB')
if size > 0:
order = min(int(math.log(size) / math.log(1024)), len(units)-1)
else:
order = 0
return '{:6.1f} {}'.format(size/(1024**order), units[order])
def profile(snapshot, source):
snapshot = snapshot.filter_traces([Filter(True, DUMMY_SRC_NAME),])
stats = snapshot.statistics('lineno')
infos = {}
total = 0
for index, stat in enumerate(stats, 1):
frame = stat.traceback[0]
infos[str(frame.lineno)] = stat.size
total += stat.size
print('Total {}(raw {} B)'.format(bytes_to_hrf(total), total))
print('Line # Increment Line Contents')
print('=' * (24+80))
for no, line in enumerate(source.split(sep='\n'), 1):
size = infos.get(str(no))
usage = ' '*10 if size is None else bytes_to_hrf(size)
print('{no:6d} {usage:10s} {contents}'.format(no=no, usage=usage, contents=line))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment