Skip to content

Instantly share code, notes, and snippets.

@asodeur
Last active May 7, 2020 14:03
Show Gist options
  • Save asodeur/2a5d9101cc3a49b8181192ff56aafec2 to your computer and use it in GitHub Desktop.
Save asodeur/2a5d9101cc3a49b8181192ff56aafec2 to your computer and use it in GitHub Desktop.
helpers to profile Tasks
from __future__ import annotations
from asyncio import ensure_future, gather, Task
from contextlib import contextmanager
from contextvars import ContextVar
from more_itertools import pairwise
from networkx import all_simple_paths, DiGraph
import os
from pstats import Stats
import sys
import threading
from typing import TYPE_CHECKING
import yappi, _yappi
if TYPE_CHECKING:
from typing import Dict, Tuple
__all__ = [
'call_tree_to_pstat', 'clear', 'create_call_tree', 'rewrite_call_tree', 'run',
'start', 'stop', 'TaggedYFuncStats'
]
class ProfileNode:
"""used as node in the (networkx) call graph, omits some attributes from yappi.YFuncStat"""
__slots__ = (
'name', 'module', 'lineno', 'full_name', 'ncall', 'nactualcall', 'ttot', 'tsub', 'tag'
)
def __init__(self, name, module, lineno, full_name, ncall, nactualcall, ttot, tsub, tag=0):
self.name = name
self.module = module
self.lineno = lineno
self.full_name = full_name
self.ncall = ncall
self.nactualcall = nactualcall
self.ttot = ttot
self.tsub = tsub
self.tag = tag
def copy(self):
return ProfileNode(
self.name, self.module, self.lineno, self.full_name, self.ncall, self.nactualcall,
self.ttot, self.tsub, self.tag
)
def __eq__(self, other):
return (self.full_name == other.full_name and self.tag == other.tag)
def __hash__(self):
return hash((self.full_name, self.tag))
def __repr__(self):
return f'{self.full_name}<{self.tag}>'
class TaggedYFuncStats(yappi.YFuncStats):
def _enumerator(self, stat_entry):
"""appends the tag to the function name: name -> name<tag>
works around the fact that YFuncStat compares equal if full_name matches
regardless of the tags
"""
global _fn_descriptor_dict
fname, fmodule, flineno, fncall, fnactualcall, fbuiltin, fttot, ftsub, \
findex, fchildren, fctxid, fctxname, ftag, ffn_descriptor = stat_entry
if ftag:
fname = f'{fname}<{ftag}>'
stat_entry = (fname,) + stat_entry[1:]
# builtin function?
ffull_name = yappi._func_fullname(bool(fbuiltin), fmodule, flineno, fname)
ftavg = fttot / fncall
fstat = yappi.YFuncStat(stat_entry + (ftavg, ffull_name))
yappi._fn_descriptor_dict[ffull_name] = ffn_descriptor
# do not show profile stats of yappi itself.
if os.path.basename(
fstat.module
) == "yappi.py" or fstat.module == "_yappi":
return
fstat.builtin = bool(fstat.builtin)
if self._filter_callback:
if not self._filter_callback(fstat):
return
self.append(fstat)
# hold the max idx number for merging new entries(for making the merging
# entries indexes unique)
if self._idx_max < fstat.index:
self._idx_max = fstat.index
def create_call_tree(stats: yappi.YFuncStats) -> DiGraph:
"""create a directed graph from yappi stats"""
call_tree = DiGraph()
for fs in stats:
pn = ProfileNode(fs.name, fs.module, fs.lineno, fs.full_name, fs.ncall, fs.nactualcall, fs.ttot, fs.tsub, fs.tag)
call_tree.add_node(pn)
for c in fs.children:
child = stats[c.full_name]
child = ProfileNode(
child.name, child.module, child.lineno, child.full_name, child.ncall, child.nactualcall,
child.ttot, child.tsub, child.tag
)
call_tree.add_edge( # note timings on c are not those of the aggregate child
pn, child, ncall=c.ncall, nactualcall=c.nactualcall,
ttot=c.ttot, tsub=c.tsub
)
return call_tree
_marker = ContextVar('yappi_task_marker', default=0)
_task_counter = 0
_callees = {}
_tokens = {}
def task_factory(loop, coro):
task_counter = _marker.get()
_callees[task_counter] = (coro.cr_code.co_filename, coro.cr_code.co_firstlineno)
task = Task(coro, loop=loop)
if task._source_traceback:
del task._source_traceback[-1]
return task
def start(loop, builtins=True, profile_threads=True):
"""installs a profile function and a tag callback to help profiling Tasks
ATTN: does nothing with `profile_threads=True` (even for single-threaded)
and crashes for multi-threaded programs
"""
assert builtins, "Must use `builtins=True`"
assert not profile_threads, "`profile_threads=True` currently not working"
original_task_factory = loop.get_task_factory()
assert original_task_factory is None, "Task factory must be default task factory."
def _task_tag_cbk():
return _marker.get(0)
yappi.set_tag_callback(_task_tag_cbk)
_loop_create_task_loc = (
loop.create_task.__code__.co_filename, loop.create_task.__code__.co_firstlineno
)
def profile_thread_callback(frame, event, arg):
global _task_counter
if event == 'call' and (frame.f_code.co_filename, frame.f_code.co_firstlineno) == _loop_create_task_loc:
_task_counter += 1
_tokens[_task_counter] = _marker.set(_task_counter)
elif event == 'return' and (frame.f_code.co_filename, frame.f_code.co_firstlineno) == _loop_create_task_loc:
token = _tokens.pop(_marker.get(None), None)
if token:
_marker.reset(token)
return _yappi._profile_event(frame, event, arg)
if profile_threads:
threading.setprofile(profile_thread_callback)
# fingers crossed this is properly initializing yappi and setting the profile hook. seems you
# cannot get the yappi hook via sys.getprofile and wrap
_yappi.start(builtins, profile_threads)
sys.setprofile(profile_thread_callback)
loop.set_task_factory(task_factory)
def stop(loop):
loop.set_task_factory(None)
yappi.stop()
def clear():
global _task_counter
_task_counter = 0
_callees.clear()
_tokens.clear()
@contextmanager
def run(loop, builtins=False, profile_threads=True):
"""
Context manger for profiling block of code.
Starts profiling before entering the context, and stop profilying when
exiting from the context.
Usage:
with yappi.run():
print("this call is profiled")
Warning: don't use this recursively, the inner context will stop profiling
when exited:
with yappi.run():
with yappi.run():
print("this call will be profiled")
print("this call will *not* be profiled")
"""
start(loop, builtins, profile_threads)
try:
yield
finally:
stop(loop)
def rewrite_call_tree(tree: DiGraph, propagate_timings=True) -> DiGraph:
"""move coroutines run via loop.create_task in the call tree
Without rewrite the caller for coroutines run via loop.create_task
is Handle._run. This rewrites the call tree to make them appear as
children of loop.create_task. If `propagate_timings` ttot of the
coroutine is propagated up until the first non-asyncio caller.
"""
g = DiGraph()
_copies = {}
def copy_node(n):
return _copies.setdefault(n, n.copy())
task_factory_loc = (
task_factory.__code__.co_filename, task_factory.__code__.co_firstlineno
)
task_factory_calls = {fs.tag: copy_node(fs) for fs in tree if (fs.module, fs.lineno) == task_factory_loc}
callee_nodes = set()
for u, v, data in tree.edges(data=True):
if (
(callee_loc := _callees.get(v.tag))
and (v.module, v.lineno) == callee_loc
and u.full_name == "__builtin__.<method 'run' of 'Context' objects>"
):
# make this a child of task_factory
tfc = task_factory_calls[v.tag]
callee = copy_node(v)
callee_nodes.add(callee)
g.add_edge(tfc, callee, **data)
else:
g.add_edge(copy_node(u), copy_node(v), **data)
if propagate_timings:
roots = [n for n, d in g.in_degree if not d]
seen_callees = set()
edges_to_fix = set()
for r in roots:
for callee in callee_nodes:
paths = list(all_simple_paths(g, r, callee))
if not paths:
continue
elif len(paths) > 1:
raise ValueError('Cannot update timings due to recursion.')
if callee in seen_callees:
raise ValueError('Cannot update timings b/c there is no unique root for callee.')
else:
seen_callees.add(callee)
for u, v in reversed(list(pairwise(paths[0]))):
if (u.module, u.lineno) == task_factory_loc or os.path.commonpath(
[os.path.join(sys.base_prefix, 'lib/asyncio'), u.module]
).endswith('asyncio'):
edges_to_fix.add((u, v))
else:
break
gg = g.edge_subgraph(edges_to_fix)
def fix_subgraph_timings(u: ProfileNode, ttot=0.) -> float:
gather_loc = (gather.__code__.co_filename, gather.__code__.co_firstlineno)
ensure_future_loc = (ensure_future.__code__.co_filename, ensure_future.__code__.co_firstlineno)
if (u.module, u.lineno) == gather_loc:
# gather's ttot is max(ttot for ttot in coros) but coros are run via ensure_future
for _, v, data in gg.out_edges(u, data=True):
vttot = 0.
parallel_vttot = 0.
if (v.module, v.lineno) == ensure_future_loc: # TODO make this more selective
fix_subgraph_timings(v)
for _, w, ddata in g.out_edges(v, data=True): # note this is in g
if w in gg:
parallel_vttot = max(parallel_vttot, ddata['ttot'])
else:
vttot += ddata['ttot']
vttot = vttot + parallel_vttot
ttot += vttot - data['ttot']
u.ttot += vttot - data['ttot']
data['ttot'] = vttot
else:
vttot = fix_subgraph_timings(v, data['ttot'])
data['ttot'] = vttot
ttot += vttot
u.ttot += vttot
else:
for _, v, data in gg.out_edges(u, data=True):
vttot = fix_subgraph_timings(v, data['ttot'])
data['ttot'] += (vttot - data['ttot'])
ttot += vttot
u.ttot += vttot
return ttot
roots = [n for n, d in gg.in_degree if not d]
for r in roots:
fix_subgraph_timings(r)
return g
def call_tree_to_pstat(tree: DiGraph) ->Stats:
# create pstats, seems easier if we are fixing timings bottom-up
class _PStatHolder:
def __init__(self, d):
self.stats = d
def create_stats(self):
pass
def pstat_id(fs):
return (fs.module, fs.lineno, fs.name)
_pdict = {}
# populate the pstat dict.
for fs in tree:
callers = {}
for caller, _, data in tree.in_edges(fs, data=True):
callers[pstat_id(caller)] = (data['ncall'], data['nactualcall'], data['tsub'], data['ttot'])
_pdict[pstat_id(fs)] = (
fs.ncall,
fs.nactualcall,
fs.tsub,
fs.ttot,
callers,
)
return Stats(_PStatHolder(_pdict))
if __name__ == '__main__':
from asyncio import create_task, get_event_loop, sleep
from itertools import chain
import networkx
async def doit():
tsk = create_task(sleep(1.))
coro1 = sleep(1.)
coro2 = sleep(2.)
coro3 = sleep(3.)
await gather(coro1, coro2, coro3)
return await tsk
loop = get_event_loop()
yappi.set_clock_type('wall')
with run(loop, builtins=True, profile_threads=False):
loop.run_until_complete(doit())
stats = TaggedYFuncStats().get(filter_callback=lambda fs: fs.tag >= 0)
tree = create_call_tree(stats)
rewritten = rewrite_call_tree(tree)
# get the node for doit (TaggedYFuncStats appended the tag, so goes by doit<1> now)
doit = next(fs for fs in tree if fs.name == 'doit<1>')
# w/o rewrite you see create_task and gather etc as descendants of doit but not sleep as
# it is always wrapped in a Task
create_task, = [fs for fs in networkx.descendants(tree, doit) if fs.name.startswith('create_task')]
gather, = [fs for fs in networkx.descendants(tree, doit) if fs.name.startswith('gather')]
assert not [fs for fs in networkx.descendants(tree, doit) if fs.name.startswith('sleep')]
# after rewrite sleep shows-up in the tree under create_task/gather
[fs.name for fs in networkx.descendants(rewritten, create_task) if fs.name.startswith('sleep')] # ['sleep<2>']
[fs.name for fs in networkx.descendants(rewritten, gather) if fs.name.startswith('sleep')] # ['sleep<3>', 'sleep<5>', 'sleep<4>']
# w/o rewrite doit's own time includes the time spend awaiting tsk (none as run in parallel
# with gather) and gather (~3. from sleep(3.), create_task and gather almost do not consume
# time at all
doit.ttot, doit.tsub # ~ (3., 3.)
[fs.ttot for fs in tree if fs.name.startswith('create_task')] # ~ [0.]
[fs.ttot for fs in tree if fs.name.startswith('gather')] # ~ [0.]
# doit's timings are unaffected by the rewrite but create_task and gather now show
assert (doit.ttot, doit.tsub) == ((r_doit := next(fs for fs in rewritten if fs.name == 'doit<1>')).ttot, r_doit.tsub)
[fs.ttot for fs in rewritten if fs.name.startswith('create_task')] # ~ [1.]
[fs.ttot for fs in rewritten if fs.name.startswith('gather')] # ~ [3.]
# sleep's timings are not affected by rewrite
# {'sleep<2>': 1.,
# 'sleep<3>': 1.,
# 'sleep<4>': 2.,
# 'sleep<5>': 3.}
assert {
fs.name: fs.ttot for fs in sorted(tree, key=lambda fs: fs.tag) if fs.name.startswith('sleep')
} == {
fs.name: fs.ttot for fs in sorted(rewritten, key=lambda fs: fs.tag) if fs.name.startswith('sleep')
}
# w/o rewrite caller of doit/sleep is the loop (via Context.run)
assert set(fs.name for fs in chain.from_iterable(
tree.predecessors(fs) for fs in tree if fs.name.startswith('sleep') or fs.name.startswith('doit')
)) == set(["<method 'run' of 'Context' objects>"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment