Created
February 9, 2016 10:54
-
-
Save rohanar/73e5a2b2ee47e23ccb51 to your computer and use it in GitHub Desktop.
LabelAwareIterator for ParagraphVector Learning Test code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.synthesis.java.extension; | |
import java.util.ArrayList; | |
import java.util.Iterator; | |
import java.util.List; | |
import java.util.Map; | |
import lombok.NonNull; | |
import org.deeplearning4j.text.documentiterator.LabelAwareIterator; | |
import org.deeplearning4j.text.documentiterator.LabelledDocument; | |
import org.deeplearning4j.text.documentiterator.LabelsSource; | |
public class SolrDocLabelAwareIterator2 implements LabelAwareIterator { | |
private static final long serialVersionUID = 1L; | |
private Iterator<String> iter; | |
private Map<String,String> docContentsMap; | |
protected LabelsSource labelsSource; | |
public SolrDocLabelAwareIterator2(Map<String,String> doccontentsmap, @NonNull LabelsSource source) { | |
docContentsMap = doccontentsmap; | |
iter = docContentsMap.keySet().iterator(); | |
labelsSource = source; | |
} | |
@Override | |
public boolean hasNextDocument() { | |
return iter.hasNext(); | |
} | |
@Override | |
public LabelledDocument nextDocument() { | |
LabelledDocument document = new LabelledDocument(); | |
if(iter.hasNext()) { | |
String label = iter.next(); | |
String txt = docContentsMap.get(label); | |
if (txt.length() < 30) { | |
txt = " insufficient content to cluster this document"; | |
} | |
document.setContent(txt); | |
document.setLabel(label); | |
} | |
return document; | |
} | |
@Override | |
public void reset() { | |
iter = docContentsMap.keySet().iterator(); | |
} | |
@Override | |
public LabelsSource getLabelsSource() { | |
return labelsSource; | |
} | |
public static class Builder { | |
public Builder() { | |
} | |
public SolrDocLabelAwareIterator2 build(Map<String,String> doccontentsmap) { | |
List<String> labels = new ArrayList<>(); | |
for (String docid: doccontentsmap.keySet()) { | |
labels.add(docid); | |
} | |
LabelsSource source = new LabelsSource(labels); | |
SolrDocLabelAwareIterator2 iterator = new SolrDocLabelAwareIterator2(doccontentsmap, source); | |
return iterator; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment