Last active
May 7, 2020 14:03
-
-
Save asodeur/2a5d9101cc3a49b8181192ff56aafec2 to your computer and use it in GitHub Desktop.
helpers to profile Tasks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from asyncio import ensure_future, gather, Task | |
from contextlib import contextmanager | |
from contextvars import ContextVar | |
from more_itertools import pairwise | |
from networkx import all_simple_paths, DiGraph | |
import os | |
from pstats import Stats | |
import sys | |
import threading | |
from typing import TYPE_CHECKING | |
import yappi, _yappi | |
if TYPE_CHECKING: | |
from typing import Dict, Tuple | |
__all__ = [ | |
'call_tree_to_pstat', 'clear', 'create_call_tree', 'rewrite_call_tree', 'run', | |
'start', 'stop', 'TaggedYFuncStats' | |
] | |
class ProfileNode: | |
"""used as node in the (networkx) call graph, omits some attributes from yappi.YFuncStat""" | |
__slots__ = ( | |
'name', 'module', 'lineno', 'full_name', 'ncall', 'nactualcall', 'ttot', 'tsub', 'tag' | |
) | |
def __init__(self, name, module, lineno, full_name, ncall, nactualcall, ttot, tsub, tag=0): | |
self.name = name | |
self.module = module | |
self.lineno = lineno | |
self.full_name = full_name | |
self.ncall = ncall | |
self.nactualcall = nactualcall | |
self.ttot = ttot | |
self.tsub = tsub | |
self.tag = tag | |
def copy(self): | |
return ProfileNode( | |
self.name, self.module, self.lineno, self.full_name, self.ncall, self.nactualcall, | |
self.ttot, self.tsub, self.tag | |
) | |
def __eq__(self, other): | |
return (self.full_name == other.full_name and self.tag == other.tag) | |
def __hash__(self): | |
return hash((self.full_name, self.tag)) | |
def __repr__(self): | |
return f'{self.full_name}<{self.tag}>' | |
class TaggedYFuncStats(yappi.YFuncStats): | |
def _enumerator(self, stat_entry): | |
"""appends the tag to the function name: name -> name<tag> | |
works around the fact that YFuncStat compares equal if full_name matches | |
regardless of the tags | |
""" | |
global _fn_descriptor_dict | |
fname, fmodule, flineno, fncall, fnactualcall, fbuiltin, fttot, ftsub, \ | |
findex, fchildren, fctxid, fctxname, ftag, ffn_descriptor = stat_entry | |
if ftag: | |
fname = f'{fname}<{ftag}>' | |
stat_entry = (fname,) + stat_entry[1:] | |
# builtin function? | |
ffull_name = yappi._func_fullname(bool(fbuiltin), fmodule, flineno, fname) | |
ftavg = fttot / fncall | |
fstat = yappi.YFuncStat(stat_entry + (ftavg, ffull_name)) | |
yappi._fn_descriptor_dict[ffull_name] = ffn_descriptor | |
# do not show profile stats of yappi itself. | |
if os.path.basename( | |
fstat.module | |
) == "yappi.py" or fstat.module == "_yappi": | |
return | |
fstat.builtin = bool(fstat.builtin) | |
if self._filter_callback: | |
if not self._filter_callback(fstat): | |
return | |
self.append(fstat) | |
# hold the max idx number for merging new entries(for making the merging | |
# entries indexes unique) | |
if self._idx_max < fstat.index: | |
self._idx_max = fstat.index | |
def create_call_tree(stats: yappi.YFuncStats) -> DiGraph: | |
"""create a directed graph from yappi stats""" | |
call_tree = DiGraph() | |
for fs in stats: | |
pn = ProfileNode(fs.name, fs.module, fs.lineno, fs.full_name, fs.ncall, fs.nactualcall, fs.ttot, fs.tsub, fs.tag) | |
call_tree.add_node(pn) | |
for c in fs.children: | |
child = stats[c.full_name] | |
child = ProfileNode( | |
child.name, child.module, child.lineno, child.full_name, child.ncall, child.nactualcall, | |
child.ttot, child.tsub, child.tag | |
) | |
call_tree.add_edge( # note timings on c are not those of the aggregate child | |
pn, child, ncall=c.ncall, nactualcall=c.nactualcall, | |
ttot=c.ttot, tsub=c.tsub | |
) | |
return call_tree | |
_marker = ContextVar('yappi_task_marker', default=0) | |
_task_counter = 0 | |
_callees = {} | |
_tokens = {} | |
def task_factory(loop, coro): | |
task_counter = _marker.get() | |
_callees[task_counter] = (coro.cr_code.co_filename, coro.cr_code.co_firstlineno) | |
task = Task(coro, loop=loop) | |
if task._source_traceback: | |
del task._source_traceback[-1] | |
return task | |
def start(loop, builtins=True, profile_threads=True): | |
"""installs a profile function and a tag callback to help profiling Tasks | |
ATTN: does nothing with `profile_threads=True` (even for single-threaded) | |
and crashes for multi-threaded programs | |
""" | |
assert builtins, "Must use `builtins=True`" | |
assert not profile_threads, "`profile_threads=True` currently not working" | |
original_task_factory = loop.get_task_factory() | |
assert original_task_factory is None, "Task factory must be default task factory." | |
def _task_tag_cbk(): | |
return _marker.get(0) | |
yappi.set_tag_callback(_task_tag_cbk) | |
_loop_create_task_loc = ( | |
loop.create_task.__code__.co_filename, loop.create_task.__code__.co_firstlineno | |
) | |
def profile_thread_callback(frame, event, arg): | |
global _task_counter | |
if event == 'call' and (frame.f_code.co_filename, frame.f_code.co_firstlineno) == _loop_create_task_loc: | |
_task_counter += 1 | |
_tokens[_task_counter] = _marker.set(_task_counter) | |
elif event == 'return' and (frame.f_code.co_filename, frame.f_code.co_firstlineno) == _loop_create_task_loc: | |
token = _tokens.pop(_marker.get(None), None) | |
if token: | |
_marker.reset(token) | |
return _yappi._profile_event(frame, event, arg) | |
if profile_threads: | |
threading.setprofile(profile_thread_callback) | |
# fingers crossed this is properly initializing yappi and setting the profile hook. seems you | |
# cannot get the yappi hook via sys.getprofile and wrap | |
_yappi.start(builtins, profile_threads) | |
sys.setprofile(profile_thread_callback) | |
loop.set_task_factory(task_factory) | |
def stop(loop): | |
loop.set_task_factory(None) | |
yappi.stop() | |
def clear(): | |
global _task_counter | |
_task_counter = 0 | |
_callees.clear() | |
_tokens.clear() | |
@contextmanager | |
def run(loop, builtins=False, profile_threads=True): | |
""" | |
Context manger for profiling block of code. | |
Starts profiling before entering the context, and stop profilying when | |
exiting from the context. | |
Usage: | |
with yappi.run(): | |
print("this call is profiled") | |
Warning: don't use this recursively, the inner context will stop profiling | |
when exited: | |
with yappi.run(): | |
with yappi.run(): | |
print("this call will be profiled") | |
print("this call will *not* be profiled") | |
""" | |
start(loop, builtins, profile_threads) | |
try: | |
yield | |
finally: | |
stop(loop) | |
def rewrite_call_tree(tree: DiGraph, propagate_timings=True) -> DiGraph: | |
"""move coroutines run via loop.create_task in the call tree | |
Without rewrite the caller for coroutines run via loop.create_task | |
is Handle._run. This rewrites the call tree to make them appear as | |
children of loop.create_task. If `propagate_timings` ttot of the | |
coroutine is propagated up until the first non-asyncio caller. | |
""" | |
g = DiGraph() | |
_copies = {} | |
def copy_node(n): | |
return _copies.setdefault(n, n.copy()) | |
task_factory_loc = ( | |
task_factory.__code__.co_filename, task_factory.__code__.co_firstlineno | |
) | |
task_factory_calls = {fs.tag: copy_node(fs) for fs in tree if (fs.module, fs.lineno) == task_factory_loc} | |
callee_nodes = set() | |
for u, v, data in tree.edges(data=True): | |
if ( | |
(callee_loc := _callees.get(v.tag)) | |
and (v.module, v.lineno) == callee_loc | |
and u.full_name == "__builtin__.<method 'run' of 'Context' objects>" | |
): | |
# make this a child of task_factory | |
tfc = task_factory_calls[v.tag] | |
callee = copy_node(v) | |
callee_nodes.add(callee) | |
g.add_edge(tfc, callee, **data) | |
else: | |
g.add_edge(copy_node(u), copy_node(v), **data) | |
if propagate_timings: | |
roots = [n for n, d in g.in_degree if not d] | |
seen_callees = set() | |
edges_to_fix = set() | |
for r in roots: | |
for callee in callee_nodes: | |
paths = list(all_simple_paths(g, r, callee)) | |
if not paths: | |
continue | |
elif len(paths) > 1: | |
raise ValueError('Cannot update timings due to recursion.') | |
if callee in seen_callees: | |
raise ValueError('Cannot update timings b/c there is no unique root for callee.') | |
else: | |
seen_callees.add(callee) | |
for u, v in reversed(list(pairwise(paths[0]))): | |
if (u.module, u.lineno) == task_factory_loc or os.path.commonpath( | |
[os.path.join(sys.base_prefix, 'lib/asyncio'), u.module] | |
).endswith('asyncio'): | |
edges_to_fix.add((u, v)) | |
else: | |
break | |
gg = g.edge_subgraph(edges_to_fix) | |
def fix_subgraph_timings(u: ProfileNode, ttot=0.) -> float: | |
gather_loc = (gather.__code__.co_filename, gather.__code__.co_firstlineno) | |
ensure_future_loc = (ensure_future.__code__.co_filename, ensure_future.__code__.co_firstlineno) | |
if (u.module, u.lineno) == gather_loc: | |
# gather's ttot is max(ttot for ttot in coros) but coros are run via ensure_future | |
for _, v, data in gg.out_edges(u, data=True): | |
vttot = 0. | |
parallel_vttot = 0. | |
if (v.module, v.lineno) == ensure_future_loc: # TODO make this more selective | |
fix_subgraph_timings(v) | |
for _, w, ddata in g.out_edges(v, data=True): # note this is in g | |
if w in gg: | |
parallel_vttot = max(parallel_vttot, ddata['ttot']) | |
else: | |
vttot += ddata['ttot'] | |
vttot = vttot + parallel_vttot | |
ttot += vttot - data['ttot'] | |
u.ttot += vttot - data['ttot'] | |
data['ttot'] = vttot | |
else: | |
vttot = fix_subgraph_timings(v, data['ttot']) | |
data['ttot'] = vttot | |
ttot += vttot | |
u.ttot += vttot | |
else: | |
for _, v, data in gg.out_edges(u, data=True): | |
vttot = fix_subgraph_timings(v, data['ttot']) | |
data['ttot'] += (vttot - data['ttot']) | |
ttot += vttot | |
u.ttot += vttot | |
return ttot | |
roots = [n for n, d in gg.in_degree if not d] | |
for r in roots: | |
fix_subgraph_timings(r) | |
return g | |
def call_tree_to_pstat(tree: DiGraph) ->Stats: | |
# create pstats, seems easier if we are fixing timings bottom-up | |
class _PStatHolder: | |
def __init__(self, d): | |
self.stats = d | |
def create_stats(self): | |
pass | |
def pstat_id(fs): | |
return (fs.module, fs.lineno, fs.name) | |
_pdict = {} | |
# populate the pstat dict. | |
for fs in tree: | |
callers = {} | |
for caller, _, data in tree.in_edges(fs, data=True): | |
callers[pstat_id(caller)] = (data['ncall'], data['nactualcall'], data['tsub'], data['ttot']) | |
_pdict[pstat_id(fs)] = ( | |
fs.ncall, | |
fs.nactualcall, | |
fs.tsub, | |
fs.ttot, | |
callers, | |
) | |
return Stats(_PStatHolder(_pdict)) | |
if __name__ == '__main__': | |
from asyncio import create_task, get_event_loop, sleep | |
from itertools import chain | |
import networkx | |
async def doit(): | |
tsk = create_task(sleep(1.)) | |
coro1 = sleep(1.) | |
coro2 = sleep(2.) | |
coro3 = sleep(3.) | |
await gather(coro1, coro2, coro3) | |
return await tsk | |
loop = get_event_loop() | |
yappi.set_clock_type('wall') | |
with run(loop, builtins=True, profile_threads=False): | |
loop.run_until_complete(doit()) | |
stats = TaggedYFuncStats().get(filter_callback=lambda fs: fs.tag >= 0) | |
tree = create_call_tree(stats) | |
rewritten = rewrite_call_tree(tree) | |
# get the node for doit (TaggedYFuncStats appended the tag, so goes by doit<1> now) | |
doit = next(fs for fs in tree if fs.name == 'doit<1>') | |
# w/o rewrite you see create_task and gather etc as descendants of doit but not sleep as | |
# it is always wrapped in a Task | |
create_task, = [fs for fs in networkx.descendants(tree, doit) if fs.name.startswith('create_task')] | |
gather, = [fs for fs in networkx.descendants(tree, doit) if fs.name.startswith('gather')] | |
assert not [fs for fs in networkx.descendants(tree, doit) if fs.name.startswith('sleep')] | |
# after rewrite sleep shows-up in the tree under create_task/gather | |
[fs.name for fs in networkx.descendants(rewritten, create_task) if fs.name.startswith('sleep')] # ['sleep<2>'] | |
[fs.name for fs in networkx.descendants(rewritten, gather) if fs.name.startswith('sleep')] # ['sleep<3>', 'sleep<5>', 'sleep<4>'] | |
# w/o rewrite doit's own time includes the time spend awaiting tsk (none as run in parallel | |
# with gather) and gather (~3. from sleep(3.), create_task and gather almost do not consume | |
# time at all | |
doit.ttot, doit.tsub # ~ (3., 3.) | |
[fs.ttot for fs in tree if fs.name.startswith('create_task')] # ~ [0.] | |
[fs.ttot for fs in tree if fs.name.startswith('gather')] # ~ [0.] | |
# doit's timings are unaffected by the rewrite but create_task and gather now show | |
assert (doit.ttot, doit.tsub) == ((r_doit := next(fs for fs in rewritten if fs.name == 'doit<1>')).ttot, r_doit.tsub) | |
[fs.ttot for fs in rewritten if fs.name.startswith('create_task')] # ~ [1.] | |
[fs.ttot for fs in rewritten if fs.name.startswith('gather')] # ~ [3.] | |
# sleep's timings are not affected by rewrite | |
# {'sleep<2>': 1., | |
# 'sleep<3>': 1., | |
# 'sleep<4>': 2., | |
# 'sleep<5>': 3.} | |
assert { | |
fs.name: fs.ttot for fs in sorted(tree, key=lambda fs: fs.tag) if fs.name.startswith('sleep') | |
} == { | |
fs.name: fs.ttot for fs in sorted(rewritten, key=lambda fs: fs.tag) if fs.name.startswith('sleep') | |
} | |
# w/o rewrite caller of doit/sleep is the loop (via Context.run) | |
assert set(fs.name for fs in chain.from_iterable( | |
tree.predecessors(fs) for fs in tree if fs.name.startswith('sleep') or fs.name.startswith('doit') | |
)) == set(["<method 'run' of 'Context' objects>"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment