Skip to content

Instantly share code, notes, and snippets.

@ckrapu
Created April 24, 2021 16:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ckrapu/0bfd8ab515110e230d4063102f7d33b5 to your computer and use it in GitHub Desktop.
Save ckrapu/0bfd8ab515110e230d4063102f7d33b5 to your computer and use it in GitHub Desktop.
Conjugate step for linear regression in PyMC3
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
from pymc3.step_methods.arraystep import BlockedStep
# ### Create simulated data under linear regression model
rng = np.random.default_rng(827)
n = 40
p = 5
sigma = 0.5
beta = rng.standard_normal(p)
X = rng.standard_normal([n,p])
eps = rng.standard_normal(n) * sigma
y = X@beta + eps
plt.scatter(X[:,0], y, color='k', facecolor='none', label='Observations')
domain = np.linspace(-2,2)
plt.plot(domain, beta[0]*domain, color='k', label='True slope')
plt.ylabel('$y$', fontsize=14), plt.xlabel('$x_1$', fontsize=14)
plt.legend();
# ### Define multiple regression model
with pm.Model() as model:
coef_var = pm.InverseGamma('coef_var', alpha=1, beta=1)
coefs = pm.Normal('coefs', sd=coef_var**0.5, shape=p)
error_var = pm.InverseGamma('error_var', alpha=1, beta=1)
likelihood = pm.Normal('likelihood', mu=X@coefs, sigma=error_var**0.5, observed=y)
with model:
trace = pm.sample(return_inferencedata=True)
pm.summary(trace)
def undo_transform(point):
'''
Automatically transforms variables which were sampled on log
scale back into original scale for convenience.
'''
transform_marker = '_log__'
varnames = list(point.keys())
for varname in varnames:
if transform_marker in varname:
new_key = varname.split(transform_marker)[0]
point[new_key] = np.exp(point[varname])
return point
class ConjugateStep(BlockedStep):
'''
Uses closed-form conditional distribution for linear regression
coefficients given prior and error variances.
'''
def __init__(self, coefs, X, y, rng=None):
self.vars=[coefs]
self.X = X
self.y = y
self.p = X.shape[0]
if rng:
self.rng = rng
else:
self.rng = np.random.default_rng()
def step(self, point: dict):
point = undo_transform(point)
prior_prec = point['coef_var']**-1
XTX = X.T@X
beta_hat = np.linalg.inv(XTX)@X.T@y
prec_new_unscaled = (X.T @ X + prior_prec)
cov_new = np.linalg.inv(prec_new_unscaled)* point['error_var']
mu_new = prec_new_unscaled @ beta_hat
point['coefs'] = self.rng.multivariate_normal(mu_new, cov_new)
return point
with model:
step = ConjugateStep(coefs, X, y, rng)
trace = pm.sample(step=step, chains=2, cores=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment