Skip to content

Instantly share code, notes, and snippets.

@mhlr
Last active January 15, 2018 02:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mhlr/09fce5cd38834cbcdf2e1354a9c7cc1d to your computer and use it in GitHub Desktop.
Save mhlr/09fce5cd38834cbcdf2e1354a9c7cc1d to your computer and use it in GitHub Desktop.
from pylab import *
import scipy
from scipy import stats
from numba import vectorize, guvectorize, float64, jit, njit
@guvectorize(['void(float64[:], float64[:], float64[:])'],
'(i),(i)->()',
nopython=True
)
def kl(p, q, result):
result[0] = 0
for i in range(p.shape[0]):
if p[i] > 0:
result[0] += p[i] * (log(p[i]) - log(q[i]))
@njit
def ib(data, k, b=1. epd=1e-12):
pxy = data/data.sum()
px = pxy.sum(axis=1)[:,newaxis]
py_x = pxy/px
qt_x = np.random.dirichlet(ones(k), pxy.shape[0])
qt_x0 = ones(qt_x.shape)/k
while kl(qt_x, qt_x0) > eps:
qt_x0 = qt_x
qt = px.T.dot(qt_x)
qx_t = (px * qt_x / qt).T
qy_t = qx_t.dot(py_x)
Bxt = exp(-b*kl(py_x[:,newaxis], qy_t))
Zx = Bxt.dot(qt.T)
qt_x = qt * Bxt / Zx
return qt_x, qy_t
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment