Skip to content

Instantly share code, notes, and snippets.

@astropenguin
Last active September 23, 2017 04:14
Show Gist options
  • Save astropenguin/eedd3a29e398f524e67fe0cf214838c5 to your computer and use it in GitHub Desktop.
Save astropenguin/eedd3a29e398f524e67fe0cf214838c5 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from numba import jit
from scipy.special import gammaln
from sklearn import decomposition
def lnpU(k):
q = (D-np.arange(k)) / 2
ln2, lnpi = np.log(2), np.log(np.pi)
return np.sum(-ln2+gammaln(q)-lnpi*q)
@jit
def lnAz(k, lmd, lmd_hat):
lnN = np.log(N)
lnAz = 0.0
for i in range(k):
for j in range(i+1, D):
lnAz += lnN
lnAz += np.log(1/lmd_hat[j]-1/lmd_hat[i])
lnAz += np.log(lmd[i]-lmd[j])
return lnAz
def prob_laplace(k):
assert 0 < k <= N
m = D*k - k*(k+1)/2
lmd_hat = np.zeros_like(lmd)
lmd_hat[:k] = lmd[:k]
lmd_hat[k:] = lmd[k:].mean()
lnprob = lnpU(k)
lnprob -= N/2 * np.log(lmd[:k]).sum()
lnprob -= N*(D-k)/2 * np.log(lmd[k:].mean())
lnprob += (m+k)/2 * np.log(2*np.pi)
lnprob -= 1/2 * lnAz(k, lmd, lmd_hat)
lnprob -= k/2 * np.log(N)
return lnprob
def prob_bic(k):
assert 0 < k <= N
m = D*k - k*(k+1)/2
lnprob = -N/2 * np.log(lmd[:k]).sum()
lnprob -= N*(D-k)/2 * np.log(lmd[k:].mean())
lnprob -= (m+k)/2 * np.log(N)
return lnprob
if __name__ == '__main__':
D, N = X.shape
K = min(D, N)
model = decomposition.PCA(K)
C = model.fit_transform(X)
P = model.components_
lmd = np.zeros(D)
lmd[:K] = np.var(C, 0)
lmd[K:] = np.mean(np.var(X-C@P, 0))
plt.plot([prob_laplace(k) for k in range(1, K+1)])
plt.plot([prob_bic(k) for k in range(1, K+1)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment