Skip to content

Instantly share code, notes, and snippets.

@frankfliu
Last active January 27, 2022 09:57
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save frankfliu/7de9c080947ce97a329a585ef9f8fef0 to your computer and use it in GitHub Desktop.
Save frankfliu/7de9c080947ce97a329a585ef9f8fef0 to your computer and use it in GitHub Desktop.
Load PyTorch model for jar file
import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.IOException;
public class LoadModelFromJar {
public static void main(String[] args) throws IOException, ModelException, TranslateException {
/*
* 1. Download pytorch resnet18 model from: https://resources.djl.ai/test-models/pytorch/resnet18_jit.tar.gz
* 2. put resnet18_jit.tar.gz file in your project resource folder
*/
String imageFile = "https://resources.djl.ai/images/kitten.jpg";
Image img = ImageFactory.getInstance().fromUrl(imageFile);
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class)
.optModelUrls("jar:///resnet18_jit.tar.gz")
.optModelName("resnet18_jit")
.optEngine("PyTorch")
.build();
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Classifications result = predictor.predict(img);
System.out.println(result);
}
}
}
@yongde1990
Copy link

nice

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment