Skip to content

Instantly share code, notes, and snippets.

@jzuern
Last active June 30, 2020 13:12
Show Gist options
  • Save jzuern/8f5db63901e948255d3f35766fa633e1 to your computer and use it in GitHub Desktop.
Save jzuern/8f5db63901e948255d3f35766fa633e1 to your computer and use it in GitHub Desktop.
from tensorflow.python.platform import gfile
import tensorflow as tf
from tensorflow.contrib import tensorrt as trt
graph_filename ='resnetV150_frozen.pb'
f = gfile.FastGFile(graph_filename, 'rb')
# define graph def object
frozen_graph_def = tf.GraphDef()
# store frozen graph from pb file
frozen_graph_def.ParseFromString(f.read())
# Parameters:
output_node_name = "resnet_v1_50/predictions/Reshape_1"
workspace_size = 1 << 30
precision = "FP32"
batch_size = 1
trt_graph = trt.create_inference_graph(
frozen_graph_def,
[output_node_name],
max_batch_size=batch_size,
max_workspace_size_bytes=workspace_size,
precision_mode=precision)
# write modified graph def to disk
graph_filename_converted = 'resnetV150_frozen_tensorrt.pb'
with gfile.FastGFile(graph_filename_converted, 'wb') as s:
s.write(trt_graph.SerializeToString())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment