Created
February 8, 2016 11:01
-
-
Save rohanar/8581a57e7d94839b28e8 to your computer and use it in GitHub Desktop.
ParahraphVectors deserializing issue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.paragraphvectors; | |
import org.canova.api.util.ClassPathResource; | |
import org.deeplearning4j.berkeley.Pair; | |
import org.deeplearning4j.clustering.kmeans.KMeansClustering; | |
import org.deeplearning4j.examples.paragraphvectors.tools.FileLabelAwareIterator; | |
import org.deeplearning4j.examples.paragraphvectors.tools.LabelSeeker; | |
import org.deeplearning4j.examples.paragraphvectors.tools.MeansBuilder; | |
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; | |
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; | |
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; | |
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; | |
import org.deeplearning4j.models.word2vec.VocabWord; | |
import org.deeplearning4j.text.documentiterator.LabelAwareIterator; | |
import org.deeplearning4j.text.documentiterator.LabelledDocument; | |
import org.deeplearning4j.text.documentiterator.SimpleLabelAwareIterator; | |
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor; | |
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory; | |
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import org.synthesis.java.extension.ParagraphVectorSerializer; | |
import java.io.File; | |
import java.util.List; | |
/** | |
* This is basic example for documents classification done with DL4j ParagraphVectors. | |
* The overall idea is to use ParagraphVectors in the same way we use LDA: topic space modelling. | |
* | |
* In this example we assume we have few labeled categories that we can use for training, and few unlabeled documents. And our goal is to determine, which category these unlabeled documents fall into | |
* | |
* | |
* Please note: This example could be improved by using learning cascade for higher accuracy, but that's beyond basic example paradigm. | |
* | |
* @author raver119@gmail.com | |
*/ | |
public class ParagraphVectorsClassifierExample { | |
private static final Logger log = LoggerFactory.getLogger(ParagraphVectorsClassifierExample.class); | |
public static void main(String[] args) throws Exception { | |
ClassPathResource resource = new ClassPathResource("paravec/labeled"); | |
// build a iterator for our dataset | |
LabelAwareIterator iterator = new FileLabelAwareIterator.Builder() | |
.addSourceFolder(resource.getFile()) | |
.build(); | |
TokenizerFactory t = new DefaultTokenizerFactory(); | |
t.setTokenPreProcessor(new CommonPreprocessor()); | |
// ParagraphVectors training configuration | |
ParagraphVectors paragraphVectors = new ParagraphVectors.Builder() | |
.learningRate(0.025) | |
.minLearningRate(0.001) | |
.batchSize(1000) | |
.epochs(10) | |
.iterate(iterator) | |
.trainWordVectors(true) | |
.tokenizerFactory(t) | |
.build(); | |
// Start model training | |
paragraphVectors.fit(); | |
//Serialising/Deserializing example | |
//This works too. Needs updated version of WordVectorSerializer.java from here: https://gist.github.com/raver119/4f2e3aa550bea7b7b74c (shared by raver119) | |
// ParagraphVectorSerializer.writeWordVectors(paragraphVectors, "test.txt"); | |
// WordVectorSerializer.writeWordVectors(paragraphVectors, "test.txt"); | |
// paragraphVectors = null; | |
// //paragraphVectors = ParagraphVectorSerializer.readParagraphVectorsFromText("test.txt"); | |
// paragraphVectors = WordVectorSerializer.readParagraphVectorsFromText("test.txt"); | |
//This works too. Needs updated version of WordVectorSerializer.java from here: https://gist.github.com/raver119/4f2e3aa550bea7b7b74c (shared by raver119) | |
//ParagraphVectorSerializer.writeWordVectors(paragraphVectors, new File("test.txt")); | |
WordVectorSerializer.writeWordVectors(paragraphVectors, new File("test.txt")); | |
paragraphVectors = null; | |
// paragraphVectors = ParagraphVectorSerializer.readParagraphVectorsFromText(new File("test.txt")); | |
paragraphVectors = WordVectorSerializer.readParagraphVectorsFromText(new File("test.txt")); | |
/* | |
At this point we assume that we have model built and we can check, which categories our unlabeled document falls into | |
So we'll start loading our unlabeled documents and checking them | |
*/ | |
ClassPathResource unlabeledResource = new ClassPathResource("paravec/unlabeled"); | |
FileLabelAwareIterator unlabeledIterator = new FileLabelAwareIterator.Builder() | |
.addSourceFolder(unlabeledResource.getFile()) | |
.build(); | |
/* | |
Now we'll iterate over unlabeled data, and check which label it could be assigned to | |
Please note: for many domains it's normal to have 1 document fall into few labels at once, with different "weight" for each. | |
*/ | |
MeansBuilder meansBuilder = new MeansBuilder((InMemoryLookupTable<VocabWord>) paragraphVectors.getLookupTable(), t); | |
LabelSeeker seeker = new LabelSeeker(iterator.getLabelsSource().getLabels(), (InMemoryLookupTable<VocabWord>) paragraphVectors.getLookupTable()); | |
while (unlabeledIterator.hasNextDocument()) { | |
LabelledDocument document = unlabeledIterator.nextDocument(); | |
INDArray documentAsCentroid = meansBuilder.documentAsVector(document); | |
List<Pair<String, Double>> scores = seeker.getScores(documentAsCentroid); | |
/* | |
please note, document.getLabel() is used just to show which document we're looking at now, as a substitute for printing out the whole document itself. | |
So, labels on these two documents are used like titles, just to visualize our classification done properly | |
*/ | |
log.info("Document '" + document.getLabel() + "' falls into the following categories: "); | |
for (Pair<String, Double> score: scores) { | |
log.info(" " + score.getFirst() + ": " + score.getSecond()); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment