Skip to content

Instantly share code, notes, and snippets.

@madanh
Created August 9, 2017 11:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save madanh/8bb0cff8e83c1be99ebd4147c866e218 to your computer and use it in GitHub Desktop.
Save madanh/8bb0cff8e83c1be99ebd4147c866e218 to your computer and use it in GitHub Desktop.
Strange shape of internal logp function .stats._log_post_trace() test case
import numpy as np
import matplotlib.pyplot as plt
import pymc3 as pm
import theano
print("numpy version",np.__version__)
print("theano version",theano.__version__)
print("pymc version",pm.__version__)
floatX = theano.config.floatX # "float32"
#np.random.seed(1337)
## ALEPH: GENERATE DATA
mean_t = np.array((0,0),dtype=floatX)
cov_t = np.eye(2,dtype=floatX)
n = 2222
data= np.random.multivariate_normal(mean=mean_t,cov=cov_t,size=n).astype(floatX)
# look at the data
plt.plot(data[:,0],data[:,1],'o')
plt.gca().set_aspect('equal')
## BETH: FORMULATE THE MODEL
mvg_model = pm.Model()
with mvg_model:
mu = pm.Uniform('mu',lower=-1,upper=1,shape=2)
y = pm.MvNormal('y',mu=mu,cov=cov_t,observed=data)
with mvg_model:
trace = pm.sample(2000,njobs = 2)
_ = pm.traceplot(trace,combined=True)
plt.show()
lnp = np.array([mvg_model.logp(trace.point(i,chain=c)) for c in trace.chains for i in range(len(trace))])
print('lnp.shape ',lnp.shape)
lnp_strange = pm.stats._log_post_trace(trace,mvg_model)
print('lnp_strange.shape ',lnp_strange.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment