Skip to content

Instantly share code, notes, and snippets.

@sd12832
Created May 10, 2021 23:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sd12832/ebc45aae468bf9080b6e1203a6148c62 to your computer and use it in GitHub Desktop.
Save sd12832/ebc45aae468bf9080b6e1203a6148c62 to your computer and use it in GitHub Desktop.
import os
import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.compiler.tensorrt import trt_convert as trt
import tritonclient.grpc.model_config_pb2 as trt_model_config
from absl import logging
from typing import Optional, List
from dc_common.modeling.model_config import ModelConfig, ConfigSection, ConfigParam
from trainer import model
import numpy as np
def get_gcs_output_dir(num_classes: int,
dataset_dir: str,
model_name: str,
global_steps: int,
base_dir='gs://dc_models/Models'):
class_dir = os.path.join(base_dir, '{}-class'.format(num_classes))
dataset_dir_name = os.path.basename(dataset_dir)
if dataset_dir_name.startswith('Set_'):
dataset_dir_name = dataset_dir_name[4:]
model_dir_name = '{}_{}_{}k'.format(dataset_dir_name, model_name, int(global_steps/1000))
return os.path.join(class_dir, model_dir_name)
def export_model(output_dir: str, checkpoint_dir: str, checkpoint_file: str):
logging.set_verbosity(logging.INFO)
if output_dir.endswith('/'):
output_dir = output_dir[:-1]
if tf.io.gfile.exists(output_dir):
logging.error('Output dir already exists! %s', str(output_dir))
output_file = os.path.join(output_dir, os.path.basename(output_dir) + '.pb')
if tf.io.gfile.exists(output_file):
logging.error('Output file already exists!')
return
train_config = ModelConfig(checkpoint_dir, file_must_exists=True)
train_config.log_config()
model_name = train_config.get_param(ConfigSection.TRAINING, ConfigParam.network)
input_layer = train_config.get_param(ConfigSection.TRAINING, ConfigParam.input_layer)
output_layer = train_config.get_param(ConfigSection.TRAINING, ConfigParam.output_layer)
num_classes = int(train_config.get_param(ConfigSection.TFRECORD, ConfigParam.number_of_classes))
use_tpu = train_config.get_bool(ConfigSection.TRAINING, ConfigParam.use_tpu)
image_size = int(train_config.get_param(ConfigSection.TRAINING, ConfigParam.network_image_size))
moving_average_decay = float(train_config.get_param(ConfigSection.TRAINING, ConfigParam.moving_average_decay))
if not checkpoint_file:
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
else:
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
if checkpoint_path is None:
checkpoint_path = train_config.get_param(ConfigSection.TRAINING, ConfigParam.checkpoint)
logging.info('Checkpoint path is %s', checkpoint_path)
train_config.set_param(ConfigSection.TRAINING, ConfigParam.checkpoint, checkpoint_path)
with tf.Graph().as_default() as graph:
network_fn = model.get_network_fn(model_name, num_classes, weight_decay=0.0, use_tpu=use_tpu)
placeholder = tf.compat.v1.placeholder(name=input_layer, dtype=tf.float32,
shape=[None, image_size, image_size, 3])
network_fn(placeholder)
if moving_average_decay:
ema = tf.train.ExponentialMovingAverage(moving_average_decay)
shadow_vars = ema.variables_to_restore()
saver = tf.compat.v1.train.Saver(shadow_vars, reshape=True)
else:
saver = tf.compat.v1.train.Saver()
with tf.compat.v1.Session() as sess:
saver.restore(sess, checkpoint_path)
logging.info('model restored...')
op = tf.compat.v1.get_default_graph().get_operation_by_name("input")
logging.info(op)
graph_def = graph.as_graph_def()
output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, graph_def, [output_layer])
tf.io.gfile.makedirs(output_dir)
with tf.io.gfile.GFile(output_file, 'wb') as f:
f.write(output_graph_def.SerializeToString())
logging.info('Copying %s to %s\nThis can take a few minutes', checkpoint_path, output_dir)
for filepath in tf.io.gfile.glob(checkpoint_path + '*'):
filename = os.path.basename(filepath)
tf.io.gfile.copy(filepath, os.path.join(output_dir, filename))
train_config.save_file(output_dir)
def export_graph_def_to_trt_model(model_dir: str, max_batch_size: int = 2):
# Make the model repository structure
def check_and_make_dirs(dirs: List[str]):
for dir in dirs:
if not os.path.isdir(dir):
os.mkdir(dir)
repo_dir = os.path.join(os.getcwd(), 'models')
new_model_dir = os.path.join(repo_dir, model_dir.split('/')[0])
version_dir = os.path.join(new_model_dir, '1')
saved_model_dir = os.path.join(version_dir, 'model.savedmodel')
check_and_make_dirs([repo_dir, new_model_dir, version_dir, saved_model_dir])
logging.info(f'SAVED MODEL DIR IS {[model_dir, repo_dir, new_model_dir, version_dir, saved_model_dir]}')
train_config = ModelConfig(model_dir, file_must_exists=True)
output_node = train_config.get_param(ConfigSection.TRAINING, ConfigParam.output_layer)
pb_file = os.path.join(model_dir, model_dir + '.pb')
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS
conversion_params = conversion_params._replace(precision_mode="FP32")
# conversion_params = conversion_params._replace(max_batch_size=2)
num_classes = int(train_config.get_param(ConfigSection.TFRECORD, ConfigParam.number_of_classes))
convert_from_graph_to_saved_model(pb_file, f'{model_dir}/saved', output_node)
converter = trt.TrtGraphConverterV2(input_saved_model_dir=f'{model_dir}/saved', conversion_params=conversion_params)
converter.convert()
def input_fn():
inp = np.random.normal(size=(1, 299, 299, 3)).astype(np.float32)
yield [inp]
# converter.build(input_fn=input_fn)
# for node in trt_graph.node:
# logging.info(f'Node: {node.op}')
converter.save(saved_model_dir)
write_trt_config(model_dir=model_dir, new_model_dir=new_model_dir, platform='tensorflow_savedmodel',
output_node=output_node, num_classes=num_classes, channel_first=False,
max_batch_size=max_batch_size)
def convert_from_graph_to_saved_model(input_pb: str, output_dir: str, output_node: str):
"""
Function that converts the Graph Def .pb (primarily used with TF1) to a Saved Model being used by TF2. Necessary
since the TF2 Tensor RT (Triton Server) only accepts TF2 based saved models, but Models with TF1 Slim were
historically saved as Graph Def protobufs.
"""
if os.path.exists(output_dir):
os.system(f"rm -rf {output_dir}")
builder = tf.compat.v1.saved_model.Builder(output_dir)
with tf.compat.v1.gfile.GFile(input_pb, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
sigs = {}
with tf.compat.v1.Session(graph=tf.compat.v1.Graph()) as sess:
tf.compat.v1.import_graph_def(graph_def, name="")
g = tf.compat.v1.get_default_graph()
nodes = [f'{n.name}:0' for n in graph_def.node]
for node in nodes:
logging.info(f'Node {node}')
inp = g.get_tensor_by_name(nodes[0])
logging.info(f'INPUT Is {inp}')
out = g.get_tensor_by_name(f'{output_node}:0')
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.compat.v1.saved_model.signature_def_utils.predict_signature_def(
{"in": inp}, {"out": out})
builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
signature_def_map=sigs)
builder.save()
def write_trt_config(model_dir: str,
new_model_dir: str,
platform: str,
output_node: str,
num_classes: int,
channel_first: bool = False,
max_batch_size: int = 2):
train_config = ModelConfig(model_dir, file_must_exists=True)
image_size = int(train_config.get_param(ConfigSection.TRAINING, ConfigParam.network_image_size))
mc = trt_model_config.ModelConfig()
mc.name = os.path.basename(model_dir)
mc.platform = platform
mc.max_batch_size = max_batch_size
mc_input = trt_model_config.ModelInput()
mc_input.name = 'input'
mc_input.data_type = trt_model_config.TYPE_FP32
if channel_first:
mc_input.format = trt_model_config.ModelInput.FORMAT_NCHW
mc_input.dims.extend([3, image_size, image_size])
else:
mc_input.format = trt_model_config.ModelInput.FORMAT_NHWC
mc_input.dims.extend([image_size, image_size, 3])
mc.input.extend([mc_input])
mc_output = trt_model_config.ModelOutput()
mc_output.name = 'out'
mc_output.data_type = trt_model_config.TYPE_FP32
if platform == 'tensorrt_plan':
mc_output.dims.extend([1, num_classes, 1])
else:
mc_output.dims.extend([num_classes])
mc.output.extend([mc_output])
mc_instance_grp = trt_model_config.ModelInstanceGroup()
mc_instance_grp.kind = trt_model_config.ModelInstanceGroup.KIND_GPU
mc_instance_grp.count = 3
mc.instance_group.extend([mc_instance_grp])
output_file = os.path.join(new_model_dir, 'config.pbtxt')
logging.info('Writing %s', output_file)
with tf.io.gfile.GFile(output_file, 'w') as f:
f.write(str(mc))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment