Skip to content

Instantly share code, notes, and snippets.

@amakelov
Created June 6, 2023 21:06
Show Gist options
  • Save amakelov/c48d1bfb2eec75385dd5df2d81dcd759 to your computer and use it in GitHub Desktop.
Save amakelov/c48d1bfb2eec75385dd5df2d81dcd759 to your computer and use it in GitHub Desktop.
Code for "Practical dependency tracking for Python function calls"
################################################################################
### Proposed solution
################################################################################
from types import FunctionType
from functools import wraps
from typing import Optional, Callable, Any
import copy
from functools import update_wrapper
class TracerState:
current: Optional['Tracer'] = None
def track(f: Callable):
f = make_tracked_copy(f)
@wraps(f) # to make the wrapped function look like `f`
def wrapper(*args, **kwargs):
tracer = TracerState.current
if tracer is not None:
tracer.register_call(func=f) # put call to `f` on stack
result = f(*args, **kwargs)
tracer.register_return() # pop call to `f` from stack
return result
else:
return f(*args, **kwargs)
return wrapper
class Tracer:
def __init__(self):
# call stack of (module name, qualified function/method name) tuples
self.stack = []
# list of (caller module, caller qualname, callee module, callee
# qualname) tuples
self.graph = []
def register_call(self, func: Callable):
# Add a call to the stack and the graph
module_name, qual_name = func.__module__, func.__qualname__
self.stack.append((module_name, qual_name))
if len(self.stack) > 1:
caller_module, caller_qual_name = self.stack[-2]
self.graph.append((caller_module, caller_qual_name,
module_name, qual_name))
def register_global_access(self, key: str, value): # <- ADD THIS METHOD
assert len(self.stack) > 0
caller_module, caller_qual_name = self.stack[-1]
self.graph.append((caller_module, caller_qual_name, {key: value}))
def register_return(self):
self.stack.pop()
def __enter__(self):
TracerState.current = self
return self
def __exit__(self, exc_type, exc_value, traceback):
TracerState.current = None
class TrackedDict(dict):
def __init__(self, original: dict):
self.__original__ = original
def __getitem__(self, __key: str) -> Any:
value = self.__original__.__getitem__(__key)
if TracerState.current is not None:
tracer = TracerState.current
tracer.register_global_access(key=__key, value=value)
return value
def make_tracked_copy(f: FunctionType) -> FunctionType:
result = FunctionType(
code=f.__code__,
globals=TrackedDict(f.__globals__),
name=f.__name__,
argdefs=f.__defaults__,
closure=f.__closure__,
)
result = update_wrapper(result, f)
result.__module__ = f.__module__
result.__kwdefaults__ = copy.deepcopy(f.__kwdefaults__)
result.__annotations__ = copy.deepcopy(f.__annotations__)
return result
################################################################################
### Example
################################################################################
A = 23
B = 42
@track
def f(x):
return x + A
class C:
@track
def __init__(self, x):
self.x = x + B
@track
def m(self, y):
return self.x + y
class D:
@track
def __init__(self, x):
self.x = x + f(x)
@track
def m(self, y):
return y + A
@track
def g(x):
if x % 2 == 0:
return C(x).m(x)
else:
return C.D(x).m(x)
if __name__ == '__main__':
with Tracer() as t:
g(23)
print(t.graph)
with Tracer() as t:
g(42)
print(t.graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment