Skip to content

Instantly share code, notes, and snippets.

@comaniac
Created November 12, 2020 01:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save comaniac/cc10a341b7d1c2cd504a5cd5456f6b44 to your computer and use it in GitHub Desktop.
Save comaniac/cc10a341b7d1c2cd504a5cd5456f6b44 to your computer and use it in GitHub Desktop.
import logging
import numpy as np
import tvm
from tvm import relay, te, topi, transform, auto_scheduler
from tvm.contrib import graph_runtime
from tvm.relay.backend import compile_engine
# logging.basicConfig(level=logging.INFO)
ishape = (1, 3, 224, 224)
w1shape = (32, 3, 3, 3)
w2shape = (32, 32, 3, 3)
dtype = "float32"
target = tvm.target.Target("cuda")
log_file = "tune.json"
def get_relay_func():
data = relay.var("data", shape=(ishape), dtype=dtype)
weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
weight2 = relay.var("weight2", shape=(w2shape), dtype=dtype)
conv2d = relay.nn.conv2d(data, weight1, kernel_size=(3, 3), padding=(1, 1))
relu = relay.nn.relu(conv2d)
conv2d = relay.nn.conv2d(relu, weight2, kernel_size=(3, 3), padding=(1, 1))
out = relay.nn.relu(conv2d)
func = relay.Function([data, weight1, weight2], out)
data = np.random.uniform(-1, 1, size=ishape).astype("float32")
w1 = np.random.uniform(-1, 1, size=w1shape).astype("float32")
w2 = np.random.uniform(-1, 1, size=w2shape).astype("float32")
params = {"weight1": w1, "weight2": w2}
return func, {"data": data}, params
def get_relay_fused_func():
data = relay.var("data", shape=(ishape), dtype=dtype)
weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
weight2 = relay.var("weight2", shape=(w2shape), dtype=dtype)
fused_func, args, params = get_relay_func()
# Set to primitive to keep fuse_ops untouch.
fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
call = relay.Call(fused_func, [data, weight1, weight2])
func = relay.Function([data, weight1, weight2], call)
return func, args, params
def get_relay_injective_func():
batch_size = ishape[0]
grid_size = (150, 150)
crop_size = (112, 112)
input_tensor = relay.var("image", relay.TensorType(ishape, "float32"))
transform_tensor = relay.var("transform", relay.TensorType((batch_size, 2, 3), "float32"))
bbx_tensor = relay.var("bbox", relay.TensorType((batch_size, 4), "float32"))
idx_tensor = relay.var("index", relay.TensorType((batch_size,), "int32"))
affine_grid_tensor = relay.image.affine_grid(transform_tensor, grid_size)
alignment_tensor = relay.image.grid_sample(input_tensor, affine_grid_tensor)
output_tensor = relay.image.crop_and_resize(
alignment_tensor, bbx_tensor, idx_tensor, crop_size, layout="NCHW", out_dtype="float32"
)
func = relay.Function([input_tensor, transform_tensor, bbx_tensor, idx_tensor], output_tensor)
image = np.random.uniform(-1, 1, size=ishape).astype("float32")
trans = np.random.uniform(-1, 1, size=(batch_size, 2, 3)).astype("float32")
bbx = np.random.uniform(-1, 1, size=(batch_size, 4)).astype("float32")
idx = np.random.uniform(-1, 1, size=(batch_size,)).astype("float32")
return func, {"image": image, "transform": trans, "bbox": bbx, "index": idx}, None
def tune_task(task, cost_model="random"):
measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
cost_model = auto_scheduler.XGBModel() if cost_model == "xgb" else auto_scheduler.RandomModel()
policy = auto_scheduler.SketchPolicy(task, cost_model)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=16,
runner=measure_ctx.runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
)
sch, args = auto_scheduler.auto_schedule(task, search_policy=policy, tuning_options=tune_option)
# Kill the process for measurement
del measure_ctx
def run_relay_task(func, args, params, need_tune=False, apply_log_file=None):
compile_engine.get().clear()
mod = tvm.IRModule.from_expr(func)
print(mod)
if need_tune:
print("Extract tasks...")
tasks, _ = auto_scheduler.extract_tasks(mod["main"], params, target)
assert len(tasks) == 1
print("Tuning...")
tune_task(tasks[0])
print("Compile...")
lib = None
if apply_log_file is not None:
try:
with auto_scheduler.ApplyHistoryBest(apply_log_file):
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
except Exception as err:
print("Failed to build with auto_scheduler schedule: %s" % str(err))
else:
try:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
except Exception as err:
print("Failed to build with TOPI schedule: %s" % str(err))
if lib is None:
return
ctx = tvm.context(str(target))
module = graph_runtime.GraphModule(lib["default"](ctx))
for name, data in args.items():
module.set_input(name, data)
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=600)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print(
"Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))
)
print("=== An injective function with TOPI schedule")
run_relay_task(*get_relay_injective_func())
print("=== An injective function with auto_scheduler schedule")
run_relay_task(*get_relay_injective_func(), need_tune=False, apply_log_file=log_file)
print("=== A two-conv2d function with TOPI schedule (expected to fail)")
run_relay_task(*get_relay_fused_func())
print("=== A two-conv2d function with auto_scheduler schedule")
run_relay_task(*get_relay_fused_func(), need_tune=False, apply_log_file=log_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment