Skip to content

Instantly share code, notes, and snippets.

@frankfliu
Created January 5, 2022 03:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save frankfliu/ec7ac22cedd7540a9d95c90a440847ac to your computer and use it in GitHub Desktop.
Save frankfliu/ec7ac22cedd7540a9d95c90a440847ac to your computer and use it in GitHub Desktop.
Load PyTorch model from local file
import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Paths;
public class LoadModelFromLocalFile {
/*
* 1. Download pytorch resnet18 model from: https://resources.djl.ai/test-models/pytorch/resnet18_jit.tar.gz
* 2. untar resnet18_jit.tar.gz into folder: resnet18_jit
*/
public static void main(String[] args) throws IOException, ModelException, TranslateException {
String imageFile = "https://resources.djl.ai/images/kitten.jpg";
Image img = ImageFactory.getInstance().fromUrl(imageFile);
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class)
// .optModelUrls("file:resnet18_jit") // use relative directory
// .optModelUrls("file:/downloads/resnet18_jit") // use absolute path
// .optModelUrls("file:///downloads/resnet18_jit") // use absolute path
// .optModelUrls("file:///downloads/resnet18_jit.tar.gz") // load from archive file
.optModelPath(Paths.get("downloads/resnet18_jit"))
// .optModelPath(Paths.get("downloads/resnet18_jit/resnet18_jit.pt"))
.optEngine("PyTorch")
.build();
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Classifications result = predictor.predict(img);
System.out.println(result);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment