Skip to content

Instantly share code, notes, and snippets.

@tonyreina
Created November 10, 2020 14:34
Show Gist options
  • Save tonyreina/50c5e57053612142395d64ef948b31fa to your computer and use it in GitHub Desktop.
Save tonyreina/50c5e57053612142395d64ef948b31fa to your computer and use it in GitHub Desktop.
# !/usr/bin/env python
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
from tensorflow.lite.python.util import run_graph_optimizations, get_grappler_config
from pathlib import Path
import argparse
def frozen_keras_graph(model):
tf_model = tf.function(model).get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(tf_model)
input_tensors = [
tensor for tensor in frozen_func.inputs
if tensor.dtype != tf.resource
]
output_tensors = frozen_func.outputs
graph_def = run_graph_optimizations(
graph_def,
input_tensors,
output_tensors,
config=get_grappler_config(["constfold", "function"]),
graph=frozen_func.graph)
return graph_def
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--input_model', '-m', required=True, type=str, help='Path to Keras model.')
return parser
def export_keras_to_tf(input_model, output_model):
print('Loading Keras model: ', input_model)
model = tf.keras.models.load_model(input_model, compile=True)
model.summary()
graph_def = frozen_keras_graph(model)
tf.io.write_graph(graph_def, '.', output_model, as_text=False)
def main():
argv = get_args().parse_args()
input_model = argv.input_model
output_model = str(Path(input_model).name) + '.pb'
export_keras_to_tf(input_model, output_model)
print('Saved as TF frozen model to: ', output_model)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment