Skip to content

Instantly share code, notes, and snippets.

@amakelov
Created June 6, 2023 21:06
Show Gist options
  • Save amakelov/c48d1bfb2eec75385dd5df2d81dcd759 to your computer and use it in GitHub Desktop.
Save amakelov/c48d1bfb2eec75385dd5df2d81dcd759 to your computer and use it in GitHub Desktop.
Code for "Practical dependency tracking for Python function calls"
################################################################################
### Proposed solution
################################################################################
from types import FunctionType
from functools import wraps
from typing import Optional, Callable, Any
import copy
from functools import update_wrapper
class TracerState:
current: Optional['Tracer'] = None
def track(f: Callable):
f = make_tracked_copy(f)
@wraps(f) # to make the wrapped function look like `f`
def wrapper(*args, **kwargs):
tracer = TracerState.current
if tracer is not None:
tracer.register_call(func=f) # put call to `f` on stack
result = f(*args, **kwargs)
tracer.register_return() # pop call to `f` from stack
return result
else:
return f(*args, **kwargs)
return wrapper
class Tracer:
def __init__(self):
# call stack of (module name, qualified function/method name) tuples
self.stack = []
# list of (caller module, caller qualname, callee module, callee
# qualname) tuples
self.graph = []
def register_call(self, func: Callable):
# Add a call to the stack and the graph
module_name, qual_name = func.__module__, func.__qualname__
self.stack.append((module_name, qual_name))
if len(self.stack) > 1:
caller_module, caller_qual_name = self.stack[-2]
self.graph.append((caller_module, caller_qual_name,
module_name, qual_name))
def register_global_access(self, key: str, value): # <- ADD THIS METHOD
assert len(self.stack) > 0
caller_module, caller_qual_name = self.stack[-1]
self.graph.append((caller_module, caller_qual_name, {key: value}))
def register_return(self):
self.stack.pop()
def __enter__(self):
TracerState.current = self
return self
def __exit__(self, exc_type, exc_value, traceback):
TracerState.current = None
class TrackedDict(dict):
def __init__(self, original: dict):
self.__original__ = original
def __getitem__(self, __key: str) -> Any:
value = self.__original__.__getitem__(__key)
if TracerState.current is not None:
tracer = TracerState.current
tracer.register_global_access(key=__key, value=value)
return value
def make_tracked_copy(f: FunctionType) -> FunctionType:
result = FunctionType(
code=f.__code__,
globals=TrackedDict(f.__globals__),
name=f.__name__,
argdefs=f.__defaults__,
closure=f.__closure__,
)
result = update_wrapper(result, f)
result.__module__ = f.__module__
result.__kwdefaults__ = copy.deepcopy(f.__kwdefaults__)
result.__annotations__ = copy.deepcopy(f.__annotations__)
return result
################################################################################
### Example
################################################################################
A = 23
B = 42
@track
def f(x):
return x + A
class C:
@track
def __init__(self, x):
self.x = x + B
@track
def m(self, y):
return self.x + y
class D:
@track
def __init__(self, x):
self.x = x + f(x)
@track
def m(self, y):
return y + A
@track
def g(x):
if x % 2 == 0:
return C(x).m(x)
else:
return C.D(x).m(x)
if __name__ == '__main__':
with Tracer() as t:
g(23)
print(t.graph)
with Tracer() as t:
g(42)
print(t.graph)
@jgarvin
Copy link

jgarvin commented Oct 19, 2024

How does this deal with the problem that if a user imports a module foo and then calls foo.bar() this will be seen as an access of the entire foo module object? If the idea is fine grained dependencies, do you just filter out accesses to globals that are modules?

@amakelov
Copy link
Author

Thanks, great question. The code shown here is an MVP of a more complex (and still buggy) system that lives under https://github.com/amakelov/mandala/tree/master/mandala/deps. In general, the way "function dependencies" and "value dependencies" are tracked is different, and this creates some potential issues in corner cases when the distinction is not very clear:

  • for functions (and other callable things), we explicitly use the @track decorator, and the call registers on the stack whenever it happens. This is pretty straightforward, and is the behavior you'd want;
  • for value-like things, we instead use the "tracked globals dict" mechanism to catch when they're accessed. So it makes sense when doing this to skip over module-like and function-like things - unless you're doing something weird like inspecting some property of a function or a module, which would (in the current version of the code) go unnoticed. There's probably reasonable ways to work around this, but it's not really a crux for my usecases.

If you want to use the "tracked globals dict" to keep track of function-like dependencies, it is a little wonky, in the sense that, assuming you have a file foo.py with a function bar in it which is decorated with @track, there would be a difference between the following two cases:

import foo

@op
def f(x):
    return foo.bar(x) # here, we register a global access to the module `foo` but beyond that we don't really know what is accessed 

and

from foo import bar

@op
def f(x):
    return bar(x) # here we know exactly which function we're accessing and this registers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment