Created
January 29, 2016 15:32
-
-
Save rohanar/3a8550cefb4435291a3f to your computer and use it in GitHub Desktop.
ParagraphVectorSerialiser
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.paragraphvectors; | |
import org.canova.api.util.ClassPathResource; | |
import org.deeplearning4j.berkeley.Pair; | |
import org.deeplearning4j.clustering.kmeans.KMeansClustering; | |
import org.deeplearning4j.examples.paragraphvectors.tools.FileLabelAwareIterator; | |
import org.deeplearning4j.examples.paragraphvectors.tools.LabelSeeker; | |
import org.deeplearning4j.examples.paragraphvectors.tools.MeansBuilder; | |
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; | |
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; | |
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; | |
import org.deeplearning4j.models.word2vec.VocabWord; | |
import org.deeplearning4j.text.documentiterator.LabelAwareIterator; | |
import org.deeplearning4j.text.documentiterator.LabelledDocument; | |
import org.deeplearning4j.text.documentiterator.SimpleLabelAwareIterator; | |
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | |
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | |
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import org.synthesis.java.extension.ParagraphVectorSerializer; | |
import java.io.File; | |
import java.util.List; | |
/** | |
* This is basic example for documents classification done with DL4j ParagraphVectors. | |
* The overall idea is to use ParagraphVectors in the same way we use LDA: topic space modelling. | |
* | |
* In this example we assume we have few labeled categories that we can use for training, and few unlabeled documents. And our goal is to determine, which category these unlabeled documents fall into | |
* | |
* | |
* Please note: This example could be improved by using learning cascade for higher accuracy, but that's beyond basic example paradigm. | |
* | |
* @author raver119@gmail.com | |
*/ | |
public class ParagraphVectorsClassifierExample { | |
private static final Logger log = LoggerFactory.getLogger(ParagraphVectorsClassifierExample.class); | |
public static void main(String[] args) throws Exception { | |
ClassPathResource resource = new ClassPathResource("paravec/labeled"); | |
// build a iterator for our dataset | |
LabelAwareIterator iterator = new FileLabelAwareIterator.Builder() | |
.addSourceFolder(resource.getFile()) | |
.build(); | |
TokenizerFactory t = new DefaultTokenizerFactory(); | |
t.setTokenPreProcessor(new CommonPreprocessor()); | |
// ParagraphVectors training configuration | |
ParagraphVectors paragraphVectors = new ParagraphVectors.Builder() | |
.learningRate(0.025) | |
.minLearningRate(0.001) | |
.batchSize(1000) | |
.epochs(20) | |
.iterate(iterator) | |
.trainWordVectors(true) | |
.tokenizerFactory(t) | |
.build(); | |
// Start model training | |
paragraphVectors.fit(); | |
//Serialize the learned ParagraphVectors | |
ParagraphVectorSerializer.writeWordVectors(paragraphVectors, new File("test.txt")); | |
paragraphVectors = null; //Set the in-memory model null | |
//Deserialize the serialized ParagraphVectors | |
paragraphVectors = ParagraphVectorSerializer.readParagraphVectorsFromText(new File("test.txt")); | |
/* | |
At this point we assume that we have model built and we can check, which categories our unlabeled document falls into | |
So we'll start loading our unlabeled documents and checking them | |
*/ | |
ClassPathResource unlabeledResource = new ClassPathResource("paravec/unlabeled"); | |
FileLabelAwareIterator unlabeledIterator = new FileLabelAwareIterator.Builder() | |
.addSourceFolder(unlabeledResource.getFile()) | |
.build(); | |
/* | |
Now we'll iterate over unlabeled data, and check which label it could be assigned to | |
Please note: for many domains it's normal to have 1 document fall into few labels at once, with different "weight" for each. | |
*/ | |
MeansBuilder meansBuilder = new MeansBuilder((InMemoryLookupTable<VocabWord>) paragraphVectors.getLookupTable(), t); | |
LabelSeeker seeker = new LabelSeeker(iterator.getLabelsSource().getLabels(), (InMemoryLookupTable<VocabWord>) paragraphVectors.getLookupTable()); | |
while (unlabeledIterator.hasNextDocument()) { | |
LabelledDocument document = unlabeledIterator.nextDocument(); | |
INDArray documentAsCentroid = meansBuilder.documentAsVector(document); | |
List<Pair<String, Double>> scores = seeker.getScores(documentAsCentroid); | |
/* | |
please note, document.getLabel() is used just to show which document we're looking at now, as a substitute for printing out the whole document itself. | |
So, labels on these two documents are used like titles, just to visualize our classification done properly | |
*/ | |
log.info("Document '" + document.getLabel() + "' falls into the following categories: "); | |
for (Pair<String, Double> score: scores) { | |
log.info(" " + score.getFirst() + ": " + score.getSecond()); | |
} | |
/* | |
Your output should be like this: | |
Document 'health' falls into the following categories: | |
health: 0.29721372296220205 | |
science: 0.011684473733853906 | |
finance: -0.14755302887323793 | |
Document 'finance' falls into the following categories: | |
health: -0.17290237675941766 | |
science: -0.09579267574606627 | |
finance: 0.4460859189453788 | |
so,now we know categories for yet unseen documents | |
*/ | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment