Skip to content

Instantly share code, notes, and snippets.

@comaniac
Last active October 15, 2020 21:26
Show Gist options
  • Save comaniac/5b9f11c6096aff980d9a5366656d4535 to your computer and use it in GitHub Desktop.
Save comaniac/5b9f11c6096aff980d9a5366656d4535 to your computer and use it in GitHub Desktop.
import os
import numpy as np
import logging
import tvm
from tvm import auto_scheduler, te, topi
from tvm.topi.nn.util import get_pad_tuple
from tvm.auto_scheduler.compute_dag import ComputeDAG
logging.basicConfig(level=logging.INFO, filename='time.log')
resnet_conv2d_configs = {
# format : H, W, CI, CO, KH, KW, strides, padding, dilation
'18': [
(224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)),
(56, 56, 64, 128, 3, 3, (2, 2), (1, 1), (1, 1)),
(56, 56, 64, 128, 1, 1, (2, 2), (0, 0), (1, 1)),
(56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)),
(56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)),
(28, 28, 128, 256, 3, 3, (2, 2), (1, 1), (1, 1)),
(28, 28, 128, 256, 1, 1, (2, 2), (0, 0), (1, 1)),
(28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)),
(14, 14, 256, 512, 3, 3, (2, 2), (1, 1), (1, 1)),
(14, 14, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)),
(14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)),
(7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)),
],
'50': [
(224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)),
(56, 56, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)),
(56, 56, 256, 128, 1, 1, (2, 2), (0, 0), (1, 1)),
(56, 56, 256, 64, 1, 1, (1, 1), (0, 0), (1, 1)),
(56, 56, 64, 256, 1, 1, (1, 1), (0, 0), (1, 1)),
(56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)),
(56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)),
(28, 28, 512, 1024, 1, 1, (2, 2), (0, 0), (1, 1)),
(28, 28, 512, 256, 1, 1, (2, 2), (0, 0), (1, 1)),
(28, 28, 512, 128, 1, 1, (1, 1), (0, 0), (1, 1)),
(28, 28, 128, 512, 1, 1, (1, 1), (0, 0), (1, 1)),
(28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)),
(14, 14, 1024, 2048, 1, 1, (2, 2), (0, 0), (1, 1)),
(14, 14, 1024, 512, 1, 1, (2, 2), (0, 0), (1, 1)),
(14, 14, 1024, 256, 1, 1, (1, 1), (0, 0), (1, 1)),
(14, 14, 256, 1024, 1, 1, (1, 1), (0, 0), (1, 1)),
(14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)),
(7, 7, 2048, 512, 1, 1, (1, 1), (0, 0), (1, 1)),
(7, 7, 512, 2048, 1, 1, (1, 1), (0, 0), (1, 1)),
(7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)),
],
}
def get_log_file_name_from_task(task):
return "{}.json".format(
task.workload_key.replace("[", "")
.replace("]", "")
.replace('"', "")
.replace(",", "_")
.replace(" ", "")
)
@auto_scheduler.register_workload
def conv2d_nchw(N, H, W, CI, CO, KH, KW, stride, padding, dilation):
data = te.placeholder((N, CI, H, W), name="data")
kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
out = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation, out_dtype="float32")
return [data, kernel, out]
@auto_scheduler.register_workload
def conv2d_nchw_gd(N, H, W, CI, CO, KH, KW, stride, padding, dilation):
data, kernel, f_out = conv2d_nchw(N, H, W, CI, CO, KH, KW, stride, padding, dilation)
dy = te.placeholder(f_out.shape, name="dy")
out = te.gradient(f_out, [data, kernel], head=dy)
return [data, kernel, dy, *out]
@auto_scheduler.register_workload
def conv2d_nhwc(N, H, W, CI, CO, KH, KW, stride, padding):
data = te.placeholder((N, H, W, CI), name="data")
kernel = te.placeholder((KH, KW, CI, CO), name="kernel")
out = topi.nn.conv2d_nhwc(data, kernel, stride, padding, dilation=1, out_dtype="float32")
return [data, kernel, out]
@auto_scheduler.register_workload
def conv2d_nhwc_gd(N, H, W, CI, CO, KH, KW, stride, padding):
data, kernel, f_out = conv2d_nhwc(N, H, W, CI, CO, KH, KW, stride, padding)
dy = te.placeholder(f_out.shape, name="dy")
out = te.gradient(f_out, [data, kernel], head=dy)
return [data, kernel, dy, *out]
target = tvm.target.Target("cuda -model=t4")
batch = 32
tasks = []
for cfg in resnet_conv2d_configs["18"]:
tasks.append(
auto_scheduler.create_task(conv2d_nchw, (batch, *cfg), target)
)
tasks.append(
auto_scheduler.create_task(conv2d_nchw_gd, (batch, *cfg), target)
)
print('Getting device...')
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
ctx = tvm.gpu()
num_out = 2
for idx, task in enumerate(tasks):
log_file = get_log_file_name_from_task(task)
logging.info("[%d / %d Tasks] Log to %s" % (idx + 1, len(tasks), log_file))
cost_model = auto_scheduler.XGBModel()
if os.path.exists(log_file):
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
else:
search_policy = auto_scheduler.SketchPolicy(task, cost_model)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=1500,
runner=measure_ctx.runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
#print(task.compute_dag)
sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options=tune_option)
#inp, res = auto_scheduler.load_best(log_file, task.workload_key)
#print(task.compute_dag.print_python_code_from_state(inp.state))
#sch, args = task.compute_dag.apply_steps_from_state(inp.state)
func = tvm.build(sch, args, target)
in_nps = [np.random.uniform(size=[v.value for v in a.shape]).astype(np.float32) for a in args[:-num_out]]
in_args = [tvm.nd.array(dnp, ctx=ctx) for dnp in in_nps]
out_args = [tvm.nd.empty([v.value for v in a.shape], ctx=ctx) for a in args[-num_out:]]
# Evaluate execution time
evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500)
logging.info(
"Median execution time: %.3f ms"
% (np.median(evaluator(*in_args, *out_args).results) * 1000)
)
#np.testing.assert_equal(ref_out_args[1].asnumpy(), out_args[0].asnumpy())
del measure_ctx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment