Skip to content

Instantly share code, notes, and snippets.

@steven-mi
Created October 3, 2020 09:44
Show Gist options
  • Save steven-mi/c4578d312546bd2414b152331100ac26 to your computer and use it in GitHub Desktop.
Save steven-mi/c4578d312546bd2414b152331100ac26 to your computer and use it in GitHub Desktop.
Running a SavedModel in Java
package org.example;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.IntNdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.types.TInt32;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
public class Inference {
public static void main(String[] arg) throws URISyntaxException {
// get path to model folder in ressources
URL modelURL = Inference.class.getClassLoader().getResource("tfmodel");
String modelPath = Paths.get(modelURL.toURI()).toString();
// load saved model
SavedModelBundle model = SavedModelBundle.load(modelPath, "serve");
SignatureDef sig = model.metaGraphDef().getSignatureDefMap().get("predict");
IntNdArray input_matrix = NdArrays.ofInts(Shape.of(1, 9));
input_matrix.set(NdArrays.vectorOf(1, 2, 3, 5, 7, 21, 23, 43, 123), 0);
Tensor<TInt32> input_tensor = TInt32.tensorOf(input_matrix);
System.out.println(input_tensor.shape());
Map<String, Tensor<?>> feed_dict = new HashMap<>();
feed_dict.put("context", input_tensor);
System.out.println(model.function("predict").call(feed_dict));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment