Skip to content

Instantly share code, notes, and snippets.

@dacr
Last active January 29, 2024 09:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dacr/8ab38d35575ff6ad8aa9f962f6bf9b87 to your computer and use it in GitHub Desktop.
Save dacr/8ab38d35575ff6ad8aa9f962f6bf9b87 to your computer and use it in GitHub Desktop.
image classification compare models / published by https://github.com/dacr/code-examples-manager #f1c02983-8800-4c08-812e-14b732be895d/ba35a5e64fe7c54eecb279958b76e0db310484e1
// summary : image classification compare models
// keywords : djl, machine-learning, tutorial, detection, ai, @testable
// publish : gist
// authors : David Crosson
// license : Apache NON-AI License Version 2.0 (https://raw.githubusercontent.com/non-ai-licenses/non-ai-licenses/main/NON-AI-APACHE2)
// id : f1c02983-8800-4c08-812e-14b732be895d
// created-on : 2024-01-28T13:24:06+01:00
// managed-by : https://github.com/dacr/code-examples-manager
// run-with : scala-cli $file
// ---------------------
//> using scala "3.3.1"
//> using dep "org.slf4j:slf4j-api:2.0.11"
//> using dep "org.slf4j:slf4j-simple:2.0.11"
//> using dep "net.java.dev.jna:jna:5.14.0"
//> using dep "ai.djl:api:0.26.0"
//> using dep "ai.djl:basicdataset:0.26.0"
//> using dep "ai.djl:model-zoo:0.26.0"
//> using dep "ai.djl.huggingface:tokenizers:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-engine:0.26.0"
//> using dep "ai.djl.mxnet:mxnet-model-zoo:0.26.0"
//> using dep "ai.djl.pytorch:pytorch-engine:0.26.0"
//> using dep "ai.djl.pytorch:pytorch-model-zoo:0.26.0"
//> using dep "ai.djl.tensorflow:tensorflow-engine:0.26.0"
//> using dep "ai.djl.tensorflow:tensorflow-model-zoo:0.26.0"
//> using dep "ai.djl.paddlepaddle:paddlepaddle-engine:0.26.0"
//> using dep "ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.26.0"
//> using dep "ai.djl.onnxruntime:onnxruntime-engine:0.26.0"
// ---------------------
System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "error")
import ai.djl.Application
import ai.djl.engine.Engine
import ai.djl.repository.Artifact
import ai.djl.repository.zoo.{Criteria, ModelNotFoundException, ModelZoo, ModelZooResolver, ZooModel}
import ai.djl.training.util.ProgressBar
import ai.djl.modality.Classifications
import ai.djl.modality.Classifications.Classification
import ai.djl.modality.cv.Image
import ai.djl.modality.cv.ImageFactory
import java.net.{URI, URL}
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.util.UUID
import java.util.concurrent.TimeUnit
import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters.*
import scala.io.AnsiColor.{BLUE, BOLD, CYAN, GREEN, MAGENTA, RED, RESET, UNDERLINED, YELLOW}
case class ModelArtifact(artifact: Artifact) {
val uuid = UUID.nameUUIDFromBytes(
s"$groupId$artifactId$version${properties.toList.sorted}".getBytes
)
def groupId: String = artifact.getMetadata.getGroupId
def artifactId: String = artifact.getMetadata.getArtifactId
def version: String = artifact.getVersion
def properties: Map[String, String] = artifact.getProperties.asScala.toMap
def ident = toString()
override def toString: String = s"$groupId:$artifactId:$version"
}
// ----------------------------------------------------------------------------------------------
case class ImageClassification(
classification: String,
probability: Double
)
case class ModelResult(
inputImageSource: URL,
modelArtifact: ModelArtifact,
selectedModelPath: Path,
responseTime: Duration,
imageClassifications: List[ImageClassification]
)
val blackListed = Set[String](
"ai.djl.pytorch:resnet18_embedding:0.0.1", // ai.djl.translate.TranslateException: java.io.FileNotFoundException: File not found: /home/dcr/.djl.ai/cache/repo/model/cv/image_classification/ai/djl/pytorch/resnet18_embedding/0.0.1/synset.txt
"ai.djl.zoo:mlp:0.0.3" // ai.djl.translate.TranslateException: ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: src.Size() % known_dim_size_prod == 0 (672 vs. 0) : Cannot reshape array of size 2056320 into shape [-1,784]
)
def testModel(modelArtifact: ModelArtifact, inputImageSources: List[URL]): List[ModelResult] = {
println(s"${RED}TESTING MODEL $modelArtifact$RESET")
val criteria =
Criteria
.builder()
.setTypes(classOf[Image], classOf[Classifications])
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.optGroupId(modelArtifact.groupId)
.optArtifactId(modelArtifact.artifactId)
.optFilters(modelArtifact.properties.asJava)
.optProgress(new ProgressBar)
.build()
try {
val model = ModelZoo.loadModel(criteria)
val predictor = model.newPredictor()
inputImageSources.map { inputImageSource =>
val inputImage = ImageFactory.getInstance().fromUrl(inputImageSource)
val started = System.currentTimeMillis()
val detected: Classifications = predictor.predict(inputImage)
val duration = Duration.apply(System.currentTimeMillis() - started, TimeUnit.MILLISECONDS)
val imageClassifications = detected
.items[Classification]()
.asScala
.toList
.filter(_.getProbability > 0.01) // mandatory as many classes are returned with 0.00 probability
.map(detected => ImageClassification(detected.getClassName, detected.getProbability))
ModelResult(
inputImageSource = inputImageSource,
modelArtifact = modelArtifact,
selectedModelPath = model.getModelPath,
responseTime = duration,
imageClassifications = imageClassifications
)
}
} catch {
case err: ModelNotFoundException =>
println(s"No matching model for $modelArtifact : ${err.getMessage}")
Nil
}
}
def showResults(results: Seq[ModelResult]): Unit = {
results.groupBy(_.inputImageSource).foreach { (imageURL, resultsForImage) =>
println(s"${BLUE}${BOLD}==========================================================================$RESET")
println(s"${BLUE}${BOLD}RESULTS FOR $imageURL$RESET")
resultsForImage.foreach { result =>
import result.*
println(s"${BLUE}${BOLD}--------------------------------------------------------------------------$RESET")
println(s"${BLUE}${BOLD}MODEL ${modelArtifact.ident}$RESET")
println(s"${BLUE}PATH $selectedModelPath$RESET")
println(s"${GREEN}Number of detected image classes : ${imageClassifications.size} in $responseTime$RESET")
imageClassifications.sortBy(-_.probability).foreach { detected =>
println(f" $YELLOW$BOLD${detected.classification} ${RED} ${detected.probability}%1.2f$RESET")
}
}
}
}
val inputImageSources =
1.to(16).toList.map(n => URI.create(f"https://mapland.fr/data/ai/images-samples/example-$n%03d.jpg").toURL)
val objectDetectionsArtifacts =
ModelZoo
.listModels()
.asScala
.get(Application.CV.IMAGE_CLASSIFICATION)
.map(_.asScala)
.getOrElse(Nil)
.toList
val results =
objectDetectionsArtifacts
.map(ModelArtifact.apply)
.filterNot(artifactKey => blackListed.contains(artifactKey.ident))
.flatMap(modelArtifact => testModel(modelArtifact, inputImageSources))
showResults(results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment