Skip to content

Instantly share code, notes, and snippets.

@rohanar
Created February 8, 2016 11:01
Show Gist options
  • Save rohanar/8581a57e7d94839b28e8 to your computer and use it in GitHub Desktop.
Save rohanar/8581a57e7d94839b28e8 to your computer and use it in GitHub Desktop.
ParahraphVectors deserializing issue
package org.deeplearning4j.examples.paragraphvectors;
import org.canova.api.util.ClassPathResource;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.clustering.kmeans.KMeansClustering;
import org.deeplearning4j.examples.paragraphvectors.tools.FileLabelAwareIterator;
import org.deeplearning4j.examples.paragraphvectors.tools.LabelSeeker;
import org.deeplearning4j.examples.paragraphvectors.tools.MeansBuilder;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.deeplearning4j.text.documentiterator.SimpleLabelAwareIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.synthesis.java.extension.ParagraphVectorSerializer;
import java.io.File;
import java.util.List;
/**
* This is basic example for documents classification done with DL4j ParagraphVectors.
* The overall idea is to use ParagraphVectors in the same way we use LDA: topic space modelling.
*
* In this example we assume we have few labeled categories that we can use for training, and few unlabeled documents. And our goal is to determine, which category these unlabeled documents fall into
*
*
* Please note: This example could be improved by using learning cascade for higher accuracy, but that's beyond basic example paradigm.
*
* @author raver119@gmail.com
*/
public class ParagraphVectorsClassifierExample {
private static final Logger log = LoggerFactory.getLogger(ParagraphVectorsClassifierExample.class);
public static void main(String[] args) throws Exception {
ClassPathResource resource = new ClassPathResource("paravec/labeled");
// build a iterator for our dataset
LabelAwareIterator iterator = new FileLabelAwareIterator.Builder()
.addSourceFolder(resource.getFile())
.build();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
// ParagraphVectors training configuration
ParagraphVectors paragraphVectors = new ParagraphVectors.Builder()
.learningRate(0.025)
.minLearningRate(0.001)
.batchSize(1000)
.epochs(10)
.iterate(iterator)
.trainWordVectors(true)
.tokenizerFactory(t)
.build();
// Start model training
paragraphVectors.fit();
//Serialising/Deserializing example
//This works too. Needs updated version of WordVectorSerializer.java from here: https://gist.github.com/raver119/4f2e3aa550bea7b7b74c (shared by raver119)
// ParagraphVectorSerializer.writeWordVectors(paragraphVectors, "test.txt");
// WordVectorSerializer.writeWordVectors(paragraphVectors, "test.txt");
// paragraphVectors = null;
// //paragraphVectors = ParagraphVectorSerializer.readParagraphVectorsFromText("test.txt");
// paragraphVectors = WordVectorSerializer.readParagraphVectorsFromText("test.txt");
//This works too. Needs updated version of WordVectorSerializer.java from here: https://gist.github.com/raver119/4f2e3aa550bea7b7b74c (shared by raver119)
//ParagraphVectorSerializer.writeWordVectors(paragraphVectors, new File("test.txt"));
WordVectorSerializer.writeWordVectors(paragraphVectors, new File("test.txt"));
paragraphVectors = null;
// paragraphVectors = ParagraphVectorSerializer.readParagraphVectorsFromText(new File("test.txt"));
paragraphVectors = WordVectorSerializer.readParagraphVectorsFromText(new File("test.txt"));
/*
At this point we assume that we have model built and we can check, which categories our unlabeled document falls into
So we'll start loading our unlabeled documents and checking them
*/
ClassPathResource unlabeledResource = new ClassPathResource("paravec/unlabeled");
FileLabelAwareIterator unlabeledIterator = new FileLabelAwareIterator.Builder()
.addSourceFolder(unlabeledResource.getFile())
.build();
/*
Now we'll iterate over unlabeled data, and check which label it could be assigned to
Please note: for many domains it's normal to have 1 document fall into few labels at once, with different "weight" for each.
*/
MeansBuilder meansBuilder = new MeansBuilder((InMemoryLookupTable<VocabWord>) paragraphVectors.getLookupTable(), t);
LabelSeeker seeker = new LabelSeeker(iterator.getLabelsSource().getLabels(), (InMemoryLookupTable<VocabWord>) paragraphVectors.getLookupTable());
while (unlabeledIterator.hasNextDocument()) {
LabelledDocument document = unlabeledIterator.nextDocument();
INDArray documentAsCentroid = meansBuilder.documentAsVector(document);
List<Pair<String, Double>> scores = seeker.getScores(documentAsCentroid);
/*
please note, document.getLabel() is used just to show which document we're looking at now, as a substitute for printing out the whole document itself.
So, labels on these two documents are used like titles, just to visualize our classification done properly
*/
log.info("Document '" + document.getLabel() + "' falls into the following categories: ");
for (Pair<String, Double> score: scores) {
log.info(" " + score.getFirst() + ": " + score.getSecond());
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment