Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created March 15, 2022 04:25
Show Gist options
  • Save tiandiao123/4dd6d5d882e9934efc5cfe04960f8738 to your computer and use it in GitHub Desktop.
Save tiandiao123/4dd6d5d882e9934efc5cfe04960f8738 to your computer and use it in GitHub Desktop.
import tvm
from tvm import relay
import tvm.contrib.graph_runtime as runtime
import numpy as np
from tvm.contrib.download import download_testdata
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
# PyTorch imports
import torch
import torchvision
model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()
# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
from PIL import Image
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
# Preprocess the image and convert to tensor
from torchvision import transforms
my_preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)
input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
target = "cuda"
print(mod['main'])
use_trt = False
if use_trt:
mod , config = partition_for_tensorrt(mod, params, remove_no_mac_subgraphs=True)
print("after partition using trt ... ")
print(mod['main'])
with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}):
lib = relay.build(mod, target=target, params=params)
else:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
ctx = tvm.context(str(target), 0)
module = runtime.GraphModule(lib["default"](ctx))
module.set_input("input0", tvm.nd.array(img, ctx))
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, repeat=10, min_repeat_ms=500)
prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond
message = "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))
print(message)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment