-
-
Save sd12832/829d214e4f92e6d81c5c248bfe7a264c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
from absl import logging | |
from dc_common.modeling.model_config import ModelConfig, ConfigSection, ConfigParam | |
from tensorflow_core.python.framework import graph_util | |
from trainer import model | |
def get_gcs_output_dir(num_classes: int, | |
dataset_dir: str, | |
model_name: str, | |
global_steps: int, | |
base_dir='gs://dc_models/Models'): | |
class_dir = os.path.join(base_dir, '{}-class'.format(num_classes)) | |
dataset_dir_name = os.path.basename(dataset_dir) | |
if dataset_dir_name.startswith('Set_'): | |
dataset_dir_name = dataset_dir_name[4:] | |
model_dir_name = '{}_{}_{}k'.format(dataset_dir_name, model_name, int(global_steps/1000)) | |
return os.path.join(class_dir, model_dir_name) | |
def export_model(output_dir: str, checkpoint_dir: str, checkpoint_file: str): | |
logging.set_verbosity(logging.INFO) | |
if output_dir.endswith('/'): | |
output_dir = output_dir[:-1] | |
if tf.io.gfile.exists(output_dir): | |
logging.error('Output dir already exists! %s', str(output_dir)) | |
output_file = os.path.join(output_dir, os.path.basename(output_dir) + '.pb') | |
if tf.io.gfile.exists(output_file): | |
logging.error('Output file already exists!') | |
return | |
train_config = ModelConfig(checkpoint_dir, file_must_exists=True) | |
train_config.log_config() | |
model_name = train_config.get_param(ConfigSection.TRAINING, ConfigParam.network) | |
input_layer = train_config.get_param(ConfigSection.TRAINING, ConfigParam.input_layer) | |
output_layer = train_config.get_param(ConfigSection.TRAINING, ConfigParam.output_layer) | |
num_classes = int(train_config.get_param(ConfigSection.TFRECORD, ConfigParam.number_of_classes)) | |
use_tpu = train_config.get_bool(ConfigSection.TRAINING, ConfigParam.use_tpu) | |
image_size = int(train_config.get_param(ConfigSection.TRAINING, ConfigParam.network_image_size)) | |
moving_average_decay = float(train_config.get_param(ConfigSection.TRAINING, ConfigParam.moving_average_decay)) | |
if not checkpoint_file: | |
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) | |
else: | |
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file) | |
if checkpoint_path is None: | |
checkpoint_path = train_config.get_param(ConfigSection.TRAINING, ConfigParam.checkpoint) | |
logging.info('Checkpoint path is %s', checkpoint_path) | |
train_config.set_param(ConfigSection.TRAINING, ConfigParam.checkpoint, checkpoint_path) | |
with tf.Graph().as_default() as graph: | |
network_fn = model.get_network_fn(model_name, num_classes, weight_decay=0.0, use_tpu=use_tpu) | |
placeholder = tf.placeholder(name=input_layer, dtype=tf.float32, shape=[None, image_size, image_size, 3]) | |
network_fn(placeholder) | |
if moving_average_decay: | |
ema = tf.train.ExponentialMovingAverage(moving_average_decay) | |
shadow_vars = ema.variables_to_restore() | |
saver = tf.train.Saver(shadow_vars, reshape=True) | |
else: | |
saver = tf.train.Saver() | |
with tf.Session() as sess: | |
saver.restore(sess, checkpoint_path) | |
logging.info('model restored...') | |
op = tf.get_default_graph().get_operation_by_name("input") | |
logging.info(op) | |
graph_def = graph.as_graph_def() | |
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, [output_layer]) | |
tf.io.gfile.makedirs(output_dir) | |
with tf.io.gfile.GFile(output_file, 'wb') as f: | |
f.write(output_graph_def.SerializeToString()) | |
logging.info('Copying %s to %s\nThis can take a few minutes', checkpoint_path, output_dir) | |
for filepath in tf.io.gfile.glob(checkpoint_path + '*'): | |
filename = os.path.basename(filepath) | |
tf.io.gfile.copy(filepath, os.path.join(output_dir, filename)) | |
train_config.save_file(output_dir) | |
def export_tf_trt_model(model_dir: str, max_batch_size: int = 2): | |
from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt | |
pb_file = os.path.join(model_dir, os.path.basename(model_dir) + '.pb') | |
graph = tf.Graph() | |
output_dir = os.path.join(model_dir, '1') | |
tf.io.gfile.mkdir(output_dir) | |
trt_file = os.path.join(output_dir, 'model.graphdef') | |
train_config = ModelConfig(model_dir, file_must_exists=True) | |
output_node = train_config.get_param(ConfigSection.TRAINING, ConfigParam.output_layer) | |
with graph.as_default(): | |
# First deserialize your frozen graph: | |
with tf.io.gfile.GFile(pb_file, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
# Now you can create a TensorRT inference graph from your | |
# frozen graph: | |
converter = tf_trt.TrtGraphConverter( | |
input_graph_def=graph_def, | |
nodes_blacklist=[output_node], | |
max_batch_size=max_batch_size, | |
max_workspace_size_bytes=1 << 30, | |
precision_mode="fp32") | |
trt_graph = converter.convert() | |
for node in trt_graph.node: | |
logging.info(f'Node: {node.op}') | |
logging.info('Writing %s', trt_file) | |
with tf.io.gfile.GFile(trt_file, 'wb') as f: | |
f.write(trt_graph.SerializeToString()) | |
num_classes = int(train_config.get_param(ConfigSection.TFRECORD, ConfigParam.number_of_classes)) | |
write_trt_config(model_dir=model_dir, | |
platform='tensorflow_graphdef', | |
output_node=output_node, | |
num_classes=num_classes, | |
channel_first=False, | |
max_batch_size=max_batch_size) | |
def export_plan_model(model_dir: str, max_batch_size: int = 2): | |
import tensorrt as trt | |
import uff | |
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | |
pb_file = os.path.join(model_dir, os.path.basename(model_dir) + '.pb') | |
uff_file = os.path.join(model_dir, os.path.basename(model_dir) + '.uff') | |
output_dir = os.path.join(model_dir, '1') | |
tf.gfile.MkDir(output_dir) | |
plan_file = os.path.join(output_dir, 'model.plan') | |
train_config = ModelConfig(model_dir, file_must_exists=True) | |
output_node = train_config.get_param(ConfigSection.TRAINING, ConfigParam.output_layer) | |
image_size = int(train_config.get_param(ConfigSection.TRAINING, ConfigParam.network_image_size)) | |
logging.info('Writing %s', uff_file) | |
uff.from_tensorflow_frozen_model( | |
pb_file, | |
[output_node], | |
output_filename=str(uff_file) | |
) | |
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser: | |
parser.register_input("input", (image_size, image_size, 3), trt.UffInputOrder.NHWC) | |
parser.register_output(output_node) | |
parser.parse(uff_file, network) | |
builder.max_batch_size = max_batch_size | |
builder.max_workspace_size = 1 << 30 | |
logging.info('Writing %s', plan_file) | |
with builder.build_cuda_engine(network) as engine: | |
with open(plan_file, 'wb') as f: | |
f.write(engine.serialize()) | |
num_classes = int(train_config.get_param(ConfigSection.TFRECORD, ConfigParam.number_of_classes)) | |
write_trt_config(model_dir=model_dir, | |
platform='tensorrt_plan', | |
output_node=output_node, | |
num_classes=num_classes, | |
channel_first=False, | |
max_batch_size=max_batch_size) | |
def write_trt_config(model_dir: str, | |
platform: str, | |
output_node: str, | |
num_classes: int, | |
channel_first: bool = False, | |
max_batch_size: int = 2): | |
from tensorrtserver.api import model_config_pb2 as trt_model_config | |
train_config = ModelConfig(model_dir, file_must_exists=True) | |
image_size = int(train_config.get_param(ConfigSection.TRAINING, ConfigParam.network_image_size)) | |
mc = trt_model_config.ModelConfig() | |
mc.name = os.path.basename(model_dir) | |
mc.platform = platform | |
mc.max_batch_size = max_batch_size | |
mc_input = trt_model_config.ModelInput() | |
mc_input.name = 'input' | |
mc_input.data_type = trt_model_config.TYPE_FP32 | |
if channel_first: | |
mc_input.format = trt_model_config.ModelInput.FORMAT_NCHW | |
mc_input.dims.extend([3, image_size, image_size]) | |
else: | |
mc_input.format = trt_model_config.ModelInput.FORMAT_NHWC | |
mc_input.dims.extend([image_size, image_size, 3]) | |
mc.input.extend([mc_input]) | |
mc_output = trt_model_config.ModelOutput() | |
mc_output.name = output_node | |
mc_output.data_type = trt_model_config.TYPE_FP32 | |
if platform == 'tensorrt_plan': | |
mc_output.dims.extend([1, num_classes, 1]) | |
else: | |
mc_output.dims.extend([num_classes]) | |
mc.output.extend([mc_output]) | |
mc_instance_grp = trt_model_config.ModelInstanceGroup() | |
mc_instance_grp.kind = trt_model_config.ModelInstanceGroup.KIND_GPU | |
mc_instance_grp.count = 3 | |
mc.instance_group.extend([mc_instance_grp]) | |
output_file = os.path.join(model_dir, 'config.pbtxt') | |
logging.info('Writing %s', output_file) | |
with tf.io.gfile.GFile(output_file, 'w') as f: | |
f.write(str(mc)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment