This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def EM_Gaussian(mu_init, sd_init, pi_init, n_iter=1000): | |
# EM for Gaussian mixture models | |
mu = mu_init | |
sd = sd_init | |
pi = pi_init | |
K = len(pi_init) # number of Gaussians | |
for n in range(n_iter): | |
# Expectation step | |
# calculate responsibilities for each Guassian |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# with known labels, we can estimate the model params | |
mu1_mle = np.sum(x1)/len(x1) | |
mu2_mle = np.sum(x2)/len(x2) | |
sd1_mle = np.std(x1, ddof=0) | |
sd2_mle = np.std(x2, ddof=0) | |
pi_est = len(x1)/(len(x1)+len(x2)) | |
print(f'mu_1: {mu1_mle:.2f}; sd_1: {sd1_mle:.2f}; mu_2: {mu2_mle:.2f}; sd_2: {sd2_mle:.2f}; pi: {pi_est:.2f}; ') | |
t = np.linspace(-7,20,100) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# with known model params, we can estimate z, i.e. which Gaussian did the data come from | |
norm_factor = pi_1*st.norm(mu1,sd1).pdf(x) + (1-pi_1)*st.norm(mu2,sd2).pdf(x) | |
pz_1 = pi_1*st.norm(mu1,sd1).pdf(x) / (norm_factor) | |
pz_2 = (1-pi_1)*st.norm(mu2,sd2).pdf(x) / (norm_factor) | |
fig = plt.figure(figsize=(8,0.5)) | |
plt.vlines(x[pz_1>0.5], 0, 0.01, color='orange', alpha=0.01); | |
plt.vlines(x[pz_2>0.5], 0, 0.01, color='steelblue', alpha=0.01); | |
plt.yticks([]); | |
plt.xlabel('x'); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import seaborn as sns | |
fig = plt.figure(figsize=(8,4)) | |
sns.histplot(x=x) | |
plt.vlines(x1, 0, 50, color='orange', alpha=0.1); | |
plt.vlines(x2, 0, 50, color='steelblue', alpha=0.1); | |
plt.xlabel('x'); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy.stats as st | |
import numpy as np | |
mu1 = -1 | |
mu2 = 10 | |
sd1 = 2 | |
sd2 = 3 | |
pi_1 = 0.2 | |
k = st.bernoulli(pi_1).rvs(30000) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy.stats as st | |
def calc_theta(VE, r=1): | |
# calculate case rate (theta) given VE and surveillance time ratio | |
return r*(1-VE) / (1+r*(1-VE)) | |
def calc_VE(theta, r=1): | |
# calculate VE given case rate (theta) and surveillance time ratio | |
return 1 + theta/(r*(theta-1)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.stats as st | |
import matplotlib.pyplot as plt | |
c_vs = np.array([10, 25, 50, 75]) # no. cases in vaccinated group | |
c_ps = 150 - c_vs # no. cases in placebo group | |
a = 0.700102 # param a in prior | |
b = 1 # param b in prior | |
fig,ax = plt.subplots(1,3, figsize=(15,5)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy.stats as st | |
class func_mvn: | |
# multivariate normal given mean and covariance | |
def __init__(self, mu, cov): | |
self.mu = mu | |
self.cov = cov | |
self.func = st.multivariate_normal(mu, cov) | |
def calc_pdf(self, x): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pymc3 as pm | |
with pm.Model() as model: | |
prior = pm.Normal('mu', mu=0, sigma=1) # prior | |
obs = pm.Normal('obs', mu=prior, sigma=1, observed=X) # likelihood | |
step = pm.Metropolis() | |
trace = pm.sample(draws=50000, chains=3, step=step, return_inferencedata=True) # 3 independent Markov chains | |
pm.traceplot(trace) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from statsmodels.graphics.tsaplots import plot_acf | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
def plot_res(xs, burn_in, x_name): | |
# plot trace (based on xs) | |
# plot distribution | |
xs_kept = xs[burn_in:] | |
NewerOlder