Skip to content

Instantly share code, notes, and snippets.

@eggie5
Created April 8, 2019 03:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eggie5/5812ad0b7a41d55309c995d009a85322 to your computer and use it in GitHub Desktop.
Save eggie5/5812ad0b7a41d55309c995d009a85322 to your computer and use it in GitHub Desktop.
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.TensorFlow;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.SavedModelBundle.Loader;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.TensorInfo;
import org.tensorflow.example.*;
import java.util.Map;
import java.util.Arrays;
public class HelloTensorFlow {
private static void printSignature(SavedModelBundle model) throws Exception {
MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
System.out.println("MODEL SIGNATURE");
System.out.println("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype());
}
int numOutputs = sig.getOutputsCount();
i = 1;
System.out.println("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
System.out.printf(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype());
}
System.out.println("-----------------------------------------------");
}
public static void main(String[] args) throws Exception {
SavedModelBundle bundle = SavedModelBundle.load("./exports/1554425274/", "serve");
System.out.println(bundle);
printSignature(bundle);
Graph graph = bundle.graph();
System.out.println(graph);
Int64List int64List = Int64List.newBuilder().addValue(i).build();
Feature offset = Feature.newBuilder().setInt64List(int64List).build();
Features features = Features.newBuilder()
.putFeature("rid", offset)
.build();
Example example = Example.newBuilder().setFeatures(features).build();
System.out.println(example);
System.out.println(examples);
Tensor<String> inputBatch = Tensors.create(new byte[][] {example.toByteArray()});
System.out.println(inputBatch);
final String input_tensor = "input_example_tensor";
final String output_tensor = "groupwise_dnn_v2/accumulate_scores/truediv";
Tensor<Float> result = bundle.session().runner()
.feed(input_tensor, inputBatch)
.fetch(output_tensor)
.run().get(0)
.expect(Float.class);
System.out.println(Arrays.deepToString(result.copyTo(new float[1][1])));
// float[][] scores = result.copyTo(new float[1][1]);
// for(int i=0; i<scores[0].length; i++){
// System.out.println(scores[i][0]);
// }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment