Last active
August 3, 2018 07:38
-
-
Save Hasenpfote/32ef187f27b0a5008783839ca1935751 to your computer and use it in GitHub Desktop.
Take a snapshot in a function.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import sys\n", | |
"sys.path.append(os.getcwd())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import tracemalloc_utils" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def foo(x, y, z):\n", | |
" dataset1 = np.random.uniform(low=-1., high=1., size=100).astype(np.float64)\n", | |
" print('x', x)\n", | |
" dataset2 = np.random.uniform(low=-1., high=1., size=1000).astype(np.float64)\n", | |
"\n", | |
" l = [i for i in range(100000)]\n", | |
" \n", | |
" if x == 0:\n", | |
" dataset_a = np.array([i for i in range(100000)], dtype=np.float64)\n", | |
" return 0\n", | |
" elif x == 1:\n", | |
" dataset_b = np.array([i for i in range(100000)], dtype=np.float64)\n", | |
" return 1\n", | |
"\n", | |
" dataset3 = np.random.uniform(low=-1., high=1., size=3000).astype(np.float64)\n", | |
" return 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"x 0\n" | |
] | |
} | |
], | |
"source": [ | |
"snapshot, source = tracemalloc_utils.take_inner_snapshot(\n", | |
" function=foo,\n", | |
" function_args=dict(x=0, y=1, z=1),\n", | |
" setup='import numpy as np',\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total 4.2 MiB(raw 4427494 B)\n", | |
"Line # Increment Line Contents\n", | |
"========================================================================================================\n", | |
" 1 def foo(x, y, z):\n", | |
" 2 1.6 KiB dataset1 = np.random.uniform(low=-1., high=1., size=100).astype(np.float64)\n", | |
" 3 434.0 B print('x', x)\n", | |
" 4 7.9 KiB dataset2 = np.random.uniform(low=-1., high=1., size=1000).astype(np.float64)\n", | |
" 5 \n", | |
" 6 3.4 MiB l = [i for i in range(100000)]\n", | |
" 7 \n", | |
" 8 if x == 0:\n", | |
" 9 781.3 KiB dataset_a = np.array([i for i in range(100000)], dtype=np.float64)\n", | |
" 10 return 0\n", | |
" 11 elif x == 1:\n", | |
" 12 dataset_b = np.array([i for i in range(100000)], dtype=np.float64)\n", | |
" 13 return 1\n", | |
" 14 \n", | |
" 15 dataset3 = np.random.uniform(low=-1., high=1., size=3000).astype(np.float64)\n", | |
" 16 return 2\n" | |
] | |
} | |
], | |
"source": [ | |
"tracemalloc_utils.profile(snapshot, source=source)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import inspect | |
import ast | |
import types | |
import linecache | |
import math | |
from tracemalloc import start, take_snapshot, Filter | |
DUMMY_SRC_NAME = '<tracemalloc_utils-src>' | |
class Transformer(ast.NodeTransformer): | |
'''Add tracemalloc functions.''' | |
def __init__(self, result_id): | |
self._result_id = result_id | |
def visit_FunctionDef(self, node): | |
pre_hook_expr = ast.Expr( | |
value=ast.Call( | |
func=ast.Name(id='start', ctx=ast.Load()), | |
args=[], | |
keywords=[] | |
) | |
) | |
finalbody = [ | |
ast.Global(names=[self._result_id]), | |
ast.Assign( | |
targets=[ast.Name(id=self._result_id, ctx=ast.Store())], | |
value=ast.Call( | |
func=ast.Name(id='take_snapshot', ctx=ast.Load()), | |
args=[], | |
keywords=[] | |
) | |
) | |
] | |
body_elems = [pre_hook_expr] | |
body_elems.extend([elem for elem in node.body]) | |
node.body.clear() | |
node.body.append( | |
ast.Try( | |
body=body_elems, | |
handlers=[], | |
orelse=[], | |
finalbody=finalbody | |
) | |
) | |
return ast.fix_missing_locations(node) | |
def take_inner_snapshot( | |
function, | |
function_args, | |
setup='pass' | |
): | |
'''Take a inner snapshot.''' | |
if not isinstance(function, types.FunctionType): | |
raise TypeError | |
if not isinstance(function_args, dict): | |
raise TypeError | |
if not isinstance(setup, str): | |
raise TypeError | |
# Add modules temporarily. | |
temp = {'SNAPSHOT': None} | |
code = compile(setup, DUMMY_SRC_NAME, 'exec') | |
exec(code, globals(), temp) | |
globals().update(temp) | |
global SNAPSHOT | |
# Modify the function. | |
source = inspect.getsource(function).strip() | |
node = ast.parse(source) | |
node = Transformer(result_id='SNAPSHOT').visit(node) | |
# Take the snapshot. | |
_locals = {} | |
code = compile(node, DUMMY_SRC_NAME, 'exec') | |
exec(code, globals(), _locals) | |
dst_function = _locals[function.__name__] | |
dst_function(**function_args) | |
snapshot = SNAPSHOT | |
# Restore. | |
for key in temp.keys(): | |
globals().pop(key, None) | |
return snapshot, source | |
def display_top( | |
snapshot, | |
source=None, | |
key_type='lineno', | |
limit=10 | |
): | |
if source is not None: | |
source_lines = source.split(sep='\n') | |
snapshot = snapshot.filter_traces([Filter(True, DUMMY_SRC_NAME),]) | |
top_stats = snapshot.statistics(key_type) | |
print('Top {} lines'.format(limit)) | |
for index, stat in enumerate(top_stats[:limit], 1): | |
frame = stat.traceback[0] | |
# replace "/path/to/module/file.py" with "module/file.py" | |
filename = os.sep.join(frame.filename.split(os.sep)[-2:]) | |
fmt = '#{index}: {filename}:{lineno}: {size:.3f} KiB'.format( | |
index=index, | |
filename=filename, | |
lineno=frame.lineno, | |
size=stat.size / 1024 | |
) | |
print(fmt) | |
line = '' | |
if frame.filename == DUMMY_SRC_NAME: | |
if source is not None: | |
line = source_lines[frame.lineno-1].strip() | |
else: | |
line = linecache.getline(frame.filename, frame.lineno).strip() | |
if line: | |
print(' %s' % line) | |
other = top_stats[limit:] | |
if other: | |
size = sum(stat.size for stat in other) | |
fmt = '{length} other: {size:.3f} KiB'.format( | |
length=len(other), | |
size=size / 1024 | |
) | |
print(fmt) | |
total = sum(stat.size for stat in top_stats) | |
fmt = 'Total allocated size: {size:.3f} KiB'.format(size=total / 1024) | |
print(fmt) | |
def bytes_to_hrf(size): | |
'''Convert bytes to human readable format.''' | |
units = ('B', 'KiB', 'MiB', 'GiB', 'TiB') | |
if size > 0: | |
order = min(int(math.log(size) / math.log(1024)), len(units)-1) | |
else: | |
order = 0 | |
return '{:6.1f} {}'.format(size/(1024**order), units[order]) | |
def profile(snapshot, source): | |
snapshot = snapshot.filter_traces([Filter(True, DUMMY_SRC_NAME),]) | |
stats = snapshot.statistics('lineno') | |
infos = {} | |
total = 0 | |
for index, stat in enumerate(stats, 1): | |
frame = stat.traceback[0] | |
infos[str(frame.lineno)] = stat.size | |
total += stat.size | |
print('Total {}(raw {} B)'.format(bytes_to_hrf(total), total)) | |
print('Line # Increment Line Contents') | |
print('=' * (24+80)) | |
for no, line in enumerate(source.split(sep='\n'), 1): | |
size = infos.get(str(no)) | |
usage = ' '*10 if size is None else bytes_to_hrf(size) | |
print('{no:6d} {usage:10s} {contents}'.format(no=no, usage=usage, contents=line)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment