Skip to content

Instantly share code, notes, and snippets.

@SrivastavaKshitij
Created December 2, 2020 20:16
Show Gist options
  • Save SrivastavaKshitij/9341a414147fbc290eff4a92b8e73acd to your computer and use it in GitHub Desktop.
Save SrivastavaKshitij/9341a414147fbc290eff4a92b8e73acd to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
# # [TVMConf 2020] BYOC Tutorial Demo
#
# Cody Yu (hyuz@amazon.com), Zhi Chen (chzhi@amazon.com) from AWS AI.
#
#
#
# This demo has two parts. In the first part, we use a simple Relay graph to walkthrough the BYOC workflow. In the second part, we build a SSD model with TensorRT BYOC integration to showcase a real world application.
#
# This tutorial is based on upstream TVM commit d42c1553952366154c001d7bfce73ca239ab0700 with the following settings in `config.cmake`:
#
# ```
# set(USE_CUDA ON)
# set(USE_GRAPH_RUNTIME ON)
# set(USE_LLVM llvm-config-10)
# set(USE_CUDNN ON)
# set(USE_CUBLAS ON)
# set(USE_SORT ON)
# set(USE_THRUST ON)
# set(USE_TENSORRT_CODEGEN ON)
# set(USE_TENSORRT_RUNTIME /path/to/TensorRT)
# ```
#
# And run on the following environment:
#
# ```
# Platform:
# - Amazon EC2 g4dn.4xl
# - NVIDIA T4 TensorCore
#
# Software:
# - Ubuntu 18.04
# - CUDA 10.0
# - LLVM 10.0
# - TensorRT-7.0.0.11
# ```
#
#
# ## Part 1: Workflow Walkthrough
#
# Let's first use a simple example to walkthrough the graph partitioning flow.
# In[1]:
import tvm
from tvm import relay
# Here we demonstrate how BYOC annotates a Relay graph.
#
# Let's first define a simple Relay graph with supported and unsupported operators. This function includes a loop (control flow) to represent 3 convolution layers, although it's a bit weird to apply the same weights and biases many times...
# In[2]:
def get_demo_mod():
# Loop
iter1 = relay.var("iter1", shape=(), dtype="int32")
cond = relay.less(iter1, relay.const(2, dtype="int32"))
inc = iter1 + relay.const(1, dtype="int32")
loop_var = relay.var("while_loop")
# Loop body
d1 = relay.var("d1", shape=(1, 32, 56, 56), dtype="float32")
w1 = relay.var("w1", shape=(32, 32, 3, 3), dtype="float32")
b1 = relay.var("b1", shape=(32,), dtype="float32")
conv = relay.nn.conv2d(d1, w1, strides=(1, 1), padding=(1, 1))
bias = relay.nn.bias_add(conv, b1)
relu = relay.nn.relu(bias)
loop_cond_out = loop_var(inc, relu, w1, b1)
conv = relay.nn.conv2d(d1, w1, strides=(1, 1), padding=(1, 1))
bias = relay.nn.bias_add(conv, b1)
relu = relay.nn.relu(bias)
loop_break_out = relay.reshape(relu, (1, 56, 56, 32))
ife = relay.If(cond, loop_cond_out, loop_break_out)
data = relay.var("data", shape=(1, 32, 56, 56), dtype="float32")
weight = relay.var("weight", shape=(32, 32, 3, 3), dtype="float32")
bias = relay.var("bias", shape=(32,), dtype="float32")
loop_func = relay.Function([iter1, d1, w1, b1], ife)
out = relay.Let(loop_var, loop_func, loop_var(relay.const(0, dtype="int32"), data, weight, bias))
func = relay.Function([data, weight, bias], out)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
return mod
mod = get_demo_mod()
print(mod["main"].astext(show_meta_data=False))
# Then we define the annotation rules. As we have mentioned in the presentation, developers can specify both operator-based and pattern-based annotation rules. Here, we define the single operators `reshape` and `add` are supported. In addition, we also define two supported patterns (Conv2D - (Bias) - ReLU).
# In[3]:
demo_target = "byoc-target"
# Operator-based annotation rules
@tvm.ir.register_op_attr("reshape", "target.byoc-target")
def reshape(expr):
return True
@tvm.ir.register_op_attr("add", "target.byoc-target")
def add(expr):
return True
# Pattern-based annotation rules
def make_pattern(with_bias=True):
from tvm.relay.dataflow_pattern import is_op, wildcard
data = wildcard()
weight = wildcard()
bias = wildcard()
conv = is_op("nn.conv2d")(data, weight)
if with_bias:
conv_out = is_op("nn.bias_add")(conv, bias)
else:
conv_out = conv
return is_op("nn.relu")(conv_out)
conv2d_bias_relu_pat = ("byoc-target.conv2d_relu_with_bias", make_pattern(with_bias=True))
conv2d_relu_pat = ("byoc-target.conv2d_relu_wo_bias", make_pattern(with_bias=False))
patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
# Now let's perform pattern-based annotation:
# In[4]:
mod2 = relay.transform.MergeComposite(patterns)(mod)
print(mod2["main"].astext(show_meta_data=False))
# We can see that now all subgraphs that match the defined patterns are partitioned to "composite functions". In this example, we got two composite functions:
#
# ```
# %4 = fn (%FunctionVar_1_0: Tensor[(1, 32, 56, 56), float32],
# %FunctionVar_1_1: Tensor[(32, 32, 3, 3), float32],
# %FunctionVar_1_2: Tensor[(32), float32],
# PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_",
# Composite="byoc-target.conv2d_relu_with_bias") -> Tensor[(1, 32, 56, 56), float32] {
# %2 = nn.conv2d(%FunctionVar_1_0, %FunctionVar_1_1, padding=[1, 1, 1, 1]);
# %3 = nn.bias_add(%2, %FunctionVar_1_2);
# nn.relu(%3)
# };
#
# %8 = fn (%FunctionVar_0_0: Tensor[(1, 32, 56, 56), float32],
# %FunctionVar_0_1: Tensor[(32, 32, 3, 3), float32],
# %FunctionVar_0_2: Tensor[(32), float32],
# PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_",
# Composite="byoc-target.conv2d_relu_with_bias") -> Tensor[(1, 32, 56, 56), float32] {
# %6 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[1, 1, 1, 1]);
# %7 = nn.bias_add(%6, %FunctionVar_0_2);
# nn.relu(%7)
# };
# ```
#
# A composite function has two specialized attributes -- "PartitionedFromPattern" and "Composite":
# * PartitionedFromPattern: Indicate the operators in the function body.
# * Composite: Indicate the pattern name we defined.
#
# As you can imagine, this information could be useful for you to map a composite function to your accelerator in the codegen.
#
# Next, let's continue to apply the operator-based annotation rules:
# In[5]:
mod3 = relay.transform.AnnotateTarget("byoc-target")(mod2)
print(mod3["main"].astext(show_meta_data=False))
# Looks scary! Let me make this Relay graph more readable so that we can easilty find some interesting points.
#
# ```
# fn (%data: Tensor[(1, 32, 56, 56), float32],
# %weight: Tensor[(32, 32, 3, 3), float32],
# %bias: Tensor[(32), float32]) -> Tensor[(1, 56, 56, 32), float32] {
# let %while_loop: fn (int32,
# Tensor[(1, 32, 56, 56), float32],
# Tensor[(32, 32, 3, 3), float32],
# Tensor[(32), float32]) -> Tensor[(1, 56, 56, 32), float32] =
# fn (%iter1: int32,
# %d1: Tensor[(1, 32, 56, 56), float32],
# %w1: Tensor[(32, 32, 3, 3), float32],
# %b1: Tensor[(32), float32]) -> Tensor[(1, 56, 56, 32), float32] {
# %0 = annotation.compiler_begin(%iter1, compiler=default);
# %1 = annotation.compiler_begin(2, compiler=default);
# %2 = less(%0, %1);
# %3 = annotation.compiler_end(%2, compiler=default);
# if (%3) {
# %4 = annotation.compiler_begin(%iter1, compiler=byoc-target);
# %5 = annotation.compiler_begin(1;
# %6 = add(%4, %5);
# %7 = annotation.compiler_end(%6, compiler=byoc-target);
#
# %8 = annotation.compiler_begin(%7, compiler=byoc-target);
# %9 = annotation.compiler_begin(%d1, compiler=byoc-target);
# %10 = annotation.compiler_begin(%w1, compiler=byoc-target);
# %11 = annotation.compiler_begin(%b1, compiler=byoc-target);
# %14 = /* skip composite function body */;
# %15 = %14(%9, %10, %11);
# %16 = annotation.compiler_end(%15, compiler=byoc-target);
#
# %17 = annotation.compiler_begin(%16, compiler=default);
# %18 = annotation.compiler_begin(%w1, compiler=default);
# %19 = annotation.compiler_begin(%b1, compiler=default);
# %20 = %while_loop(%8, %17, %18, %19);
# annotation.compiler_end(%20, compiler=default)
# } else {
# %21 = annotation.compiler_begin(%d1, compiler=byoc-target);
# %22 = annotation.compiler_begin(%w1, compiler=byoc-target);
# %23 = annotation.compiler_begin(%b1, compiler=byoc-target);
# %26 = /* skip composite function body */;
# %27 = %26(%21, %22, %23);
# %28 = annotation.compiler_end(%27, compiler=byoc-target);
#
# %29 = annotation.compiler_begin(%28, compiler=byoc-target);
# %30 = reshape(%29, newshape=[1, 56, 56, 32]);
# annotation.compiler_end(%30, compiler=byoc-target)
# }
# };
# %31 = annotation.compiler_begin(0;
# %32 = annotation.compiler_begin(%data, compiler=default);
# %33 = annotation.compiler_begin(%weight, compiler=default);
# %34 = annotation.compiler_begin(%bias, compiler=default);
# %35 = %while_loop(%31, %32, %33, %34);
# annotation.compiler_end(%35, compiler=default)
# }
# ```
#
# * Almost all nodes in the graph are annotated with `compiler_begin` and `compiler_end` nodes. `compiler_*` nodes has an attribute `compiler` to indicate which target should this node go. In this example, it can be `default` or `byoc-target`.
#
# * Composite function calls are also annotated with `compiler=byoc-target`, indicating that this entire function can be offloaded.
#
# * We can find that some annotations can actually be merged, such as the annotations for the composite function and the following `reshape`. We use the next pass, `MergeCompilerRegion`, to merge them so that we can minimize the number of subgraphs.
# In[6]:
mod4 = relay.transform.MergeCompilerRegions()(mod3)
print(mod4["main"].astext(show_meta_data=False))
# We can see that now the `add` and the composite function call in the loop body share the same set of annotation nodes. i.e., the consecutive `compiler_end` and `compiler_begin` nodes are removed.
#
# ```
# %21 = annotation.compiler_begin(%d1, compiler=byoc-target);
# %22 = annotation.compiler_begin(%w1, compiler=byoc-target);
# %23 = annotation.compiler_begin(%b1, compiler=byoc-target);
# %26 = /* skip composite function body */;
# %27 = %26(%21, %22, %23);
# %28 = reshape(%27, newshape=[1, 56, 56, 32]);
# annotation.compiler_end(%28, compiler=byoc-target) /* Only one compiler_end means only one subgraph! */
# ```
#
# Finally, let's partition this graph:
# In[7]:
mod5 = relay.transform.PartitionGraph()(mod4)
print(mod5["main"].astext(show_meta_data=False))
# It's much clean now, right? We can see that 3 subgraphs have been partitioned for `byoc-target`. Let's see dive into each of them:
# In[8]:
for name in ["byoc-target_0", "byoc-target_2", "byoc-target_5"]:
print("%s: " % name)
print(mod5[name].astext(show_meta_data=False))
print("==================")
# * **byoc-target_0** contains only one add operator.
# * **byoc-target_2** contains only one composite function call.
# * **byoc-target_5** contains one composite function call as well as the `reshape`.
#
# Each partitioned function will be sent to the "byoc-target" codegen for code generation. As a result, you can imagine that the customized codegen only needs to consider the subgraphs without worrying about rest parts of the graph. In this example, it also means that the customzied codegen doesn't have to worry abou the control flow (nice!
#
# In the rest part of this demo, we are going to build a real world SSD model with the TensorRT BYOC integration, which is already available in the upstream TVM so you are welcome to try it by yourself. Specifically, we will build a Gluon CV SSD-ResNet50 model with TensorRT.
#
# Please note that in order to run this example by yourself, you need to set up TensorRT in your environment and build the TVM with TensorRT runtime. You can refer to the TVM TensorRT tutorial for detail instructions: https://tvm.apache.org/docs/deploy/tensorrt.html
# ## Part 2: Build a SSD Model with TensorRT
# In[9]:
import time
from gluoncv import data as gcv_data, model_zoo, utils
import mxnet as mx
import numpy as np
from tvm.relay.backend import compile_engine
from tvm.contrib import graph_runtime
from tvm.contrib.download import download_testdata
from tvm.relay.op.contrib import tensorrt
# Then we download an image as our input data.
# In[10]:
im_fname = download_testdata(
"https://github.com/dmlc/web-data/blob/master/gluoncv/detection/street_small.jpg?raw=true",
"street_small.jpg",
module="data",
)
# Next, let's load the MXNet SSD model from Gluon CV model zoo and convert the model to a Relay graph. We use the Gluon CV SSD model with ResNet-50 as the backbone network and trained by the COCO dataset.
# In[11]:
def get_ssd_model(model_name, image_size=512):
# Setup model
input_name = "data"
input_shape = (1, 3, image_size, image_size)
# Prepare model input data
data, img = gcv_data.transforms.presets.ssd.load_test(im_fname, short=image_size)
# Prepare SSD model
block = model_zoo.get_model(model_name, pretrained=True)
block.hybridize()
block.forward(data)
block.export("temp")
model_json = mx.symbol.load("temp-symbol.json")
save_dict = mx.ndarray.load("temp-0000.params")
arg_params = {}
aux_params = {}
for param, val in save_dict.items():
param_type, param_name = param.split(":", 1)
if param_type == "arg":
arg_params[param_name] = val
elif param_type == "aux":
aux_params[param_name] = val
# Convert the MXNet SSD model to Relay module
mod, params = relay.frontend.from_mxnet(
model_json, {input_name: input_shape}, arg_params=arg_params, aux_params=aux_params
)
return mod, params, block.classes, data.asnumpy(), img
mod, params, class_names, data, img = get_ssd_model("ssd_512_resnet50_v1_coco")
# Since the entire Relay graph is pretty large, here we use a simple Relay pass to show the total number of operators it has and what they are.
# In[12]:
def profile_graph(func):
class OpProfiler(tvm.relay.ExprVisitor):
def __init__(self):
super().__init__()
self.ops = {}
def visit_call(self, call):
op = call.op
if op not in self.ops:
self.ops[op] = 0
self.ops[op] += 1
super().visit_call(call)
def get_trt_graph_num(self):
cnt = 0
for op in self.ops:
if str(op).find("tensorrt") != -1:
cnt += 1
return cnt
profiler = OpProfiler()
profiler.visit(func)
print("Total number of operators: %d" % sum(profiler.ops.values()))
print("Detail breakdown")
for op, count in profiler.ops.items():
print("\t%s: %d" % (op, count))
print("TensorRT subgraph #: %d" % profiler.get_trt_graph_num())
profile_graph(mod["main"])
# Wow we have 676 operators in this big model! In addition to the normal compute-intensive operators, it also has many data processing operators such as reshape and where. It also has the Non Maximum Suppersion (NMS) operator that usually cannot be accelerated by many accelerators.
#
# We first try to build and run this model without TensorRT:
# In[44]:
def build_and_run(mod, data, params, build_config=None):
compile_engine.get().clear()
with tvm.transform.PassContext(opt_level=3, config=build_config):
lib = relay.build(mod, target="cuda", params=params)
# Create the runtime module
mod = graph_runtime.GraphModule(lib["default"](tvm.gpu(0)))
# Run inference 10 times
times = []
for _ in range(10):
start = time.time()
mod.run(data=data)
times.append(time.time() - start)
print("Runtime module structure:")
print("\t %s" % str(lib.get_lib()))
for sub_mod in lib.get_lib().imported_modules:
print("\t |- %s" % str(sub_mod))
print("Median inference latency %.2f ms" % (1000 * np.median(times)))
return mod, lib
_ = build_and_run(mod, data, params)
# It takes about 50 ms on NVIDIA T4 GPU! To improve this performance, we can use either AutoTVM or auto-scheduler, but they will need hours. Fortunately, we found that TensorRT is available on our environment! So now let's partition the graph and accelerate the bottleneck part (e.g., the backbone ResNet-50) with TensorRT.
# In[14]:
trt_mod, config = tensorrt.partition_for_tensorrt(mod, params)
print(config)
# The `config` is TensorRT specific build configuration. This depends on the BYOC backend and can vary.
#
# Now let's see the partitioned Relay graph. Since the entire graph is still too large, we only focus on the partitioned TensorRT functions. We can see that there will be 10 subgraphs to offload to TensorRT, including all the `conv2d` operators.
# In[15]:
profile_graph(trt_mod["main"])
print("=======================")
print(trt_mod["main"].astext(show_meta_data=False))
# Let's see what `tensorrt_0` has:
# In[16]:
profile_graph(trt_mod["tensorrt_0"])
# Apparently, this is the entire backbone ResNet-50 network.
#
# Let's see another function:
# In[17]:
profile_graph(trt_mod["tensorrt_351"])
# Oops, this function only has one operator -- `multiply`. This is because the original graph looks like the following:
#
# ```
# %88 = ones_like(%87);
# %89 = @tensorrt_351(%88);
# %90 = where(%3, %87, %89);
# ```
#
# Since `ones_like` and `where` are not supported in the TVM TensorRT runtime, the supported `multiply` becomes a single operator in this subgraph. However, this is definitely not the case we want, because the data transfer and kernel launching overhead will moderate the speedup brought by the TensorRT multiply.
#
# To deal with this case, TVM TensorRT integration provides a backend-specific optimization: `prune_tensorrt_subgraphs`. This TensorRT specific Relay pass prunes the subgraphs with no complex operators (i.e., the operators without MAC computation).
# In[30]:
config["remove_no_mac_subgraphs"] = True
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
trt_mod = tensorrt.prune_tensorrt_subgraphs(trt_mod)
# Let's see how many subgraphs we have now:
# In[19]:
profile_graph(trt_mod["main"])
# Nice! We only have one giant subgraph!
#
# Let's build and run this module to see what speedup we can achieve
# In[45]:
runtime_mod, lib = build_and_run(
trt_mod, data, params, build_config={"relay.ext.tensorrt.options": config}
)
# Wow, we get about 2x speedup!
#
# Diving into the new runtime module, we can find that there has a `tensorrt` module. This is where our subgraphs laid. We can also print its source code to see what it generated. Since the TensorRT BYOC integration runtime has a JSON graph interpreter, it simply generates a JSON file to represent the subgraphs.
# In[43]:
print(lib.get_lib().imported_modules[1].get_source())
# Finally, let's see the inference output:
# In[21]:
from matplotlib import pyplot as plt
results = [runtime_mod.get_output(i).asnumpy() for i in range(runtime_mod.get_num_outputs())]
ax = utils.viz.plot_bbox(
img, results[2][0], results[1][0], results[0][0], class_names=class_names
)
plt.show()
# Lastly, the TnesorRT BYOC integration is just about 4K LOC, mainly contributed by Trevor Morris (trevmorr@amazon.com). This demonstrates that it is totally possible for any hardware vendor to bring your own accelerator codegen to TVM!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment