Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created August 23, 2021 21:19
Show Gist options
  • Save tiandiao123/de52ef96f574645c2dccb3544b291487 to your computer and use it in GitHub Desktop.
Save tiandiao123/de52ef96f574645c2dccb3544b291487 to your computer and use it in GitHub Desktop.
# FP16 TRT command to run : TVM_TENSORRT_USE_FP16=1 python test_trt.py
# INT8 TRT command to run : TVM_TENSORRT_USE_INT8=1 TENSORRT_NUM_CALI_INT8=10 python test_trt.py
# use tvm branch: https://github.com/tiandiao123/tvm/tree/pr_trt_int8
import tvm
from tvm import relay
import os
from tvm import te
import tvm.relay as relay
from tvm.contrib.download import download_testdata
import onnx
import numpy as np
# Tensorflow imports
import tensorflow as tf
import numpy as np
try:
tf_compat_v1 = tf.compat.v1
except ImportError:
tf_compat_v1 = tf
# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing
model_url = "".join(
[
"https://gist.github.com/zhreshold/",
"bcda4716699ac97ea44f791c24310193/raw/",
"93672b029103648953c4e5ad3ac3aadf346a4cdc/",
"super_resolution_0.2.onnx",
]
)
model_path = download_testdata(model_url, "super_resolution.onnx", module="onnx")
# now you have super_resolution.onnx on disk
onnx_model = onnx.load(model_path)
from PIL import Image
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
img_ycbcr = img.convert("YCbCr") # convert to YCbCr
img_y, img_cb, img_cr = img_ycbcr.split()
x = np.array(img_y)[np.newaxis, np.newaxis, :, :]
input_name = "1"
shape_dict = {input_name: x.shape}
print("shape: ")
print(x.shape)
input_shape = x.shape
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
# compile the model
target = "cuda"
dev = tvm.cuda(0)
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
mod, config = partition_for_tensorrt(mod, params)
print("python script building --------------")
with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}):
lib = relay.build(mod, target=target, params=params)
print("python script finsihed building -------------------")
dtype = "float32"
lib.export_library('compiled.so')
loaded_lib = tvm.runtime.load_module('compiled.so')
gen_module = tvm.contrib.graph_executor.GraphModule(loaded_lib['default'](dev))
num_cali_int8 = 0
try:
num_cali_int8 = os.environ["TENSORRT_NUM_CALI_INT8"]
print("we are going to set {} times calibration in this case".format(num_cali_int8))
except:
print("no TENSORRT_NUM_CALI_INT8 found in this case ... ")
num_cali_int8 = int(num_cali_int8)
if num_cali_int8 != 0:
print("calibration steps ... ")
for i in range(num_cali_int8):
gen_module.run(data=x)
print("finished calibration step")
print("test run ... ")
gen_module.run(data=x)
out = gen_module.get_output(0)
print(out)
epochs = 100
total_time = 0
import time
for i in range(epochs):
start = time.time()
gen_module.run(data=x)
end = time.time()
total_time += end-start
print("the average time is {} ms".format(str(total_time/epochs * 1000)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment