import scipy.stats as stats
import theano
import theano.tensor as tt
import pymc3 as pm
from symbolic_pymc.theano.pymc3 import model_graph, graph_model
from symbolic_pymc.theano.utils import canonicalize
# These just make things quicker (e.g. no computations or compilation)
theano.config.cxx = ""
theano.config.mode = "FAST_COMPILE"
tt.config.compute_test_value = 'ignore'
rawvals = stats.norm.rvs(scale=2, loc=1, size=300)
obsvals = stats.norm.rvs(loc=6, scale=1, size=300)
with pm.Model() as model:
raw = pm.Data("raw", rawvals)
z = pm.Deterministic("z(raw)", (raw - rawvals.mean()) / rawvals.std())
z_latent = pm.Normal("Latent", mu=z, shape=300)
obs = pm.Normal("Observed", mu=6 * z_latent, observed=obsvals)
f_latent = pm.Deterministic("f(latent)", 22 * z_latent)
Convert the PyMC3 graph into a symbolic-pymc graph:
fgraph = model_graph(model)
Print the graph:
from theano.printing import debugprint as tt_dprint
tt_dprint(fgraph)
<symbolic_pymc.theano.random_variables.Observed object at 0x7f24b941d518> [id A] '' 9
|TensorConstant{[5.4752971...46160106]} [id B]
|normal_rv.1 [id C] 'Observed' 8
|Elemwise{mul,no_inplace} [id D] '' 7
| |InplaceDimShuffle{x} [id E] '' 6
| | |TensorConstant{6} [id F]
| |normal_rv.1 [id G] 'Latent' 5
| |Elemwise{identity} [id H] 'z(raw)' 4
| | |Elemwise{true_div,no_inplace} [id I] '' 3
| | |Elemwise{sub,no_inplace} [id J] '' 2
| | | |raw [id K]
| | | |InplaceDimShuffle{x} [id L] '' 1
| | | |TensorConstant{0.898279247770425} [id M]
| | |InplaceDimShuffle{x} [id N] '' 0
| | |TensorConstant{1.9465271641490767} [id O]
| |TensorConstant{1.0} [id P]
| |TensorConstant{(1,) of 300} [id Q]
| |<RandomStateType> [id R]
|TensorConstant{1.0} [id P]
|TensorConstant{[]} [id S]
|<RandomStateType> [id R]
Get the dependencies for the random variable named Latent
and filter for constants:
raw_rv, = [v for v in fgraph.variables if v.name == 'raw']
latent_rv, = [v for v in fgraph.variables if v.name == 'Latent']
latent_deps = tt.gof.graph.ancestors([latent_rv])
latent_const_deps = [a for a in latent_deps
if isinstance(a, (tt.sharedvar.SharedVariable, tt.Constant))]
tt_dprint(latent_deps)
normal_rv.1 [id A] 'Latent'
|Elemwise{identity} [id B] 'z(raw)'
| |Elemwise{true_div,no_inplace} [id C] ''
| |Elemwise{sub,no_inplace} [id D] ''
| | |raw [id E]
| | |InplaceDimShuffle{x} [id F] ''
| | |TensorConstant{0.898279247770425} [id G]
| |InplaceDimShuffle{x} [id H] ''
| |TensorConstant{1.9465271641490767} [id I]
|TensorConstant{1.0} [id J]
|TensorConstant{(1,) of 300} [id K]
|<RandomStateType> [id L]
Elemwise{identity} [id B] 'z(raw)'
Elemwise{true_div,no_inplace} [id C] ''
Elemwise{sub,no_inplace} [id D] ''
raw [id E]
InplaceDimShuffle{x} [id F] ''
TensorConstant{0.898279247770425} [id G]
InplaceDimShuffle{x} [id H] ''
TensorConstant{1.9465271641490767} [id I]
TensorConstant{1.0} [id J]
TensorConstant{(1,) of 300} [id K]
<RandomStateType> [id L]
tt_dprint(latent_const_deps)
raw [id A]
TensorConstant{0.898279247770425} [id B]
TensorConstant{1.9465271641490767} [id C]
TensorConstant{1.0} [id D]
TensorConstant{(1,) of 300} [id E]
<RandomStateType> [id F]
Now, let's say you have a non-random variable term and you want to determine if it's dependent on a random variable. We'll use the 8 * z_latent
term (i.e. the mean of Latent
) as an example.
from symbolic_pymc.theano.ops import RandomVariable
observed_tf, = [v for v in fgraph.variables if v.name == 'Observed']
mu_latent = observed_tf.owner.inputs[0]
mu_latent_dep_ops = list(tt.gof.graph.ops(fgraph.inputs, [mu_latent]))
mu_dep_rvs = [a for a in mu_latent_dep_ops if isinstance(a.op, RandomVariable)]
The following are random variable nodes upon which mu_latent
depends:
tt_dprint(mu_dep_rvs)
normal_rv.0 [id A] ''
|Elemwise{identity} [id B] 'z(raw)'
| |Elemwise{true_div,no_inplace} [id C] ''
| |Elemwise{sub,no_inplace} [id D] ''
| | |raw [id E]
| | |InplaceDimShuffle{x} [id F] ''
| | |TensorConstant{0.898279247770425} [id G]
| |InplaceDimShuffle{x} [id H] ''
| |TensorConstant{1.9465271641490767} [id I]
|TensorConstant{1.0} [id J]
|TensorConstant{(1,) of 300} [id K]
|<RandomStateType> [id L]
normal_rv.1 [id A] 'Latent'