Skip to content

Instantly share code, notes, and snippets.

@dacr
Last active February 3, 2024 13:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dacr/9e12eee05eaa3fe2e37d45d71cbc8602 to your computer and use it in GitHub Desktop.
Save dacr/9e12eee05eaa3fe2e37d45d71cbc8602 to your computer and use it in GitHub Desktop.
Playing with Java Deep Learning (DJL), tutorial-03 / published by https://github.com/dacr/code-examples-manager #2b57959f-ea11-4c27-b715-333efdf5582e/230ac2bef739d85f436f3aad026d7e4568e94472
// summary : Playing with Java Deep Learning (DJL), tutorial-03
// keywords : djl, machine-learning, tutorial, ai
// publish : gist
// authors : David Crosson
// license : Apache NON-AI License Version 2.0 (https://raw.githubusercontent.com/non-ai-licenses/non-ai-licenses/main/NON-AI-APACHE2)
// id : 2b57959f-ea11-4c27-b715-333efdf5582e
// created-on : 2021-03-05T09:23:01Z
// managed-by : https://github.com/dacr/code-examples-manager
// run-with : scala-cli $file
// ---------------------
//> using scala "3.3.1"
//> using dep "org.slf4j:slf4j-api:2.0.11"
//> using dep "org.slf4j:slf4j-simple:2.0.11"
//> using dep "ai.djl:api:0.26.0"
//> using dep "ai.djl:model-zoo:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-engine:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-model-zoo:0.26.0"
////> using dep "net.java.dev.jna:jna:5.13.0"
// ---------------------
// inspired from https://docs.djl.ai/jupyter/tutorial/03_image_classification_with_your_model.html
System.setProperty("org.slf4j.simpleLogger.defaultLogLevel","debug")
import java.awt.image._
import java.nio.file._
import java.util._
import java.util.stream._
import ai.djl._
import ai.djl.basicmodelzoo.basic._
import ai.djl.ndarray._
import ai.djl.modality._
import ai.djl.modality.cv._
import ai.djl.modality.cv.util.NDImageUtils
import ai.djl.translate._
println("----------------- Load your model")
val modelDir = Paths.get("build/mlp")
val model = Model.newInstance("mlp")
model.setBlock(new Mlp(28 * 28, 10, Array[Int](128, 64)))
model.load(modelDir)
println("----------------- Create a Translator")
val translator = new Translator[Image, Classifications] {
override def processInput(ctx:TranslatorContext, input:Image):NDList = {
// Convert Image to NDArray
val array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE)
new NDList(NDImageUtils.toTensor(array))
}
override def processOutput(ctx:TranslatorContext, list:NDList):Classifications = {
// Create a Classifications with the output probabilities
val probabilities = list.singletonOrThrow().softmax(0)
val classNames =
IntStream
.range(0, 10)
.mapToObj(_.toString)
.collect(Collectors.toList())
new Classifications(classNames, probabilities);
}
override def getBatchifier():Batchifier = {
// The Batchifier describes how to combine a batch together
// Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array
Batchifier.STACK;
}
}
println("----------------- Create Predictor")
val predictor = model.newPredictor(translator)
println("----------------- Run inference")
0.to(9).foreach { num =>
val url = s"https://mapland.fr/data/ai/images-numbers/$num.png"
val img = ImageFactory.getInstance().fromUrl(url)
//img.getWrappedImage()
val classifications = predictor.predict(img)
println(s"*** result for $num ($url)")
println(classifications)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment