Skip to content

Instantly share code, notes, and snippets.

@comaniac
Created August 12, 2020 20:03
Show Gist options
  • Save comaniac/b689010eaf0467105248efccffa13352 to your computer and use it in GitHub Desktop.
Save comaniac/b689010eaf0467105248efccffa13352 to your computer and use it in GitHub Desktop.
import numpy as np
import tvm
from tvm import relay
from tvm.autotvm.graph_tuner import DPTuner
from tvm.contrib import graph_runtime
import torch
import torchvision
target = 'llvm'
ctx = tvm.cpu(0)
model_name = 'inception_v3'
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()
input_shape = [1, 3, 299, 299]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
input_name = 'img'
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
### Graph tuning
executor = DPTuner(mod['main'], {input_name: input_shape}, 'inception_v3.json',
[relay.op.get('nn.conv2d')], target)
executor.benchmark_layout_transform(min_exec_num=2000)
executor.run()
executor.write_opt_sch2record_file('graph.log')
###
tvm.autotvm.task.DispatchContext.current = autotvm.apply_graph_best('graph.log')
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build_module.build(mod, target=target, params=params)
runtime = graph_runtime.create(graph, lib, ctx)
runtime.set_input(input_name,
tvm.nd.array(np.random.uniform(size=input_shape).astype('float32')))
runtime.set_input(**params)
ftimer = runtime.module.time_evaluator('run', ctx, number=10, repeat=3)
prof_res = np.array(ftimer().results) * 1000
print('Mean inference time (std dev): %.2f ms (%.2f ms)' %
(np.mean(prof_res), np.std(prof_res)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment