Skip to content

Instantly share code, notes, and snippets.

@davesnowdon
Created October 21, 2017 16:23
Show Gist options
  • Save davesnowdon/5f2a4f8fd6decdfd9dbf9eea13251311 to your computer and use it in GitHub Desktop.
Save davesnowdon/5f2a4f8fd6decdfd9dbf9eea13251311 to your computer and use it in GitHub Desktop.
Complete meetup event description generator code
package com.davesnowdon.meetup;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Trained on a list of event descriptions, this network generates new event descriptions
* Input sequence: |go| This is our monthly social gathering which incorporates members of The London Java Community, ...
* Target sequence: This is our monthly social gathering which incorporates members of The London Java Community, ... |eos|
*/
public class MeetupEventDescriptionGenerator {
private static final String MEETUP_DESCRIPTIONS_FILE = "ljc-meetup-descriptions.txt";
private static final String MEETUP_DESCRIPTIONS_MODEL = "ljc-meetup-descriptions";
private static final Logger LOG = LoggerFactory.getLogger(MeetupEventDescriptionGenerator.class);
private static final String LINK_REGEX = "(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]";
private static final String NUMBER_TOKEN = "|number|";
private static final String LINK_TOKEN = "|link|";
private static final String TIME_TOKEN = "|time|";
private static final String GO_TOKEN = "|go|";
private static final int GO_ID = 0;
private static final String END_OF_SEQUENCE_TOKEN = "|eos|";
private static final int EOS_ID = 1;
private static final String MODEL_FILE_SUFFIX = "_model.zip";
private static final String VOCAB_FILE_SUFFIX = "_vocab.txt";
private static final int EMBEDDING_SIZE = 300;
private static final int RNN_SIZE = 512;
private static final int SEQUENCE_LENGTH = 200;
// adjust depending on amount of GPU memory
private static final int BATCH_SIZE = 32;
private static final int EPOCHS = 200;
private static final int STATS_ITERATIONS = 50;
private static final int MAX_GENERATED_LENGTH = 2048;
private Map<String, String> punctuationToToken;
private Map<String, String> tokenToPunctuation;
private Map<String, Integer> wordToInt;
private Map<Integer, String> intToWord;
private MultiLayerNetwork net;
public static void main(String[] args) throws Exception {
String filePath = new ClassPathResource(MEETUP_DESCRIPTIONS_FILE).getFile().getAbsolutePath();
MeetupEventDescriptionGenerator generator = new MeetupEventDescriptionGenerator();
try {
if (1 == args.length) {
switch (args[0]) {
case "train":
generator.train(filePath);
break;
case "generate":
String output = generator.generate();
System.out.println("Your meetup description:\n\n" + output);
break;
default:
usage();
break;
}
} else {
usage();
}
System.exit(0);
} catch (Exception e) {
System.err.println("Caught error: " + e.toString());
e.printStackTrace();
System.exit(1);
}
}
private static void usage() {
System.err.println("Usage: train|generate");
}
public MeetupEventDescriptionGenerator() {
buildPunctuationMappings();
}
public void train(String filepath) throws Exception {
List<String> examples = readText(filepath);
System.out.println("Number of examples: " + examples.size());
List<List<String>> words = splitWords(examples);
int minSequenceWords = words.stream().mapToInt(l -> l.size()).min().getAsInt();
System.out.println("Shortest sequence: " + minSequenceWords + " words");
int maxSequenceWords = words.stream().mapToInt(l -> l.size()).max().getAsInt();
System.out.println("Longest sequence: " + maxSequenceWords + " words");
double meanSequenceWords = words.stream().mapToInt(l -> l.size()).average().getAsDouble();
System.out.println("Average sequence: " + meanSequenceWords + " words");
Set<String> vocab = getVocabulary(words);
System.out.println("Unique words: " + vocab.size());
buildVocabularyMapping(vocab);
List<List<Integer>> intWords = mapWords(wordToInt, words);
// Use wordToInt.size() instead of vocab.size() so we include the special start & end sequence tokens
net = createModel(wordToInt.size(), EMBEDDING_SIZE, RNN_SIZE);
// create the iterator which generates the batches of input data and labels
MeetupDescriptionIterator data = new MeetupDescriptionIterator(GO_ID, EOS_ID, SEQUENCE_LENGTH, BATCH_SIZE, wordToInt.size(), intWords);
StatsStorage statsStorage = new InMemoryStatsStorage();
net.setListeners(new StatsListener(statsStorage), new ScoreIterationListener(STATS_ITERATIONS));
// train here
System.out.println("Start training (" + EPOCHS + " epochs)\n");
Date startTraining = new Date();
for (int i = 0; i < EPOCHS; i++) {
System.out.println("Epoch=====================" + i);
net.fit(data);
}
Date endTraining = new Date();
long trainingSeconds = (endTraining.getTime() - startTraining.getTime()) / 1000;
System.out.println("End training. " + EPOCHS + " in " + trainingSeconds + " seconds");
saveModel(MEETUP_DESCRIPTIONS_MODEL);
}
// transform the string to integer representation for training
public List<List<Integer>> mapWords(Map<String, Integer> wordMap, List<List<String>> words) {
return words.stream().map(l -> l.stream().map(w -> wordMap.get(w)).collect(Collectors.toList())).collect(Collectors.toList());
}
public String generate() throws Exception {
loadModel(MEETUP_DESCRIPTIONS_MODEL);
net.rnnClearPreviousState();
// create array we'll use to sample from when generating tokens
double[] t = new double[intToWord.size()];
for (int i=0; i<intToWord.size(); ++i) {
t[i] = (double) i;
}
INDArray tokenIdSource = Nd4j.create(t);
List<String> words = new ArrayList<>();
// feed in one word per time step
INDArray input = Nd4j.zeros(1, 1, 1);
// start with |go|
input.putScalar(0, 0, 0, (double) GO_ID);
// generate next word until we get |eos| or generate max length text
boolean gotEos = false;
do {
// input is token ID, output is one-hot encoded vector
INDArray output = net.rnnTimeStep(input);
// sample from the list of tokenIds using the probability distribution generated by the RNN
INDArray selected = Nd4j.choice(tokenIdSource, output, 1);
int tokenId = selected.getInt(0);
// convert token id to string and add to output
words.add(intToWord.get(tokenId));
// put chosen token back into input for next iteration
input.putScalar(0, 0, 0, (double) tokenId);
// check whether we have an end of sequence marker
gotEos = (tokenId == EOS_ID);
} while (!gotEos && (words.size() < MAX_GENERATED_LENGTH));
// convert back to string
String generatedText = words.stream().collect(Collectors.joining(" "));
// convert punctuation back into symbols
generatedText = applyMapping(generatedText, tokenToPunctuation, false);
return generatedText;
}
private MultiLayerNetwork createModel(int vocabSize, int embeddingSize, int rnnSize) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.RELU)
.learningRate(0.01)
.list()
.layer(0, new EmbeddingLayer.Builder().nIn(vocabSize).nOut(embeddingSize).build())
.layer(1, new GravesLSTM.Builder().nIn(embeddingSize).nOut(rnnSize).activation(Activation.SOFTSIGN).build())
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(rnnSize).nOut(vocabSize).activation(Activation.SOFTMAX).build())
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(1, new FeedForwardToRnnPreProcessor())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
return net;
}
/**
* Save model and mapping from integers to words
*/
private void saveModel(String modelPrefix) throws IOException {
final File locationToSave = new File(modelPrefix + MODEL_FILE_SUFFIX);
System.out.println("Saving model to " + locationToSave);
ModelSerializer.writeModel(net, locationToSave, true);
saveVocabularyMapping(modelPrefix + VOCAB_FILE_SUFFIX);
System.out.println("Model saved");
}
/**
* Load trained model from file and set of mappings from integers to words
*/
private void loadModel(String modelPrefix) throws IOException {
final File locationToSave = new File(modelPrefix + MODEL_FILE_SUFFIX);
System.out.println("Loading model from " + locationToSave);
net = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
loadVocabularyMapping(modelPrefix + VOCAB_FILE_SUFFIX);
System.out.println("Model loaded");
}
private List<String> readText(String filename) throws IOException {
File file = new File(filename);
return FileUtils.readLines(file, "UTF-8");
}
public List<List<String>> splitWords(List<String> text) {
return text.stream()
// lowercase
.map(String::toLowerCase)
// remove all instances of pipe character
.map(t -> t.replace("|", ""))
// replace URLs with placeholder
.map(t -> replaceLinks(t))
// consider & as just "and"
.map(t -> t.replace("&", " and "))
// remove everything except basic punctuation ( ,;:.!?()) and text
.map(t -> t.replaceAll("[^A-Za-z0-9\\-,;:.!?()%'|\"\\s]]*", ""))
// replace time of day with special token |time|
.map(t -> t.replaceAll("\\b[0-9][0-9]:[0-9][0-9]\\b", TIME_TOKEN))
.map(t -> t.replaceAll("\\b[0-9][0-9]?(am|pm)\\b", TIME_TOKEN))
// replace numbers by the special token |number|
.map(t -> t.replaceAll("\\b[0-9]+(st|nd|rd|th)?\\b", NUMBER_TOKEN))
// replace punctuation with tokens
.map(t -> applyMapping(t, punctuationToToken, true))
// separate words into separate strings
.map(t -> Arrays.asList(t.split("\\s")))
// remove empty strings
.map(wl -> wl.stream().filter(StringUtils::isNotBlank).collect(Collectors.toList()))
.collect(Collectors.toList());
}
public Set<String> getVocabulary(List<List<String>> words) {
Set<String> uniqueWords = new HashSet<>();
words.forEach(l -> uniqueWords.addAll(l));
return uniqueWords;
}
// TODO use word frequency and discard words with less that specified minimum occurrences
private void buildVocabularyMapping(Set<String> vocab) {
wordToInt = new HashMap<>();
intToWord = new LinkedHashMap<>();
addVocabularyMapping(GO_ID, GO_TOKEN);
addVocabularyMapping(EOS_ID, END_OF_SEQUENCE_TOKEN);
int tokenId = EOS_ID + 1;
for (String word : vocab) {
addVocabularyMapping(tokenId++, word);
}
}
private void saveVocabularyMapping(String filename) throws IOException {
File file = new File(filename);
try (PrintWriter out = new PrintWriter(new FileWriter(file))) {
intToWord.forEach((k, v) -> out.println(k + " " + v));
}
}
private void loadVocabularyMapping(String filename) throws IOException {
wordToInt = new HashMap<>();
intToWord = new LinkedHashMap<>();
List<String> mappings = readText(filename);
for (String mapping : mappings) {
String[] parts = mapping.split(" ");
addVocabularyMapping(Integer.valueOf(parts[0]), parts[1]);
}
}
private void addVocabularyMapping(int tokenId, String word) {
final Integer id = new Integer(tokenId);
wordToInt.put(word, id);
intToWord.put(id, word);
}
public String replaceLinks(String text) {
return text.replaceAll(LINK_REGEX, LINK_TOKEN + " ");
}
private String applyMapping(String text, Map<String, String> mapping, boolean withPadding) {
for (Map.Entry<String, String> e : mapping.entrySet()) {
final String v = withPadding ? " " + e.getValue() + " " : e.getValue();
text = text.replace(e.getKey(), v);
}
return text;
}
private void buildPunctuationMappings() {
punctuationToToken = new HashMap<>();
punctuationToToken.put(".", "|dot|");
punctuationToToken.put(",", "|comma|");
punctuationToToken.put(";", "|semicolon|");
punctuationToToken.put(":", "|colon|");
punctuationToToken.put("!", "|bang|");
punctuationToToken.put("?", "|question|");
punctuationToToken.put("(", "|leftparens|");
punctuationToToken.put(")", "|rightparens|");
punctuationToToken.put("\"", "|doublequote|");
punctuationToToken.put("'", "|quote|");
punctuationToToken.put("-", "|dash|");
punctuationToToken.put("%", "|percent|");
tokenToPunctuation = punctuationToToken.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
}
/**
* Convert a list of token sequences into feature and label data.
* The dataset is quite small so we assume we can contain it entirely in memory and so construct all data up-front.
* For each sequence we generate an input and an output. The input is prefixed with |go| (0) and the output has |eos| (1)
* This assumes that tokenization has already taken place and so the input is a list of integer token Ids rather than words
* appended. This produces input and output sequences of the same length but with the output shifted by one token.
* Masking is used to pad sequences less than the specified sequence length.
*
* Sequence: 2 3 4 5 6
* Input: 0 2 3 4 5 6
* Output 2 3 4 5 6 1
*/
public static class MeetupDescriptionIterator implements DataSetIterator {
private final int startSequenceId;
private final int endSequenceId;
private final int batchSize;
private final int dictSize;
private final int sequenceLength;
private final List<List<Double>> inputs;
private final List<List<Double>> outputs;
private int totalBatches;
private int currentBatch = 0;
private DataSetPreProcessor preProcessor;
public MeetupDescriptionIterator(int startSequenceId, int endSequenceId, int sequenceLength, int batchSize, int dictSize, List<List<Integer>> words) {
this.startSequenceId = startSequenceId;
this.endSequenceId = endSequenceId;
this.sequenceLength = sequenceLength;
this.batchSize = batchSize;
this.dictSize = dictSize;
// split up long examples to make multiple sequences
List<List<Double>>[] io = makeSequences(words);
this.inputs = io[0];
this.outputs = io[1];
this.currentBatch = 0;
this.totalBatches = (int) Math.ceil((double) inputs.size() / batchSize);
}
/**
* Break token sequences into sequences no more than sequenceLength long and convert to doubles.
* Since we prepend |go| to input sequences and append |eos| to output sequences we generate two lists
* one for input and one for output. This is wasteful but avoids us needing to track whether a sequence is
* the start or end of an example.
*/
private List<List<Double>>[] makeSequences(List<List<Integer>> words) {
List<List<Double>> in = new ArrayList<>();
List<List<Double>> out = new ArrayList<>();
for (List<Integer> row : words) {
List<Integer> i = new ArrayList<>(row);
i.add(0, startSequenceId);
splitSequence(in, i);
List<Integer> o = new ArrayList<>(row);
o.add(endSequenceId);
splitSequence(out, o);
}
return new List[]{in, out};
}
private void splitSequence(List<List<Double>> accumulator, List<Integer> input) {
int pos = 0;
while ((input.size() - pos) > sequenceLength) {
accumulator.add(doublerize(input.subList(pos, pos + sequenceLength)));
pos += sequenceLength;
}
if ((input.size() - pos) > 0) {
accumulator.add(doublerize(input.subList(pos, input.size())));
}
}
private List<Double> doublerize(List<Integer> ints) {
return ints.stream().map(i -> new Double((double) i)).collect(Collectors.toList());
}
/**
* Implementation adapted from org.deeplearning4j.examples.recurrent.encdec.CorpusIterator
*/
@Override
public DataSet next(int num) {
int i = currentBatch * batchSize;
int currentBatchSize = Math.min(batchSize, inputs.size() - i - 1);
INDArray input = Nd4j.zeros(currentBatchSize, 1, sequenceLength);
INDArray prediction = Nd4j.zeros(currentBatchSize, dictSize, sequenceLength);
INDArray inputMask = Nd4j.zeros(currentBatchSize, sequenceLength);
// this mask is also used for the decoder input, the length is the same
INDArray predictionMask = Nd4j.zeros(currentBatchSize, sequenceLength);
for (int j = 0; j < currentBatchSize; j++) {
List<Double> rowIn = new ArrayList<>(inputs.get(i));
// shift and |eos| token already added
List<Double> rowPred = new ArrayList<>(outputs.get(i));
// replace the entire row in "input" using NDArrayIndex, it's faster than putScalar(); input is NOT made of one-hot vectors
// because of the embedding layer that accepts token indexes directly
input.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.point(0), NDArrayIndex.interval(0, rowIn.size())},
Nd4j.create(ArrayUtils.toPrimitive(rowIn.toArray(new Double[0]))));
inputMask.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.interval(0, rowIn.size())}, Nd4j.ones(rowIn.size()));
predictionMask.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.interval(0, rowPred.size())},
Nd4j.ones(rowPred.size()));
// prediction (output) IS one-hot though
double predOneHot[][] = new double[dictSize][rowPred.size()];
int predIdx = 0;
for (Double pred : rowPred) {
predOneHot[pred.intValue()][predIdx] = 1;
++predIdx;
}
prediction.put(new INDArrayIndex[]{NDArrayIndex.point(j), NDArrayIndex.interval(0, dictSize),
NDArrayIndex.interval(0, rowPred.size())}, Nd4j.create(predOneHot));
++i;
}
++currentBatch;
return new DataSet(input, prediction, inputMask, predictionMask);
}
@Override
public int totalExamples() {
return totalBatches;
}
@Override
public int inputColumns() {
return 1;
}
@Override
public int totalOutcomes() {
return dictSize;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return false;
}
@Override
public void reset() {
currentBatch = 0;
}
@Override
public int batch() {
return currentBatch;
}
@Override
public int cursor() {
return 0;
}
@Override
public int numExamples() {
return totalBatches;
}
@Override
public boolean hasNext() {
return currentBatch < totalBatches;
}
@Override
public DataSet next() {
return next(batchSize);
}
@Override
public List<String> getLabels() {
return Collections.emptyList();
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
this.preProcessor = preProcessor;
}
@Override
public DataSetPreProcessor getPreProcessor() {
return preProcessor;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment