Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active June 6, 2020 18:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/8be593cd7b29e936017d5031dc3b4984 to your computer and use it in GitHub Desktop.
Save brandonwillard/8be593cd7b29e936017d5031dc3b4984 to your computer and use it in GitHub Desktop.
PyMC3 Toposorting
import numpy as np
import theano.tensor as tt
import pymc3 as pm
from itertools import chain, filterfalse
tt.optimizer = 'fast_compile'
def unique_everseen(iterable, key=None):
"List unique elements, preserving order. Remember all elements ever seen."
seen = set()
seen_add = seen.add
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen_add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen_add(k)
yield element
#
# Create a model with a lot of dependencies
#
with pm.Model() as test_model:
c = pm.HalfCauchy('c', 1)
b = pm.Normal('b', 0, c)
ab = pm.Normal('ab', b, c)
a1 = pm.Poisson('a1', 10)
w = pm.Dirichlet('w', np.r_[1, 1])
d = pm.Mixture('d', w, [pm.Normal.dist(ab, c), pm.Normal.dist(b, tt.abs_(ab))])
e = pm.Binomial('e', a1, w[0])
aa = pm.Deterministic('aa', d * ab + e)
pm.Potential('p', pm.Gamma.dist(a1, e).logp(c))
#
# Topologically sort the random variables
#
topo_sorted_logp = tt.gof.graph.io_toposort(test_model.basic_RVs, [test_model.logpt])
topo_sorted_rvs = list(unique_everseen(reversed(
[x for x in chain.from_iterable([o.inputs for o in topo_sorted_logp])
if x in test_model.basic_RVs]
)))
topo_sorted_rvs
# [c_log__, b, ab, a1, w_stickbreaking__, d, e]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment