Skip to content

Instantly share code, notes, and snippets.

@Joshuaalbert
Last active March 25, 2017 09:16
Show Gist options
  • Save Joshuaalbert/9df97e1ffd9f36a5bbac0e85362a9dc7 to your computer and use it in GitHub Desktop.
Save Joshuaalbert/9df97e1ffd9f36a5bbac0e85362a9dc7 to your computer and use it in GitHub Desktop.
lstm stack cuda error
package com.tactico.tm.asyncRL;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.api.UIServer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class RecurrentClassifierComputationGraph {
public Configuration conf;
public ComputationGraph buildCG (int numInputs, int numClasses, double learningRate){
conf = new Configuration(numInputs);
conf.learningRate = learningRate;
ComputationGraphConfiguration.GraphBuilder confBuilder = new NeuralNetConfiguration.Builder()
.iterations(1)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(conf.learningRate)
.updater(Updater.ADAM).adamMeanDecay(1. - 1./10.).adamVarDecay(1. - 1./50.)
.weightInit(WeightInit.XAVIER_UNIFORM)
.regularization(true)
.l2(conf.l2).l1(conf.l1).dropOut(0.9)
.graphBuilder()
.setInputTypes(InputType.recurrent(numInputs))
.addInputs("input");
String[][] lstmStacks = new String[conf.numStacks][conf.depthStack];
for (int stackIdx = 0; stackIdx < conf.numStacks; stackIdx++){
for (int i = 0; i < conf.depthStack; i++) {
lstmStacks[stackIdx][i] = "lstm-"+stackIdx+"-"+i;
GravesLSTM lstmLayer = new GravesLSTM.Builder().gateActivationFunction(Activation.SIGMOID)
.nIn(i==0 ? numInputs : conf.numHiddenNodes[i-1])
.nOut(conf.numHiddenNodes[i]).forgetGateBiasInit(5.)
.activation(Activation.TANH)
.build();
if (i == 0){
confBuilder.addLayer(lstmStacks[stackIdx][i], lstmLayer, "input");
} else {
confBuilder.addLayer(lstmStacks[stackIdx][i], lstmLayer, lstmStacks[stackIdx][i-1]);
}
}
}
String[] stackEnds = new String[conf.numStacks];
for (int stackIdx = 0; stackIdx < conf.numStacks; stackIdx++){
stackEnds[stackIdx] = lstmStacks[stackIdx][conf.depthStack-1];
}
confBuilder.addLayer("class", new RnnOutputLayer.Builder(LossFunctions.LossFunction.KL_DIVERGENCE)
.nIn(conf.numHiddenNodes[conf.depthStack-1]*conf.numStacks)
.nOut(numClasses)
.activation(Activation.SOFTMAX)
.build(), stackEnds);
confBuilder.setOutputs("class");
ComputationGraphConfiguration cgconf = confBuilder.pretrain(false).backprop(true).build();
cgconf.addPreProcessors(InputType.recurrent(numInputs));
ComputationGraph model = new ComputationGraph(cgconf);
model.init();
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
int listenerFrequency = 1;
model.setListeners(new StatsListener(statsStorage, listenerFrequency),
new ScoreIterationListener(Constants.NEURAL_NET_ITERATION_LISTENER));
uiServer.attach(statsStorage);
System.out.println("Machine number of params:"+(model.numParams()));
System.out.println("Learning rate: "+learningRate);
return model;
}
class Configuration {
int numStacks = 1;
int depthStack;
int[] numHiddenNodes;
public double learningRate=1e-2;
double l2 = 0.001;//minimal weight gaussian prior
double l1 = 0.0001;//sparse weights laplace prior
public Configuration(int numInput){
this.numHiddenNodes = new int[] {20*numInput,15*numInput,10*numInput,5*numInput,5*numInput};
this.depthStack = this.numHiddenNodes.length;
}
}
public static void main (String[] args){
ComputationGraph cg = new RecurrentClassifierComputationGraph().buildCG(9,3,1e-2);
String modelFile = String.format("IntermediateModels/bbrp_rc_KL_%1.3e",1e-2);
MultiDataSet dataSet = new MultiDataSet();
for (int pass = 0; pass < 10; pass++){
System.out.println("Starting pass: "+(pass+1));
for (int i = 1; i <= 10; i++){
try {
dataSet.load(new File("trainingSets/bbrp_"+i));
} catch (IOException e1) {
e1.printStackTrace();
}
INDArray[] dataSetInputs = dataSet.getFeatures();
cg.setInputs(dataSetInputs);
double score0 = cg.score(dataSet,true);
double score = score0;
System.out.println("Initial KL score: "+score0);
int iter = 0;
while (iter++ < 50 && score > 0.5*score0){
cg.fit(dataSet);
score = cg.score();
}
System.out.println("Number of iterations used: "+iter);
System.out.println("Final KL score: "+score);
try {
System.out.println("Saving model to: "+modelFile);
ModelSerializer.writeModel(cg, modelFile, true);
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment