Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created October 19, 2021 05:55
Show Gist options
  • Save tiandiao123/844ce92d56ef1127b23205ea82fba1ff to your computer and use it in GitHub Desktop.
Save tiandiao123/844ce92d56ef1127b23205ea82fba1ff to your computer and use it in GitHub Desktop.
### here is a demo how to convert your tf2 model into tvm relay
import tensorflow as tf
from tensorflow.python.tools import saved_model_utils
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from tvm import relay
from tvm.relay.frontend.tensorflow2 import from_tensorflow
dtype_dict = {
"tf.float32": "float32",
"tf.float16" : "float16",
"tf.float64": "float64",
"tf.int32": "int32",
"tf.int16": "int16",
"tf.int8": "int8",
"tf.uint8": "uint8",
"tf.int64": "int64",
}
model_path = "/data00/cuiqing.li/models/debug_model/1"
### convert to fronzen pb
custom_objects = {}
for backend_alias in ('K', 'backend'):
if backend_alias not in custom_objects:
custom_objects[backend_alias] = tf.keras.backend
new_model = tf.keras.models.load_model(model_path, custom_objects=custom_objects, compile=False)
print(new_model.summary())
print("print input info: ")
for input in new_model.inputs:
print(input.shape)
full_model = tf.function(lambda x: new_model(x))
input_info = []
for input in new_model.inputs:
print(input)
input_info.append(tf.TensorSpec(input.shape, input.dtype))
full_model = full_model.get_concrete_function(input_info)
frozen_func = convert_variables_to_constants_v2(full_model)
graph_def = frozen_func.graph.as_graph_def()
input_names = []
input_shapes = []
input_types = []
batch_size = 1
for input in frozen_func.inputs:
input_names.append(input.name)
input_types.append(input.dtype)
temp_shape = []
for ele in input.shape:
num = batch_size if ele == None else ele
temp_shape.append(num)
input_shapes.append(temp_shape)
for i in range(len(input_types)):
if str(input_types[i]) in dtype_dict:
input_types[i] = dtype_dict[str(input_types[i])]
else:
input_types[i] = "float32"
print("input names: ")
print(input_names)
print("input_shapes: ")
print(input_shapes)
print("input_types")
print(input_types)
tvm_shape_dict = {k: v for k, v in zip(input_names, input_shapes)}
print(tvm_shape_dict)
print("converting to relay graph ... ")
mod, params = from_tensorflow(graph_def, shape=tvm_shape_dict)
print(mod['main'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment