Created
December 7, 2017 17:26
-
-
Save zmjjmz/00ee658007b57111b9b936e1ae94c0cd to your computer and use it in GitHub Desktop.
TF serving export code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def export_keras_model_simple(model, export_dir, version, input_name_map={}, output_name_map={}, model_name=None, conflict_policy='fail', | |
exclude_outputs=[], verbose=True): | |
# only supports one input / one output models | |
# version is supposed to be an uint! | |
model_format = 'keras' | |
version_str = str(version) | |
if datautil.is_valid_path(model): | |
if verbose: | |
print("Exporting {0}".format(model)) | |
model = _WHOLE_MODEL_LOADERS[model_format](model) | |
if model_name is None: | |
model_name = os.path.basename(export_dir.strip('/')) | |
# for anything that has to change behavior here | |
keras.backend.set_learning_phase(0) | |
final_export_dir = os.path.join(export_dir, version_str) | |
# TODO (ZJ) move this outside of here | |
if os.path.exists(final_export_dir): | |
print("Export dir exists! Resolving according to conflict policy: {0}".format( | |
conflict_policy)) | |
if conflict_policy == 'bump': | |
import glob | |
existing_versions = map(lambda x: x.split( | |
'/')[-1], glob.glob(os.path.join(export_dir, '*'))) | |
viable_versions = [] | |
for v in existing_versions: | |
try: | |
viable_versions.append(int(v)) | |
except ValueError: | |
continue | |
new_version = max(viable_versions) + 1 | |
print("Bumping to version {0}".format(new_version)) | |
final_export_dir = os.path.join(export_dir, str(new_version)) | |
if conflict_policy == 'replace': | |
# just remove the current export dir and replace it. not a great idea! | |
# doesn't create a backup, so if for some reason the rest of this fails, | |
# it will just remove that directory... | |
import shutil | |
shutil.rmtree(final_export_dir) | |
if conflict_policy == 'fail': | |
# raise an error | |
raise ValueError( | |
"Directory {0} already exists, please specify a new version!".format(final_export_dir)) | |
builder = tensorflow.saved_model.builder.SavedModelBuilder( | |
final_export_dir) | |
input_map = {input_name_map.get(il.name, il.name):il_t for il, il_t in zip(model.input_layers, model.inputs)} | |
output_map = {output_name_map.get(ol.name, ol.name):ol_t for ol, ol_t in zip(model.output_layers, model.outputs) | |
if ol.name not in exclude_outputs} | |
if verbose: | |
print("Input map: {}".format(input_map)) | |
print("Output map: {}".format(output_map)) | |
signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( | |
inputs=input_map, | |
outputs=output_map) | |
if verbose: | |
print("Signature: {}".format(signature)) | |
main_op_new = control_flow_ops.group( | |
lookup_ops.tables_initializer(), | |
variables.local_variables_initializer(), | |
#variables.global_variables_initializer()) | |
) | |
with keras.backend.get_session() as sess: | |
builder.add_meta_graph_and_variables( | |
sess=sess, | |
tags=[tensorflow.saved_model.tag_constants.SERVING], | |
signature_def_map={'predict': signature}, | |
main_op=main_op_new, | |
#main_op=tensorflow.saved_model.main_op.main_op() | |
) | |
builder.save() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment