Created
September 27, 2021 15:29
-
-
Save danyaljj/2243ff5ca199353bf09690a8ec430657 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import seaborn as sns | |
import scipy | |
import random | |
class NormalGammaPrior(): | |
"""" | |
Suppose X is distributed according to a normal distribution: X ~ N(mu, tau^{-1}) | |
And that the prior over mu and eta has [Normal-Gamma distribution](https://en.wikipedia.org/wiki/Normal-gamma_distribution): | |
(mu, tau) = NormalGamma(mu0, lambda0, alpha0, beta0) | |
The posterior of this construction would naturally follow a NormalGamma distribution: https://en.wikipedia.org/wiki/Normal-gamma_distribution#Posterior_distribution_of_the_parameters | |
The marginal distribution over mu (mean of the posterior) would follow a student-t distribution: https://en.wikipedia.org/wiki/Normal-gamma_distribution#Marginal_distributions | |
---------- | |
""" | |
def __init__(self, mu, lambdaa, alpha, beta): | |
self.mu = mu | |
self.lambdaa = lambdaa | |
self.alpha = alpha | |
self.beta = beta | |
def fit(self, x: List[int]): | |
x_mean = np.mean(x) | |
n = len(x) | |
s = np.var(x) | |
self.mu = (self.lambdaa * self.mu + n * x_mean) / (self.lambdaa + n) | |
self.lambdaa = self.lambdaa + n | |
self.alpha = self.alpha + n/2 | |
self.beta = self.beta + 0.5 * ( n * s + self.lambdaa * n * pow(x_mean - self.mu, 2) / (self.lambdaa + n) ) | |
def print_params(self): | |
print(f"mu: {self.mu}\nlambda: {self.lambdaa}\nalpha: {self.alpha}\nbeta: {self.beta}\n") | |
def get_marginal_mu_dist(self): | |
return scipy.stats.t(df=2 * self.alpha, loc=self.mu, scale=(self.beta / (self.alpha * self.lambdaa))) | |
def get_marginal_mu_mean_variance(self): | |
dist = self.get_marginal_mu_dist() | |
interval95 = dist.interval(0.95) | |
return { | |
'mean': float(dist.stats('m')), | |
'variance': float(dist.stats('v')), | |
'std': pow(dist.stats('v'), 0.5), | |
'interval95': interval95 | |
} | |
def plot(): | |
sns.set() | |
mu = 1 | |
dist = NormalGammaPrior( | |
mu=mu, | |
lambdaa=50, | |
alpha=5, beta=200 | |
) | |
MAX= 20 | |
x = list(range(MAX)) | |
y = [] | |
mean = [] | |
intervalu = [] | |
intervald = [] | |
for _ in x: | |
if random.random() < 0.7: | |
y.append(1.0) | |
else: | |
y.append(0.0) | |
dist.fit(y) | |
stats = dist.get_marginal_mu_mean_variance() | |
mean.append(stats['mean']) | |
intervald.append(stats['interval95'][0]) | |
intervalu.append(stats['interval95'][1]) | |
plt.plot(x, mean, '-o', color='gray', label='mean and 95% prob intervals') | |
plt.fill_between(x, intervald, intervalu, color='gray', alpha=0.2) | |
plt.xlim([0, MAX]) | |
plt.ylim([0, 1]) | |
plt.legend() | |
plt.title(f'GammaNormal (mu={mu}) prior for estimating x~Bernoulli(p=0.7)') | |
plt.show() | |
plot() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Two samples:
prior mu = 0:
prior mu = 1: