Skip to content

Instantly share code, notes, and snippets.

View thierryherrmann's full-sized avatar

Thierry Herrmann thierryherrmann

  • Montreal, Canada
View GitHub Profile
public class TrainAndServeSavedModel {
public static void main(String[] args) throws Exception {
// args[0]: saved model directory
SavedModelBundle savedModel = SavedModelBundle.load(args[0], "serve");
Map<String, SignatureDef> signatureMap = savedModel.metaGraphDef().getSignatureDefMap();
Tensor<TFloat32> inputTensor = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[][] { { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f } }));
Tensor<TFloat32> labelTensor = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[] { 1.0f }));
input_X = graph.get_tensor_by_name('my_train_X:0')
input_y = graph.get_tensor_by_name('my_train_y:0')
output_1 = graph.get_tensor_by_name('StatefulPartitionedCall_1:0')
output_2 = graph.get_tensor_by_name('StatefulPartitionedCall_1:1')
out_val_1, out_val_2 = session.run([output_1, output_2],
feed_dict={input_X: X_train[0:1], input_y: y_train[0:1]})
def train_predict_serve(model_dir):
tf.compat.v1.reset_default_graph()
session = tf.compat.v1.Session()
tf.compat.v1.saved_model.loader.load(session, tags=[tf.saved_model.SERVING], export_dir=model_dir)
graph = session.graph
operations=graph.get_operations()
input_X = graph.get_tensor_by_name('my_train_X:0')
input_y = graph.get_tensor_by_name('my_train_y:0')
output_loss = graph.get_tensor_by_name('StatefulPartitionedCall_1:0')
tensor : model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE (30,)
[ 3.4107354e-05 -6.1855838e-04 -1.6536651e-06 1.6930330e-06
-9.1597438e-05 -3.8669934e-04 5.6557164e-05 7.1755665e-08
-1.2517045e-04 -1.0449246e-03 5.9954262e-05 7.3613039e-05
6.6272205e-06 -5.7156640e-04 5.4908687e-06 -7.3699164e-05
-8.7973615e-04 -3.6661630e-04 5.2946081e-05 -5.7122961e-04
-8.7792240e-04 -4.1600107e-04 -1.2562575e-03 -2.4318745e-06
7.0880642e-06 9.7999236e-06 -6.5629813e-04 1.1121790e-05
-1.3819840e-03 6.7142719e-06]
del new_module
new_module_2 = tf.keras.models.load_model(model_dir)
loss_hist = train_module(new_module_2, train_dataset, valid_dataset)
plot_loss(loss_hist)
save_module(new_module_2, model_dir)
tensor : model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE (30,)
[ 1.3162548e-04 -1.0862495e-03 8.3323405e-04 8.4080239e-06
-1.6426330e-04 -9.0881845e-04 4.7971989e-04 -6.0352772e-06
-9.3550794e-04 -3.1544755e-03 5.4244534e-04 1.0909925e-03
1.3340317e-03 -1.0700974e-03 3.7469756e-04 -1.5879219e-03
-2.1641832e-03 -1.7716389e-03 2.8458738e-04 -6.3899945e-04
-2.9655998e-03 -1.7114554e-03 -3.9885961e-03 2.6567639e-05
-3.6036890e-05 6.1224034e-04 -1.0181948e-03 1.6523007e-04
-4.8340447e-03 1.5539475e-03]
inspect_checkpoint(model_dir + '/variables/variables', print_values=True,
variables=['model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE'])
loss_hist = train_module(new_module, train_dataset, valid_dataset)
plot_loss(loss_hist)
save_module(new_module, model_dir)
type of reloaded module: <class 'tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject'>
type of instantiated module: <class '__main__.CustomModule'>
my_train function: <tensorflow.python.saved_model.function_deserialization.RestoredFunction object at 0x7fc1c5872390>
__call__ function: <tensorflow.python.saved_model.function_deserialization.RestoredFunction object at 0x7fc1c58cbfd0>
sample prediction: [[0.54084957]]
print('type of reloaded module:', type(new_module))
print('type of instantiated module:', type(CustomModule()))
print('my_train function:', new_module.my_train)
print('__call__ function:', new_module.__call__)
# demo a call to the module. (calls the __call__() method)
print('sample prediction: ', new_module(X_train[0:1]).numpy())