Skip to content

Instantly share code, notes, and snippets.

@dkohlsdorf
Last active March 13, 2019 21:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dkohlsdorf/9947ad20013993e6d54d to your computer and use it in GitHub Desktop.
Save dkohlsdorf/9947ad20013993e6d54d to your computer and use it in GitHub Desktop.
Latent Dirichlet Allocations
package analysis.preprocessing;
import java.util.HashMap;
import java.util.Vector;
import database.OhuraDB;
import database.indexing.LocalFeature;
/**
* Topic Modeling using the Latent Dirichlet
* Allocation. Inference performed using a collapsed
* Gibbs Sampler.
*
* There is no hyperparameter learning
* and there is no automatic selection of
* number of topics. So this is a very simple
* model.
*
* GO TO BLEI OR DEAL WITH IT !
*
* @author Daniel Kohlsdorf
*/
public class LatentDirichletAllocation extends DataItemization {
/**
* Uniform dirichlet parameters
*/
private double alpha, beta;
/**
* Number of topics
*/
private int T;
/**
* Number of Documents
*/
private int D;
/**
* Number of words
*/
private int W;
/**
* Count how often a word from
* document d is assigned to topic k
*/
private int docu2topic[][];
/**
* Count how often a specific word
* w is assigned to topic k
*/
private int word2topic[][];
/**
* Count how often any word is assigned
* to topic k
*/
private int any2topic[];
/**
* Length of each document
*/
private int doc_sze[];
public LatentDirichletAllocation(double alpha, double beta, int T) {
this.alpha = alpha;
this.beta = beta;
this.T = T;
}
public int[][] gibbs_estimator(Vector< Vector<Integer> > corpus, int iterations) {
D = corpus.size();
/**
* initialize counts
*/
docu2topic = new int[D][T];
word2topic = new int[W][T];
any2topic = new int[T];
doc_sze = new int[D];
int z[][] = new int[D][];
for(int i = 0; i < D; i++) {
z[i] = new int[corpus.get(i).size()];
doc_sze[i] = corpus.get(i).size();
for(int j = 0; j < corpus.get(i).size(); j++) {
double rand = Math.random();
int topic = (int) Math.floor(rand * T);
int w = corpus.get(i).get(j);
z[i][j] = topic;
word2topic[w][topic]++;
docu2topic[i][topic]++;
any2topic[topic]++;
}
}
for(int iter = 0; iter < iterations; iter++) {
// for all z_i
for (int m = 0; m < D; m++){
for (int n = 0; n < corpus.get(m).size(); n++){
int topic = z[m][n];
int w = corpus.get(m).get(n);
// remove current variable from assignment
int N = corpus.get(m).size();
word2topic[w][topic] -= 1;
docu2topic[m][topic] -= 1;
any2topic[topic] -= 1;
N -= 1;
double scaleWords = W * beta;
double scaleTopics = T * alpha;
// multinominal sampling
double mult[] = new double[T];
for (int k = 0; k < T; k++){
mult[k] = (word2topic[w][k] + beta) / (any2topic[k] * scaleWords);
mult[k] *= (docu2topic[m][k] + alpha) / (N * scaleTopics);
}
for (int k = 1; k < T; k++){
mult[k] += mult[k - 1];
}
double u = Math.random() * mult[T - 1];
for (topic = 0; topic < T - 1; topic++){
if (mult[topic] > u) {break;}
}
// update counts
word2topic[w][topic] += 1;
docu2topic[m][topic] += 1;
any2topic[topic] += 1;
z[m][n] = topic;
}
}
}
return z;
}
/**
* Topic x Doc Distribution
*/
public double[][] computeTheta(){
double theta[][] = new double[D][T];
for (int m = 0; m < D; m++){
for (int k = 0; k < T; k++){
theta[m][k] = (docu2topic[m][k] + alpha) / (doc_sze[m] + T * alpha);
}
}
return theta;
}
/**
* Topic x Word Distribution
*/
public double[][] computePhi(){
double phi[][] = new double[T][W];
for (int k = 0; k < T; k++){
for (int w = 0; w < W; w++){
phi[k][w] = (word2topic[w][k] + beta) / (any2topic[k] + W * beta);
}
}
return phi;
}
public String serial() {
double theta[][] = computeTheta();
String theta_str = "THETA:\n";
for(int i = 0; i < theta.length; i++) {
for(int j = 0; j < theta[i].length; j++) {
theta_str += theta[i][j] + ", ";
}
theta_str += "\n";
}
double phi[][] = computePhi();
String phi_str = "\n\nPHI:\n";
for(int i = 0; i < phi.length; i++) {
for(int j = 0; j < phi[i].length; j++) {
phi_str += phi[i][j] + ", ";
}
phi_str += "\n";
}
return theta_str + phi_str;
}
@Override
public Vector<double[]> compute() {
Vector<Vector<Integer>> labels = new Vector<Vector<Integer>>();
for (int id = 0; id < OhuraDB.spec().getDb().queryMaxDocID() + 1; id++) {
int ngroups = OhuraDB.spec().getDb().queryMaxGrp(id);
for(int i = 0; i < ngroups; i++) {
Vector<Integer> lab = new Vector<Integer>();
Vector<LocalFeature> localF;
int k = 0;
while((localF = OhuraDB.spec().getDb().queryLocalByIDandGroupAndFeature(id, i, k)).size() > 0) {
if(localF.get(0).getLabel_id() >= 0) {
lab.add(localF.get(0).getLabel_id());
}
k++;
}
System.out.println(" " + i + " " + k);
labels.add(lab);
}
}
Vector<Vector<Integer>> corpus = new Vector<Vector<Integer>>();
HashMap<Integer, Integer> ids = new HashMap<Integer, Integer>();
int id = 0;
for(Vector<Integer> grp : labels) {
Vector<Integer> doc = new Vector<Integer>();
for(Integer i : grp) {
if(!ids.containsKey(i)) {
ids.put(i, id);
doc.add(id);
id++;
} else {
doc.add(ids.get(i));
}
}
corpus.add(doc);
}
this.W = id + 1;
gibbs_estimator(corpus, 5000);
Vector<double[]> features = new Vector<>();
double[][] theta = computeTheta();
for(int i = 0; i < theta.length; i++) {
features.add(theta[i]);
}
return features;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment