Skip to content

Instantly share code, notes, and snippets.

@rsandler00
Created November 15, 2019 01:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rsandler00/10d120bd28c249a60f77f00e4b4a0618 to your computer and use it in GitHub Desktop.
Save rsandler00/10d120bd28c249a60f77f00e4b4a0618 to your computer and use it in GitHub Desktop.
saved_model_to_trt
import numpy as np
import tensorflow as tf
from ipdb import set_trace
from tensorflow.python.compiler.tensorrt import trt_convert as trt
tf.enable_eager_execution()
INPUT_SAVED_MODEL_DIR = 'tst'
OUTPUT_SAVED_MODEL_DIR = 'tst_out'
def load_run_savedmodel():
mod = tf.saved_model.load_v2('tst')
inp = tf.convert_to_tensor(np.ones((32, 18, 63, 8)), dtype=tf.float32)
out = mod(inp)
def convert_savedmodel():
params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode='FP16',
is_dynamic_op=True)
converter = trt.TrtGraphConverterV2(input_saved_model_dir=INPUT_SAVED_MODEL_DIR,
conversion_params=params)
converter.convert()
converter.save(OUTPUT_SAVED_MODEL_DIR)
load_infer_savedmodel()
return None
def load_infer_savedmodel():
saved_model_loaded = tf.saved_model.load_v2(OUTPUT_SAVED_MODEL_DIR, tags=[tf.saved_model.tag_constants.SERVING])
graph_func = saved_model_loaded.signatures[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
frozen_func = trt.convert_to_constants.convert_variables_to_constants_v2(graph_func)
def wrap_func(*args, **kwargs):
# Assumes frozen_func has one output tensor
return frozen_func(*args, **kwargs)[0]
input_data = tf.convert_to_tensor(np.ones((2, 18, 63, 8)), dtype=tf.float32)
output = wrap_func(input_data).numpy()
if __name__ == '__main__':
convert_savedmodel()
# load_infer_savedmodel()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment