Skip to content

Instantly share code, notes, and snippets.

@Hasenpfote
Created August 4, 2018 06:15
Show Gist options
  • Save Hasenpfote/191c904cf490a85ce59c406efbdda242 to your computer and use it in GitHub Desktop.
Save Hasenpfote/191c904cf490a85ce59c406efbdda242 to your computer and use it in GitHub Desktop.
Take a snapshot in a function.
import numpy as np
from memory_profiler import *
def foo(x, y, z):
dataset1 = np.random.uniform(low=-1., high=1., size=100).astype(np.float64)
print('x', x)
dataset1 = np.random.uniform(low=-1., high=1., size=1000).astype(np.float64)
l = [i for i in range(100000)]
if x == 0:
dataset4a = np.array([i for i in range(100000)], dtype=np.float64)
return 0
elif x == 1:
dataset4b = np.array([i for i in range(100000)], dtype=np.float64)
return 1
dataset3 = np.random.uniform(low=-1., high=1., size=3000).astype(np.float64)
return 2
class Klass:
def __init__(self, value):
self._value = value
@staticmethod
def func():
dataset = np.random.uniform(low=-1., high=1., size=100).astype(np.float64)
print('Hello')
def main():
# function
mp = MemoryProfiler(
function=foo,
function_args=dict(x=0, y=11, z=12),
setup='import numpy as np',
)
mp.profile()
# static method
mp = MemoryProfiler(
function=Klass.func,
function_args=dict(),
setup='import numpy as np\nfrom __main__ import Klass',
)
mp.profile()
if __name__ == '__main__':
main()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import inspect
import ast
import types
import math
import textwrap
from tracemalloc import start, take_snapshot, Filter
__all__ = ['MemoryProfiler', ]
DUMMY_SRC_NAME = '<memory_profiler-src>'
class Transformer(ast.NodeTransformer):
'''Add tracemalloc functions.'''
def __init__(self, result_id):
self._result_id = result_id
def visit_FunctionDef(self, node):
# Pre-hook.
pre_hook_expr = ast.Expr(
value=ast.Call(
func=ast.Name(id='start', ctx=ast.Load()),
args=[],
keywords=[]
)
)
# Post-hook.
finalbody = [
ast.Global(names=[self._result_id]),
ast.Assign(
targets=[ast.Name(id=self._result_id, ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id='take_snapshot', ctx=ast.Load()),
args=[],
keywords=[]
)
)
]
body_elems = [pre_hook_expr]
body_elems.extend([elem for elem in node.body])
node.body.clear()
node.body.append(
ast.Try(
body=body_elems,
handlers=[],
orelse=[],
finalbody=finalbody
)
)
return ast.fix_missing_locations(node)
def bytes_to_hrf(size):
'''Convert bytes to human readable format.'''
units = ('B', 'KiB', 'MiB', 'GiB', 'TiB')
if size > 0:
order = min(int(math.log(size) / math.log(1024)), len(units)-1)
else:
order = 0
fmt = '6.0f' if order == 0 else '6.1f'
return '{0:{1}} {2}'.format(size/(1024**order), fmt, units[order])
class MemoryProfiler(object):
'''This class uses tracemalloc to perform memory profile.
Args:
function (types.FunctionType):
function_args (dict):
setup (str):
'''
def __init__(self, function, function_args, setup='pass'):
if isinstance(function, types.FunctionType):
self._function = function
else:
raise TypeError
if isinstance(function_args, dict):
self._function_args = function_args
else:
raise TypeError
if isinstance(setup, str):
self._setup = setup
else:
raise TypeError
# Modify the function.
source_lines, start_pos = inspect.getsourcelines(function)
source_text = ''.join(source_lines)
if '@staticmethod' in source_text:
self._source_text = textwrap.dedent(source_text.replace('@staticmethod', '')).strip()
self._start_pos = start_pos + 1
else:
self._source_text = source_text.strip()
self._start_pos = start_pos
node = ast.parse(self._source_text)
node = Transformer(result_id='SNAPSHOT').visit(node)
locals_ = {}
code = compile(node, DUMMY_SRC_NAME, 'exec')
exec(code, globals(), locals_)
self._inner = locals_[function.__name__]
# Other information.
self._filepath = inspect.getfile(function)
def _take_snapshot(self):
try:
# Add modules temporarily.
temp = {'SNAPSHOT': None}
code = compile(self._setup, DUMMY_SRC_NAME, 'exec')
exec(code, globals(), temp)
globals().update(temp)
global SNAPSHOT
self._inner(**self._function_args)
return SNAPSHOT
finally:
# Restore.
for key in temp.keys():
globals().pop(key, None)
def profile(self):
snapshot = self._take_snapshot()
snapshot = snapshot.filter_traces([Filter(True, DUMMY_SRC_NAME),])
stats = snapshot.statistics('lineno')
total = 0
detected_lines = {}
for stat in stats:
frame = stat.traceback[0]
detected_lines[str(frame.lineno)] = stat.size
total += stat.size
print('File "{}"'.format(self._filepath))
print('Total {}(raw {} B)'.format(bytes_to_hrf(total), total))
print('Line # Increment Line Contents')
print('=' * (24+80))
for number, line in enumerate(self._source_text.split(sep='\n'), 1):
size = detected_lines.get(str(number))
usage = ' '*10 if size is None else bytes_to_hrf(size)
print('{number:6d} {usage:10s} {contents}'.format(number=self._start_pos + (number - 1), usage=usage, contents=line))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment