Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created September 19, 2021 02:45
Show Gist options
  • Save tiandiao123/2d15f321c42d3e6a77c0f69b0d0f6107 to your computer and use it in GitHub Desktop.
Save tiandiao123/2d15f321c42d3e6a77c0f69b0d0f6107 to your computer and use it in GitHub Desktop.
import mxnet
from mxnet.gluon.model_zoo.vision import get_model
import tvm
from tvm import relay
import tvm.contrib.graph_runtime as runtime
import numpy as np
dtype = "float32"
input_shape = (1, 3, 224, 224)
block = get_model('resnet18_v1', pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt
mod, config = partition_for_tensorrt(mod, params)
target = "cuda"
with tvm.transform.PassContext(opt_level=3, config={'relay.ext.tensorrt.options': config}):
lib = relay.build(mod, target=target, params=params)
dev = tvm.context(str(target), 1)
loaded_lib = tvm.runtime.load_module('compiled.so')
gen_module = runtime.GraphModule(loaded_lib['default'](dev))
input_data = np.random.uniform(0, 1, input_shape).astype(dtype)
gen_module.run(data=input_data)
# Evaluate
print("Evaluate inference time cost...")
ftimer = gen_module.module.time_evaluator("run", dev, repeat=10, min_repeat_ms=500)
prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond
message = "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))
print(message)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment