-
-
Save sd12832/ebc45aae468bf9080b6e1203a6148c62 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
from tensorflow.python.saved_model import signature_constants | |
from tensorflow.python.saved_model import tag_constants | |
from tensorflow.python.compiler.tensorrt import trt_convert as trt | |
import tritonclient.grpc.model_config_pb2 as trt_model_config | |
from absl import logging | |
from typing import Optional, List | |
from dc_common.modeling.model_config import ModelConfig, ConfigSection, ConfigParam | |
from trainer import model | |
import numpy as np | |
def get_gcs_output_dir(num_classes: int, | |
dataset_dir: str, | |
model_name: str, | |
global_steps: int, | |
base_dir='gs://dc_models/Models'): | |
class_dir = os.path.join(base_dir, '{}-class'.format(num_classes)) | |
dataset_dir_name = os.path.basename(dataset_dir) | |
if dataset_dir_name.startswith('Set_'): | |
dataset_dir_name = dataset_dir_name[4:] | |
model_dir_name = '{}_{}_{}k'.format(dataset_dir_name, model_name, int(global_steps/1000)) | |
return os.path.join(class_dir, model_dir_name) | |
def export_model(output_dir: str, checkpoint_dir: str, checkpoint_file: str): | |
logging.set_verbosity(logging.INFO) | |
if output_dir.endswith('/'): | |
output_dir = output_dir[:-1] | |
if tf.io.gfile.exists(output_dir): | |
logging.error('Output dir already exists! %s', str(output_dir)) | |
output_file = os.path.join(output_dir, os.path.basename(output_dir) + '.pb') | |
if tf.io.gfile.exists(output_file): | |
logging.error('Output file already exists!') | |
return | |
train_config = ModelConfig(checkpoint_dir, file_must_exists=True) | |
train_config.log_config() | |
model_name = train_config.get_param(ConfigSection.TRAINING, ConfigParam.network) | |
input_layer = train_config.get_param(ConfigSection.TRAINING, ConfigParam.input_layer) | |
output_layer = train_config.get_param(ConfigSection.TRAINING, ConfigParam.output_layer) | |
num_classes = int(train_config.get_param(ConfigSection.TFRECORD, ConfigParam.number_of_classes)) | |
use_tpu = train_config.get_bool(ConfigSection.TRAINING, ConfigParam.use_tpu) | |
image_size = int(train_config.get_param(ConfigSection.TRAINING, ConfigParam.network_image_size)) | |
moving_average_decay = float(train_config.get_param(ConfigSection.TRAINING, ConfigParam.moving_average_decay)) | |
if not checkpoint_file: | |
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) | |
else: | |
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file) | |
if checkpoint_path is None: | |
checkpoint_path = train_config.get_param(ConfigSection.TRAINING, ConfigParam.checkpoint) | |
logging.info('Checkpoint path is %s', checkpoint_path) | |
train_config.set_param(ConfigSection.TRAINING, ConfigParam.checkpoint, checkpoint_path) | |
with tf.Graph().as_default() as graph: | |
network_fn = model.get_network_fn(model_name, num_classes, weight_decay=0.0, use_tpu=use_tpu) | |
placeholder = tf.compat.v1.placeholder(name=input_layer, dtype=tf.float32, | |
shape=[None, image_size, image_size, 3]) | |
network_fn(placeholder) | |
if moving_average_decay: | |
ema = tf.train.ExponentialMovingAverage(moving_average_decay) | |
shadow_vars = ema.variables_to_restore() | |
saver = tf.compat.v1.train.Saver(shadow_vars, reshape=True) | |
else: | |
saver = tf.compat.v1.train.Saver() | |
with tf.compat.v1.Session() as sess: | |
saver.restore(sess, checkpoint_path) | |
logging.info('model restored...') | |
op = tf.compat.v1.get_default_graph().get_operation_by_name("input") | |
logging.info(op) | |
graph_def = graph.as_graph_def() | |
output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, graph_def, [output_layer]) | |
tf.io.gfile.makedirs(output_dir) | |
with tf.io.gfile.GFile(output_file, 'wb') as f: | |
f.write(output_graph_def.SerializeToString()) | |
logging.info('Copying %s to %s\nThis can take a few minutes', checkpoint_path, output_dir) | |
for filepath in tf.io.gfile.glob(checkpoint_path + '*'): | |
filename = os.path.basename(filepath) | |
tf.io.gfile.copy(filepath, os.path.join(output_dir, filename)) | |
train_config.save_file(output_dir) | |
def export_graph_def_to_trt_model(model_dir: str, max_batch_size: int = 2): | |
# Make the model repository structure | |
def check_and_make_dirs(dirs: List[str]): | |
for dir in dirs: | |
if not os.path.isdir(dir): | |
os.mkdir(dir) | |
repo_dir = os.path.join(os.getcwd(), 'models') | |
new_model_dir = os.path.join(repo_dir, model_dir.split('/')[0]) | |
version_dir = os.path.join(new_model_dir, '1') | |
saved_model_dir = os.path.join(version_dir, 'model.savedmodel') | |
check_and_make_dirs([repo_dir, new_model_dir, version_dir, saved_model_dir]) | |
logging.info(f'SAVED MODEL DIR IS {[model_dir, repo_dir, new_model_dir, version_dir, saved_model_dir]}') | |
train_config = ModelConfig(model_dir, file_must_exists=True) | |
output_node = train_config.get_param(ConfigSection.TRAINING, ConfigParam.output_layer) | |
pb_file = os.path.join(model_dir, model_dir + '.pb') | |
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS | |
conversion_params = conversion_params._replace(precision_mode="FP32") | |
# conversion_params = conversion_params._replace(max_batch_size=2) | |
num_classes = int(train_config.get_param(ConfigSection.TFRECORD, ConfigParam.number_of_classes)) | |
convert_from_graph_to_saved_model(pb_file, f'{model_dir}/saved', output_node) | |
converter = trt.TrtGraphConverterV2(input_saved_model_dir=f'{model_dir}/saved', conversion_params=conversion_params) | |
converter.convert() | |
def input_fn(): | |
inp = np.random.normal(size=(1, 299, 299, 3)).astype(np.float32) | |
yield [inp] | |
# converter.build(input_fn=input_fn) | |
# for node in trt_graph.node: | |
# logging.info(f'Node: {node.op}') | |
converter.save(saved_model_dir) | |
write_trt_config(model_dir=model_dir, new_model_dir=new_model_dir, platform='tensorflow_savedmodel', | |
output_node=output_node, num_classes=num_classes, channel_first=False, | |
max_batch_size=max_batch_size) | |
def convert_from_graph_to_saved_model(input_pb: str, output_dir: str, output_node: str): | |
""" | |
Function that converts the Graph Def .pb (primarily used with TF1) to a Saved Model being used by TF2. Necessary | |
since the TF2 Tensor RT (Triton Server) only accepts TF2 based saved models, but Models with TF1 Slim were | |
historically saved as Graph Def protobufs. | |
""" | |
if os.path.exists(output_dir): | |
os.system(f"rm -rf {output_dir}") | |
builder = tf.compat.v1.saved_model.Builder(output_dir) | |
with tf.compat.v1.gfile.GFile(input_pb, "rb") as f: | |
graph_def = tf.compat.v1.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
sigs = {} | |
with tf.compat.v1.Session(graph=tf.compat.v1.Graph()) as sess: | |
tf.compat.v1.import_graph_def(graph_def, name="") | |
g = tf.compat.v1.get_default_graph() | |
nodes = [f'{n.name}:0' for n in graph_def.node] | |
for node in nodes: | |
logging.info(f'Node {node}') | |
inp = g.get_tensor_by_name(nodes[0]) | |
logging.info(f'INPUT Is {inp}') | |
out = g.get_tensor_by_name(f'{output_node}:0') | |
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \ | |
tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( | |
{"in": inp}, {"out": out}) | |
builder.add_meta_graph_and_variables(sess, | |
[tag_constants.SERVING], | |
signature_def_map=sigs) | |
builder.save() | |
def write_trt_config(model_dir: str, | |
new_model_dir: str, | |
platform: str, | |
output_node: str, | |
num_classes: int, | |
channel_first: bool = False, | |
max_batch_size: int = 2): | |
train_config = ModelConfig(model_dir, file_must_exists=True) | |
image_size = int(train_config.get_param(ConfigSection.TRAINING, ConfigParam.network_image_size)) | |
mc = trt_model_config.ModelConfig() | |
mc.name = os.path.basename(model_dir) | |
mc.platform = platform | |
mc.max_batch_size = max_batch_size | |
mc_input = trt_model_config.ModelInput() | |
mc_input.name = 'input' | |
mc_input.data_type = trt_model_config.TYPE_FP32 | |
if channel_first: | |
mc_input.format = trt_model_config.ModelInput.FORMAT_NCHW | |
mc_input.dims.extend([3, image_size, image_size]) | |
else: | |
mc_input.format = trt_model_config.ModelInput.FORMAT_NHWC | |
mc_input.dims.extend([image_size, image_size, 3]) | |
mc.input.extend([mc_input]) | |
mc_output = trt_model_config.ModelOutput() | |
mc_output.name = 'out' | |
mc_output.data_type = trt_model_config.TYPE_FP32 | |
if platform == 'tensorrt_plan': | |
mc_output.dims.extend([1, num_classes, 1]) | |
else: | |
mc_output.dims.extend([num_classes]) | |
mc.output.extend([mc_output]) | |
mc_instance_grp = trt_model_config.ModelInstanceGroup() | |
mc_instance_grp.kind = trt_model_config.ModelInstanceGroup.KIND_GPU | |
mc_instance_grp.count = 3 | |
mc.instance_group.extend([mc_instance_grp]) | |
output_file = os.path.join(new_model_dir, 'config.pbtxt') | |
logging.info('Writing %s', output_file) | |
with tf.io.gfile.GFile(output_file, 'w') as f: | |
f.write(str(mc)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment