Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Created May 22, 2024 15:36
Show Gist options
  • Save ricardoV94/e8902b4c35c26e87e189ab477f8d9288 to your computer and use it in GitHub Desktop.
Save ricardoV94/e8902b4c35c26e87e189ab477f8d9288 to your computer and use it in GitHub Desktop.
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.graph.basic import Constant, io_toposort
from pytensor.compile import SharedVariable
from pytensor.compile.mode import get_mode
from pytensor.printing import _debugprint
# Define graph
x = pt.vector("x")
out = pt.where(x > 1, -pt.nan, pt.log(x))
# Optimize function from inputs to outputs
mode = get_mode("FAST_RUN").excluding("inplace", "fusion")
out = pytensor.function(inputs=[x], outputs=out, mode=mode).maker.fgraph.outputs
# Seed initial values
evaled_vars = {x: np.random.default_rng(37).normal(size=(50,))}
# Compute every intermediate variable
for node in pytensor.graph.basic.io_toposort([x], out):
input_values = [
inp.data
if isinstance(inp, Constant)
else (
inp.get_value(borrow=True)
if isinstance(inp, SharedVariable)
else evaled_vars[inp]
) for inp in node.inputs
]
output_values = [[None] for _ in node.outputs]
node.op.perform(node, input_values, output_values)
for out, [out_value] in zip(node.outputs, output_values):
evaled_vars[out] = out_value
# Compile extra information to print next to each node
extra_info = {}
for key, value in evaled_vars.items():
if key.type.dtype == "bool":
extra_info[key] = f"true={np.mean(value):.2%}"
elif key.type.dtype.startswith("float"):
extra_info[key] = f"nan={np.mean(np.isnan(value)):.2%}"
_debugprint(out, storage_map=extra_info)
# Switch [id A] nan=60.00%
# ├─ Gt [id B] true=10.00%
# │ ├─ x [id C] nan=0.00%
# │ └─ [1] [id D]
# ├─ [nan] [id E]
# └─ Log [id F] nan=50.00%
# └─ x [id C] nan=0.00%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment