Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active January 6, 2022 08:58
Show Gist options
  • Save ricardoV94/a4da531a60a0f16facf8dbb5416a2cd4 to your computer and use it in GitHub Desktop.
Save ricardoV94/a4da531a60a0f16facf8dbb5416a2cd4 to your computer and use it in GitHub Desktop.
import numpy as np
import scipy.stats as st
import aesara
import aesara.tensor as at
import pymc as pm
from aeppl import joint_logprob
# Generative graph
size = 5
mu = 1
sigma = .01
rng = at.random.RandomStream(1234)
def scan_step(prev_value):
new_value = rng.normal(prev_value + mu, sigma)
return new_value
rv, updates = aesara.scan(
fn=scan_step,
outputs_info=np.array(0.0),
n_steps=size,
strict=True,
)
# Random draws
f = aesara.function([], rv, updates=updates)
print(f())
print(f())
# [0.98457646 1.98781195 2.99168834 3.99588533 4.98836932]
# [1.02558257 2.03999037 3.03123116 4.0277299 5.02523961]
# Logp graph
vv = rv.clone()
logp = joint_logprob({rv: vv})
print(logp.eval({vv: [1, 2, 3, 4, 5]}))
# 18.43115826391709
# Confirm it is correct
print(st.norm([1, 2, 3, 4, 5], .01).logpdf([1, 2, 3, 4, 5]).sum())
# 18.43115826391709
import numpy as np
import scipy.stats as st
import aesara
import aesara.tensor as at
import pymc as pm
from aeppl import joint_logprob
def seed_pymc_dist(dist, rng):
# Black magic to properly seed a PyMC distribution inside a Scan
rv_op = dist.owner.op
_, size, _, *params = dist.owner.inputs
return rng.gen(rv_op, *params, size=size)
# Generative graph
size = 5
mu = 1
sigma = .01
rng = at.random.RandomStream(1234)
def scan_step(prev_value):
new_value = pm.Normal.dist(mu=prev_value + mu, sigma=sigma)
new_value = seed_pymc_dist(new_value, rng=rng)
return new_value
rv, updates = aesara.scan(
fn=scan_step,
outputs_info=np.array(0.0),
n_steps=size,
strict=True,
)
# Random draws
f = aesara.function([], rv, updates=updates)
print(f())
print(f())
# [0.98457646 1.98781195 2.99168834 3.99588533 4.98836932]
# [1.02558257 2.03999037 3.03123116 4.0277299 5.02523961]
# Logp graph
vv = rv.clone()
logp = joint_logprob({rv: vv})
print(logp.eval({vv: [1, 2, 3, 4, 5]}))
# 18.43115826391709
# Confirm it is correct
print(st.norm([1, 2, 3, 4, 5], .01).logpdf([1, 2, 3, 4, 5]).sum())
# 18.43115826391709
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment