Skip to content

Instantly share code, notes, and snippets.

@AlexDBlack
Last active September 3, 2018 09:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AlexDBlack/916dc4dd2a3df5d5b3241a0f60ff2285 to your computer and use it in GitHub Desktop.
Save AlexDBlack/916dc4dd2a3df5d5b3241a0f60ff2285 to your computer and use it in GitHub Desktop.
package org.deeplearning4j;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.util.HashMap;
import java.util.Map;
public class Debug6316 {
@Test
public void test(){
ComputationGraph cg = getComputationGraph();
INDArray in = Nd4j.create(1, 501*501, 3);
INDArray label = Nd4j.create(1, 501*501, 3);
INDArray inMask = Nd4j.ones(1, 3);
INDArray lMask = Nd4j.ones(1, 3);
cg.fit(new INDArray[]{in}, new INDArray[]{label}, new INDArray[]{inMask}, new INDArray[]{lMask});
}
public static ComputationGraph getComputationGraph(){
ComputationGraph multiLayerNetwork;
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.seed(140);
builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
builder.weightInit(WeightInit.XAVIER);
Map<Integer, Double> lrSchedule = new HashMap<>();
lrSchedule.put(0, 1e-2);
lrSchedule.put(400, 1e-3);
lrSchedule.put(3000, 1e-4);
MapSchedule mapSchedule = new MapSchedule(ScheduleType.ITERATION, lrSchedule);
builder.updater(new AdaGrad(mapSchedule));
int lstmHiddenCount = 200;
int cnnStride1 = 5;
int kernelSize1 = 3;
int cnnStride3 = 5;
int kernelSize3 = 3;
int channels = 1;
int padding = 1;
int samplingSize = 1;
int samplingStride = 1;
int cnn1Output = (501 - kernelSize1 + padding) / cnnStride1 + 1;
int cnn2Output = (cnn1Output - samplingSize + 0) / samplingStride + 1;
int cnn3Output = (cnn2Output - kernelSize3 + padding) / cnnStride3 + 1;
int lstmInWidth = cnn3Output;
int cnn4Output = (cnn3Output - samplingSize + 0) / samplingStride + 1;
int cnn5Output = (cnn4Output - kernelSize3 + padding) / cnnStride3 + 1;
lstmInWidth = cnn5Output + 1; // output of cnn
Map<String, InputPreProcessor> inputPreProcessors = new HashMap<String, InputPreProcessor>();
inputPreProcessors.put("cnn1", new RnnToCnnPreProcessor(501, 501, channels));
inputPreProcessors.put("lstm1", new CnnToRnnPreProcessor(lstmInWidth, lstmInWidth, 128));
ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder().pretrain(false).backprop(true)
.backpropType(BackpropType.Standard)
.addInputs("inputs")
// cnn
.addLayer("cnn1",
new ConvolutionLayer.Builder(new int[] { kernelSize1, kernelSize1 },
new int[] { cnnStride1, cnnStride1 },
new int[] { padding, padding })
.nIn(channels)
.nOut(501)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10)
.updater(new AdaGrad(mapSchedule))
.weightInit(WeightInit.RELU)
.activation(Activation.RELU).build(), "inputs")
// Output: (501 - kernelSize + padding) / cnn1Stride + 1 = 125 --> x * x * nOut = paramsNum
.addLayer("cnn2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX,
new int[] { samplingSize, samplingSize }, new int[] { samplingStride, samplingStride }).build(), "cnn1")
// Output: (125-1+0)/1+1 = 125
.addLayer("cnn3",
new ConvolutionLayer.Builder(new int[] { kernelSize3, kernelSize3 },
new int[] { cnnStride3, cnnStride3 },
new int[] { padding, padding })
.nIn(501)
.nOut(128)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10)
.updater(new AdaGrad(mapSchedule))
.weightInit(WeightInit.RELU)
.activation(Activation.RELU).build(), "cnn2");
// Output: (125 - kernelSize + padding) / cnn3Stride + 1 = 25 --> x * x * 100 = ?
graphBuilder = graphBuilder.addLayer("cnn4", new SubsamplingLayer.Builder(
SubsamplingLayer.PoolingType.MAX,
new int[] { samplingSize, samplingSize },
new int[] { samplingStride, samplingStride }).build(), "cnn3")
.addLayer("cnn5",
new ConvolutionLayer.Builder(new int[] { kernelSize3, kernelSize3 },
new int[] { cnnStride3, cnnStride3 },
new int[] { padding, padding })
.nIn(128)
.nOut(128)
.updater(new AdaGrad(mapSchedule))
.weightInit(WeightInit.RELU)
.activation(Activation.RELU).build(), "cnn4");
graphBuilder = graphBuilder.addLayer("lstm1", new LSTM.Builder()
.activation(Activation.SOFTSIGN)
.nIn(lstmInWidth * lstmInWidth * 128)
.nOut(lstmHiddenCount)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10)
.updater(new AdaGrad(mapSchedule))
.build(), "cnn5");
graphBuilder = graphBuilder.addVertex("thoughtVector", new LastTimeStepVertex("inputs"), "lstm1");
graphBuilder = graphBuilder.addVertex("dup", new DuplicateToTimeSeriesVertex("inputs"), "thoughtVector");
graphBuilder = graphBuilder.addLayer("lstmDecode1", new LSTM.Builder()
.activation(Activation.SOFTSIGN)
.nIn(lstmHiddenCount)
.nOut(lstmHiddenCount)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10)
.updater(new AdaGrad(mapSchedule))
.build(), "dup")
.addLayer("output", new RnnOutputLayer
.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.RELU)
.nIn(lstmHiddenCount)
.nOut(501 * 501)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10)
.updater(new AdaGrad(mapSchedule))
.build(), "lstmDecode1");
graphBuilder = graphBuilder.setOutputs("output");
graphBuilder.setInputPreProcessors(inputPreProcessors);
int inputSize = 30 * 2;
graphBuilder.setInputTypes(InputType.recurrent(501 * 501, inputSize));
multiLayerNetwork = new ComputationGraph(graphBuilder.build());
multiLayerNetwork.init();
return multiLayerNetwork;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment