Skip to content

Instantly share code, notes, and snippets.

@ferrine
Last active November 23, 2022 07:59
Show Gist options
  • Save ferrine/70cbcf6d3b6f033ac070d70b10ac8d25 to your computer and use it in GitHub Desktop.
Save ferrine/70cbcf6d3b6f033ac070d70b10ac8d25 to your computer and use it in GitHub Desktop.
clone substitute example
import pytensor.tensor as at
import pytensor
import numpy as np
from typing import Collection, Optional, Union, Iterable, Tuple, Dict
from pytensor.graph.basic import Apply, Variable
def is_in_ancestors(node, search, *, known_independent=None, known_dependent=None):
if known_independent is None:
known_independent = set()
if known_dependent is None:
known_dependent = set()
if node in known_independent:
return True
ancestors = pytensor.graph.basic.ancestors([node])
for candidate in ancestors:
if candidate in known_dependent or candidate in search:
known_dependent.add(node)
return True
known_independent.add(node)
return False
def independent_apply_nodes_between(ins, outs):
known_independent = set()
known_dependent = set()
for apply in pytensor.graph.basic.io_toposort(ins, outs):
if not any(is_in_ancestors(
a, ins,
known_independent=known_independent,
known_dependent=known_dependent,
) for a in apply.inputs):
yield from apply.inputs
yield apply
yield from apply.outputs
def clone_substitute(
output: Collection[Variable],
replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
] = None,
**kwargs,
):
memo = {n: n for n in independent_apply_nodes_between(replace, output)}
memo.update(replace)
pytensor.graph.basic.clone_get_equiv(list(replace), output, memo=memo)
return [memo[o] for o in output]
a = at.scalar("a")
b = at.scalar("b")
b2 = b * 2
a2 = a * 2
d = (a2 ** 2 + b2 ** 2).flatten()
assert is_in_ancestors(b2, [b])
assert is_in_ancestors(d, [b])
assert not is_in_ancestors(a2, [b])
assert a in independent_apply_nodes_between([b], [d])
assert a2 in independent_apply_nodes_between([b], [d])
assert b2 not in independent_apply_nodes_between([b], [d])
d_clone = clone_substitute([d], {b: b.clone()})[0]
assert not is_in_ancestors(d_clone, [b2])
assert is_in_ancestors(d_clone, [a])
assert is_in_ancestors(d_clone, [a2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment