Skip to content

Instantly share code, notes, and snippets.

@filipsedivy
Created May 2, 2022 11:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save filipsedivy/0ca5a8c0178e5936b3285fad0d68862d to your computer and use it in GitHub Desktop.
Save filipsedivy/0ca5a8c0178e5936b3285fad0d68862d to your computer and use it in GitHub Desktop.
package org.example;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
public class Main {
public static void main(String[] args) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
ImageClassificationTranslator translator = ImageClassificationTranslator.builder()
.optSynset(new ArrayList<String>() {
{
add("class_1");
add("class_2");
add("class_3");
add("class_4");
add("class_5");
add("class_6");
add("class_7");
add("class_8");
}
})
.addTransform(new Resize(224, 224))
.addTransform(new ToTensor())
.build();
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get("model.onnx"))
.optTranslator(translator)
.optEngine("OnnxRuntime")
.build();
ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor();
Image image = ImageFactory.getInstance().fromFile(Paths.get("picture.jpg"));
Classifications predict = predictor.predict(image);
System.out.println(predict.getAsString());
}
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>demo2</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.16.0</version>
</dependency>
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-model-zoo</artifactId>
<version>0.16.0</version>
</dependency>
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-engine</artifactId>
<version>0.16.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.onnxruntime</groupId>
<artifactId>onnxruntime-engine</artifactId>
<version>0.16.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.apache.cassandra</groupId>
<artifactId>cassandra-all</artifactId>
<version>4.0.3</version>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
</exclusion>
<exclusion>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
</project>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment