Skip to content

Instantly share code, notes, and snippets.

@davesnowdon
Last active October 22, 2017 07:12
Show Gist options
  • Save davesnowdon/1654f4dbccb2815549cb77d74c8e9d0f to your computer and use it in GitHub Desktop.
Save davesnowdon/1654f4dbccb2815549cb77d74c8e9d0f to your computer and use it in GitHub Desktop.
Training RNN to generate text with word tokens as input and one-hot vector as output
/*
* Model
*/
private MultiLayerNetwork createModel(int vocabSize, int embeddingSize, int rnnSize) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.RELU)
.learningRate(0.01)
.list()
.layer(0, new EmbeddingLayer.Builder().nIn(vocabSize).nOut(300).build())
.layer(1, new GravesLSTM.Builder().nIn(300).nOut(512).activation(Activation.SOFTSIGN).build())
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(512).nOut(vocabSize).activation(Activation.SOFTMAX).build())
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(1, new FeedForwardToRnnPreProcessor())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
return net;
}
/*
* train
*/
// Use wordToInt.size() instead of vocab.size() so we include the special start & end sequence tokens
net = createModel(wordToInt.size(), EMBEDDING_SIZE, RNN_SIZE);
// create the iterator which generates the batches of input data and labels
MeetupDescriptionIterator data = new MeetupDescriptionIterator(GO_ID, EOS_ID, SEQUENCE_LENGTH, BATCH_SIZE, wordToInt.size(), intWords);
StatsStorage statsStorage = new InMemoryStatsStorage();
net.setListeners(new StatsListener(statsStorage), new ScoreIterationListener(STATS_ITERATIONS));
// train here
System.out.println("Start training (" + EPOCHS + " epochs)\n");
Date startTraining = new Date();
for (int i = 0; i < EPOCHS; i++) {
System.out.println("Epoch=====================" + i);
net.fit(data);
}
/*
* iterator
*/
/**
* Convert a list of token sequences into feature and label data.
* The dataset is quite small so we assume we can contain it entirely in memory and so construct all data up-front.
* For each sequence we generate an input and an output. The input is prefixed with |go| (0) and the output has |eos| (1)
* This assumes that tokenization has already taken place and so the input is a list of integer token Ids rather than words
* appended. This produces input and output sequences of the same length but with the output shifted by one token.
* Masking is used to pad sequences less than the specified sequence length.
*
* Sequence: 2 3 4 5 6
* Input: 0 2 3 4 5 6
* Output 2 3 4 5 6 1
*/
public static class MeetupDescriptionIterator implements DataSetIterator {
private final int startSequenceId;
private final int endSequenceId;
private final int batchSize;
private final int dictSize;
private final int sequenceLength;
private final List<List<Double>> inputs;
private final List<List<Double>> outputs;
private int totalBatches;
private int currentBatch = 0;
private DataSetPreProcessor preProcessor;
public MeetupDescriptionIterator(int startSequenceId, int endSequenceId, int sequenceLength, int batchSize, int dictSize, List<List<Integer>> words) {
this.startSequenceId = startSequenceId;
this.endSequenceId = endSequenceId;
this.sequenceLength = sequenceLength;
this.batchSize = batchSize;
this.dictSize = dictSize;
// split up long examples to make multiple sequences
List<List<Double>>[] io = makeSequences(words);
this.inputs = io[0];
this.outputs = io[1];
this.currentBatch = 0;
this.totalBatches = (int) Math.ceil((double) inputs.size() / batchSize);
}
/**
* Break token sequences into sequences no more than sequenceLength long and convert to doubles.
* Since we prepend |go| to input sequences and append |eos| to output sequences we generate two lists
* one for input and one for output. This is wasteful but avoids us needing to track whether a sequence is
* the start or end of an example.
*/
private List<List<Double>>[] makeSequences(List<List<Integer>> words) {
List<List<Double>> in = new ArrayList<>();
List<List<Double>> out = new ArrayList<>();
for (List<Integer> row : words) {
List<Integer> i = new ArrayList<>(row);
i.add(0, startSequenceId);
splitSequence(in, i);
List<Integer> o = new ArrayList<>(row);
o.add(endSequenceId);
splitSequence(out, o);
}
return new List[]{in, out};
}
private void splitSequence(List<List<Double>> accumulator, List<Integer> input) {
int pos = 0;
while ((input.size() - pos) > sequenceLength) {
accumulator.add(doublerize(input.subList(pos, pos + sequenceLength)));
pos += sequenceLength;
}
if ((input.size() - pos) > 0) {
accumulator.add(doublerize(input.subList(pos, input.size())));
}
}
private List<Double> doublerize(List<Integer> ints) {
return ints.stream().map(i -> new Double((double) i)).collect(Collectors.toList());
}
/**
* Implementation adapted from org.deeplearning4j.examples.recurrent.encdec.CorpusIterator
*/
@Override
public DataSet next(int num) {
int i = currentBatch * batchSize;
int currentBatchSize = Math.min(batchSize, inputs.size() - i - 1);
INDArray input = Nd4j.zeros(currentBatchSize, 1, sequenceLength);
INDArray prediction = Nd4j.zeros(currentBatchSize, dictSize, sequenceLength);
INDArray inputMask = Nd4j.zeros(currentBatchSize, sequenceLength);
// this mask is also used for the decoder input, the length is the same
INDArray predictionMask = Nd4j.zeros(currentBatchSize, sequenceLength);
for (int j = 0; j < currentBatchSize; j++) {
List<Double> rowIn = new ArrayList<>(inputs.get(i));
// shift and |eos| token already added
List<Double> rowPred = new ArrayList<>(outputs.get(i));
// replace the entire row in "input" using NDArrayIndex, it's faster than putScalar(); input is NOT made of one-hot vectors
// because of the embedding layer that accepts token indexes directly
input.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.point(0), NDArrayIndex.interval(0, rowIn.size())},
Nd4j.create(ArrayUtils.toPrimitive(rowIn.toArray(new Double[0]))));
inputMask.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.interval(0, rowIn.size())}, Nd4j.ones(rowIn.size()));
predictionMask.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.interval(0, rowPred.size())},
Nd4j.ones(rowPred.size()));
// prediction (output) IS one-hot though
double predOneHot[][] = new double[dictSize][rowPred.size()];
int predIdx = 0;
for (Double pred : rowPred) {
predOneHot[pred.intValue()][predIdx] = 1;
++predIdx;
}
prediction.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.interval(0, dictSize),
NDArrayIndex.interval(0, rowPred.size())}, Nd4j.create(predOneHot));
++i;
}
++currentBatch;
return new DataSet(input, prediction, inputMask, predictionMask);
}
@Override
public int totalExamples() {
return totalBatches;
}
@Override
public int inputColumns() {
return 1;
}
@Override
public int totalOutcomes() {
return dictSize;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return false;
}
@Override
public void reset() {
currentBatch = 0;
}
@Override
public int batch() {
return currentBatch;
}
@Override
public int cursor() {
return 0;
}
@Override
public int numExamples() {
return totalBatches;
}
@Override
public boolean hasNext() {
return currentBatch < totalBatches;
}
@Override
public DataSet next() {
return next(batchSize);
}
@Override
public List<String> getLabels() {
return Collections.emptyList();
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
this.preProcessor = preProcessor;
}
@Override
public DataSetPreProcessor getPreProcessor() {
return preProcessor;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment