Skip to content

Instantly share code, notes, and snippets.

@agibsonccc
Created June 24, 2021 22:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save agibsonccc/a28480a111c7d1c3f62e45997dbb2f62 to your computer and use it in GitHub Desktop.
Save agibsonccc/a28480a111c7d1c3f62e45997dbb2f62 to your computer and use it in GitHub Desktop.
package org.example;
import java.io.FileNotFoundException;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.schedule.ExponentialSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
public class LSTMTest {
private ComputationGraph createGraphModel(int inputs,int outputs) {
int size = 400;
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(System.currentTimeMillis())
.updater(new Sgd(new ExponentialSchedule(ScheduleType.EPOCH, 0.01, 0.95)))
.weightInit(WeightInit.XAVIER)
.l2(1e-4)
.graphBuilder()
.addInputs("word")
.addLayer("bi_lstm", new Bidirectional(new LSTM.Builder()
.nOut(size)
.activation(Activation.SOFTSIGN)
.build()), "word")
.addLayer("lstm_hide1", new LSTM.Builder()
.nOut(size*2)
.activation(Activation.SOFTSIGN)
.build(), "bi_lstm")
.addLayer("rnn_out", new RnnOutputLayer.Builder()
.nOut(outputs)
.lossFunction(LossFunction.MCXENT)
.build() , "lstm_hide1")
.setOutputs("rnn_out")
.setInputTypes(InputType.recurrent(inputs,size,RNNFormat.NCW))
.build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
return graph;
}
public void trainModel() throws IOException {
ComputationGraph model = createGraphModel(200, 56);
TestTokenIterator iter = new TestTokenIterator();
model.addListeners(new ScoreIterationListener(1));
model.fit(iter);
}
public static void main(String[] args) {
LSTMTest trainer = new LSTMTest();
try {
trainer.trainModel();
} catch (Exception e) {
e.printStackTrace();
}
}
}
class TestTokenIterator implements MultiDataSetIterator {
private static final long serialVersionUID = -8725261332489394391L;
private int OUTPUTS;
private int INPUTS;
private int count;
private List<String> labels;
private MultiDataSetPreProcessor pre;
private int maxTrain;
private INDArray stop;
private Random rand;
public TestTokenIterator() throws FileNotFoundException {
this.count = 0;
this.maxTrain = 1500;
this.stop = Nd4j.ones(1,200);
labels = new LinkedList<>();
for (int i=0;i<56;i++) {
labels.add("label" + i);
}
this.rand = new Random(System.currentTimeMillis());
this.OUTPUTS = labels.size();
this.INPUTS = 200;
}
@Override
public boolean hasNext() {
return count < maxTrain;
}
@Override
public MultiDataSet next() {
return next(1);
}
@Override
public MultiDataSet next(int num) {
int mySHIFT,TIMESTEPS;
int mynum = 1;
INDArray[] featuresList = new INDArray[mynum];
INDArray[] labelsList = new INDArray[mynum];
INDArray[] featureMasks = new INDArray[mynum];
INDArray[] labelsMasks = new INDArray[mynum];
for (int i=0;i<mynum;i++) {
int sent_size = 2 + rand.nextInt(15);
System.out.println("seq len: " + sent_size);
mySHIFT = sent_size;
TIMESTEPS = sent_size;
INDArray featureSteps = Nd4j.create(INPUTS, TIMESTEPS+mySHIFT+1);
INDArray labelSteps = Nd4j.create(OUTPUTS,TIMESTEPS+mySHIFT+1);
INDArray featureMask = Nd4j.zeros(TIMESTEPS+mySHIFT+1); // expand direct included
INDArray labelMask = Nd4j.zeros(TIMESTEPS+mySHIFT+1); // expand direct include
// create example
for (int e=0;e<TIMESTEPS;e++) {
int wv = rand.nextInt(200);
INDArray wvec = Nd4j.zeros(1,200);
INDArray lvec = Nd4j.zeros(labels.size());
wvec.putScalar(new int[] {0, wv}, 1.0);
lvec.putScalar(new int[] { (wv % labels.size() ) }, 1.0);
featureSteps.putColumn(e,wvec);
labelSteps.putColumn(e + mySHIFT + 1,lvec);
featureMask.putScalar(new int[] { e }, 1.0);
labelMask.putScalar(new int[] { e + mySHIFT + 1 }, 1.0);
}
// add stop step
featureSteps.putColumn(TIMESTEPS, stop);
featureMask.putScalar(new int[] { TIMESTEPS }, 1.0);
featuresList[i] = Nd4j.expandDims(featureSteps, 0);
labelsList[i] = Nd4j.expandDims(labelSteps, 0);
featureMasks[i] = Nd4j.expandDims(featureMask, 0);
labelsMasks[i] = Nd4j.expandDims(labelMask, 0);
count++;
}
return new org.nd4j.linalg.dataset.MultiDataSet(featuresList, labelsList,featureMasks,labelsMasks);
}
@Override
public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
this.pre = preProcessor;
}
@Override
public MultiDataSetPreProcessor getPreProcessor() {
return pre;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return false;
}
@Override
public void reset() {
count = 0;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment