Created
February 2, 2018 21:10
-
-
Save Jeraldy/7dd548cbe70329934763bd831f9a092a to your computer and use it in GitHub Desktop.
Load tensorflow model in Java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package tfjava; | |
import java.io.IOException; | |
import java.nio.charset.Charset; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.nio.file.Paths; | |
import java.util.Arrays; | |
import java.util.List; | |
import org.tensorflow.DataType; | |
import org.tensorflow.Graph; | |
import org.tensorflow.Output; | |
import org.tensorflow.Session; | |
import org.tensorflow.Tensor; | |
/** | |
* | |
* @author Deus | |
*/ | |
public class NewClass { | |
private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) { | |
try (Graph g = new Graph()) { | |
g.importGraphDef(graphDef); | |
try (Session s = new Session(g); | |
Tensor result = s.runner().feed("input:0", image).fetch("final_result:0").run().get(0)) { | |
final long[] rshape = result.shape(); | |
if (result.numDimensions() != 2 || rshape[0] != 1) { | |
throw new RuntimeException(String.format("Expected model to produce a [1 N] ", Arrays.toString(rshape))); | |
} | |
int nlabels = (int) rshape[1]; | |
return result.copyTo(new float[1][nlabels])[0]; | |
} | |
} | |
} | |
private void method(String imagepath) { | |
byte[] imageBytes = readAllBytesOrExit(Paths.get(imagepath)); | |
List<String> labels = readAllLinesOrExit(Paths.get("labels.txt")); | |
try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { | |
float[] res = executeInceptionGraph(readAllBytesOrExit(Paths.get("model.pb")), image); | |
int bestLabelIdx = maxIndex(res); | |
System.out.println( | |
String.format("BEST MATCH: %s (%.2f%% likely)", labels.get(bestLabelIdx), | |
res[bestLabelIdx] * 100f)); | |
} | |
} | |
private static int maxIndex(float[] probabilities) { | |
int best = 0; | |
for (int i = 1; i < probabilities.length; ++i) { | |
if (probabilities[i] > probabilities[best]) { | |
best = i; | |
} | |
} | |
return best; | |
} | |
private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) { | |
try (Graph g = new Graph()) { | |
GraphBuilder b = new GraphBuilder(g); | |
final int H = 224; | |
final int W = 224; | |
final float mean = 117f; | |
final float scale = 1f; | |
final Output input = b.constant("input", imageBytes); | |
final Output output | |
= b.div( | |
b.sub( | |
b.resizeBilinear( | |
b.expandDims( | |
b.cast(b.decodeJpeg(input, 3), DataType.FLOAT), | |
b.constant("make_batch", 0)), | |
b.constant("size", new int[]{H, W})), | |
b.constant("mean", mean)), | |
b.constant("scale", scale)); | |
try (Session s = new Session(g)) { | |
return s.runner().fetch(output.op().name()).run().get(0); | |
} | |
} | |
} | |
private static byte[] readAllBytesOrExit(Path path) { | |
try { | |
return Files.readAllBytes(path); | |
} catch (IOException e) { | |
System.err.println("Failed to read [" + path + "]: " + e.getMessage()); | |
System.exit(1); | |
} | |
return null; | |
} | |
private static List<String> readAllLinesOrExit(Path path) { | |
try { | |
return Files.readAllLines(path, Charset.forName("UTF-8")); | |
} catch (IOException e) { | |
System.err.println("Failed to read [" + path + "]: " + e.getMessage()); | |
System.exit(0); | |
} | |
return null; | |
} | |
static class GraphBuilder { | |
GraphBuilder(Graph g) { | |
this.g = g; | |
} | |
Output div(Output x, Output y) { | |
return binaryOp("Div", x, y); | |
} | |
Output sub(Output x, Output y) { | |
return binaryOp("Sub", x, y); | |
} | |
Output resizeBilinear(Output images, Output size) { | |
return binaryOp("ResizeBilinear", images, size); | |
} | |
Output expandDims(Output input, Output dim) { | |
return binaryOp("ExpandDims", input, dim); | |
} | |
Output cast(Output value, DataType dtype) { | |
return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0); | |
} | |
Output decodeJpeg(Output contents, long channels) { | |
return g.opBuilder("DecodeJpeg", "DecodeJpeg") | |
.addInput(contents) | |
.setAttr("channels", channels) | |
.build() | |
.output(0); | |
} | |
Output constant(String name, Object value) { | |
try (Tensor t = Tensor.create(value)) { | |
return g.opBuilder("Const", name) | |
.setAttr("dtype", t.dataType()) | |
.setAttr("value", t) | |
.build() | |
.output(0); | |
} | |
} | |
private Output binaryOp(String type, Output in1, Output in2) { | |
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0); | |
} | |
private Graph g; | |
} | |
public static void main(String args[]) { | |
NewClass n = new NewClass(); | |
//n.method("sun.jpg"); | |
n.method("daisy.jpg"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment