Skip to content

Instantly share code, notes, and snippets.

@thierryherrmann
Created August 3, 2020 00:08
Show Gist options
  • Save thierryherrmann/31a5d6922bd54bb7ace473df96f00e90 to your computer and use it in GitHub Desktop.
Save thierryherrmann/31a5d6922bd54bb7ace473df96f00e90 to your computer and use it in GitHub Desktop.
public class TrainAndServeSavedModel {
public static void main(String[] args) throws Exception {
// args[0]: saved model directory
SavedModelBundle savedModel = SavedModelBundle.load(args[0], "serve");
Map<String, SignatureDef> signatureMap = savedModel.metaGraphDef().getSignatureDefMap();
Tensor<TFloat32> inputTensor = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[][] { { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f } }));
Tensor<TFloat32> labelTensor = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[] { 1.0f }));
Session session = savedModel.session();
train(session, signatureMap.get("my_train"), inputTensor, labelTensor);
serve(session, signatureMap.get("my_serve"), inputTensor);
session.close();
}
private static void serve(Session session, SignatureDef modelInfo, Tensor<TFloat32> inputTensor) {
Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
TensorInfo inputX = inputs.get("x");
TensorInfo outputPred = modelInfo.getOutputsMap().get("output_0");
Session.Runner runner = session.runner();
runner.feed(inputX.getName(), inputTensor);
TFloat32 data = runner.fetch(outputPred.getName()).run().get(0).expect(TFloat32.DTYPE).data();
data.scalars().forEachIndexed((i, s) -> {
System.out.println("prediction: " + s.getFloat());
});
}
private static void train(Session session, SignatureDef modelInfo, Tensor<TFloat32> inputTensor, Tensor<TFloat32> labelTensor) {
Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
TensorInfo inputX = inputs.get("X");
TensorInfo inputY = inputs.get("y");
TensorInfo outputLoss = modelInfo.getOutputsMap().get("output_0");
Session.Runner runner = session.runner();
runner.feed(inputX.getName(), inputTensor).feed(inputY.getName(), labelTensor);
Tensor<TFloat32> loss = runner.fetch(outputLoss.getName()).run().get(0).expect(TFloat32.DTYPE);
System.out.println("loss after training: " + loss.data().getFloat());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment