Skip to content

Instantly share code, notes, and snippets.

@guschmue
Created November 13, 2020 15:35
Show Gist options
  • Save guschmue/a8e6a52619e5f1734e89a98102d37146 to your computer and use it in GitHub Desktop.
Save guschmue/a8e6a52619e5f1734e89a98102d37146 to your computer and use it in GitHub Desktop.
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
"""Test examples."""
import os
import subprocess
import unittest
import tensorflow as tf
import numpy as np
import tensorflow_hub as hub
import tf2onnx
import timeit
import time
import zipfile
import keras2onnx
import onnxruntime as rt
from onnx import helper
from tensorflow.python.keras.saving import saving_utils as _saving_utils
from tensorflow.keras import layers, models
from tensorflow import keras
from common import check_opset_min_version, check_opset_max_version, check_tf_min_version
def onnx_session(model):
providers = ['CPUExecutionProvider']
if rt.get_device() == "GPU":
gpus = os.environ.get("CUDA_VISIBLE_DEVICES")
if gpus is None or len(gpus) > 1:
providers = ['CUDAExecutionProvider']
opt = rt.SessionOptions()
opt.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
# opt.enable_profiling = True
sess = rt.InferenceSession(model, sess_options=opt, providers=providers)
inputs = [input_meta.name for input_meta in sess.get_inputs()]
outputs = [output_meta.name for output_meta in sess.get_outputs()]
return sess, inputs, outputs
def to_onnx(model, output=None, use_tf2onnx=True, large_model=False):
if use_tf2onnx:
const_node_values = None
external_tensor_storage = None
function = _saving_utils.trace_model_call(model)
concrete_func = function.get_concrete_function()
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
if input_tensor.dtype != tf.dtypes.resource]
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
if output_tensor.dtype != tf.dtypes.resource]
frozen_graph = tf2onnx.tf_loader.from_function(
concrete_func, input_names, output_names, large_model=large_model)
if large_model:
# frozen_graph, input_names, outputs, concrete_func, imported = frozen_graph
from tf2onnx.tf_utils import compress_graph_def
const_node_values = compress_graph_def(frozen_graph)
external_tensor_storage = tf2onnx.graph.ExternalTensorStorage()
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(frozen_graph, name='')
g = tf2onnx.tfonnx.process_tf_graph(
tf_graph, opset=12, input_names=input_names, output_names=output_names, const_node_values=const_node_values)
onnx_graph = tf2onnx.optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model("converted from tf2onnx", external_tensor_storage=external_tensor_storage)
else:
model_proto = keras2onnx.convert_keras(model, model.name)
if output:
if large_model:
tf2onnx.utils.save_onnx_zip(output, model_proto, external_tensor_storage)
else:
tf2onnx.utils.save_protobuf(output, model_proto)
inputs = [n.name for n in model_proto.graph.input]
outputs = [n.name for n in model_proto.graph.output]
return model_proto, inputs, outputs, external_tensor_storage
def bert_from_hub(url, output_name):
max_seq_length = 512
input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name="input_mask")
segment_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name="segment_ids")
bert_layer = hub.KerasLayer(url, trainable=False)
bert_inputs = [input_word_ids, input_mask, segment_ids]
pooled_output, sequence_output = bert_layer(bert_inputs)
# vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
# do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
# tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)
model = tf.keras.models.Model(inputs=bert_inputs, outputs=[
pooled_output, sequence_output])
model.compile(loss='binary_crossentropy',
optimizer='adam', metrics=['accuracy'])
print("inputs", model.inputs)
print("outputs", model.outputs)
in_words = np.ones([1, max_seq_length], dtype=np.int32)
in_mask = np.ones([1, max_seq_length], dtype=np.int32)
in_segment = np.ones([1, max_seq_length], dtype=np.int32)
tf_ret = model.predict([in_words, in_mask, in_segment])
N = 20
print("tf perftest ...")
start = time.time()
for i in range(0, N):
_ = model.predict([in_words, in_mask, in_segment])
tf_res = 1000 * (time.time() - start) / N
print("convert ...")
large_model = output_name.endswith(".zip")
model_proto, inputs, outputs, external_tensor_storage = to_onnx(
model, output_name, use_tf2onnx=True, large_model=large_model)
tf.compat.v1.reset_default_graph()
feeds = {"input_word_ids:0": in_words,
"input_mask:0": in_mask,
"segment_ids:0": in_segment}
name = "bert"
save_path = output_name+".unpacked"
with zipfile.ZipFile(output_name, 'r') as z:
z.extractall(save_path)
output_name = os.path.join(save_path, "__MODEL_PROTO.onnx")
sess, inputs, outputs = onnx_session(output_name)
onnx_ret = sess.run(outputs, feeds)
print("ort perftest ...")
start = time.time()
for i in range(0, N):
_ = sess.run(outputs, feeds)
ort_res = 1000 * (time.time() - start) / N
print(f"{url} TF={tf_res} ORT={ort_res}")
for i1, i2 in zip(tf_ret, onnx_ret):
np.testing.assert_allclose(i1, i2, rtol=1)
class TestKerasApps(unittest.TestCase):
def setUp(self):
tf.compat.v1.reset_default_graph()
def test_tfkeras_bert0(self):
bert_from_hub(
"https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/2", "/tmp/bert_en_cased_L-12_H-768_A-12.zip")
@unittest.skip("later")
def test_tfkeras_bert1(self):
bert_from_hub(
"https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/2", "/tmp/bert_en_cased_L-24_H-1024_A-16.zip")
@unittest.skip("later")
def test_tfkeras_bert2(self):
bert_from_hub(
"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2", "/tmp/bert_en_uncased_L-12_H-768_A-12.zip")
@unittest.skip("later")
def test_tfkeras_bert3(self):
bert_from_hub(
"https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/2", "/tmp/bert_en_uncased_L-24_H-1024_A-16.zip")
@unittest.skip("later")
def test_tfkeras_bert4(self):
bert_from_hub(
"https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/2", "/tmp/bert_en_wwm_cased_L-24_H-1024_A-16.zip")
@unittest.skip("later")
def test_tfkeras_bert5(self):
bert_from_hub(
"https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/2", "/tmp/bert_en_wwm_uncased_L-24_H-1024_A-16.zip")
@unittest.skip("later")
def test_tfkeras_bert6(self):
bert_from_hub(
"https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/3", "/tmp/bert_en_cased_L-12_H-768_A-12_3.zip")
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment