Created
June 6, 2023 21:06
-
-
Save amakelov/c48d1bfb2eec75385dd5df2d81dcd759 to your computer and use it in GitHub Desktop.
Code for "Practical dependency tracking for Python function calls"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################################################################################ | |
### Proposed solution | |
################################################################################ | |
from types import FunctionType | |
from functools import wraps | |
from typing import Optional, Callable, Any | |
import copy | |
from functools import update_wrapper | |
class TracerState: | |
current: Optional['Tracer'] = None | |
def track(f: Callable): | |
f = make_tracked_copy(f) | |
@wraps(f) # to make the wrapped function look like `f` | |
def wrapper(*args, **kwargs): | |
tracer = TracerState.current | |
if tracer is not None: | |
tracer.register_call(func=f) # put call to `f` on stack | |
result = f(*args, **kwargs) | |
tracer.register_return() # pop call to `f` from stack | |
return result | |
else: | |
return f(*args, **kwargs) | |
return wrapper | |
class Tracer: | |
def __init__(self): | |
# call stack of (module name, qualified function/method name) tuples | |
self.stack = [] | |
# list of (caller module, caller qualname, callee module, callee | |
# qualname) tuples | |
self.graph = [] | |
def register_call(self, func: Callable): | |
# Add a call to the stack and the graph | |
module_name, qual_name = func.__module__, func.__qualname__ | |
self.stack.append((module_name, qual_name)) | |
if len(self.stack) > 1: | |
caller_module, caller_qual_name = self.stack[-2] | |
self.graph.append((caller_module, caller_qual_name, | |
module_name, qual_name)) | |
def register_global_access(self, key: str, value): # <- ADD THIS METHOD | |
assert len(self.stack) > 0 | |
caller_module, caller_qual_name = self.stack[-1] | |
self.graph.append((caller_module, caller_qual_name, {key: value})) | |
def register_return(self): | |
self.stack.pop() | |
def __enter__(self): | |
TracerState.current = self | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
TracerState.current = None | |
class TrackedDict(dict): | |
def __init__(self, original: dict): | |
self.__original__ = original | |
def __getitem__(self, __key: str) -> Any: | |
value = self.__original__.__getitem__(__key) | |
if TracerState.current is not None: | |
tracer = TracerState.current | |
tracer.register_global_access(key=__key, value=value) | |
return value | |
def make_tracked_copy(f: FunctionType) -> FunctionType: | |
result = FunctionType( | |
code=f.__code__, | |
globals=TrackedDict(f.__globals__), | |
name=f.__name__, | |
argdefs=f.__defaults__, | |
closure=f.__closure__, | |
) | |
result = update_wrapper(result, f) | |
result.__module__ = f.__module__ | |
result.__kwdefaults__ = copy.deepcopy(f.__kwdefaults__) | |
result.__annotations__ = copy.deepcopy(f.__annotations__) | |
return result | |
################################################################################ | |
### Example | |
################################################################################ | |
A = 23 | |
B = 42 | |
@track | |
def f(x): | |
return x + A | |
class C: | |
@track | |
def __init__(self, x): | |
self.x = x + B | |
@track | |
def m(self, y): | |
return self.x + y | |
class D: | |
@track | |
def __init__(self, x): | |
self.x = x + f(x) | |
@track | |
def m(self, y): | |
return y + A | |
@track | |
def g(x): | |
if x % 2 == 0: | |
return C(x).m(x) | |
else: | |
return C.D(x).m(x) | |
if __name__ == '__main__': | |
with Tracer() as t: | |
g(23) | |
print(t.graph) | |
with Tracer() as t: | |
g(42) | |
print(t.graph) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment