Skip to content

Instantly share code, notes, and snippets.

@dacr
Last active February 3, 2024 13:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dacr/7e6a838a8d79e55b21bb7c8e68af97a7 to your computer and use it in GitHub Desktop.
Save dacr/7e6a838a8d79e55b21bb7c8e68af97a7 to your computer and use it in GitHub Desktop.
Playing with Java Deep Learning (DJL), tutorial-02 / published by https://github.com/dacr/code-examples-manager #5bad9d22-6b77-4427-894c-ec30c1a6f019/33bcf3e9db946dd3cfc5297966a995bdca006194
// summary : Playing with Java Deep Learning (DJL), tutorial-02
// keywords : djl, machine-learning, tutorial, ai
// publish : gist
// authors : David Crosson
// license : Apache NON-AI License Version 2.0 (https://raw.githubusercontent.com/non-ai-licenses/non-ai-licenses/main/NON-AI-APACHE2)
// id : 5bad9d22-6b77-4427-894c-ec30c1a6f019
// created-on : 2021-03-05T09:23:01Z
// managed-by : https://github.com/dacr/code-examples-manager
// run-with : scala-cli $file
// ---------------------
//> using scala "3.3.1"
//> using dep "org.slf4j:slf4j-api:2.0.11"
//> using dep "org.slf4j:slf4j-simple:2.0.11"
//> using dep "ai.djl:api:0.26.0"
//> using dep "ai.djl:basicdataset:0.26.0"
//> using dep "ai.djl:model-zoo:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-engine:0.26.0"
////> using dep "net.java.dev.jna:jna:5.13.0"
// ---------------------
// inspired from https://docs.djl.ai/jupyter/tutorial/02_train_your_first_model.html
import java.nio.file._
import ai.djl._
import ai.djl.basicdataset.cv.classification.Mnist
import ai.djl.ndarray.types._
import ai.djl.training._
import ai.djl.training.dataset._
import ai.djl.training.initializer._
import ai.djl.training.loss._
import ai.djl.training.listener._
import ai.djl.training.evaluator._
import ai.djl.training.optimizer._
import ai.djl.training.util._
import ai.djl.basicmodelzoo.cv.classification._
import ai.djl.basicmodelzoo.basic._
println("----------------- Prepare MNIST dataset for training")
val batchSize = 32
val mnist = Mnist.builder.setSampling(batchSize, true).build
mnist.prepare(new ProgressBar)
println("----------------- Create your Model")
val model = Model.newInstance("mlp")
model.setBlock(new Mlp(28 * 28, 10, Array(128, 64)))
println("----------------- Create a Trainer")
val config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
//softmaxCrossEntropyLoss is a standard loss for classification problems
.addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
.addTrainingListeners(TrainingListener.Defaults.logging() : _*)
val trainer = model.newTrainer(config)
println("----------------- Initialize Training")
trainer.initialize(new Shape(1, 28 * 28))
println("----------------- Train your model")
val epoch = 2
EasyTrain.fit(trainer, epoch, mnist, null)
println("----------------- Save your model")
val modelDir = Paths.get("build/mlp")
Files.createDirectories(modelDir)
model.setProperty("Epoch", String.valueOf(epoch))
model.save(modelDir, "mlp")
println(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment