Skip to content

Instantly share code, notes, and snippets.

@thierryherrmann
Created August 3, 2020 00:00
Show Gist options
  • Save thierryherrmann/740878d7be2387aacd4a968544229466 to your computer and use it in GitHub Desktop.
Save thierryherrmann/740878d7be2387aacd4a968544229466 to your computer and use it in GitHub Desktop.
def train_predict_serve(model_dir):
tf.compat.v1.reset_default_graph()
session = tf.compat.v1.Session()
tf.compat.v1.saved_model.loader.load(session, tags=[tf.saved_model.SERVING], export_dir=model_dir)
graph = session.graph
operations=graph.get_operations()
input_X = graph.get_tensor_by_name('my_train_X:0')
input_y = graph.get_tensor_by_name('my_train_y:0')
output_loss = graph.get_tensor_by_name('StatefulPartitionedCall_1:0')
loss = session.run(output_loss, feed_dict={input_X: X_train[0:batch_size],
input_y: y_train[0:batch_size]})
print('loss:', loss)
input_X_serve = graph.get_tensor_by_name('my_serve_X:0')
output_pred = graph.get_tensor_by_name('StatefulPartitionedCall:0')
pred = session.run(output_pred, feed_dict={input_X_serve: X_train[0:1]})
print('prediction:', pred)
saver_filename = graph.get_tensor_by_name('saver_filename:0')
save_op = graph.get_tensor_by_name('StatefulPartitionedCall_2:0')
session.run(save_op, feed_dict={saver_filename: model_dir + '/variables/variables'})
print('checkpoint saved')
session.close()
train_predict_serve(model_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment