Skip to content

Instantly share code, notes, and snippets.

@rkfg
Created March 30, 2017 08:27
Show Gist options
  • Save rkfg/f740bea6afc0106e0dff04d37f018627 to your computer and use it in GitHub Desktop.
Save rkfg/f740bea6afc0106e0dff04d37f018627 to your computer and use it in GitHub Desktop.
package dlchat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
@SuppressWarnings("serial")
public class CorpusIterator implements MultiDataSetIterator {
/*
* Motivation: I want to get asynchronous data iteration while not blocking on net.fit() until the end of epoch. I want to checkpoint
* the network, show intermediate test results and some stats, it would be harder to achieve with listeners I think so this is how I
* solved the problem. This way the learn process is asynchronous inside one macrobatch and synchronous across all the macrobatches.
*
* Macrobatch is a group of minibatches. The iterator is modified so that it reports the end of data when it exhausts a macrobatch. Then
* it advances (manually) to the next macrobatch.
*/
private List<List<Double>> corpus;
private int batchSize;
private int batchesPerMacrobatch;
private int totalBatches;
private int totalMacroBatches;
private int currentBatch = 0;
private int currentMacroBatch = 0;
private int dictSize;
private int rowSize;
public CorpusIterator(List<List<Double>> corpus, int batchSize, int batchesPerMacrobatch, int dictSize, int rowSize) {
this.corpus = corpus;
this.batchSize = batchSize;
this.batchesPerMacrobatch = batchesPerMacrobatch;
this.dictSize = dictSize;
this.rowSize = rowSize;
totalBatches = (int) Math.ceil((double) corpus.size() / batchSize);
totalMacroBatches = (int) Math.ceil((double) totalBatches / batchesPerMacrobatch);
}
@Override
public boolean hasNext() {
return currentBatch < totalBatches && getMacroBatchByCurrentBatch() == currentMacroBatch;
}
private int getMacroBatchByCurrentBatch() {
return currentBatch / batchesPerMacrobatch;
}
@Override
public MultiDataSet next() {
return next(batchSize);
}
@Override
public MultiDataSet next(int num) {
int i = currentBatch * batchSize;
int currentBatchSize = Math.min(batchSize, corpus.size() - i - 1);
int sequenceLength = 0;
for (int j = 0; j <= currentBatchSize; ++j) {
int size = corpus.get(i + j).size();
if (size > sequenceLength) {
sequenceLength = size;
}
}
sequenceLength = Math.min(rowSize, sequenceLength + 1);
INDArray input = Nd4j.zeros(currentBatchSize, 1, sequenceLength);
INDArray prediction = Nd4j.zeros(currentBatchSize, dictSize, sequenceLength);
INDArray decode = Nd4j.zeros(currentBatchSize, dictSize, sequenceLength);
INDArray inputMask = Nd4j.zeros(currentBatchSize, sequenceLength);
// this mask is also used for the decoder input, the length is the same
INDArray predictionMask = Nd4j.zeros(currentBatchSize, sequenceLength);
for (int j = 0; j < currentBatchSize; ++j) {
List<Double> rowIn = new ArrayList<>(corpus.get(i));
Collections.reverse(rowIn);
List<Double> rowPred = new ArrayList<>(corpus.get(i + 1));
rowPred.add(1.0); // add <eos> token
// replace the entire row in "input" using NDArrayIndex, it's faster than putScalar(); input is NOT made of one-hot vectors
// because of the embedding layer that accepts token indexes directly
input.put(new INDArrayIndex[] { NDArrayIndex.point(j), NDArrayIndex.point(0), NDArrayIndex.interval(0, rowIn.size()) },
Nd4j.create(ArrayUtils.toPrimitive(rowIn.toArray(new Double[0]))));
inputMask.put(new INDArrayIndex[] { NDArrayIndex.point(j), NDArrayIndex.interval(0, rowIn.size()) }, Nd4j.ones(rowIn.size()));
predictionMask.put(new INDArrayIndex[] { NDArrayIndex.point(j), NDArrayIndex.interval(0, rowPred.size()) },
Nd4j.ones(rowPred.size()));
// prediction (output) and decode ARE one-hots though, I couldn't add an embedding layer on top of the decoder and I'm not sure
// it's a good idea either
double predOneHot[][] = new double[dictSize][rowPred.size()];
double decodeOneHot[][] = new double[dictSize][rowPred.size()];
decodeOneHot[2][0] = 1; // <go> token
int predIdx = 0;
for (Double pred : rowPred) {
predOneHot[pred.intValue()][predIdx] = 1;
if (predIdx < rowPred.size() - 1) { // put the same vals to decode with +1 offset except the last token that is <eos>
decodeOneHot[pred.intValue()][predIdx + 1] = 1;
}
++predIdx;
}
prediction.put(new INDArrayIndex[] { NDArrayIndex.point(j), NDArrayIndex.interval(0, dictSize),
NDArrayIndex.interval(0, rowPred.size()) }, Nd4j.create(predOneHot));
decode.put(new INDArrayIndex[] { NDArrayIndex.point(j), NDArrayIndex.interval(0, dictSize),
NDArrayIndex.interval(0, rowPred.size()) }, Nd4j.create(decodeOneHot));
++i;
}
++currentBatch;
return new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { input, decode }, new INDArray[] { prediction },
new INDArray[] { inputMask, predictionMask }, new INDArray[] { predictionMask });
}
@Override
public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
}
@Override
public boolean resetSupported() {
// we don't want this iterator to be reset on each macrobatch pseudo-epoch
return false;
}
@Override
public boolean asyncSupported() {
return true;
}
@Override
public void reset() {
// but we still can do it manually before the epoch starts
currentBatch = 0;
currentMacroBatch = 0;
}
public int batch() {
return currentBatch;
}
public int totalBatches() {
return totalBatches;
}
public void setCurrentBatch(int currentBatch) {
this.currentBatch = currentBatch;
currentMacroBatch = getMacroBatchByCurrentBatch();
}
public boolean hasNextMacrobatch() {
return getMacroBatchByCurrentBatch() < totalMacroBatches && currentMacroBatch < totalMacroBatches;
}
public void nextMacroBatch() {
++currentMacroBatch;
}
}
package dlchat;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
public class CorpusProcessor {
public static final String SPECIALS = "!\"#$;%^:?*()[]{}<>«»,.–—=+…";
private Set<String> dictSet = new HashSet<>();
private Map<String, Double> freq = new HashMap<>();
private Map<String, Double> dict = new HashMap<>();
private boolean countFreq;
private InputStream is;
private int rowSize;
private String separator = " +++$+++ ";
private int fieldsCount = 5;
private int nameFieldIdx = 1;
private int textFieldIdx = 4;
public CorpusProcessor(String filename, int rowSize, boolean countFreq) throws FileNotFoundException {
this(new FileInputStream(filename), rowSize, countFreq);
}
public CorpusProcessor(InputStream is, int rowSize, boolean countFreq) {
this.is = is;
this.rowSize = rowSize;
this.countFreq = countFreq;
}
public void setFormatParams(String separator, int fieldsCount, int nameFieldIdx, int textFieldIdx) {
this.separator = separator;
this.fieldsCount = fieldsCount;
this.nameFieldIdx = nameFieldIdx;
this.textFieldIdx = textFieldIdx;
}
public void start() throws IOException {
try (BufferedReader br = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
String line;
String lastName = "";
String lastLine = "";
while ((line = br.readLine()) != null) {
String[] lineSplit = line.split(Pattern.quote(separator), fieldsCount);
if (lineSplit.length >= fieldsCount) {
// join consecuitive lines from the same speaker
String curLine = lineSplit[textFieldIdx];
String curName = lineSplit[nameFieldIdx];
if (curName.equals(lastName)) {
if (!lastLine.isEmpty()) {
// if the previous line doesn't end with a special symbol, append a comma and the current line
if (!SPECIALS.contains(lastLine.substring(lastLine.length() - 1))) {
lastLine += ",";
}
lastLine += " " + curLine;
} else {
lastLine = curLine;
}
} else {
if (!lastLine.isEmpty()) {
processLine(lastLine.toLowerCase());
}
lastLine = curLine;
lastName = curName;
}
}
}
processLine(lastLine.toLowerCase());
}
}
protected void processLine(String lastLine) {
tokenizeLine(lastLine, dictSet, false);
}
// here we not only split the words but also store punctuation marks
protected void tokenizeLine(String lastLine, Collection<String> resultCollection, boolean addSpecials) {
String[] words = lastLine.split("[ \t]");
for (String word : words) {
if (!word.isEmpty()) {
boolean specialFound = true;
while (specialFound && !word.isEmpty()) {
for (int i = 0; i < word.length(); ++i) {
int idx = SPECIALS.indexOf(word.charAt(i));
specialFound = false;
if (idx >= 0) {
String word1 = word.substring(0, i);
if (!word1.isEmpty()) {
addWord(resultCollection, word1);
}
if (addSpecials) {
addWord(resultCollection, String.valueOf(word.charAt(i)));
}
word = word.substring(i + 1);
specialFound = true;
break;
}
}
}
if (!word.isEmpty()) {
addWord(resultCollection, word);
}
}
}
}
private void addWord(Collection<String> coll, String word) {
if (coll != null) {
coll.add(word);
}
if (countFreq) {
Double count = freq.get(word);
if (count == null) {
freq.put(word, 1.0);
} else {
freq.put(word, count + 1);
}
}
}
public Set<String> getDictSet() {
return dictSet;
}
public Map<String, Double> getFreq() {
return freq;
}
public void setDict(Map<String, Double> dict) {
this.dict = dict;
}
protected boolean wordsToIndexes(Collection<String> words, List<Double> wordIdxs) {
int i = rowSize;
for (String word : words) {
if (--i == 0) {
break;
}
Double wordIdx = dict.get(word);
if (wordIdx != null) {
wordIdxs.add(wordIdx);
} else {
wordIdxs.add(0.0);
}
}
if (!wordIdxs.isEmpty()) {
return true;
}
return false;
}
}
package dlchat;
import java.io.BufferedWriter;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Scanner;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class EncoderDecoderLSTM {
/*
* This is a seq2seq encoder-decoder LSTM model made according to the Google's paper: [1] The model tries to predict the next dialog
* line using the provided one. It learns on the Cornell Movie Dialogs corpus. Unlike simple char RNNs this model is more sophisticated
* and theoretically, given enough time and data, can deduce facts from raw text. Your mileage may vary. This particular code is based
* on AdditionRNN but heavily changed to be used with a huge amount of possible tokens (10-20k), it also utilizes the decoder input
* unlike AdditionRNN.
*
* Use the get_data.sh script to download, extract and optimize the train data. It's been only tested on Linux, it could work on OS X or
* even on Windows 10 in the Ubuntu shell.
*
* Special tokens used:
*
* <unk> - replaces any word or other token that's not in the dictionary (too rare to be included or completely unknown)
*
* <eos> - end of sentence, used only in the output to stop the processing; the model input and output length is limited by the ROW_SIZE
* constant.
*
* <go> - used only in the decoder input as the first token before the model produced anything
*
* The architecture is like this: Input => Embedding Layer => Encoder => Decoder => Output (softmax)
*
* The encoder layer produces a so called "thought vector" that contains a compressed representation of the input. Depending on that
* vector the model produces different sentences even if they start with the same token. There's one more input, connected directly to
* the decoder layer, it's used to provide the previous token of the output. For the very first output token we send a special <go>
* token there, on the next iteration we use the token that the model produced the last time. On the training stage everything is
* simple, we apriori know the desired output so the decoder input would be the same token set prepended with the <go> token and without
* the last <eos> token. Example:
*
* Input: "how" "do" "you" "do" "?"
*
* Output: "I'm" "fine" "," "thanks" "!" "<eos>"
*
* Decoder: "<go>" "I'm" "fine" "," "thanks" "!"
*
* Actually, the input is reversed as per [2], the most important words are usually in the beginning of the phrase and they would get
* more weight if supplied last (the model "forgets" tokens that were supplied "long ago"). The output and decoder input sequence
* lengths are always equal. The input and output could be of any length (less than ROW_SIZE) so for purpose of batching we mask the
* unused part of the row. The encoder and decoder networks work sequentially. First the encoder creates the thought vector, that is the
* last activations of the layer. Those activations are then duplicated for as many time steps as there are elements in the output so
* that every output element can have its own copy of the thought vector. Then the decoder starts working. It receives two inputs, the
* thought vector made by the encoder and the token that it _should have produced_ (but usually it outputs something else so we have our
* loss metric and can compute gradients for the backward pass) on the previous step (or <go> for the very first step). These two
* vectors are simply concatenated by the merge vertex. The decoder's output goes to the softmax layer and that's it.
*
* The test phase is much more tricky. We don't know the decoder input because we don't know the output yet (unlike in the train phase),
* it could be anything. So we can't use methods like outputSingle() and have to do some manual work. Actually, we can but it would
* require full restarts of the entire process, it's super slow and ineffective.
*
* First, we do a single feed forward pass for the input with a single decoder element, <go>. We don't need the actual activations
* except the "thought vector". It resides in the second merge vertex input (named "dup"). So we get it and store for the entire
* response generation time. Then we put the decoder input (<go> for the first iteration) and the thought vector to the merge vertex
* inputs and feed it forward. The result goes to the decoder layer, now with rnnTimeStep() method so that the internal layer state is
* updated for the next iteration. The result is fed to the output softmax layer and then we sample it randomly (not with argMax(), it
* tends to give a lot of same tokens in a row). The resulting token is looked up in the dictionary, printed to the stdout and then it
* goes to the next iteration as the decoder input and so on until we get <eos>.
*
* To continue the training process from a specific batch number, enter it when prompted; batch numbers are printed after each processed
* macrobatch. If you've changed the minibatch size after the last launch, recalculate the number accordingly, i.e. if you doubled the
* minibatch size, specify half of the value and so on.
*
* [1] https://arxiv.org/abs/1506.05869 A Neural Conversational Model
*
* [2] https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf Sequence to Sequence Learning with
* Neural Networks
*/
public enum SaveState {
NONE, READY, SAVING, SAVENOW
}
private final Map<String, Double> dict = new HashMap<>();
private final Map<Double, String> revDict = new HashMap<>();
private final String CHARS = "-\\/_&" + CorpusProcessor.SPECIALS;
private List<List<Double>> corpus = new ArrayList<>();
private static final int HIDDEN_LAYER_WIDTH = 1024; // this is purely empirical, affects performance and VRAM requirement
private static final int EMBEDDING_WIDTH = 128; // one-hot vectors will be embedded to more dense vectors with this width
private static final String CORPUS_FILENAME = "movie_lines.txt"; // filename of data corpus to learn
private static final String MODEL_FILENAME = "rnn_train_movies.zip"; // filename of the model
private static final String BACKUP_MODEL_FILENAME = "rnn_train_movies.bak.zip"; // filename of the previous version of the model (backup)
private static final String DICTIONARY_FILENAME = "dictionary.txt";
private static final int MINIBATCH_SIZE = 16;
private static final Random rnd = new Random(new Date().getTime());
private static final long SAVE_EACH_MS = TimeUnit.MINUTES.toMillis(10); // save the model with this period
private static final long TEST_EACH_MS = TimeUnit.MINUTES.toMillis(1); // test the model with this period
private static final int MAX_DICT = 40000; // this number of most frequent words will be used, unknown words (that are not in the
// dictionary) are replaced with <unk> token
private static final int TBPTT_SIZE = 25;
private static final double LEARNING_RATE = 1e-2;
private static final double RMS_DECAY = 0.95;
private static final int ROW_SIZE = 20; // maximum line length in tokens
private static final int MACROBATCH_SIZE = 20; // see CorpusIterator
private static final boolean TMP_DATA_DIR = false;
private SaveState saveState = SaveState.NONE;
private static final boolean SAVE_ON_EXIT = false;
private ComputationGraph net;
public static void main(String[] args) throws Exception {
new EncoderDecoderLSTM().run(args);
}
private void run(String[] args) throws Exception {
File networkFile = new File(toTempPath(MODEL_FILENAME));
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
if (SAVE_ON_EXIT && saveState == SaveState.READY) {
saveState = SaveState.SAVENOW;
System.out.println(
"Wait for the current macrobatch to end, then the model will be saved and the program will terminate.");
while (saveState != SaveState.READY) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
});
Nd4j.getMemoryManager().togglePeriodicGc(false);
createDictionary();
int offset = 0;
if (networkFile.exists()) {
System.out.println("Loading the existing network...");
net = ModelSerializer.restoreComputationGraph(networkFile);
offset = net.getConfiguration().getIterationCount();
System.out.print("Enter d to start dialog or a number to continue training from that minibatch (press Enter to start from ["
+ offset + "]: ");
String input;
try (Scanner scanner = new Scanner(System.in)) {
input = scanner.nextLine();
if (input.toLowerCase().equals("d")) {
startDialog(scanner);
} else {
if (!input.isEmpty()) {
offset = Integer.valueOf(input);
}
net.getConfiguration().setIterationCount(offset);
test();
}
}
} else {
System.out.println("Creating a new network...");
createComputationGraph();
}
System.out.println("Number of parameters: " + net.numParams());
net.setListeners(new ScoreIterationListener(1));
train(networkFile, offset);
}
private void createComputationGraph() {
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.iterations(1).learningRate(LEARNING_RATE).rmsDecay(RMS_DECAY)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).miniBatch(true).updater(Updater.RMSPROP)
.weightInit(WeightInit.XAVIER).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer);
GraphBuilder graphBuilder = builder.graphBuilder().pretrain(false).backprop(true).backpropType(BackpropType.Standard)
.tBPTTBackwardLength(TBPTT_SIZE).tBPTTForwardLength(TBPTT_SIZE);
graphBuilder.addInputs("inputLine", "decoderInput")
.setInputTypes(InputType.recurrent(dict.size()), InputType.recurrent(dict.size()))
.addLayer("embeddingEncoder", new EmbeddingLayer.Builder().nIn(dict.size()).nOut(EMBEDDING_WIDTH).build(), "inputLine")
.addLayer("encoder",
new GravesLSTM.Builder().nIn(EMBEDDING_WIDTH).nOut(HIDDEN_LAYER_WIDTH).activation(Activation.TANH).build(),
"embeddingEncoder")
.addVertex("thoughtVector", new LastTimeStepVertex("inputLine"), "encoder")
.addVertex("dup", new DuplicateToTimeSeriesVertex("decoderInput"), "thoughtVector")
.addVertex("merge", new MergeVertex(), "decoderInput", "dup")
.addLayer("decoder",
new GravesLSTM.Builder().nIn(dict.size() + HIDDEN_LAYER_WIDTH).nOut(HIDDEN_LAYER_WIDTH).activation(Activation.TANH)
.build(),
"merge")
.addLayer("output", new RnnOutputLayer.Builder().nIn(HIDDEN_LAYER_WIDTH).nOut(dict.size()).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "decoder")
.setOutputs("output");
net = new ComputationGraph(graphBuilder.build());
net.init();
}
private void train(File networkFile, int offset) throws Exception {
saveState = SaveState.READY;
long lastSaveTime = System.currentTimeMillis();
long lastTestTime = System.currentTimeMillis();
CorpusIterator logsIterator = new CorpusIterator(corpus, MINIBATCH_SIZE, MACROBATCH_SIZE, dict.size(), ROW_SIZE);
for (int epoch = 1; epoch < 10000; ++epoch) {
System.out.println("Epoch " + epoch);
if (epoch == 1) {
logsIterator.setCurrentBatch(offset);
} else {
logsIterator.reset();
}
int lastPerc = 0;
while (logsIterator.hasNextMacrobatch()) {
long t1 = System.currentTimeMillis();
net.fit(logsIterator);
long t2 = System.currentTimeMillis();
int batch = logsIterator.batch();
System.out.println("Batch = " + batch + " / " + logsIterator.totalBatches() + " time = " + (t2 - t1));
logsIterator.nextMacroBatch();
int newPerc = (batch * 100 / logsIterator.totalBatches());
if (newPerc != lastPerc) {
System.out.println("Epoch complete: " + newPerc + "%");
lastPerc = newPerc;
}
if (saveState == SaveState.SAVENOW) {
saveModel(networkFile, batch);
return;
}
if (System.currentTimeMillis() - lastSaveTime > SAVE_EACH_MS) {
saveModel(networkFile, batch);
lastSaveTime = System.currentTimeMillis();
}
if (System.currentTimeMillis() - lastTestTime > TEST_EACH_MS) {
test();
lastTestTime = System.currentTimeMillis();
}
}
}
}
private void startDialog(Scanner scanner) throws IOException {
System.out.println("Dialog started.");
while (true) {
System.out.print("In> ");
// input line is appended to conform to the corpus format
String line = appendInputLine(scanner.nextLine());
CorpusProcessor dialogProcessor = new CorpusProcessor(new ByteArrayInputStream(line.getBytes(StandardCharsets.UTF_8)), ROW_SIZE,
false) {
@Override
protected void processLine(String lastLine) {
List<String> words = new ArrayList<>();
tokenizeLine(lastLine, words, true);
List<Double> wordIdxs = new ArrayList<>();
if (wordsToIndexes(words, wordIdxs)) {
System.out.print("Got words: ");
for (Double idx : wordIdxs) {
System.out.print(revDict.get(idx) + " ");
}
System.out.println();
System.out.print("Out> ");
output(wordIdxs, true);
}
}
};
setupCorpusProcessor(dialogProcessor);
dialogProcessor.setDict(dict);
dialogProcessor.start();
}
}
private String appendInputLine(String line) {
return "1 +++$+++ u11 +++$+++ m0 +++$+++ WALTER +++$+++ " + line + "\n";
// return "me¦" + line + "\n";
}
private void saveModel(File networkFile, int batch) throws IOException {
saveState = SaveState.SAVING;
System.out.println("Saving the model...");
System.gc();
File backup = new File(toTempPath(BACKUP_MODEL_FILENAME));
if (networkFile.exists()) {
if (backup.exists()) {
backup.delete();
}
networkFile.renameTo(backup);
}
ModelSerializer.writeModel(net, networkFile, true);
System.gc();
System.out.println("Done.");
saveState = SaveState.READY;
}
private void test() {
System.out.println("======================== TEST ========================");
int selected = rnd.nextInt(corpus.size());
List<Double> rowIn = new ArrayList<>(corpus.get(selected));
System.out.print("In: ");
for (Double idx : rowIn) {
System.out.print(revDict.get(idx) + " ");
}
System.out.println();
System.out.print("Out: ");
output(rowIn, true);
System.out.println("====================== TEST END ======================");
}
private void output(List<Double> rowIn, boolean printUnknowns) {
net.rnnClearPreviousState();
Collections.reverse(rowIn);
INDArray in = Nd4j.create(ArrayUtils.toPrimitive(rowIn.toArray(new Double[0])), new int[] { 1, 1, rowIn.size() });
double[] decodeArr = new double[dict.size()];
decodeArr[2] = 1;
INDArray decode = Nd4j.create(decodeArr, new int[] { 1, dict.size(), 1 });
net.feedForward(new INDArray[] { in, decode }, false);
org.deeplearning4j.nn.layers.recurrent.GravesLSTM decoder = (org.deeplearning4j.nn.layers.recurrent.GravesLSTM) net
.getLayer("decoder");
Layer output = net.getLayer("output");
GraphVertex mergeVertex = net.getVertex("merge");
INDArray thoughtVector = mergeVertex.getInputs()[1];
for (int row = 0; row < ROW_SIZE; ++row) {
mergeVertex.setInputs(decode, thoughtVector);
INDArray merged = mergeVertex.doForward(false);
INDArray activateDec = decoder.rnnTimeStep(merged);
INDArray out = output.activate(activateDec, false);
double d = rnd.nextDouble();
double sum = 0.0;
int idx = -1;
for (int s = 0; s < out.size(1); s++) {
sum += out.getDouble(0, s, 0);
if (d <= sum) {
idx = s;
if (printUnknowns || s != 0) {
System.out.print(revDict.get((double) s) + " ");
}
break;
}
}
if (idx == 1) {
break;
}
double[] newDecodeArr = new double[dict.size()];
newDecodeArr[idx] = 1;
decode = Nd4j.create(newDecodeArr, new int[] { 1, dict.size(), 1 });
}
System.out.println();
}
private void createDictionary() throws IOException, FileNotFoundException {
double idx = 3.0;
dict.put("<unk>", 0.0);
revDict.put(0.0, "<unk>");
dict.put("<eos>", 1.0);
revDict.put(1.0, "<eos>");
dict.put("<go>", 2.0);
revDict.put(2.0, "<go>");
for (char c : CHARS.toCharArray()) {
if (!dict.containsKey(c)) {
dict.put(String.valueOf(c), idx);
revDict.put(idx, String.valueOf(c));
++idx;
}
}
System.out.println("Building the dictionary...");
CorpusProcessor corpusProcessor = new CorpusProcessor(toTempPath(CORPUS_FILENAME), ROW_SIZE, true);
setupCorpusProcessor(corpusProcessor);
corpusProcessor.start();
Map<String, Double> freqs = corpusProcessor.getFreq();
Set<String> dictSet = new TreeSet<>(); // the tokens order is preserved for TreeSet
Map<Double, Set<String>> freqMap = new TreeMap<>(new Comparator<Double>() {
@Override
public int compare(Double o1, Double o2) {
return (int) (o2 - o1);
}
}); // tokens of the same frequency fall under the same key, the order is reversed so the most frequent tokens go first
for (Entry<String, Double> entry : freqs.entrySet()) {
Set<String> set = freqMap.get(entry.getValue());
if (set == null) {
set = new TreeSet<>(); // tokens of the same frequency would be sorted alphabetically
freqMap.put(entry.getValue(), set);
}
set.add(entry.getKey());
}
int cnt = 0;
dictSet.addAll(dict.keySet());
// get most frequent tokens and put them to dictSet
for (Entry<Double, Set<String>> entry : freqMap.entrySet()) {
for (String val : entry.getValue()) {
if (dictSet.add(val) && ++cnt >= MAX_DICT) {
break;
}
}
if (cnt >= MAX_DICT) {
break;
}
}
// all of the above means that the dictionary with the same MAX_DICT constraint and made from the same source file will always be
// the same, the tokens always correspond to the same number so we don't need to save/restore the dictionary
System.out.println("Dictionary is ready, size is " + dictSet.size());
// index the dictionary and build the reverse dictionary for lookups
try (BufferedWriter bw = new BufferedWriter(new FileWriter(DICTIONARY_FILENAME))) {
for (String word : dictSet) {
bw.write(word + "\n");
if (!dict.containsKey(word)) {
dict.put(word, idx);
revDict.put(idx, word);
++idx;
}
}
}
System.out.println("Total dictionary size is " + dict.size() + ". Processing the dataset...");
corpusProcessor = new CorpusProcessor(toTempPath(CORPUS_FILENAME), ROW_SIZE, false) {
@Override
protected void processLine(String lastLine) {
ArrayList<String> words = new ArrayList<>();
tokenizeLine(lastLine, words, true);
if (!words.isEmpty()) {
List<Double> wordIdxs = new ArrayList<>();
if (wordsToIndexes(words, wordIdxs)) {
corpus.add(wordIdxs);
}
}
}
};
setupCorpusProcessor(corpusProcessor);
corpusProcessor.setDict(dict);
corpusProcessor.start();
System.out.println("Done. Corpus size is " + corpus.size());
}
private void setupCorpusProcessor(CorpusProcessor corpusProcessor) {
// corpusProcessor.setFormatParams("¦", 2, 0, 1);
}
private String toTempPath(String path) {
if (!TMP_DATA_DIR) {
return path;
}
return System.getProperty("java.io.tmpdir") + "/" + path;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment