Skip to content

Instantly share code, notes, and snippets.

@vinx13
Created September 20, 2018 06:34
Show Gist options
  • Save vinx13/c3fbbc1fb45db9b8662b68640b9a783c to your computer and use it in GitHub Desktop.
Save vinx13/c3fbbc1fb45db9b8662b68640b9a783c to your computer and use it in GitHub Desktop.
import os
import nnvm
import nnvm.testing
import nnvm.compiler
from nnvm import sym
import tvm
from tvm import autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
import numpy as np
import logging
logging.getLogger('autotvm').setLevel(logging.DEBUG)
target = tvm.target.cuda()
network = 'custom'
dtype = 'int8'
log_file = "%s.log" % network
tuning_option = {
'log_filename': log_file,
'tuner': 'xgb',
'n_trial': 10,
'early_stopping': 600,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(timeout=10),
runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4),
),
}
def get_network(batch_size):
input_shape = (batch_size, 512, 7, 7)
output_shape = (batch_size, 512, 7, 7)
data = sym.Variable(name="data")
symbol = sym.conv2d(data=data, kernel_size=(3, 3), padding=(1, 1), channels=512, name="conv1", use_bias=True)
net, params = nnvm.testing.create_workload(symbol, batch_size, (512,7,7), dtype)
return net, params, input_shape, output_shape
def tune_tasks(tasks,
measure_option,
tuner='xgb',
n_trial=1000,
early_stopping=None,
log_filename='tuning.log',
use_transfer_learning=True):
for i in range(len(tasks)):
args = tasks[i].args
data, kernel, padding, stride, layout, dtype = tasks[i].args
block_factor = 4
N, CI, H, W = data[1]
CO, _, KH, KW = kernel[1]
new_args = (data, kernel, padding, stride, layout, dtype)
if CO % block_factor == 0 and CI % block_factor == 0:
# use int8 template if CI and CO are multiple of block_factor
data = (data[0], (N, CI // block_factor, H, W, block_factor), data[2])
kernel = (kernel[0], (CO // block_factor, CI // block_factor, KH, KW, block_factor, block_factor), kernel[2])
new_task = autotvm.task.create(tasks[i].name, new_args, tasks[i].target, tasks[i].target_host, 'int8')
tasks[i] = new_task
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " %(i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(tsk, loss_type='rank')
elif tuner == 'ga':
tuner_obj = GATuner(tsk, pop_size=100)
elif tuner == 'random':
tuner_obj = RandomTuner(tsk)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(tsk)
else:
raise ValueError("Invalid tuner: " + tuner)
if use_transfer_learning:
if os.path.isfile(tmp_log_file):
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
def tune_and_evaluate(tuning_opt):
net, params, input_shape, out_shape = get_network(batch_size=1)
tasks = autotvm.task.extract_from_graph(net, target=target,
shape={'data': input_shape}, dtype=dtype,
symbols=(nnvm.sym.conv2d,))
m = nnvm.compiler.build(net, target=target, shape={'data':input_shape}, dtype=dtype)
tune_tasks(tasks, **tuning_opt)
with autotvm.apply_history_best(log_file):
print("Compile...")
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
# export library
tmp = tempdir()
filename = "net.tar"
lib.export_library(tmp.relpath(filename))
# load parameters
ctx = tvm.context(str(target), 0)
params_tvm = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module = runtime.create(graph, lib, ctx)
module.set_input('data', data_tvm)
module.set_input(**params_tvm)
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=400, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
tune_and_evaluate(tuning_option)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment