Skip to content

Instantly share code, notes, and snippets.

@dacr
Last active February 3, 2024 13:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dacr/218db9ad82f7c1e26db006cf0797f244 to your computer and use it in GitHub Desktop.
Save dacr/218db9ad82f7c1e26db006cf0797f244 to your computer and use it in GitHub Desktop.
Playing with Java Deep Learning (DJL), tutorial-02 & tutorial-03 combined to a standalone executable script / published by https://github.com/dacr/code-examples-manager #250223b8-c123-4c71-9f24-c57f207d371e/dca35ae666b53733fe16f963383c141217464617
// summary : Playing with Java Deep Learning (DJL), tutorial-02 & tutorial-03 combined to a standalone executable script
// keywords : djl, machine-learning, tutorial, ai, @testable
// publish : gist
// authors : David Crosson
// license : Apache NON-AI License Version 2.0 (https://raw.githubusercontent.com/non-ai-licenses/non-ai-licenses/main/NON-AI-APACHE2)
// id : 250223b8-c123-4c71-9f24-c57f207d371e
// created-on : 2021-03-05T09:23:01Z
// managed-by : https://github.com/dacr/code-examples-manager
// run-with : scala-cli $file
// ---------------------
//> using scala "3.3.1"
//> using dep "org.slf4j:slf4j-api:2.0.11"
//> using dep "org.slf4j:slf4j-simple:2.0.11"
//> using dep "ai.djl:api:0.26.0"
//> using dep "ai.djl:basicdataset:0.26.0"
//> using dep "ai.djl:model-zoo:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-engine:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-model-zoo:0.26.0"
////> using dep "net.java.dev.jna:jna:5.13.0"
// ---------------------
// inspired from https://docs.djl.ai/jupyter/tutorial/03_image_classification_with_your_model.html
System.setProperty("org.slf4j.simpleLogger.defaultLogLevel","debug")
import java.awt.image._
import java.nio.file._
import java.util._
import java.util.stream._
import ai.djl._
import ai.djl.basicdataset.cv.classification.Mnist
import ai.djl.ndarray.types._
import ai.djl.training._
import ai.djl.training.dataset._
import ai.djl.training.initializer._
import ai.djl.training.loss._
import ai.djl.training.listener._
import ai.djl.training.evaluator._
import ai.djl.training.optimizer._
import ai.djl.training.util._
import ai.djl.basicmodelzoo.cv.classification._
import ai.djl.basicmodelzoo.basic._
import ai.djl.ndarray._
import ai.djl.modality._
import ai.djl.modality.cv._
import ai.djl.modality.cv.util.NDImageUtils
import ai.djl.translate._
val modelPath = "build/mlp"
val modelDir = Paths.get(modelPath)
if (!modelDir.toFile.exists()) {
println("----------------- Prepare MNIST dataset for training")
val batchSize = 32
val mnist = Mnist.builder.setSampling(batchSize, true).build
mnist.prepare(new ProgressBar)
println("----------------- Create your Model")
val model = Model.newInstance("mlp")
model.setBlock(new Mlp(28 * 28, 10, Array(128, 64)))
println("----------------- Create a Trainer")
val config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
//softmaxCrossEntropyLoss is a standard loss for classification problems
.addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
.addTrainingListeners(TrainingListener.Defaults.logging() : _*)
val trainer = model.newTrainer(config)
println("----------------- Initialize Training")
trainer.initialize(new Shape(1, 28 * 28))
println("----------------- Train your model")
val epoch = 5
EasyTrain.fit(trainer, epoch, mnist, null)
println("----------------- Save your model")
Files.createDirectories(modelDir)
model.setProperty("Epoch", String.valueOf(epoch))
model.save(modelDir, "mlp")
}
// ========================================================================================
println("----------------- Load your model")
val model = Model.newInstance("mlp")
model.setBlock(new Mlp(28 * 28, 10, Array[Int](128, 64)))
model.load(modelDir)
println("----------------- Create a Translator")
val translator = new Translator[Image, Classifications] {
override def processInput(ctx:TranslatorContext, input:Image):NDList = {
// Convert Image to NDArray
val array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE)
new NDList(NDImageUtils.toTensor(array))
}
override def processOutput(ctx:TranslatorContext, list:NDList):Classifications = {
// Create a Classifications with the output probabilities
val probabilities = list.singletonOrThrow().softmax(0)
val classNames =
IntStream
.range(0, 10)
.mapToObj(_.toString)
.collect(Collectors.toList())
new Classifications(classNames, probabilities);
}
override def getBatchifier():Batchifier = {
// The Batchifier describes how to combine a batch together
// Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array
Batchifier.STACK;
}
}
println("----------------- Create Predictor")
val predictor = model.newPredictor(translator)
println("----------------- Run inference")
0.to(9).foreach { num =>
val url = s"https://mapland.fr/data/ai/images-numbers/$num.png"
val img = ImageFactory.getInstance().fromUrl(url)
//img.getWrappedImage()
val classifications = predictor.predict(img)
println(s"*** result for $num ($url)")
println(classifications)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment