Created
June 24, 2021 22:55
-
-
Save agibsonccc/a28480a111c7d1c3f62e45997dbb2f62 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.example; | |
import java.io.FileNotFoundException; | |
import java.util.LinkedList; | |
import java.util.List; | |
import java.util.Random; | |
import org.deeplearning4j.nn.conf.RNNFormat; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.api.MultiDataSet; | |
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; | |
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; | |
import org.nd4j.linalg.factory.Nd4j; | |
import java.io.IOException; | |
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.LSTM; | |
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; | |
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.learning.config.Sgd; | |
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; | |
import org.nd4j.linalg.schedule.ExponentialSchedule; | |
import org.nd4j.linalg.schedule.ScheduleType; | |
public class LSTMTest { | |
private ComputationGraph createGraphModel(int inputs,int outputs) { | |
int size = 400; | |
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(System.currentTimeMillis()) | |
.updater(new Sgd(new ExponentialSchedule(ScheduleType.EPOCH, 0.01, 0.95))) | |
.weightInit(WeightInit.XAVIER) | |
.l2(1e-4) | |
.graphBuilder() | |
.addInputs("word") | |
.addLayer("bi_lstm", new Bidirectional(new LSTM.Builder() | |
.nOut(size) | |
.activation(Activation.SOFTSIGN) | |
.build()), "word") | |
.addLayer("lstm_hide1", new LSTM.Builder() | |
.nOut(size*2) | |
.activation(Activation.SOFTSIGN) | |
.build(), "bi_lstm") | |
.addLayer("rnn_out", new RnnOutputLayer.Builder() | |
.nOut(outputs) | |
.lossFunction(LossFunction.MCXENT) | |
.build() , "lstm_hide1") | |
.setOutputs("rnn_out") | |
.setInputTypes(InputType.recurrent(inputs,size,RNNFormat.NCW)) | |
.build(); | |
ComputationGraph graph = new ComputationGraph(conf); | |
graph.init(); | |
return graph; | |
} | |
public void trainModel() throws IOException { | |
ComputationGraph model = createGraphModel(200, 56); | |
TestTokenIterator iter = new TestTokenIterator(); | |
model.addListeners(new ScoreIterationListener(1)); | |
model.fit(iter); | |
} | |
public static void main(String[] args) { | |
LSTMTest trainer = new LSTMTest(); | |
try { | |
trainer.trainModel(); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
} | |
class TestTokenIterator implements MultiDataSetIterator { | |
private static final long serialVersionUID = -8725261332489394391L; | |
private int OUTPUTS; | |
private int INPUTS; | |
private int count; | |
private List<String> labels; | |
private MultiDataSetPreProcessor pre; | |
private int maxTrain; | |
private INDArray stop; | |
private Random rand; | |
public TestTokenIterator() throws FileNotFoundException { | |
this.count = 0; | |
this.maxTrain = 1500; | |
this.stop = Nd4j.ones(1,200); | |
labels = new LinkedList<>(); | |
for (int i=0;i<56;i++) { | |
labels.add("label" + i); | |
} | |
this.rand = new Random(System.currentTimeMillis()); | |
this.OUTPUTS = labels.size(); | |
this.INPUTS = 200; | |
} | |
@Override | |
public boolean hasNext() { | |
return count < maxTrain; | |
} | |
@Override | |
public MultiDataSet next() { | |
return next(1); | |
} | |
@Override | |
public MultiDataSet next(int num) { | |
int mySHIFT,TIMESTEPS; | |
int mynum = 1; | |
INDArray[] featuresList = new INDArray[mynum]; | |
INDArray[] labelsList = new INDArray[mynum]; | |
INDArray[] featureMasks = new INDArray[mynum]; | |
INDArray[] labelsMasks = new INDArray[mynum]; | |
for (int i=0;i<mynum;i++) { | |
int sent_size = 2 + rand.nextInt(15); | |
System.out.println("seq len: " + sent_size); | |
mySHIFT = sent_size; | |
TIMESTEPS = sent_size; | |
INDArray featureSteps = Nd4j.create(INPUTS, TIMESTEPS+mySHIFT+1); | |
INDArray labelSteps = Nd4j.create(OUTPUTS,TIMESTEPS+mySHIFT+1); | |
INDArray featureMask = Nd4j.zeros(TIMESTEPS+mySHIFT+1); // expand direct included | |
INDArray labelMask = Nd4j.zeros(TIMESTEPS+mySHIFT+1); // expand direct include | |
// create example | |
for (int e=0;e<TIMESTEPS;e++) { | |
int wv = rand.nextInt(200); | |
INDArray wvec = Nd4j.zeros(1,200); | |
INDArray lvec = Nd4j.zeros(labels.size()); | |
wvec.putScalar(new int[] {0, wv}, 1.0); | |
lvec.putScalar(new int[] { (wv % labels.size() ) }, 1.0); | |
featureSteps.putColumn(e,wvec); | |
labelSteps.putColumn(e + mySHIFT + 1,lvec); | |
featureMask.putScalar(new int[] { e }, 1.0); | |
labelMask.putScalar(new int[] { e + mySHIFT + 1 }, 1.0); | |
} | |
// add stop step | |
featureSteps.putColumn(TIMESTEPS, stop); | |
featureMask.putScalar(new int[] { TIMESTEPS }, 1.0); | |
featuresList[i] = Nd4j.expandDims(featureSteps, 0); | |
labelsList[i] = Nd4j.expandDims(labelSteps, 0); | |
featureMasks[i] = Nd4j.expandDims(featureMask, 0); | |
labelsMasks[i] = Nd4j.expandDims(labelMask, 0); | |
count++; | |
} | |
return new org.nd4j.linalg.dataset.MultiDataSet(featuresList, labelsList,featureMasks,labelsMasks); | |
} | |
@Override | |
public void setPreProcessor(MultiDataSetPreProcessor preProcessor) { | |
this.pre = preProcessor; | |
} | |
@Override | |
public MultiDataSetPreProcessor getPreProcessor() { | |
return pre; | |
} | |
@Override | |
public boolean resetSupported() { | |
return true; | |
} | |
@Override | |
public boolean asyncSupported() { | |
return false; | |
} | |
@Override | |
public void reset() { | |
count = 0; | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment