Created
February 20, 2015 20:59
-
-
Save lzamparo/7b34ffd6de800bd8be56 to your computer and use it in GitHub Desktop.
numba.jit fails silently on Gibbs sampling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy import bincount, log, log2, seterr, unique, zeros | |
from scipy.special import gammaln | |
from math_utils import log_sample, vi | |
def iteration(V, D, N_DV, N_D, alpha, beta, z_D, inv_z_T, active_topics, inactive_topics, N_TV, N_T, D_T): | |
""" | |
Performs a single iteration of Radford Neal's Algorithm 3. | |
""" | |
for d in xrange(D): | |
# retain the previous cluster indicator of d | |
old_t = z_D[d] | |
# make sure z_D[d] is no longer part of the set | |
# of points associated with old_t | |
if inv_z_T is not None: | |
inv_z_T[old_t].remove(d) | |
# remove the data for d from the sum of the elements assigned to component old_t | |
N_TV[old_t, :] -= N_DV[d, :] | |
# remove sum along data features for component old_t (N_T = N_DV.sum(1)) | |
N_T[old_t] -= N_D[d] | |
# decrease the appearances of old_t in z_D | |
D_T[old_t] -= 1 | |
# compute partial log probability of assigning the data point to component | |
seterr(divide='ignore') | |
log_dist = log(D_T) | |
seterr(divide='warn') | |
# if this component was a singleton, keep the index. Otherwise, activate a new component | |
idx = old_t if D_T[old_t] == 0 else inactive_topics.pop() | |
active_topics.add(idx) | |
# log probability of assigning this point to the new component. | |
log_dist[idx] = log(alpha) | |
# compute log remaining log probability of assigning d over components | |
# note: gammaln(x) := ln(abs(gamma(x))) | |
for t in active_topics: | |
log_dist[t] += gammaln(N_T[t] + beta) | |
log_dist[t] -= gammaln(N_D[d] + N_T[t] + beta) | |
tmp = N_TV[t, :] + beta / V | |
log_dist[t] += gammaln(N_DV[d, :] + tmp).sum() | |
log_dist[t] -= gammaln(tmp).sum() | |
# sample from log_dist to get the component for d | |
[t] = log_sample(log_dist) | |
# assign component t as responsible for point d | |
z_D[d] = t | |
# assign point d as part of component t | |
if inv_z_T is not None: | |
inv_z_T[t].add(d) | |
# adjust the sufficient statistics for component t | |
# to account for the addition of d | |
N_TV[t, :] += N_DV[d, :] | |
N_T[t] += N_D[d] | |
D_T[t] += 1 | |
# accounting of active topics: | |
if t != idx: | |
active_topics.remove(idx) | |
inactive_topics.add(idx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy import argsort, bincount, ones, where, zeros | |
from numpy.random import poisson, seed | |
from numpy.random.mtrand import dirichlet | |
from math_utils import sample | |
def generate_data(V, D, l, alpha, beta): | |
""" | |
Generates a synthetic corpus of documents from a Dirichlet process | |
mixture model with multinomial mixture components (topics). The | |
mixture components are drawn from a symmetric Dirichlet prior. | |
Arguments: | |
V -- vocabulary size | |
D -- number of documents | |
l -- average document length | |
alpha -- concentration parameter for the Dirichlet process | |
beta -- concentration parameter for the symmetric Dirichlet prior | |
""" | |
T = D # maximum number of topics | |
phi_TV = zeros((T, V)) | |
z_D = zeros(D, dtype=int) | |
N_DV = zeros((D, V), dtype=int) | |
for d in xrange(D): | |
# draw a topic assignment for this document | |
dist = bincount(z_D).astype(float) | |
dist[0] = alpha | |
[t] = sample(dist) | |
t = len(dist) if t == 0 else t | |
z_D[d] = t | |
# if it's a new topic, draw the parameters for that topic | |
if t == len(dist): | |
phi_TV[t - 1, :] = dirichlet(beta * ones(V) / V) | |
# draw the tokens from the topic | |
for v in sample(phi_TV[t - 1, :], num_samples=poisson(l)): | |
N_DV[d, v] += 1 | |
z_D = z_D - 1 | |
return phi_TV, z_D, N_DV |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
In [5]: from generative_process import generate_data | |
In [6]: from numba import double | |
In [7]: from numba.decorators import jit | |
In [8]: V = 10 | |
In [9]: D = 20000 | |
In [10]: l = 1000 | |
In [11]: alpha = 1.0 | |
In [12]: beta = 0.5 | |
In [13]: num_itns = 10 | |
In [14]: seed = 0 | |
In [15]: s = 0 | |
In [16]: from numpy.random import seed | |
In [17]: seed(s) | |
In [18]: phi_TV, z_D, N_DV = generate_data(V, D, l, alpha, beta) | |
In [19]: from kale.math_utils import log_sample, vi | |
In [20]: from scipy.special import gammaln | |
In [21]: from numpy import bincount, log, log2, seterr, unique, zeros | |
In [22]: from algorithm_3 import iteration | |
In [23]: D, V = N_DV.shape | |
In [24]: T = D | |
In [25]: N_D = N_DV.sum(1) | |
In [26]: active_topics = set(unique(z_D)) | |
In [27]: inactive_topics = set(xrange(T)) - active_topics | |
In [28]: N_TV = zeros((T, V), dtype=int) | |
In [29]: N_T = zeros(T, dtype=int) | |
In [30]: for d in xrange(D): | |
....: N_TV[z_D[d], :] += N_DV[d, :] | |
....: N_T[z_D[d]] += N_D[d] | |
....: | |
In [31]: D_T = bincount(z_D, minlength=T) | |
In [32]: ### ready to start timing | |
In [33]: %timeit iteration(V, D, N_DV, N_D, alpha, beta, z_D, None, active_topics, inactive_topics, N_TV, N_T, D_T) | |
1 loops, best of 3: 33.6 s per loop | |
In [34]: %timeit -n 5 -r 5 iteration(V, D, N_DV, N_D, alpha, beta, z_D, None, active_topics, inactive_topics, N_TV, N_T, D_T) | |
5 loops, best of 5: 33.6 s per loop | |
In [35]: iteration_numba = jit(iteration) | |
In [36]: %timeit iteration_numba(V, D, N_DV, N_D, alpha, beta, z_D, None, active_topics, inactive_topics, N_TV, N_T, D_T) | |
1 loops, best of 3: 33.5 s per loop |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
from numpy import asarray, cumsum, log, log2, exp, searchsorted, sqrt | |
from numpy.random import uniform | |
from scipy.spatial.distance import euclidean | |
def sample(dist, num_samples=1): | |
""" | |
Uses the inverse CDF method to return samples drawn from the | |
specified (unnormalized) discrete distribution. | |
Arguments: | |
dist -- (unnormalized) distribution | |
Keyword arguments: | |
num_samples -- number of samples to draw | |
""" | |
cdf = cumsum(dist) | |
r = uniform(size=num_samples) * cdf[-1] | |
return cdf.searchsorted(r) | |
def log_sample(log_dist): | |
return sample(exp(log_dist - log_dist.max())) | |
def log_sum_exp(x): | |
""" | |
Returns log(sum(exp(x))). | |
If the elements of x are log probabilities, they should not be | |
exponentiated directly because of underflow. The ratio exp(x[i]) / | |
exp(x[j]) = exp(x[i] - x[j]) is not susceptible to underflow, | |
however. For any scalar m, log(sum(exp(x))) = log(sum(exp(x) * | |
exp(m) / exp(m))) = log(sum(exp(x - m) * exp(m)) = log(exp(m) * | |
sum(exp(x - m))) = m + log(sum(exp(x - m))). If m is some element | |
of x, this expression involves only ratios of the form exp(x[i]) / | |
exp(x[j]) as desired. Setting m = max(x) reduces underflow, while | |
avoiding overflow: max(x) is shifted to zero, while all other | |
elements of x remain negative, but less so than before. Even in | |
the worst case scenario, where exp(x - max(x)) results in | |
underflow for the other elements of x, max(x) will be | |
returned. Since sum(exp(x)) is dominated by exp(max(x)), max(x) is | |
a reasonable approximation to log(sum(exp(x))). | |
""" | |
m = x.max() | |
return m + log((exp(x - m)).sum()) | |
def mean_relative_error(p, q, normalized=True): | |
""" | |
Returns the mean relative error between a discrete distribution | |
and some approximation to it. | |
Arguments: | |
p -- distribution | |
q -- approximate distribution | |
Keyword arguments: | |
normalized -- whether the distributions are normalized | |
""" | |
assert len(p) == len(q) | |
p, q = asarray(p, dtype=float), asarray(q, dtype=float) | |
if not normalized: | |
p /= p.sum() | |
q /= q.sum() | |
return (abs(q - p) / p).mean() | |
def entropy(p, normalized=True): | |
""" | |
Returns the entropy of a discrete distribution. | |
Arguments: | |
p -- distribution | |
Keyword arguments: | |
normalized -- whether the distribution is normalized | |
""" | |
p = asarray(p, dtype=float) | |
if not normalized: | |
p /= p.sum() | |
return -(p * log2(p)).sum() | |
def kl(p, q, normalized=True): | |
""" | |
Returns the Kullback--Leibler divergence between a discrete | |
distribution and some approximation to it. | |
Arguments: | |
p -- distribution | |
q -- approximate distribution | |
Keyword arguments: | |
normalized -- whether the distributions are normalized | |
""" | |
assert len(p) == len(q) | |
p, q = asarray(p, dtype=float), asarray(q, dtype=float) | |
if not normalized: | |
p /= p.sum() | |
q /= q.sum() | |
return (p * log2(p / q)).sum() | |
def js(p, q, normalized=True): | |
""" | |
Returns the Jensen--Shannon divergence (a form of symmetricized KL | |
divergence) between two discrete distributions. | |
Arguments: | |
p -- first distribution | |
q -- second distribution | |
Keyword arguments: | |
normalized -- whether the distributions are normalized | |
""" | |
assert len(p) == len(q) | |
p, q = asarray(p, dtype=float), asarray(q, dtype=float) | |
if not normalized: | |
p /= p.sum() | |
q /= q.sum() | |
m = 0.5 * (p + q) | |
return 0.5 * kl(p, m) + 0.5 * kl(q, m) | |
def hellinger(p, q, normalized=True): | |
""" | |
Returns the Hellinger distance between two discrete distributions. | |
Arguments: | |
p -- distribution | |
q -- distribution | |
Keyword arguments: | |
normalized -- whether the distributions are normalized | |
""" | |
assert len(p) == len(q) | |
p, q = asarray(p, dtype=float), asarray(q, dtype=float) | |
if not normalized: | |
p /= p.sum() | |
q /= q.sum() | |
return euclidean(sqrt(p), sqrt(q)) / sqrt(2) | |
def vi(y, z): | |
""" | |
Returns the variation of information (in bits) between two | |
partitions (clusterings) of n data points. The maximum attainable | |
value is log_2(n) bits. For example, vi(y=zeros(8, dtype=int), | |
z=xrange(8)) will return a value of 3.0. | |
y -- first partition | |
z -- second partition | |
""" | |
assert len(y) == len(z) | |
D = 1.0 * len(y) | |
vi = 0.0 | |
p_y = Counter(y) | |
for i in p_y.keys(): | |
p_y[i] /= D | |
vi -= p_y[i] * log2(p_y[i]) | |
p_z = Counter(z) | |
for j in p_z.keys(): | |
p_z[j] /= D | |
vi -= p_z[j] * log2(p_z[j]) | |
p_yz = Counter(zip(y, z)) | |
for i, j in p_yz.keys(): | |
p_yz[(i, j)] /= D | |
vi -= (2 * p_yz[(i, j)] * | |
log2(p_yz[(i, j)] / (p_y[i] * p_z[j]))) | |
return vi |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment