Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Last active August 23, 2021 21:15
Show Gist options
  • Save tiandiao123/67adb11ab3d73df8e83a1469707d7db4 to your computer and use it in GitHub Desktop.
Save tiandiao123/67adb11ab3d73df8e83a1469707d7db4 to your computer and use it in GitHub Desktop.
# FP16 TRT command to run : TVM_TENSORRT_USE_FP16=1 python test_trt.py
# INT8 TRT command to run : TVM_TENSORRT_USE_INT8=1 TENSORRT_NUM_CALI_INT8=10 python test_trt.py
# use tvm branch: https://github.com/tiandiao123/tvm/tree/pr_trt_int8
import tvm
from tvm import relay
import numpy as np
from tvm.contrib.download import download_testdata
import os
# PyTorch imports
import torch
import torchvision
import numpy as np
import cv2
# PyTorch imports
import torch
import torchvision
model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()
# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
from PIL import Image
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
# Preprocess the image and convert to tensor
from torchvision import transforms
my_preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)
input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
# compile the model
target = "cuda"
dev = tvm.cuda(0)
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
mod, config = partition_for_tensorrt(mod, params)
print("python script started building --------------")
with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}):
lib = relay.build(mod, target=target, params=params)
print("python script finsihed building -------------------")
dtype = "float32"
lib.export_library('compiled.so')
loaded_lib = tvm.runtime.load_module('compiled.so')
gen_module = tvm.contrib.graph_executor.GraphModule(loaded_lib['default'](dev))
num_cali_int8 = 0
try:
num_cali_int8 = os.environ["TENSORRT_NUM_CALI_INT8"]
print("we are going to set {} times calibration in this case".format(num_cali_int8))
except:
print("no TENSORRT_NUM_CALI_INT8 found in this case ... ")
num_cali_int8 = int(num_cali_int8)
if num_cali_int8 != 0:
print("calibration steps ... ")
for i in range(num_cali_int8):
gen_module.run(data=img)
print("finished calibration step")
print("test run ... ")
gen_module.run(data=img)
out = gen_module.get_output(0)
print(out)
epochs = 100
total_time = 0
import time
for i in range(epochs):
start = time.time()
gen_module.run(data=img)
end = time.time()
total_time += end-start
print("the average time is {} ms".format(str(total_time/epochs * 1000)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment