Skip to content

Instantly share code, notes, and snippets.

@Hasenpfote
Last active August 3, 2018 07:38
Show Gist options
  • Save Hasenpfote/32ef187f27b0a5008783839ca1935751 to your computer and use it in GitHub Desktop.
Save Hasenpfote/32ef187f27b0a5008783839ca1935751 to your computer and use it in GitHub Desktop.
Take a snapshot in a function.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"sys.path.append(os.getcwd())"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import tracemalloc_utils"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def foo(x, y, z):\n",
" dataset1 = np.random.uniform(low=-1., high=1., size=100).astype(np.float64)\n",
" print('x', x)\n",
" dataset2 = np.random.uniform(low=-1., high=1., size=1000).astype(np.float64)\n",
"\n",
" l = [i for i in range(100000)]\n",
" \n",
" if x == 0:\n",
" dataset_a = np.array([i for i in range(100000)], dtype=np.float64)\n",
" return 0\n",
" elif x == 1:\n",
" dataset_b = np.array([i for i in range(100000)], dtype=np.float64)\n",
" return 1\n",
"\n",
" dataset3 = np.random.uniform(low=-1., high=1., size=3000).astype(np.float64)\n",
" return 2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x 0\n"
]
}
],
"source": [
"snapshot, source = tracemalloc_utils.take_inner_snapshot(\n",
" function=foo,\n",
" function_args=dict(x=0, y=1, z=1),\n",
" setup='import numpy as np',\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total 4.2 MiB(raw 4427494 B)\n",
"Line # Increment Line Contents\n",
"========================================================================================================\n",
" 1 def foo(x, y, z):\n",
" 2 1.6 KiB dataset1 = np.random.uniform(low=-1., high=1., size=100).astype(np.float64)\n",
" 3 434.0 B print('x', x)\n",
" 4 7.9 KiB dataset2 = np.random.uniform(low=-1., high=1., size=1000).astype(np.float64)\n",
" 5 \n",
" 6 3.4 MiB l = [i for i in range(100000)]\n",
" 7 \n",
" 8 if x == 0:\n",
" 9 781.3 KiB dataset_a = np.array([i for i in range(100000)], dtype=np.float64)\n",
" 10 return 0\n",
" 11 elif x == 1:\n",
" 12 dataset_b = np.array([i for i in range(100000)], dtype=np.float64)\n",
" 13 return 1\n",
" 14 \n",
" 15 dataset3 = np.random.uniform(low=-1., high=1., size=3000).astype(np.float64)\n",
" 16 return 2\n"
]
}
],
"source": [
"tracemalloc_utils.profile(snapshot, source=source)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import os
import inspect
import ast
import types
import linecache
import math
from tracemalloc import start, take_snapshot, Filter
DUMMY_SRC_NAME = '<tracemalloc_utils-src>'
class Transformer(ast.NodeTransformer):
'''Add tracemalloc functions.'''
def __init__(self, result_id):
self._result_id = result_id
def visit_FunctionDef(self, node):
pre_hook_expr = ast.Expr(
value=ast.Call(
func=ast.Name(id='start', ctx=ast.Load()),
args=[],
keywords=[]
)
)
finalbody = [
ast.Global(names=[self._result_id]),
ast.Assign(
targets=[ast.Name(id=self._result_id, ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id='take_snapshot', ctx=ast.Load()),
args=[],
keywords=[]
)
)
]
body_elems = [pre_hook_expr]
body_elems.extend([elem for elem in node.body])
node.body.clear()
node.body.append(
ast.Try(
body=body_elems,
handlers=[],
orelse=[],
finalbody=finalbody
)
)
return ast.fix_missing_locations(node)
def take_inner_snapshot(
function,
function_args,
setup='pass'
):
'''Take a inner snapshot.'''
if not isinstance(function, types.FunctionType):
raise TypeError
if not isinstance(function_args, dict):
raise TypeError
if not isinstance(setup, str):
raise TypeError
# Add modules temporarily.
temp = {'SNAPSHOT': None}
code = compile(setup, DUMMY_SRC_NAME, 'exec')
exec(code, globals(), temp)
globals().update(temp)
global SNAPSHOT
# Modify the function.
source = inspect.getsource(function).strip()
node = ast.parse(source)
node = Transformer(result_id='SNAPSHOT').visit(node)
# Take the snapshot.
_locals = {}
code = compile(node, DUMMY_SRC_NAME, 'exec')
exec(code, globals(), _locals)
dst_function = _locals[function.__name__]
dst_function(**function_args)
snapshot = SNAPSHOT
# Restore.
for key in temp.keys():
globals().pop(key, None)
return snapshot, source
def display_top(
snapshot,
source=None,
key_type='lineno',
limit=10
):
if source is not None:
source_lines = source.split(sep='\n')
snapshot = snapshot.filter_traces([Filter(True, DUMMY_SRC_NAME),])
top_stats = snapshot.statistics(key_type)
print('Top {} lines'.format(limit))
for index, stat in enumerate(top_stats[:limit], 1):
frame = stat.traceback[0]
# replace "/path/to/module/file.py" with "module/file.py"
filename = os.sep.join(frame.filename.split(os.sep)[-2:])
fmt = '#{index}: {filename}:{lineno}: {size:.3f} KiB'.format(
index=index,
filename=filename,
lineno=frame.lineno,
size=stat.size / 1024
)
print(fmt)
line = ''
if frame.filename == DUMMY_SRC_NAME:
if source is not None:
line = source_lines[frame.lineno-1].strip()
else:
line = linecache.getline(frame.filename, frame.lineno).strip()
if line:
print(' %s' % line)
other = top_stats[limit:]
if other:
size = sum(stat.size for stat in other)
fmt = '{length} other: {size:.3f} KiB'.format(
length=len(other),
size=size / 1024
)
print(fmt)
total = sum(stat.size for stat in top_stats)
fmt = 'Total allocated size: {size:.3f} KiB'.format(size=total / 1024)
print(fmt)
def bytes_to_hrf(size):
'''Convert bytes to human readable format.'''
units = ('B', 'KiB', 'MiB', 'GiB', 'TiB')
if size > 0:
order = min(int(math.log(size) / math.log(1024)), len(units)-1)
else:
order = 0
return '{:6.1f} {}'.format(size/(1024**order), units[order])
def profile(snapshot, source):
snapshot = snapshot.filter_traces([Filter(True, DUMMY_SRC_NAME),])
stats = snapshot.statistics('lineno')
infos = {}
total = 0
for index, stat in enumerate(stats, 1):
frame = stat.traceback[0]
infos[str(frame.lineno)] = stat.size
total += stat.size
print('Total {}(raw {} B)'.format(bytes_to_hrf(total), total))
print('Line # Increment Line Contents')
print('=' * (24+80))
for no, line in enumerate(source.split(sep='\n'), 1):
size = infos.get(str(no))
usage = ' '*10 if size is None else bytes_to_hrf(size)
print('{no:6d} {usage:10s} {contents}'.format(no=no, usage=usage, contents=line))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment