Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@ZachisGit
Last active January 6, 2023 13:51
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ZachisGit/71137ab43fddee6b782592e1427e016d to your computer and use it in GitHub Desktop.
Save ZachisGit/71137ab43fddee6b782592e1427e016d to your computer and use it in GitHub Desktop.
Make your model Tensorflow Serving compatible and modify it to accept png encoded images as input and return png encoded images as output.
import requests
import cv2
import json
import base64
# Load an image from file
im = cv2.imread("im.png")
# Base64 encode it
# (you can also load the raw png data from file using open instead of opencv)
encoded = base64.b64encode(cv2.imencode(".png",im)[1].tostring())
# Wrap it in json (tf-serving compatible with instances and b64)
instance =[{"b64":encoded.decode("utf-8")}]
data = json.dumps({"instances": instance})
# Standard port for tf-serving rest interface is 8501
# Request Format: http://domain:[tf-serving-port]/v1/models/[model_name]:predict
resp = requests.post("http://127.0.0.1:8501/v1/models/test_model:predict",data)
# Load base 64 encoded image string from response
prediction = resp.json()["predictions"][0]["b64"]
# Decode base64 to binary string
png_str = base64.b64decode(prediction)
# Write image to disk
with open("pred.png","wb") as file:
file.write(png_str)
''' ExportModel.py - TF-Serving
# Basically we are wrapping your pretrained model
# in a tensorflow serving compatible format.
# This excepts base64 encoded png images and uses
# them as input to your model. Then we convert
# your models output into a png encoded image and
# it gets returned by tensorflow serving base64 encoded.
'''
import tensorflow as tf
# Your model (and other stuff)
def model(input_tensor,reuse=True):
# ...
return output_tensor
# Png encoded image placeholder
input_bytes = tf.placeholder(tf.string,shape=[],name="input_bytes")
input_tensor = png_to_input_tensor(input_bytes)
output_tensor = model(input_tensor)
output_bytes = pred_to_png(output_tensor)
# Export the model in tf serving compatible format
export_for_tf_serving(input_tensor, output_bytes)
# THIS IS WHERE THE MAGIC HAPPENS
def export_for_tf_serving(input_bytes, output_bytes, export_path="path/to/export/the/model/to"):
# Create SavedModelBuilder
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
# Build the signature_def_map
# SUPER IMPORTANT: the name of the output tensor HAS to end in "_bytes"!
# yes I know pretty stupid, but if it doesn't tensorflow serving is
# not going to interpred the png encoded string as binary values.
# Instead of base64 encoding it in your json response it is going to
# write the raw png string (this is lossy and doesn't work)!
output_bytes = tf.expand_dims(output_bytes,0,name="output_bytes")
# Let the saver know what to expect as input and what to return as output
tensor_info_x = tf.saved_model.utils.build_tensor_info(input_bytes)
tensor_info_y = tf.saved_model.utils.build_tensor_info(output_bytes)
# Prediction signature - just write this down, this isn't rocket science
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'images_in': tensor_info_x},
outputs={'output_bytes': tensor_info_y}, # "_bytes" VERY IMPORTANT!
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
# Start a session
with tf.Session() as sess:
# Initialize all variables
sess.run(tf.global_variables_initializer())
# Load your model checkpoint
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state("your/checkpoint/file/or/dir")
if ckpt is not None:
saver.restore(sess,ckpt.model_checkpoint_path)
else:
print "COULD NOT LOAD MODEL!"
exit()
# Build the tensorflow serving compatible model
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature
})
# Save the model - you are all done
builder.save()
# Load png encoded image from string placeholder
def png_to_input_tensor(png_placeholder,color_channels, width=128,height=128,color_channels=3):
input_tensor = tf.reshape(png_placeholder,[])
input_tensor = tf.image.decode_png(input_tensor,channels=color_channels)
# Convert image to float and bring values in the range of 0-1
input_tensor = tf.image.convert_image_dtype(input_tensor,dtype=tf.float32)
# Reshape and add "batch" dimension (this expects a single image NOT in a list)
input_tensor = tf.reshape(input_tensor,[height,width,color_channels])
input_tensor = tf.expand_dims(input_tensor,0)
return input_tensor
# Convert the model output "pred" to a png encoded image
def pred_to_png(pred):
# Make sure the values are in range
output_tensor = tf.clip_by_value(output_tensor,0,1.0)
# This converts the float32 output tensor and brings it into uint8 format
# while also scaling the values from 0-255 and removing the extra dimension
output_tensor = tf.image.convert_image_dtype(output_tensor,dtype=tf.uint8)
output_tensor = tf.squeeze(output_tensor,[0])
# Encode the tensor to a png image and clone it
output_bytes = tf.image.encode_png(output_tensor)
output_bytes = tf.identity(output_bytes,name="output_bytes")
return output_bytes
@christian-steinmeyer
Copy link

Very helpful! Thanks!
Out of curiosity: How did you come along the hint in line 50 - to add "_bytes" to the name of the output to have it encoded automatically?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment