Skip to content

Instantly share code, notes, and snippets.

@dosht
Created May 17, 2015 07:22
Show Gist options
  • Save dosht/971ed7916543839f4eea to your computer and use it in GitHub Desktop.
Save dosht/971ed7916543839f4eea to your computer and use it in GitHub Desktop.
DL4J Scala Example
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.deeplearning4j.nn.conf.`override`.ClassifierOverride;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.optimize.api.IterationListener;
import scala.collection.JavaConversions._
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
object Main extends App {
Nd4j.MAX_SLICES_TO_PRINT = -1
Nd4j.MAX_ELEMENTS_PER_SLICE = -1
val conf = new NeuralNetConfiguration.Builder()
.iterations(100).layer(new RBM())
.weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1))
.activationFunction("tanh").momentum(0.9)
.optimizationAlgo(OptimizationAlgorithm.LBFGS)
.constrainGradientToUnitNorm(true).k(1).regularization(true).l2(2e-4)
.visibleUnit(RBM.VisibleUnit.GAUSSIAN).hiddenUnit(RBM.HiddenUnit.RECTIFIED)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT)
.learningRate(1e-1f)
.nIn(4).nOut(3).list(2)
.hiddenLayerSizes(3)
.`override`(new ClassifierOverride(1)).build()
val d = new MultiLayerNetwork(conf)
d.init()
d.setListeners(seqAsJavaList(List((new ScoreIterationListener(1)).asInstanceOf[IterationListener])))
val iter = new IrisDataSetIterator(150, 150)
val next = iter.next();
Nd4j.writeTxt(next.getFeatureMatrix(), "iris.txt", "\t");
next.normalizeZeroMeanZeroUnitVariance();
val testAndTrain = next.splitTestAndTrain(110);
val train = testAndTrain.getTrain();
d.fit(train);
val test = testAndTrain.getTest()
val eval = new Evaluation()
val output = d.output(test.getFeatureMatrix())
eval.eval(test.getLabels(), output)
println(s"Score ${eval.stats}")
}
@dosht
Copy link
Author

dosht commented May 17, 2015

This works with deeplearning4j-core version: "0.0.3.3.4.alpha1-SNAPSHOT", so you may need to clone the repo local and run mvn install -Dmaven.test.skip=true first.

Then the build.sbt should look something like this:

val dl4jVersion = "0.0.3.5.5.3"

val deeplearnVersion = "0.0.3.3.4.alpha1-SNAPSHOT"

resolvers += "Local Maven Repo" at "file:///Users/dosht/.m2/repository/"

libraryDependencies ++= Seq(
  "org.nd4j" % "nd4j-jblas" % dl4jVersion,
  "org.deeplearning4j" % "deeplearning4j-core" % deeplearnVersion
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment