Last active
March 25, 2017 09:16
-
-
Save Joshuaalbert/9df97e1ffd9f36a5bbac0e85362a9dc7 to your computer and use it in GitHub Desktop.
lstm stack cuda error
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.tactico.tm.asyncRL; | |
import org.deeplearning4j.api.storage.StatsStorage; | |
import org.deeplearning4j.nn.api.Layer; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.graph.MergeVertex; | |
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; | |
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.ActivationLayer; | |
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; | |
import org.deeplearning4j.nn.conf.layers.GravesLSTM; | |
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.conf.layers.PoolingType; | |
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; | |
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.ui.stats.StatsListener; | |
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; | |
import org.deeplearning4j.ui.api.UIServer; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
public class RecurrentClassifierComputationGraph { | |
public Configuration conf; | |
public ComputationGraph buildCG (int numInputs, int numClasses, double learningRate){ | |
conf = new Configuration(numInputs); | |
conf.learningRate = learningRate; | |
ComputationGraphConfiguration.GraphBuilder confBuilder = new NeuralNetConfiguration.Builder() | |
.iterations(1) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.learningRate(conf.learningRate) | |
.updater(Updater.ADAM).adamMeanDecay(1. - 1./10.).adamVarDecay(1. - 1./50.) | |
.weightInit(WeightInit.XAVIER_UNIFORM) | |
.regularization(true) | |
.l2(conf.l2).l1(conf.l1).dropOut(0.9) | |
.graphBuilder() | |
.setInputTypes(InputType.recurrent(numInputs)) | |
.addInputs("input"); | |
String[][] lstmStacks = new String[conf.numStacks][conf.depthStack]; | |
for (int stackIdx = 0; stackIdx < conf.numStacks; stackIdx++){ | |
for (int i = 0; i < conf.depthStack; i++) { | |
lstmStacks[stackIdx][i] = "lstm-"+stackIdx+"-"+i; | |
GravesLSTM lstmLayer = new GravesLSTM.Builder().gateActivationFunction(Activation.SIGMOID) | |
.nIn(i==0 ? numInputs : conf.numHiddenNodes[i-1]) | |
.nOut(conf.numHiddenNodes[i]).forgetGateBiasInit(5.) | |
.activation(Activation.TANH) | |
.build(); | |
if (i == 0){ | |
confBuilder.addLayer(lstmStacks[stackIdx][i], lstmLayer, "input"); | |
} else { | |
confBuilder.addLayer(lstmStacks[stackIdx][i], lstmLayer, lstmStacks[stackIdx][i-1]); | |
} | |
} | |
} | |
String[] stackEnds = new String[conf.numStacks]; | |
for (int stackIdx = 0; stackIdx < conf.numStacks; stackIdx++){ | |
stackEnds[stackIdx] = lstmStacks[stackIdx][conf.depthStack-1]; | |
} | |
confBuilder.addLayer("class", new RnnOutputLayer.Builder(LossFunctions.LossFunction.KL_DIVERGENCE) | |
.nIn(conf.numHiddenNodes[conf.depthStack-1]*conf.numStacks) | |
.nOut(numClasses) | |
.activation(Activation.SOFTMAX) | |
.build(), stackEnds); | |
confBuilder.setOutputs("class"); | |
ComputationGraphConfiguration cgconf = confBuilder.pretrain(false).backprop(true).build(); | |
cgconf.addPreProcessors(InputType.recurrent(numInputs)); | |
ComputationGraph model = new ComputationGraph(cgconf); | |
model.init(); | |
UIServer uiServer = UIServer.getInstance(); | |
StatsStorage statsStorage = new InMemoryStatsStorage(); | |
int listenerFrequency = 1; | |
model.setListeners(new StatsListener(statsStorage, listenerFrequency), | |
new ScoreIterationListener(Constants.NEURAL_NET_ITERATION_LISTENER)); | |
uiServer.attach(statsStorage); | |
System.out.println("Machine number of params:"+(model.numParams())); | |
System.out.println("Learning rate: "+learningRate); | |
return model; | |
} | |
class Configuration { | |
int numStacks = 1; | |
int depthStack; | |
int[] numHiddenNodes; | |
public double learningRate=1e-2; | |
double l2 = 0.001;//minimal weight gaussian prior | |
double l1 = 0.0001;//sparse weights laplace prior | |
public Configuration(int numInput){ | |
this.numHiddenNodes = new int[] {20*numInput,15*numInput,10*numInput,5*numInput,5*numInput}; | |
this.depthStack = this.numHiddenNodes.length; | |
} | |
} | |
public static void main (String[] args){ | |
ComputationGraph cg = new RecurrentClassifierComputationGraph().buildCG(9,3,1e-2); | |
String modelFile = String.format("IntermediateModels/bbrp_rc_KL_%1.3e",1e-2); | |
MultiDataSet dataSet = new MultiDataSet(); | |
for (int pass = 0; pass < 10; pass++){ | |
System.out.println("Starting pass: "+(pass+1)); | |
for (int i = 1; i <= 10; i++){ | |
try { | |
dataSet.load(new File("trainingSets/bbrp_"+i)); | |
} catch (IOException e1) { | |
e1.printStackTrace(); | |
} | |
INDArray[] dataSetInputs = dataSet.getFeatures(); | |
cg.setInputs(dataSetInputs); | |
double score0 = cg.score(dataSet,true); | |
double score = score0; | |
System.out.println("Initial KL score: "+score0); | |
int iter = 0; | |
while (iter++ < 50 && score > 0.5*score0){ | |
cg.fit(dataSet); | |
score = cg.score(); | |
} | |
System.out.println("Number of iterations used: "+iter); | |
System.out.println("Final KL score: "+score); | |
try { | |
System.out.println("Saving model to: "+modelFile); | |
ModelSerializer.writeModel(cg, modelFile, true); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment