Skip to content

Instantly share code, notes, and snippets.

@danyaljj
Created September 27, 2021 15:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danyaljj/2243ff5ca199353bf09690a8ec430657 to your computer and use it in GitHub Desktop.
Save danyaljj/2243ff5ca199353bf09690a8ec430657 to your computer and use it in GitHub Desktop.
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import scipy
import random
class NormalGammaPrior():
""""
Suppose X is distributed according to a normal distribution: X ~ N(mu, tau^{-1})
And that the prior over mu and eta has [Normal-Gamma distribution](https://en.wikipedia.org/wiki/Normal-gamma_distribution):
(mu, tau) = NormalGamma(mu0, lambda0, alpha0, beta0)
The posterior of this construction would naturally follow a NormalGamma distribution: https://en.wikipedia.org/wiki/Normal-gamma_distribution#Posterior_distribution_of_the_parameters
The marginal distribution over mu (mean of the posterior) would follow a student-t distribution: https://en.wikipedia.org/wiki/Normal-gamma_distribution#Marginal_distributions
----------
"""
def __init__(self, mu, lambdaa, alpha, beta):
self.mu = mu
self.lambdaa = lambdaa
self.alpha = alpha
self.beta = beta
def fit(self, x: List[int]):
x_mean = np.mean(x)
n = len(x)
s = np.var(x)
self.mu = (self.lambdaa * self.mu + n * x_mean) / (self.lambdaa + n)
self.lambdaa = self.lambdaa + n
self.alpha = self.alpha + n/2
self.beta = self.beta + 0.5 * ( n * s + self.lambdaa * n * pow(x_mean - self.mu, 2) / (self.lambdaa + n) )
def print_params(self):
print(f"mu: {self.mu}\nlambda: {self.lambdaa}\nalpha: {self.alpha}\nbeta: {self.beta}\n")
def get_marginal_mu_dist(self):
return scipy.stats.t(df=2 * self.alpha, loc=self.mu, scale=(self.beta / (self.alpha * self.lambdaa)))
def get_marginal_mu_mean_variance(self):
dist = self.get_marginal_mu_dist()
interval95 = dist.interval(0.95)
return {
'mean': float(dist.stats('m')),
'variance': float(dist.stats('v')),
'std': pow(dist.stats('v'), 0.5),
'interval95': interval95
}
def plot():
sns.set()
mu = 1
dist = NormalGammaPrior(
mu=mu,
lambdaa=50,
alpha=5, beta=200
)
MAX= 20
x = list(range(MAX))
y = []
mean = []
intervalu = []
intervald = []
for _ in x:
if random.random() < 0.7:
y.append(1.0)
else:
y.append(0.0)
dist.fit(y)
stats = dist.get_marginal_mu_mean_variance()
mean.append(stats['mean'])
intervald.append(stats['interval95'][0])
intervalu.append(stats['interval95'][1])
plt.plot(x, mean, '-o', color='gray', label='mean and 95% prob intervals')
plt.fill_between(x, intervald, intervalu, color='gray', alpha=0.2)
plt.xlim([0, MAX])
plt.ylim([0, 1])
plt.legend()
plt.title(f'GammaNormal (mu={mu}) prior for estimating x~Bernoulli(p=0.7)')
plt.show()
plot()
@danyaljj
Copy link
Author

Two samples:

  • prior mu = 0:
    image

  • prior mu = 1:
    image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment