Skip to content

Instantly share code, notes, and snippets.

@frankfliu
Last active January 5, 2022 02:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save frankfliu/da70a1ac9cf1d423be034cfce673d55e to your computer and use it in GitHub Desktop.
Save frankfliu/da70a1ac9cf1d423be034cfce673d55e to your computer and use it in GitHub Desktop.
Load PyTorch model from HDFS
import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.IOException;
public class LoadModelFromHdfs {
/*
* 1. make sure include dependency in your project: ai.djl.hadoop:hadoop:0.14.0
* 2. Download pytorch resnet18 model from: https://resources.djl.ai/test-models/pytorch/resnet18_jit.tar.gz
* 3. put resnet18_jit.tar.gz file in your HDFS
*/
public static void main(String[] args) throws IOException, ModelException, TranslateException {
String imageFile = "https://resources.djl.ai/images/kitten.jpg";
Image img = ImageFactory.getInstance().fromUrl(imageFile);
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class)
.optModelUrls("hdfs://localhost:57545/resnet18_jit.tar.gz")
.optEngine("PyTorch")
.build();
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Classifications result = predictor.predict(img);
System.out.println(result);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment