Skip to content

Instantly share code, notes, and snippets.

@rohanar
Created January 29, 2016 15:32
Show Gist options
  • Save rohanar/3a8550cefb4435291a3f to your computer and use it in GitHub Desktop.
Save rohanar/3a8550cefb4435291a3f to your computer and use it in GitHub Desktop.
ParagraphVectorSerialiser
package org.deeplearning4j.examples.paragraphvectors;
import org.canova.api.util.ClassPathResource;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.clustering.kmeans.KMeansClustering;
import org.deeplearning4j.examples.paragraphvectors.tools.FileLabelAwareIterator;
import org.deeplearning4j.examples.paragraphvectors.tools.LabelSeeker;
import org.deeplearning4j.examples.paragraphvectors.tools.MeansBuilder;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.deeplearning4j.text.documentiterator.SimpleLabelAwareIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.synthesis.java.extension.ParagraphVectorSerializer;
import java.io.File;
import java.util.List;
/**
* This is basic example for documents classification done with DL4j ParagraphVectors.
* The overall idea is to use ParagraphVectors in the same way we use LDA: topic space modelling.
*
* In this example we assume we have few labeled categories that we can use for training, and few unlabeled documents. And our goal is to determine, which category these unlabeled documents fall into
*
*
* Please note: This example could be improved by using learning cascade for higher accuracy, but that's beyond basic example paradigm.
*
* @author raver119@gmail.com
*/
public class ParagraphVectorsClassifierExample {
private static final Logger log = LoggerFactory.getLogger(ParagraphVectorsClassifierExample.class);
public static void main(String[] args) throws Exception {
ClassPathResource resource = new ClassPathResource("paravec/labeled");
// build a iterator for our dataset
LabelAwareIterator iterator = new FileLabelAwareIterator.Builder()
.addSourceFolder(resource.getFile())
.build();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
// ParagraphVectors training configuration
ParagraphVectors paragraphVectors = new ParagraphVectors.Builder()
.learningRate(0.025)
.minLearningRate(0.001)
.batchSize(1000)
.epochs(20)
.iterate(iterator)
.trainWordVectors(true)
.tokenizerFactory(t)
.build();
// Start model training
paragraphVectors.fit();
//Serialize the learned ParagraphVectors
ParagraphVectorSerializer.writeWordVectors(paragraphVectors, new File("test.txt"));
paragraphVectors = null; //Set the in-memory model null
//Deserialize the serialized ParagraphVectors
paragraphVectors = ParagraphVectorSerializer.readParagraphVectorsFromText(new File("test.txt"));
/*
At this point we assume that we have model built and we can check, which categories our unlabeled document falls into
So we'll start loading our unlabeled documents and checking them
*/
ClassPathResource unlabeledResource = new ClassPathResource("paravec/unlabeled");
FileLabelAwareIterator unlabeledIterator = new FileLabelAwareIterator.Builder()
.addSourceFolder(unlabeledResource.getFile())
.build();
/*
Now we'll iterate over unlabeled data, and check which label it could be assigned to
Please note: for many domains it's normal to have 1 document fall into few labels at once, with different "weight" for each.
*/
MeansBuilder meansBuilder = new MeansBuilder((InMemoryLookupTable<VocabWord>) paragraphVectors.getLookupTable(), t);
LabelSeeker seeker = new LabelSeeker(iterator.getLabelsSource().getLabels(), (InMemoryLookupTable<VocabWord>) paragraphVectors.getLookupTable());
while (unlabeledIterator.hasNextDocument()) {
LabelledDocument document = unlabeledIterator.nextDocument();
INDArray documentAsCentroid = meansBuilder.documentAsVector(document);
List<Pair<String, Double>> scores = seeker.getScores(documentAsCentroid);
/*
please note, document.getLabel() is used just to show which document we're looking at now, as a substitute for printing out the whole document itself.
So, labels on these two documents are used like titles, just to visualize our classification done properly
*/
log.info("Document '" + document.getLabel() + "' falls into the following categories: ");
for (Pair<String, Double> score: scores) {
log.info(" " + score.getFirst() + ": " + score.getSecond());
}
/*
Your output should be like this:
Document 'health' falls into the following categories:
health: 0.29721372296220205
science: 0.011684473733853906
finance: -0.14755302887323793
Document 'finance' falls into the following categories:
health: -0.17290237675941766
science: -0.09579267574606627
finance: 0.4460859189453788
so,now we know categories for yet unseen documents
*/
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment