Created
September 5, 2016 22:32
-
-
Save romeokienzler/ad1ca47dc9ac1751ed681dd7b9947f39 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%AddJar http://central.maven.org/maven2/org/deeplearning4j/deeplearning4j-core/0.4.0/deeplearning4j-core-0.4.0.jar | |
%AddJar http://central.maven.org/maven2/org/nd4j/nd4j-api/0.4.0/nd4j-api-0.4.0.jar | |
%AddJar http://central.maven.org/maven2/org/nd4j/nd4j-buffer/0.4.0/nd4j-buffer-0.4.0.jar | |
%AddJar http://central.maven.org/maven2/org/nd4j/canova-nd4j-image/0.0.0.17/canova-nd4j-image-0.0.0.17.jar | |
%AddJar http://central.maven.org/maven2/org/nd4j/canova-nd4j-codec/0.0.0.17/canova-nd4j-codec-0.0.0.17.jar | |
%AddJar http://central.maven.org/maven2/org/nd4j/canova-api/0.0.0.17/canova-api-0.0.0.17.jar | |
%AddJar http://central.maven.org/maven2/commons-io/commons-io/2.4/commons-io-2.4.jar | |
%AddJar http://central.maven.org/maven2/org/nd4j/nd4j-x86/0.4-rc3.8/nd4j-x86-0.4-rc3.8.jar | |
import java.io.File; | |
import org.apache.commons.io.FileUtils | |
import java.net.URL | |
import org.canova.api.records.reader.RecordReader; | |
import org.canova.api.records.reader.impl.CSVRecordReader; | |
import org.canova.api.split.FileSplit; | |
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; | |
val seed = 123; | |
val learningRate = 0.01; | |
val batchSize = 1000; | |
val nEpochs = 30; | |
val numInputs = 3; | |
val numOutputs = 7; | |
val numHiddenNodes = 3; | |
val rr = new CSVRecordReader(); | |
val url = "http://github.com/romeokienzler/developerWorks/raw/master/data_dl4j.csv" | |
val tempDir = System.getProperty("java.io.tmpdir") | |
val fileLocation = tempDir + "data_dl4j.csv" | |
val f = new File(fileLocation) | |
if (!f.exists()) { | |
FileUtils.copyURLToFile(new URL(url), f) | |
println("File downloaded to " + f.getAbsolutePath()); | |
} else { | |
println("Using existing text file at " + f.getAbsolutePath()); | |
} | |
rr.initialize(new FileSplit(new File(fileLocation))); | |
val trainIter = new RecordReaderDataSetIterator(rr,batchSize, 0, 7); | |
val rrTest = new CSVRecordReader(); | |
rrTest.initialize(new FileSplit(new File(fileLocation))); | |
val testIter = new RecordReaderDataSetIterator(rrTest,batchSize, 0, 7); | |
val conf=new NeuralNetConfiguration.Builder().seed(seed).iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(learningRate).updater(Updater.NESTEROVS).momentum(0.9).list().layer(0,new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes).weightInit(WeightInit.XAVIER).activation("relu").build()).layer(1,new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes).weightInit(WeightInit.XAVIER).activation("sigmoid").build()).layer(2,new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes).weightInit(WeightInit.XAVIER).activation("sigmoid").build()).layer(3,new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).weightInit(WeightInit.XAVIER).activation("softmax").weightInit(WeightInit.XAVIER).nIn(numHiddenNodes).nOut(numOutputs).build()).pretrain(false).backprop(true).build(); | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment