Skip to content

Instantly share code, notes, and snippets.

@shihpeng
Created May 8, 2015 08:42
Show Gist options
  • Save shihpeng/a8d5df0f9daddeee59f0 to your computer and use it in GitHub Desktop.
Save shihpeng/a8d5df0f9daddeee59f0 to your computer and use it in GitHub Desktop.
import math
import scipy.sparse as sparse
import numpy as np
def plsa(n_dw, n_z=2, iter_num=100):
"""
Fit a PLSA model on given data,
in which n_dw(d, w) is the number of occurrence of word w
in document d, d, n_z is the number of topic to be discovered
:param n_dw:
:param n_z:
:param iter_num:
:return:
"""
# pre-allocate space
n_d, n_w = n_dw.shape
p_z_d = np.random.rand(n_z, n_d) # so that when we need to access P(z_j|d_i), then the index is p_z_d[j][i]
p_w_z = np.random.rand(n_w, n_z)
n_p_z_dw = [None]*n_z # n(d,w) * p(z|d,w)
n_dw_csr = sparse.csr_matrix(n_dw)
for i in range(0, n_z):
n_p_z_dw[i] = sparse.coo_matrix((np.random.rand(n_dw_csr.nnz), n_dw_csr.nonzero()), shape=n_dw_csr.shape).toarray()
# p(d,w)
p_dw = sparse.coo_matrix((np.random.rand(n_dw_csr.nnz), n_dw_csr.nonzero()), shape=n_dw_csr.shape).toarray()
Lt = [] # log-likelihood
for i in range(0, iter_num):
# E-step
for d, w in zip(n_dw.nonzero()[0], n_dw.nonzero()[1]):
for z in range(0, n_z):
n_p_z_dw[z][d][w] = p_z_d[z][d] * p_w_z[w][z] * n_dw[d][w] / p_dw[d][w]
# M-step
# update p(z|d)
for d in range(0, n_d):
for z in range(0, n_z):
p_z_d[z][d] = sum(n_p_z_dw[z][d])
s = sum(sum(i[d]) for i in n_p_z_dw)
for z in range(0, n_z):
p_z_d[z][d] = p_z_d[z][d]/s
# update p(w|z)
for z in range(0, n_z):
for w in range(0, n_w):
p_w_z[w][z] = sum(n_p_z_dw[z][:, w])
p_w_z[:, z] = p_w_z[:, z]/np.sum(n_p_z_dw[z])
# update p(d,w) and calculate likelihood
L = 0
for d in range(0, n_d):
for w in [i for i, e in enumerate(n_dw[d, :]) if e != 0]:
p_dw[d][w] = 0
for z in range(0, n_z):
p_dw[d][w] = p_dw[d][w] + p_w_z[w][z] * p_z_d[z][d]
L += n_dw[d][w] * math.log(p_dw[d][w])
# test if converge
if len(Lt) >= 1 and abs(L - Lt[-1]) < 0.00001:
# log-likelihood converged
break
else:
Lt.append(L)
return p_w_z, p_z_d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment