Skip to content

Instantly share code, notes, and snippets.

@sd12832
Created May 10, 2021 23:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sd12832/829d214e4f92e6d81c5c248bfe7a264c to your computer and use it in GitHub Desktop.
Save sd12832/829d214e4f92e6d81c5c248bfe7a264c to your computer and use it in GitHub Desktop.
import os
import tensorflow as tf
from absl import logging
from dc_common.modeling.model_config import ModelConfig, ConfigSection, ConfigParam
from tensorflow_core.python.framework import graph_util
from trainer import model
def get_gcs_output_dir(num_classes: int,
dataset_dir: str,
model_name: str,
global_steps: int,
base_dir='gs://dc_models/Models'):
class_dir = os.path.join(base_dir, '{}-class'.format(num_classes))
dataset_dir_name = os.path.basename(dataset_dir)
if dataset_dir_name.startswith('Set_'):
dataset_dir_name = dataset_dir_name[4:]
model_dir_name = '{}_{}_{}k'.format(dataset_dir_name, model_name, int(global_steps/1000))
return os.path.join(class_dir, model_dir_name)
def export_model(output_dir: str, checkpoint_dir: str, checkpoint_file: str):
logging.set_verbosity(logging.INFO)
if output_dir.endswith('/'):
output_dir = output_dir[:-1]
if tf.io.gfile.exists(output_dir):
logging.error('Output dir already exists! %s', str(output_dir))
output_file = os.path.join(output_dir, os.path.basename(output_dir) + '.pb')
if tf.io.gfile.exists(output_file):
logging.error('Output file already exists!')
return
train_config = ModelConfig(checkpoint_dir, file_must_exists=True)
train_config.log_config()
model_name = train_config.get_param(ConfigSection.TRAINING, ConfigParam.network)
input_layer = train_config.get_param(ConfigSection.TRAINING, ConfigParam.input_layer)
output_layer = train_config.get_param(ConfigSection.TRAINING, ConfigParam.output_layer)
num_classes = int(train_config.get_param(ConfigSection.TFRECORD, ConfigParam.number_of_classes))
use_tpu = train_config.get_bool(ConfigSection.TRAINING, ConfigParam.use_tpu)
image_size = int(train_config.get_param(ConfigSection.TRAINING, ConfigParam.network_image_size))
moving_average_decay = float(train_config.get_param(ConfigSection.TRAINING, ConfigParam.moving_average_decay))
if not checkpoint_file:
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
else:
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
if checkpoint_path is None:
checkpoint_path = train_config.get_param(ConfigSection.TRAINING, ConfigParam.checkpoint)
logging.info('Checkpoint path is %s', checkpoint_path)
train_config.set_param(ConfigSection.TRAINING, ConfigParam.checkpoint, checkpoint_path)
with tf.Graph().as_default() as graph:
network_fn = model.get_network_fn(model_name, num_classes, weight_decay=0.0, use_tpu=use_tpu)
placeholder = tf.placeholder(name=input_layer, dtype=tf.float32, shape=[None, image_size, image_size, 3])
network_fn(placeholder)
if moving_average_decay:
ema = tf.train.ExponentialMovingAverage(moving_average_decay)
shadow_vars = ema.variables_to_restore()
saver = tf.train.Saver(shadow_vars, reshape=True)
else:
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, checkpoint_path)
logging.info('model restored...')
op = tf.get_default_graph().get_operation_by_name("input")
logging.info(op)
graph_def = graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, [output_layer])
tf.io.gfile.makedirs(output_dir)
with tf.io.gfile.GFile(output_file, 'wb') as f:
f.write(output_graph_def.SerializeToString())
logging.info('Copying %s to %s\nThis can take a few minutes', checkpoint_path, output_dir)
for filepath in tf.io.gfile.glob(checkpoint_path + '*'):
filename = os.path.basename(filepath)
tf.io.gfile.copy(filepath, os.path.join(output_dir, filename))
train_config.save_file(output_dir)
def export_tf_trt_model(model_dir: str, max_batch_size: int = 2):
from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt
pb_file = os.path.join(model_dir, os.path.basename(model_dir) + '.pb')
graph = tf.Graph()
output_dir = os.path.join(model_dir, '1')
tf.io.gfile.mkdir(output_dir)
trt_file = os.path.join(output_dir, 'model.graphdef')
train_config = ModelConfig(model_dir, file_must_exists=True)
output_node = train_config.get_param(ConfigSection.TRAINING, ConfigParam.output_layer)
with graph.as_default():
# First deserialize your frozen graph:
with tf.io.gfile.GFile(pb_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Now you can create a TensorRT inference graph from your
# frozen graph:
converter = tf_trt.TrtGraphConverter(
input_graph_def=graph_def,
nodes_blacklist=[output_node],
max_batch_size=max_batch_size,
max_workspace_size_bytes=1 << 30,
precision_mode="fp32")
trt_graph = converter.convert()
for node in trt_graph.node:
logging.info(f'Node: {node.op}')
logging.info('Writing %s', trt_file)
with tf.io.gfile.GFile(trt_file, 'wb') as f:
f.write(trt_graph.SerializeToString())
num_classes = int(train_config.get_param(ConfigSection.TFRECORD, ConfigParam.number_of_classes))
write_trt_config(model_dir=model_dir,
platform='tensorflow_graphdef',
output_node=output_node,
num_classes=num_classes,
channel_first=False,
max_batch_size=max_batch_size)
def export_plan_model(model_dir: str, max_batch_size: int = 2):
import tensorrt as trt
import uff
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
pb_file = os.path.join(model_dir, os.path.basename(model_dir) + '.pb')
uff_file = os.path.join(model_dir, os.path.basename(model_dir) + '.uff')
output_dir = os.path.join(model_dir, '1')
tf.gfile.MkDir(output_dir)
plan_file = os.path.join(output_dir, 'model.plan')
train_config = ModelConfig(model_dir, file_must_exists=True)
output_node = train_config.get_param(ConfigSection.TRAINING, ConfigParam.output_layer)
image_size = int(train_config.get_param(ConfigSection.TRAINING, ConfigParam.network_image_size))
logging.info('Writing %s', uff_file)
uff.from_tensorflow_frozen_model(
pb_file,
[output_node],
output_filename=str(uff_file)
)
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
parser.register_input("input", (image_size, image_size, 3), trt.UffInputOrder.NHWC)
parser.register_output(output_node)
parser.parse(uff_file, network)
builder.max_batch_size = max_batch_size
builder.max_workspace_size = 1 << 30
logging.info('Writing %s', plan_file)
with builder.build_cuda_engine(network) as engine:
with open(plan_file, 'wb') as f:
f.write(engine.serialize())
num_classes = int(train_config.get_param(ConfigSection.TFRECORD, ConfigParam.number_of_classes))
write_trt_config(model_dir=model_dir,
platform='tensorrt_plan',
output_node=output_node,
num_classes=num_classes,
channel_first=False,
max_batch_size=max_batch_size)
def write_trt_config(model_dir: str,
platform: str,
output_node: str,
num_classes: int,
channel_first: bool = False,
max_batch_size: int = 2):
from tensorrtserver.api import model_config_pb2 as trt_model_config
train_config = ModelConfig(model_dir, file_must_exists=True)
image_size = int(train_config.get_param(ConfigSection.TRAINING, ConfigParam.network_image_size))
mc = trt_model_config.ModelConfig()
mc.name = os.path.basename(model_dir)
mc.platform = platform
mc.max_batch_size = max_batch_size
mc_input = trt_model_config.ModelInput()
mc_input.name = 'input'
mc_input.data_type = trt_model_config.TYPE_FP32
if channel_first:
mc_input.format = trt_model_config.ModelInput.FORMAT_NCHW
mc_input.dims.extend([3, image_size, image_size])
else:
mc_input.format = trt_model_config.ModelInput.FORMAT_NHWC
mc_input.dims.extend([image_size, image_size, 3])
mc.input.extend([mc_input])
mc_output = trt_model_config.ModelOutput()
mc_output.name = output_node
mc_output.data_type = trt_model_config.TYPE_FP32
if platform == 'tensorrt_plan':
mc_output.dims.extend([1, num_classes, 1])
else:
mc_output.dims.extend([num_classes])
mc.output.extend([mc_output])
mc_instance_grp = trt_model_config.ModelInstanceGroup()
mc_instance_grp.kind = trt_model_config.ModelInstanceGroup.KIND_GPU
mc_instance_grp.count = 3
mc.instance_group.extend([mc_instance_grp])
output_file = os.path.join(model_dir, 'config.pbtxt')
logging.info('Writing %s', output_file)
with tf.io.gfile.GFile(output_file, 'w') as f:
f.write(str(mc))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment