Skip to content

Instantly share code, notes, and snippets.

@Jeraldy
Created February 2, 2018 21:10
Show Gist options
  • Save Jeraldy/7dd548cbe70329934763bd831f9a092a to your computer and use it in GitHub Desktop.
Save Jeraldy/7dd548cbe70329934763bd831f9a092a to your computer and use it in GitHub Desktop.
Load tensorflow model in Java
package tfjava;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
/**
*
* @author Deus
*/
public class NewClass {
private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
Tensor result = s.runner().feed("input:0", image).fetch("final_result:0").run().get(0)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(String.format("Expected model to produce a [1 N] ", Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
}
}
private void method(String imagepath) {
byte[] imageBytes = readAllBytesOrExit(Paths.get(imagepath));
List<String> labels = readAllLinesOrExit(Paths.get("labels.txt"));
try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
float[] res = executeInceptionGraph(readAllBytesOrExit(Paths.get("model.pb")), image);
int bestLabelIdx = maxIndex(res);
System.out.println(
String.format("BEST MATCH: %s (%.2f%% likely)", labels.get(bestLabelIdx),
res[bestLabelIdx] * 100f));
}
}
private static int maxIndex(float[] probabilities) {
int best = 0;
for (int i = 1; i < probabilities.length; ++i) {
if (probabilities[i] > probabilities[best]) {
best = i;
}
}
return best;
}
private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
try (Graph g = new Graph()) {
GraphBuilder b = new GraphBuilder(g);
final int H = 224;
final int W = 224;
final float mean = 117f;
final float scale = 1f;
final Output input = b.constant("input", imageBytes);
final Output output
= b.div(
b.sub(
b.resizeBilinear(
b.expandDims(
b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),
b.constant("make_batch", 0)),
b.constant("size", new int[]{H, W})),
b.constant("mean", mean)),
b.constant("scale", scale));
try (Session s = new Session(g)) {
return s.runner().fetch(output.op().name()).run().get(0);
}
}
}
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
private static List<String> readAllLinesOrExit(Path path) {
try {
return Files.readAllLines(path, Charset.forName("UTF-8"));
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(0);
}
return null;
}
static class GraphBuilder {
GraphBuilder(Graph g) {
this.g = g;
}
Output div(Output x, Output y) {
return binaryOp("Div", x, y);
}
Output sub(Output x, Output y) {
return binaryOp("Sub", x, y);
}
Output resizeBilinear(Output images, Output size) {
return binaryOp("ResizeBilinear", images, size);
}
Output expandDims(Output input, Output dim) {
return binaryOp("ExpandDims", input, dim);
}
Output cast(Output value, DataType dtype) {
return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);
}
Output decodeJpeg(Output contents, long channels) {
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
.addInput(contents)
.setAttr("channels", channels)
.build()
.output(0);
}
Output constant(String name, Object value) {
try (Tensor t = Tensor.create(value)) {
return g.opBuilder("Const", name)
.setAttr("dtype", t.dataType())
.setAttr("value", t)
.build()
.output(0);
}
}
private Output binaryOp(String type, Output in1, Output in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);
}
private Graph g;
}
public static void main(String args[]) {
NewClass n = new NewClass();
//n.method("sun.jpg");
n.method("daisy.jpg");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment