Skip to content

Instantly share code, notes, and snippets.

@dilipbobby
Created December 12, 2016 05:23
Show Gist options
  • Save dilipbobby/bb38acae7d6ca33ba49130a23e9b7e0b to your computer and use it in GitHub Desktop.
Save dilipbobby/bb38acae7d6ca33ba49130a23e9b7e0b to your computer and use it in GitHub Desktop.
LatentDirichletAllocation test code.
import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.classification.SoftMaxRegression;
import com.datumbox.framework.core.machinelearning.topicmodeling.LatentDirichletAllocation;
import com.datumbox.framework.core.utilities.text.extractors.UniqueWordSequenceExtractor;
import java.io.UncheckedIOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Map;
public class LatentDirichletAllocationTest {
/**
* class LatentDirichletAllocation.
*/
public static void main(String args[]) {
Configuration conf = Configuration.getConfiguration();
String dbName = "LatentDirichletAllocation";
Map<Object, URI> dataset = new HashMap<>();
try {
dataset.put("positive", LatentDirichletAllocationTest.class.getClassLoader().getResource("datasets/sentipos").toURI());
dataset.put("negative", LatentDirichletAllocationTest.class.getClassLoader().getResource("datasets/sentineg").toURI());
}
catch(UncheckedIOException | URISyntaxException ex) {
//logger.warn("Unable to download datasets, skipping test.");
throw new RuntimeException(ex);
}
UniqueWordSequenceExtractor wsExtractor = new UniqueWordSequenceExtractor(new UniqueWordSequenceExtractor.Parameters());
Dataframe trainingData = Dataframe.Builder.parseTextFiles(dataset, wsExtractor, conf);
LatentDirichletAllocation lda = new LatentDirichletAllocation(dbName, conf);
LatentDirichletAllocation.TrainingParameters trainingParameters = new LatentDirichletAllocation.TrainingParameters();
trainingParameters.setMaxIterations(15);
trainingParameters.setAlpha(0.01);
trainingParameters.setBeta(0.01);
trainingParameters.setK(25);
lda.fit(trainingData, trainingParameters);
lda.validate(trainingData);
Dataframe reducedTrainingData = new Dataframe(conf);
for(Record r : trainingData) {
//take the topic assignments and convert them into a new Record
reducedTrainingData.add(new Record(r.getYPredictedProbabilities(), r.getY()));
System.out.println(r);
//System.out.println(r.getY());
// System.out.println(r.getYPredictedProbabilities());
}
reducedTrainingData.get(0);
SoftMaxRegression smr = new SoftMaxRegression(dbName, conf);
SoftMaxRegression.TrainingParameters tp = new SoftMaxRegression.TrainingParameters();
tp.setLearningRate(1.0);
tp.setTotalIterations(50);
SoftMaxRegression.ValidationMetrics vm = smr.kFoldCrossValidation(reducedTrainingData, tp, 1);
Dataframe redtestData = new Dataframe(conf);
// double expResult = 0.6843125117743629;
double result = vm.getMacroF1();
System.out.println(result);
// assertEquals(expResult, result, Constants.DOUBLE_ACCURACY_HIGH);
smr.delete();
lda.delete();
reducedTrainingData.delete();
trainingData.delete();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment