Skip to content

Instantly share code, notes, and snippets.

@omc8db
Created December 7, 2023 03:48
Show Gist options
  • Save omc8db/2f17e7e96b779fae1c645598f165b6c1 to your computer and use it in GitHub Desktop.
Save omc8db/2f17e7e96b779fae1c645598f165b6c1 to your computer and use it in GitHub Desktop.
pyStateGraph
import inspect
from collections import defaultdict
class StateGraph:
"""Dictionary that lets you add keys calculated from other keys
When an entry is updated all dependent values update recursively
Values may be any type valid for use in a standard dictionary.
Keys must be valid names for function arguments.
add_calculation allows you to add a new channel that is calculated on the fly as its dependencies update
Example:
s = StateGraph()
s["foo"] = 1
# Create an entry "baz" that is automatically calculated from foo and bar
s.add_calculation("baz", lambda foo, bar: 3 * foo + 2 * bar)
print(s["baz"]) # Prints "None" because not all inputs for baz are defined
# Calculated entries can depend on other calculated entries
s.add_calculation("qux", lambda baz: baz*2)
s["bar"] = 2
# These values were recalculatd when their dependency bar changed
print(s["baz"]) # Prints 7
print(s["qux"]) # Prints 14
"""
def __init__(self):
self._measurements={}
# input -> list[output]
self._deps = defaultdict(list)
# key name -> (evaluator, inputs)
self._functions: dict[str, tuple(callable, list[str])] = {}
def __getitem__(self, key):
return self._measurements[key]
def __setitem__(self, key, value):
self._measurements[key] = value
for dep in self._deps[key]:
self._calculate(dep)
def __str__(self):
return self._measurements.__str__()
def add_calculation(self, key: str, function: callable):
deps = inspect.getfullargspec(function).args
for dep in deps:
self._deps[dep].append(key)
# Referenced inputs that don't exist yet start undefined
self._measurements[dep] = self._measurements.get(dep, None)
self._functions[key] = function, deps
self._calculate(key)
def _calculate(self, key: str):
f, input_names = self._functions[key]
inputs = [self._measurements[c] for c in input_names]
if any((x is None for x in inputs)):
self[key] = None
else:
self[key] = f(*inputs)
@omc8db
Copy link
Author

omc8db commented Dec 7, 2023

License is public domain, copy-paste wherever you need it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment