Skip to content

Instantly share code, notes, and snippets.

@tgreiser
Created November 15, 2017 01:35
Show Gist options
  • Save tgreiser/d91db7553cfd3122fd4b342710416f29 to your computer and use it in GitHub Desktop.
Save tgreiser/d91db7553cfd3122fd4b342710416f29 to your computer and use it in GitHub Desktop.
paragraph_vector_example.scala
/*
Ported from: https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp/paragraphvectors/ParagraphVectorsClassifierExample.java
NOTE - you must download the paravec resources data to /opt/data/paravec on your SKIL server.
https://github.com/deeplearning4j/dl4j-examples/tree/master/dl4j-examples/src/main/resources/paravec
*/
import scala.collection.JavaConversions._
import org.datavec.api.util.ClassPathResource;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.FileNotFoundException;
import java.util.concurrent.atomic.AtomicInteger;
val resource = new ClassPathResource("/opt/data/paravec/labeled");
val iterator = new FileLabelAwareIterator.Builder().addSourceFolder(resource.getFile()).build();
val tokenizerFactory = new DefaultTokenizerFactory();
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
// ParagraphVectors training configuration
val paragraphVectors = new ParagraphVectors.Builder()
.learningRate(0.025)
.minLearningRate(0.001)
.batchSize(1000)
.epochs(20)
.iterate(iterator)
.trainWordVectors(true)
.tokenizerFactory(tokenizerFactory)
.build();
// Start model training
paragraphVectors.fit();
def documentAsVector(lookupTable: InMemoryLookupTable[VocabWord], tokenizerFactory: TokenizerFactory, document: LabelledDocument) : INDArray = {
var vocabCache = lookupTable.getVocab()
var documentAsTokens = tokenizerFactory.create(document.getContent()).getTokens();
val cnt = new AtomicInteger(0);
for (word <- documentAsTokens) {
if (vocabCache.containsWord(word)) cnt.incrementAndGet();
}
var allWords = Nd4j.create(cnt.get(), lookupTable.layerSize());
cnt.set(0);
for (word <- documentAsTokens) {
if (vocabCache.containsWord(word))
allWords.putRow(cnt.getAndIncrement(), lookupTable.vector(word));
}
return allWords.mean(0);
}
class LabelSeeker(labelsUsed: List[String], lookupTable: org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable[
org.deeplearning4j.models.word2vec.VocabWord
]) {
if (labelsUsed.isEmpty) throw new IllegalStateException("You can't have 0 labels used for ParagraphVectors");
def getScores(vector: org.nd4j.linalg.api.ndarray.INDArray) : Array[Tuple2[String, Double]] = {
val result = new Array[Tuple2[String, Double]](0)
for (label <- labelsUsed) {
val vecLabel = lookupTable.vector(label);
if (vecLabel == null) throw new IllegalStateException("Label '"+ label+"' has no known vector!");
val sim = org.nd4j.linalg.ops.transforms.Transforms.cosineSim(vector, vecLabel);
result :+ (label, sim)
}
return result
}
}
/*
At this point we assume that we have model built and we can check
which categories our unlabeled document falls into.
So we'll start loading our unlabeled documents and checking them
*/
val unClassifiedResource = new ClassPathResource("/opt/data/paravec/unlabeled")
val builder = new FileLabelAwareIterator.Builder().addSourceFolder(unClassifiedResource.getFile())
val unClassifiedIterator = builder.build()
/*
Now we'll iterate over unlabeled data, and check which label it could be assigned to
Please note: for many domains it's normal to have 1 document fall into few labels at once,
with different "weight" for each.
*/
var labels = iterator.getLabelsSource().getLabels().toList
val seeker = new LabelSeeker(labels, paragraphVectors.getLookupTable().asInstanceOf[InMemoryLookupTable[VocabWord]]);
unClassifiedIterator.reset()
while (unClassifiedIterator.hasNextDocument()) {
// THIS LOOP NEVER RUNS?! unClassifiedIterator has labels but apparently no documents
val document = unClassifiedIterator.nextDocument();
val documentAsCentroid = documentAsVector(paragraphVectors.getLookupTable().asInstanceOf[InMemoryLookupTable[VocabWord]], tokenizerFactory, document);
val scores = seeker.getScores(documentAsCentroid);
/*
please note, document.getLabel() is used just to show which document we're looking at now,
as a substitute for printing out the whole document name.
So, labels on these two documents are used like titles,
just to visualize our classification done properly
*/
for (score <- scores) {
print(s" " + score._1 + ": " + score._2)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment