Skip to content

Instantly share code, notes, and snippets.

@comaniac
Created July 16, 2021 19:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save comaniac/37bfb1d707b0b371586cb42f8fc44bcd to your computer and use it in GitHub Desktop.
Save comaniac/37bfb1d707b0b371586cb42f8fc44bcd to your computer and use it in GitHub Desktop.
import numpy as np
import tvm
from tvm import relay, auto_scheduler
import tvm.relay.testing
from tvm.contrib import graph_executor
def get_network(name, batch_size, layout="NHWC", dtype="float32"):
"""Get the symbol definition and random weight of a network"""
# auto-scheduler prefers NHWC layout
if layout == "NHWC":
image_shape = (224, 224, 3)
elif layout == "NCHW":
image_shape = (3, 224, 224)
else:
raise ValueError("Invalid layout: " + layout)
input_shape = (batch_size,) + image_shape
output_shape = (batch_size, 1000)
if name.startswith("resnet-"):
n_layer = int(name.split("-")[1])
mod, params = relay.testing.resnet.get_workload(
num_layers=n_layer,
batch_size=batch_size,
layout=layout,
dtype=dtype,
image_shape=image_shape,
)
elif name.startswith("resnet3d-"):
n_layer = int(name.split("-")[1])
mod, params = relay.testing.resnet.get_workload(
num_layers=n_layer,
batch_size=batch_size,
layout=layout,
dtype=dtype,
image_shape=image_shape,
)
elif name == "mobilenet":
mod, params = relay.testing.mobilenet.get_workload(
batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape
)
elif name == "squeezenet_v1.1":
assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout"
mod, params = relay.testing.squeezenet.get_workload(
version="1.1",
batch_size=batch_size,
dtype=dtype,
image_shape=image_shape,
)
elif name == "inception_v3":
input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3)
mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == "mxnet":
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
assert layout == "NCHW"
block = get_model("resnet18_v1", pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
net = mod["main"]
net = relay.Function(
net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
)
mod = tvm.IRModule.from_expr(net)
return mod, params, input_shape, output_shape
# Define the neural network and compilation target
network = "resnet-18"
batch_size = 1
layout = "NCHW"
target = {"llvm": "llvm", "cuda": "cuda"}
dtype = "float32"
log_file = "%s-%s-B%d-%s.json" % (network, layout, batch_size, "cuda")
mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype)
@relay.transform.function_pass(opt_level=1)
class MyPass:
def __init__(self):
self.var = 0
# This function can define a pass.
def transform_function(self, func, mod, ctx):
obj = self
class Test(tvm.relay.ExprMutator):
def visit_call(self, expr):
visit = super().visit_call(expr)
if expr.op == tvm.relay.op.get("nn.conv2d"):
return relay.annotation.on_device(visit, 'cuda')
else:
return visit
return Test().visit(func)
dev1 = tvm.device("llvm")
dev2 = tvm.device("cuda")
custom_pass = MyPass()
mod = custom_pass(mod)
if True:
# Error
print("Extract tasks...")
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)
for idx, task in enumerate(tasks):
print("========== Task %d (workload key: %s) ==========" % (idx, task.workload_key))
print(task.compute_dag)
assert False
else:
# Working
print("Compile...")
with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
lib = relay.build(mod, target=target, params=params)
# Create graph executor
module = graph_executor.create(lib.get_graph_json(), lib.get_lib(), [dev1, dev2])
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input(**lib.get_params())
module.set_input("data", data_tvm)
module.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment