Training RNN to generate text with word tokens as input and one-hot vector as output
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Model | |
*/ | |
private MultiLayerNetwork createModel(int vocabSize, int embeddingSize, int rnnSize) { | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.activation(Activation.RELU) | |
.learningRate(0.01) | |
.list() | |
.layer(0, new EmbeddingLayer.Builder().nIn(vocabSize).nOut(300).build()) | |
.layer(1, new GravesLSTM.Builder().nIn(300).nOut(512).activation(Activation.SOFTSIGN).build()) | |
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(512).nOut(vocabSize).activation(Activation.SOFTMAX).build()) | |
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) | |
.inputPreProcessor(1, new FeedForwardToRnnPreProcessor()) | |
.build(); | |
MultiLayerNetwork net = new MultiLayerNetwork(conf); | |
net.init(); | |
return net; | |
} | |
/* | |
* train | |
*/ | |
// Use wordToInt.size() instead of vocab.size() so we include the special start & end sequence tokens | |
net = createModel(wordToInt.size(), EMBEDDING_SIZE, RNN_SIZE); | |
// create the iterator which generates the batches of input data and labels | |
MeetupDescriptionIterator data = new MeetupDescriptionIterator(GO_ID, EOS_ID, SEQUENCE_LENGTH, BATCH_SIZE, wordToInt.size(), intWords); | |
StatsStorage statsStorage = new InMemoryStatsStorage(); | |
net.setListeners(new StatsListener(statsStorage), new ScoreIterationListener(STATS_ITERATIONS)); | |
// train here | |
System.out.println("Start training (" + EPOCHS + " epochs)\n"); | |
Date startTraining = new Date(); | |
for (int i = 0; i < EPOCHS; i++) { | |
System.out.println("Epoch=====================" + i); | |
net.fit(data); | |
} | |
/* | |
* iterator | |
*/ | |
/** | |
* Convert a list of token sequences into feature and label data. | |
* The dataset is quite small so we assume we can contain it entirely in memory and so construct all data up-front. | |
* For each sequence we generate an input and an output. The input is prefixed with |go| (0) and the output has |eos| (1) | |
* This assumes that tokenization has already taken place and so the input is a list of integer token Ids rather than words | |
* appended. This produces input and output sequences of the same length but with the output shifted by one token. | |
* Masking is used to pad sequences less than the specified sequence length. | |
* | |
* Sequence: 2 3 4 5 6 | |
* Input: 0 2 3 4 5 6 | |
* Output 2 3 4 5 6 1 | |
*/ | |
public static class MeetupDescriptionIterator implements DataSetIterator { | |
private final int startSequenceId; | |
private final int endSequenceId; | |
private final int batchSize; | |
private final int dictSize; | |
private final int sequenceLength; | |
private final List<List<Double>> inputs; | |
private final List<List<Double>> outputs; | |
private int totalBatches; | |
private int currentBatch = 0; | |
private DataSetPreProcessor preProcessor; | |
public MeetupDescriptionIterator(int startSequenceId, int endSequenceId, int sequenceLength, int batchSize, int dictSize, List<List<Integer>> words) { | |
this.startSequenceId = startSequenceId; | |
this.endSequenceId = endSequenceId; | |
this.sequenceLength = sequenceLength; | |
this.batchSize = batchSize; | |
this.dictSize = dictSize; | |
// split up long examples to make multiple sequences | |
List<List<Double>>[] io = makeSequences(words); | |
this.inputs = io[0]; | |
this.outputs = io[1]; | |
this.currentBatch = 0; | |
this.totalBatches = (int) Math.ceil((double) inputs.size() / batchSize); | |
} | |
/** | |
* Break token sequences into sequences no more than sequenceLength long and convert to doubles. | |
* Since we prepend |go| to input sequences and append |eos| to output sequences we generate two lists | |
* one for input and one for output. This is wasteful but avoids us needing to track whether a sequence is | |
* the start or end of an example. | |
*/ | |
private List<List<Double>>[] makeSequences(List<List<Integer>> words) { | |
List<List<Double>> in = new ArrayList<>(); | |
List<List<Double>> out = new ArrayList<>(); | |
for (List<Integer> row : words) { | |
List<Integer> i = new ArrayList<>(row); | |
i.add(0, startSequenceId); | |
splitSequence(in, i); | |
List<Integer> o = new ArrayList<>(row); | |
o.add(endSequenceId); | |
splitSequence(out, o); | |
} | |
return new List[]{in, out}; | |
} | |
private void splitSequence(List<List<Double>> accumulator, List<Integer> input) { | |
int pos = 0; | |
while ((input.size() - pos) > sequenceLength) { | |
accumulator.add(doublerize(input.subList(pos, pos + sequenceLength))); | |
pos += sequenceLength; | |
} | |
if ((input.size() - pos) > 0) { | |
accumulator.add(doublerize(input.subList(pos, input.size()))); | |
} | |
} | |
private List<Double> doublerize(List<Integer> ints) { | |
return ints.stream().map(i -> new Double((double) i)).collect(Collectors.toList()); | |
} | |
/** | |
* Implementation adapted from org.deeplearning4j.examples.recurrent.encdec.CorpusIterator | |
*/ | |
@Override | |
public DataSet next(int num) { | |
int i = currentBatch * batchSize; | |
int currentBatchSize = Math.min(batchSize, inputs.size() - i - 1); | |
INDArray input = Nd4j.zeros(currentBatchSize, 1, sequenceLength); | |
INDArray prediction = Nd4j.zeros(currentBatchSize, dictSize, sequenceLength); | |
INDArray inputMask = Nd4j.zeros(currentBatchSize, sequenceLength); | |
// this mask is also used for the decoder input, the length is the same | |
INDArray predictionMask = Nd4j.zeros(currentBatchSize, sequenceLength); | |
for (int j = 0; j < currentBatchSize; j++) { | |
List<Double> rowIn = new ArrayList<>(inputs.get(i)); | |
// shift and |eos| token already added | |
List<Double> rowPred = new ArrayList<>(outputs.get(i)); | |
// replace the entire row in "input" using NDArrayIndex, it's faster than putScalar(); input is NOT made of one-hot vectors | |
// because of the embedding layer that accepts token indexes directly | |
input.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.point(0), NDArrayIndex.interval(0, rowIn.size())}, | |
Nd4j.create(ArrayUtils.toPrimitive(rowIn.toArray(new Double[0])))); | |
inputMask.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.interval(0, rowIn.size())}, Nd4j.ones(rowIn.size())); | |
predictionMask.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.interval(0, rowPred.size())}, | |
Nd4j.ones(rowPred.size())); | |
// prediction (output) IS one-hot though | |
double predOneHot[][] = new double[dictSize][rowPred.size()]; | |
int predIdx = 0; | |
for (Double pred : rowPred) { | |
predOneHot[pred.intValue()][predIdx] = 1; | |
++predIdx; | |
} | |
prediction.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.interval(0, dictSize), | |
NDArrayIndex.interval(0, rowPred.size())}, Nd4j.create(predOneHot)); | |
++i; | |
} | |
++currentBatch; | |
return new DataSet(input, prediction, inputMask, predictionMask); | |
} | |
@Override | |
public int totalExamples() { | |
return totalBatches; | |
} | |
@Override | |
public int inputColumns() { | |
return 1; | |
} | |
@Override | |
public int totalOutcomes() { | |
return dictSize; | |
} | |
@Override | |
public boolean resetSupported() { | |
return true; | |
} | |
@Override | |
public boolean asyncSupported() { | |
return false; | |
} | |
@Override | |
public void reset() { | |
currentBatch = 0; | |
} | |
@Override | |
public int batch() { | |
return currentBatch; | |
} | |
@Override | |
public int cursor() { | |
return 0; | |
} | |
@Override | |
public int numExamples() { | |
return totalBatches; | |
} | |
@Override | |
public boolean hasNext() { | |
return currentBatch < totalBatches; | |
} | |
@Override | |
public DataSet next() { | |
return next(batchSize); | |
} | |
@Override | |
public List<String> getLabels() { | |
return Collections.emptyList(); | |
} | |
@Override | |
public void setPreProcessor(DataSetPreProcessor preProcessor) { | |
this.preProcessor = preProcessor; | |
} | |
@Override | |
public DataSetPreProcessor getPreProcessor() { | |
return preProcessor; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment