Created
August 4, 2018 06:15
-
-
Save Hasenpfote/191c904cf490a85ce59c406efbdda242 to your computer and use it in GitHub Desktop.
Take a snapshot in a function.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from memory_profiler import * | |
def foo(x, y, z): | |
dataset1 = np.random.uniform(low=-1., high=1., size=100).astype(np.float64) | |
print('x', x) | |
dataset1 = np.random.uniform(low=-1., high=1., size=1000).astype(np.float64) | |
l = [i for i in range(100000)] | |
if x == 0: | |
dataset4a = np.array([i for i in range(100000)], dtype=np.float64) | |
return 0 | |
elif x == 1: | |
dataset4b = np.array([i for i in range(100000)], dtype=np.float64) | |
return 1 | |
dataset3 = np.random.uniform(low=-1., high=1., size=3000).astype(np.float64) | |
return 2 | |
class Klass: | |
def __init__(self, value): | |
self._value = value | |
@staticmethod | |
def func(): | |
dataset = np.random.uniform(low=-1., high=1., size=100).astype(np.float64) | |
print('Hello') | |
def main(): | |
# function | |
mp = MemoryProfiler( | |
function=foo, | |
function_args=dict(x=0, y=11, z=12), | |
setup='import numpy as np', | |
) | |
mp.profile() | |
# static method | |
mp = MemoryProfiler( | |
function=Klass.func, | |
function_args=dict(), | |
setup='import numpy as np\nfrom __main__ import Klass', | |
) | |
mp.profile() | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import inspect | |
import ast | |
import types | |
import math | |
import textwrap | |
from tracemalloc import start, take_snapshot, Filter | |
__all__ = ['MemoryProfiler', ] | |
DUMMY_SRC_NAME = '<memory_profiler-src>' | |
class Transformer(ast.NodeTransformer): | |
'''Add tracemalloc functions.''' | |
def __init__(self, result_id): | |
self._result_id = result_id | |
def visit_FunctionDef(self, node): | |
# Pre-hook. | |
pre_hook_expr = ast.Expr( | |
value=ast.Call( | |
func=ast.Name(id='start', ctx=ast.Load()), | |
args=[], | |
keywords=[] | |
) | |
) | |
# Post-hook. | |
finalbody = [ | |
ast.Global(names=[self._result_id]), | |
ast.Assign( | |
targets=[ast.Name(id=self._result_id, ctx=ast.Store())], | |
value=ast.Call( | |
func=ast.Name(id='take_snapshot', ctx=ast.Load()), | |
args=[], | |
keywords=[] | |
) | |
) | |
] | |
body_elems = [pre_hook_expr] | |
body_elems.extend([elem for elem in node.body]) | |
node.body.clear() | |
node.body.append( | |
ast.Try( | |
body=body_elems, | |
handlers=[], | |
orelse=[], | |
finalbody=finalbody | |
) | |
) | |
return ast.fix_missing_locations(node) | |
def bytes_to_hrf(size): | |
'''Convert bytes to human readable format.''' | |
units = ('B', 'KiB', 'MiB', 'GiB', 'TiB') | |
if size > 0: | |
order = min(int(math.log(size) / math.log(1024)), len(units)-1) | |
else: | |
order = 0 | |
fmt = '6.0f' if order == 0 else '6.1f' | |
return '{0:{1}} {2}'.format(size/(1024**order), fmt, units[order]) | |
class MemoryProfiler(object): | |
'''This class uses tracemalloc to perform memory profile. | |
Args: | |
function (types.FunctionType): | |
function_args (dict): | |
setup (str): | |
''' | |
def __init__(self, function, function_args, setup='pass'): | |
if isinstance(function, types.FunctionType): | |
self._function = function | |
else: | |
raise TypeError | |
if isinstance(function_args, dict): | |
self._function_args = function_args | |
else: | |
raise TypeError | |
if isinstance(setup, str): | |
self._setup = setup | |
else: | |
raise TypeError | |
# Modify the function. | |
source_lines, start_pos = inspect.getsourcelines(function) | |
source_text = ''.join(source_lines) | |
if '@staticmethod' in source_text: | |
self._source_text = textwrap.dedent(source_text.replace('@staticmethod', '')).strip() | |
self._start_pos = start_pos + 1 | |
else: | |
self._source_text = source_text.strip() | |
self._start_pos = start_pos | |
node = ast.parse(self._source_text) | |
node = Transformer(result_id='SNAPSHOT').visit(node) | |
locals_ = {} | |
code = compile(node, DUMMY_SRC_NAME, 'exec') | |
exec(code, globals(), locals_) | |
self._inner = locals_[function.__name__] | |
# Other information. | |
self._filepath = inspect.getfile(function) | |
def _take_snapshot(self): | |
try: | |
# Add modules temporarily. | |
temp = {'SNAPSHOT': None} | |
code = compile(self._setup, DUMMY_SRC_NAME, 'exec') | |
exec(code, globals(), temp) | |
globals().update(temp) | |
global SNAPSHOT | |
self._inner(**self._function_args) | |
return SNAPSHOT | |
finally: | |
# Restore. | |
for key in temp.keys(): | |
globals().pop(key, None) | |
def profile(self): | |
snapshot = self._take_snapshot() | |
snapshot = snapshot.filter_traces([Filter(True, DUMMY_SRC_NAME),]) | |
stats = snapshot.statistics('lineno') | |
total = 0 | |
detected_lines = {} | |
for stat in stats: | |
frame = stat.traceback[0] | |
detected_lines[str(frame.lineno)] = stat.size | |
total += stat.size | |
print('File "{}"'.format(self._filepath)) | |
print('Total {}(raw {} B)'.format(bytes_to_hrf(total), total)) | |
print('Line # Increment Line Contents') | |
print('=' * (24+80)) | |
for number, line in enumerate(self._source_text.split(sep='\n'), 1): | |
size = detected_lines.get(str(number)) | |
usage = ' '*10 if size is None else bytes_to_hrf(size) | |
print('{number:6d} {usage:10s} {contents}'.format(number=self._start_pos + (number - 1), usage=usage, contents=line)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment