Created
August 9, 2017 11:33
-
-
Save madanh/8bb0cff8e83c1be99ebd4147c866e218 to your computer and use it in GitHub Desktop.
Strange shape of internal logp function .stats._log_post_trace() test case
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import pymc3 as pm | |
import theano | |
print("numpy version",np.__version__) | |
print("theano version",theano.__version__) | |
print("pymc version",pm.__version__) | |
floatX = theano.config.floatX # "float32" | |
#np.random.seed(1337) | |
## ALEPH: GENERATE DATA | |
mean_t = np.array((0,0),dtype=floatX) | |
cov_t = np.eye(2,dtype=floatX) | |
n = 2222 | |
data= np.random.multivariate_normal(mean=mean_t,cov=cov_t,size=n).astype(floatX) | |
# look at the data | |
plt.plot(data[:,0],data[:,1],'o') | |
plt.gca().set_aspect('equal') | |
## BETH: FORMULATE THE MODEL | |
mvg_model = pm.Model() | |
with mvg_model: | |
mu = pm.Uniform('mu',lower=-1,upper=1,shape=2) | |
y = pm.MvNormal('y',mu=mu,cov=cov_t,observed=data) | |
with mvg_model: | |
trace = pm.sample(2000,njobs = 2) | |
_ = pm.traceplot(trace,combined=True) | |
plt.show() | |
lnp = np.array([mvg_model.logp(trace.point(i,chain=c)) for c in trace.chains for i in range(len(trace))]) | |
print('lnp.shape ',lnp.shape) | |
lnp_strange = pm.stats._log_post_trace(trace,mvg_model) | |
print('lnp_strange.shape ',lnp_strange.shape) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment