Skip to content

Instantly share code, notes, and snippets.

@dacr
Last active January 29, 2024 09:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dacr/936ea1aca3765cd9dc882a4526098088 to your computer and use it in GitHub Desktop.
Save dacr/936ea1aca3765cd9dc882a4526098088 to your computer and use it in GitHub Desktop.
Things (objects, people, animals) detection using DJL - compare models efficiency / published by https://github.com/dacr/code-examples-manager #bd813b80-9e47-489d-9a1d-86c5fb5c828e/70d3af45e506dc25d11d9bbd87ba62bc80a7a419
// summary : Things (objects, people, animals) detection using DJL - compare models efficiency
// keywords : djl, machine-learning, tutorial, detection, ai, @testable
// publish : gist
// authors : David Crosson
// license : Apache NON-AI License Version 2.0 (https://raw.githubusercontent.com/non-ai-licenses/non-ai-licenses/main/NON-AI-APACHE2)
// id : bd813b80-9e47-489d-9a1d-86c5fb5c828e
// created-on : 2024-01-27T14:36:36+01:00
// managed-by : https://github.com/dacr/code-examples-manager
// run-with : scala-cli $file
// ---------------------
//> using scala "3.3.1"
//> using dep "org.slf4j:slf4j-api:2.0.11"
//> using dep "org.slf4j:slf4j-simple:2.0.11"
//> using dep "net.java.dev.jna:jna:5.14.0"
//> using dep "ai.djl:api:0.26.0"
//> using dep "ai.djl:basicdataset:0.26.0"
//> using dep "ai.djl:model-zoo:0.26.0"
//> using dep "ai.djl.huggingface:tokenizers:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-engine:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-model-zoo:0.26.0"
//> using dep "ai.djl.pytorch:pytorch-engine:0.26.0"
//> using dep "ai.djl.pytorch:pytorch-model-zoo:0.26.0"
//> using dep "ai.djl.tensorflow:tensorflow-engine:0.26.0"
//> using dep "ai.djl.tensorflow:tensorflow-model-zoo:0.26.0"
//> using dep "ai.djl.paddlepaddle:paddlepaddle-engine:0.26.0"
//> using dep "ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.26.0"
//> using dep "ai.djl.onnxruntime:onnxruntime-engine:0.26.0"
// ---------------------
System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "error")
import ai.djl.Application
import ai.djl.engine.Engine
import ai.djl.modality.cv.Image
import ai.djl.modality.cv.ImageFactory
import ai.djl.modality.cv.output.DetectedObjects
import ai.djl.modality.cv.output.DetectedObjects.DetectedObject
import ai.djl.repository.Artifact
import ai.djl.repository.zoo.{Criteria, ModelNotFoundException, ModelZoo, ModelZooResolver, ZooModel}
import ai.djl.training.util.ProgressBar
import java.net.{URI, URL}
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.util.UUID
import java.util.concurrent.TimeUnit
import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters.*
import scala.io.AnsiColor.{BLUE, BOLD, CYAN, GREEN, MAGENTA, RED, RESET, UNDERLINED, YELLOW}
// ----------------------------------------------------------------------------------------------
def saveBoundingBoxImage(img: Image, detection: DetectedObjects, outputFile: Path): Unit = {
val newImage = img.duplicate()
newImage.drawBoundingBoxes(detection)
import java.nio.file.Files
newImage.save(Files.newOutputStream(outputFile), "png")
}
def basename(filename: String): String = {
filename
.split("[/](?=[^/]*$)", 2)
.last
.split("[.]", 2)
.head
}
// ----------------------------------------------------------------------------------------------
val outputDir = Paths.get("build/output")
Files.createDirectories(outputDir)
// ----------------------------------------------------------------------------------------------
case class ModelArtifact(artifact: Artifact) {
val uuid = UUID.nameUUIDFromBytes(
s"$groupId$artifactId$version${properties.toList.sorted}".getBytes
)
def groupId: String = artifact.getMetadata.getGroupId
def artifactId: String = artifact.getMetadata.getArtifactId
def version: String = artifact.getVersion
def properties: Map[String, String] = artifact.getProperties.asScala.toMap
def ident = toString()
override def toString: String = s"$groupId:$artifactId:$version"
}
// ----------------------------------------------------------------------------------------------
case class ModelResult(
inputImageSource: URL,
modelArtifact: ModelArtifact,
selectedModelPath: Path,
responseTime: Duration,
detectedObjects: List[DetectedObject],
generatedBoundedBoxesImagePath: Path
)
val blackListed = Set[String](
"ai.djl.paddlepaddle:face_detection:0.0.1", // java.lang.IndexOutOfBoundsException: Incorrect number of elements in NDList.singletonOrThrow: Expected 1 and was 4
"ai.djl.zoo:ssd:0.0.2" // java.lang.ArrayIndexOutOfBoundsException: Index 1 out of bounds for length 1
)
def testModel(modelArtifact: ModelArtifact, inputImageSources: List[URL]): List[ModelResult] = {
println(s"${RED}TESTING MODEL $modelArtifact$RESET")
val criteria =
Criteria
.builder()
.setTypes(classOf[Image], classOf[DetectedObjects])
.optApplication(Application.CV.OBJECT_DETECTION)
.optGroupId(modelArtifact.groupId)
.optArtifactId(modelArtifact.artifactId)
.optFilters(modelArtifact.properties.asJava)
.optProgress(new ProgressBar)
.build()
try {
val model = ModelZoo.loadModel(criteria)
val predictor = model.newPredictor()
inputImageSources.map { inputImageSource =>
val inputImage = ImageFactory.getInstance().fromUrl(inputImageSource)
val started = System.currentTimeMillis()
val detected: DetectedObjects = predictor.predict(inputImage)
val duration = Duration.apply(System.currentTimeMillis() - started, TimeUnit.MILLISECONDS)
val detectedObjects = detected
.items[DetectedObject]()
.asScala
.toList
val outputImageFile = outputDir.resolve(s"${basename(inputImageSource.getFile)}-${modelArtifact.uuid}.png")
saveBoundingBoxImage(inputImage, detected, outputImageFile)
ModelResult(
inputImageSource = inputImageSource,
modelArtifact = modelArtifact,
selectedModelPath = model.getModelPath,
responseTime = duration,
detectedObjects = detectedObjects,
generatedBoundedBoxesImagePath = outputImageFile
)
}
} catch {
case err: ModelNotFoundException =>
println(s"No matching model for $modelArtifact : ${err.getMessage}")
Nil
}
}
def showResults(results: Seq[ModelResult]): Unit = {
results.groupBy(_.inputImageSource).foreach { (imageURL, resultsForImage) =>
println(s"${BLUE}${BOLD}==========================================================================$RESET")
println(s"${BLUE}${BOLD}RESULTS FOR $imageURL$RESET")
resultsForImage.foreach { result =>
import result.*
println(s"${BLUE}${BOLD}--------------------------------------------------------------------------$RESET")
println(s"${BLUE}${BOLD}MODEL ${modelArtifact.ident}$RESET")
println(s"${BLUE}PATH $selectedModelPath$RESET")
println(s"${GREEN}Number of detected object : ${detectedObjects.size} in $responseTime$RESET")
println(s"${GREEN} look at $generatedBoundedBoxesImagePath$RESET")
detectedObjects.sortBy(-_.getProbability).foreach { detectedObject =>
println(f" $YELLOW$BOLD${detectedObject.getClassName} ${RED} ${detectedObject.getProbability}%1.2f$RESET")
}
}
}
}
val inputImageSources =
1.to(16).toList.map(n => URI.create(f"https://mapland.fr/data/ai/images-samples/example-$n%03d.jpg").toURL)
val objectDetectionsArtifacts =
ModelZoo
.listModels()
.asScala
.get(Application.CV.OBJECT_DETECTION)
.map(_.asScala)
.getOrElse(Nil)
.toList
val results =
objectDetectionsArtifacts
.map(ModelArtifact.apply)
.filterNot(artifactKey => blackListed.contains(artifactKey.ident))
.flatMap(modelArtifact => testModel(modelArtifact, inputImageSources))
showResults(results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment