Skip to content

Instantly share code, notes, and snippets.

@frankfliu
Created January 5, 2022 02:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save frankfliu/a03600323d3c7bf1334d5d8c1f0ceeed to your computer and use it in GitHub Desktop.
Save frankfliu/a03600323d3c7bf1334d5d8c1f0ceeed to your computer and use it in GitHub Desktop.
Load PyTorch model from HTTP URL
import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.IOException;
public class LoadModelFromHttpURL {
public static void main(String[] args) throws IOException, ModelException, TranslateException {
String imageFile = "https://resources.djl.ai/images/kitten.jpg";
Image img = ImageFactory.getInstance().fromUrl(imageFile);
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class)
.optModelUrls("https://resources.djl.ai/test-models/pytorch/resnet18_jit.tar.gz")
.optEngine("PyTorch")
.build();
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Classifications result = predictor.predict(img);
System.out.println(result);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment