Last active
March 13, 2019 21:01
-
-
Save dkohlsdorf/9947ad20013993e6d54d to your computer and use it in GitHub Desktop.
Latent Dirichlet Allocations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package analysis.preprocessing; | |
import java.util.HashMap; | |
import java.util.Vector; | |
import database.OhuraDB; | |
import database.indexing.LocalFeature; | |
/** | |
* Topic Modeling using the Latent Dirichlet | |
* Allocation. Inference performed using a collapsed | |
* Gibbs Sampler. | |
* | |
* There is no hyperparameter learning | |
* and there is no automatic selection of | |
* number of topics. So this is a very simple | |
* model. | |
* | |
* GO TO BLEI OR DEAL WITH IT ! | |
* | |
* @author Daniel Kohlsdorf | |
*/ | |
public class LatentDirichletAllocation extends DataItemization { | |
/** | |
* Uniform dirichlet parameters | |
*/ | |
private double alpha, beta; | |
/** | |
* Number of topics | |
*/ | |
private int T; | |
/** | |
* Number of Documents | |
*/ | |
private int D; | |
/** | |
* Number of words | |
*/ | |
private int W; | |
/** | |
* Count how often a word from | |
* document d is assigned to topic k | |
*/ | |
private int docu2topic[][]; | |
/** | |
* Count how often a specific word | |
* w is assigned to topic k | |
*/ | |
private int word2topic[][]; | |
/** | |
* Count how often any word is assigned | |
* to topic k | |
*/ | |
private int any2topic[]; | |
/** | |
* Length of each document | |
*/ | |
private int doc_sze[]; | |
public LatentDirichletAllocation(double alpha, double beta, int T) { | |
this.alpha = alpha; | |
this.beta = beta; | |
this.T = T; | |
} | |
public int[][] gibbs_estimator(Vector< Vector<Integer> > corpus, int iterations) { | |
D = corpus.size(); | |
/** | |
* initialize counts | |
*/ | |
docu2topic = new int[D][T]; | |
word2topic = new int[W][T]; | |
any2topic = new int[T]; | |
doc_sze = new int[D]; | |
int z[][] = new int[D][]; | |
for(int i = 0; i < D; i++) { | |
z[i] = new int[corpus.get(i).size()]; | |
doc_sze[i] = corpus.get(i).size(); | |
for(int j = 0; j < corpus.get(i).size(); j++) { | |
double rand = Math.random(); | |
int topic = (int) Math.floor(rand * T); | |
int w = corpus.get(i).get(j); | |
z[i][j] = topic; | |
word2topic[w][topic]++; | |
docu2topic[i][topic]++; | |
any2topic[topic]++; | |
} | |
} | |
for(int iter = 0; iter < iterations; iter++) { | |
// for all z_i | |
for (int m = 0; m < D; m++){ | |
for (int n = 0; n < corpus.get(m).size(); n++){ | |
int topic = z[m][n]; | |
int w = corpus.get(m).get(n); | |
// remove current variable from assignment | |
int N = corpus.get(m).size(); | |
word2topic[w][topic] -= 1; | |
docu2topic[m][topic] -= 1; | |
any2topic[topic] -= 1; | |
N -= 1; | |
double scaleWords = W * beta; | |
double scaleTopics = T * alpha; | |
// multinominal sampling | |
double mult[] = new double[T]; | |
for (int k = 0; k < T; k++){ | |
mult[k] = (word2topic[w][k] + beta) / (any2topic[k] * scaleWords); | |
mult[k] *= (docu2topic[m][k] + alpha) / (N * scaleTopics); | |
} | |
for (int k = 1; k < T; k++){ | |
mult[k] += mult[k - 1]; | |
} | |
double u = Math.random() * mult[T - 1]; | |
for (topic = 0; topic < T - 1; topic++){ | |
if (mult[topic] > u) {break;} | |
} | |
// update counts | |
word2topic[w][topic] += 1; | |
docu2topic[m][topic] += 1; | |
any2topic[topic] += 1; | |
z[m][n] = topic; | |
} | |
} | |
} | |
return z; | |
} | |
/** | |
* Topic x Doc Distribution | |
*/ | |
public double[][] computeTheta(){ | |
double theta[][] = new double[D][T]; | |
for (int m = 0; m < D; m++){ | |
for (int k = 0; k < T; k++){ | |
theta[m][k] = (docu2topic[m][k] + alpha) / (doc_sze[m] + T * alpha); | |
} | |
} | |
return theta; | |
} | |
/** | |
* Topic x Word Distribution | |
*/ | |
public double[][] computePhi(){ | |
double phi[][] = new double[T][W]; | |
for (int k = 0; k < T; k++){ | |
for (int w = 0; w < W; w++){ | |
phi[k][w] = (word2topic[w][k] + beta) / (any2topic[k] + W * beta); | |
} | |
} | |
return phi; | |
} | |
public String serial() { | |
double theta[][] = computeTheta(); | |
String theta_str = "THETA:\n"; | |
for(int i = 0; i < theta.length; i++) { | |
for(int j = 0; j < theta[i].length; j++) { | |
theta_str += theta[i][j] + ", "; | |
} | |
theta_str += "\n"; | |
} | |
double phi[][] = computePhi(); | |
String phi_str = "\n\nPHI:\n"; | |
for(int i = 0; i < phi.length; i++) { | |
for(int j = 0; j < phi[i].length; j++) { | |
phi_str += phi[i][j] + ", "; | |
} | |
phi_str += "\n"; | |
} | |
return theta_str + phi_str; | |
} | |
@Override | |
public Vector<double[]> compute() { | |
Vector<Vector<Integer>> labels = new Vector<Vector<Integer>>(); | |
for (int id = 0; id < OhuraDB.spec().getDb().queryMaxDocID() + 1; id++) { | |
int ngroups = OhuraDB.spec().getDb().queryMaxGrp(id); | |
for(int i = 0; i < ngroups; i++) { | |
Vector<Integer> lab = new Vector<Integer>(); | |
Vector<LocalFeature> localF; | |
int k = 0; | |
while((localF = OhuraDB.spec().getDb().queryLocalByIDandGroupAndFeature(id, i, k)).size() > 0) { | |
if(localF.get(0).getLabel_id() >= 0) { | |
lab.add(localF.get(0).getLabel_id()); | |
} | |
k++; | |
} | |
System.out.println(" " + i + " " + k); | |
labels.add(lab); | |
} | |
} | |
Vector<Vector<Integer>> corpus = new Vector<Vector<Integer>>(); | |
HashMap<Integer, Integer> ids = new HashMap<Integer, Integer>(); | |
int id = 0; | |
for(Vector<Integer> grp : labels) { | |
Vector<Integer> doc = new Vector<Integer>(); | |
for(Integer i : grp) { | |
if(!ids.containsKey(i)) { | |
ids.put(i, id); | |
doc.add(id); | |
id++; | |
} else { | |
doc.add(ids.get(i)); | |
} | |
} | |
corpus.add(doc); | |
} | |
this.W = id + 1; | |
gibbs_estimator(corpus, 5000); | |
Vector<double[]> features = new Vector<>(); | |
double[][] theta = computeTheta(); | |
for(int i = 0; i < theta.length; i++) { | |
features.add(theta[i]); | |
} | |
return features; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment