Skip to content

Instantly share code, notes, and snippets.

@zmjjmz
Created December 7, 2017 17:26
Show Gist options
  • Save zmjjmz/00ee658007b57111b9b936e1ae94c0cd to your computer and use it in GitHub Desktop.
Save zmjjmz/00ee658007b57111b9b936e1ae94c0cd to your computer and use it in GitHub Desktop.
TF serving export code
def export_keras_model_simple(model, export_dir, version, input_name_map={}, output_name_map={}, model_name=None, conflict_policy='fail',
exclude_outputs=[], verbose=True):
# only supports one input / one output models
# version is supposed to be an uint!
model_format = 'keras'
version_str = str(version)
if datautil.is_valid_path(model):
if verbose:
print("Exporting {0}".format(model))
model = _WHOLE_MODEL_LOADERS[model_format](model)
if model_name is None:
model_name = os.path.basename(export_dir.strip('/'))
# for anything that has to change behavior here
keras.backend.set_learning_phase(0)
final_export_dir = os.path.join(export_dir, version_str)
# TODO (ZJ) move this outside of here
if os.path.exists(final_export_dir):
print("Export dir exists! Resolving according to conflict policy: {0}".format(
conflict_policy))
if conflict_policy == 'bump':
import glob
existing_versions = map(lambda x: x.split(
'/')[-1], glob.glob(os.path.join(export_dir, '*')))
viable_versions = []
for v in existing_versions:
try:
viable_versions.append(int(v))
except ValueError:
continue
new_version = max(viable_versions) + 1
print("Bumping to version {0}".format(new_version))
final_export_dir = os.path.join(export_dir, str(new_version))
if conflict_policy == 'replace':
# just remove the current export dir and replace it. not a great idea!
# doesn't create a backup, so if for some reason the rest of this fails,
# it will just remove that directory...
import shutil
shutil.rmtree(final_export_dir)
if conflict_policy == 'fail':
# raise an error
raise ValueError(
"Directory {0} already exists, please specify a new version!".format(final_export_dir))
builder = tensorflow.saved_model.builder.SavedModelBuilder(
final_export_dir)
input_map = {input_name_map.get(il.name, il.name):il_t for il, il_t in zip(model.input_layers, model.inputs)}
output_map = {output_name_map.get(ol.name, ol.name):ol_t for ol, ol_t in zip(model.output_layers, model.outputs)
if ol.name not in exclude_outputs}
if verbose:
print("Input map: {}".format(input_map))
print("Output map: {}".format(output_map))
signature = tensorflow.saved_model.signature_def_utils.predict_signature_def(
inputs=input_map,
outputs=output_map)
if verbose:
print("Signature: {}".format(signature))
main_op_new = control_flow_ops.group(
lookup_ops.tables_initializer(),
variables.local_variables_initializer(),
#variables.global_variables_initializer())
)
with keras.backend.get_session() as sess:
builder.add_meta_graph_and_variables(
sess=sess,
tags=[tensorflow.saved_model.tag_constants.SERVING],
signature_def_map={'predict': signature},
main_op=main_op_new,
#main_op=tensorflow.saved_model.main_op.main_op()
)
builder.save()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment