Skip to content

Instantly share code, notes, and snippets.

@KellenSunderland
Last active September 14, 2018 21:51
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save KellenSunderland/a31a2595ca380605c23a3fb2b1028377 to your computer and use it in GitHub Desktop.
Save KellenSunderland/a31a2595ca380605c23a3fb2b1028377 to your computer and use it in GitHub Desktop.
TensorRT Walkthrough

Speeding up Deep Learning Computation Graphs in MXNet with TensorRT Integration

TensorRT is a deep learning library, provided by NVIDIA that has been shown to show large speedups when used for network inference. MXNet 1.3 is shipping with experimental integrated support for TensorRT, meaning users can easily make use of this acceleration library to efficiently run their networks. In this blog post we'll see how to install, enable and run TensorRT with MXNet.

Installation and Prerequisites

Installing MXNet with TensorRT integration is an easy process. First ensure that you are running Ubuntu 16.04 or newer, and have NVIDIA drivers installed and up to date. You’ll need a Pascal or newer generation NVIDIA gpu. You'll also have to download and install TensorRT libraries instructions here. Once your drivers are installed and up-to-date you can install a special build of MXNet with TensorRT support enabled via PyPi and pip. CUDA 9.2 and 9.0 drivers are both supported, install the appropriate version by running:

CUDA 9.0:

pip install mxnet-tensorrt-cu90

CUDA 9.2:

pip install mxnet-tensorrt-cu92

If you prefer to use a docker image with all pre-requisites installed you can instead run:

nvidia-docker run -ti mxnet/tensorrt python

Sample Models

Resnet 18

TensorRT is an inference only library, so for the purposes of this blog post we will be using a pre-trained network, in this case a Resnet 18. Resnets are a computationally intensive model architecture that are often used as a backbone for various computer vision tasks. Resnets are also commonly used as a reference for benchmarking deep learning library performance. In this section we'll use a pretrained Resnet 18 from the Gluon Model Zoo and compare its inference speed with TensorRT using MXNet with TensorRT integration turned off as a baseline.

Model Initialization

import mxnet as mx
from mxnet.gluon.model_zoo import vision
import time

batch_shape = (1, 3, 224, 224)
resnet18 = vision.resnet18_v2(pretrained=True)
resnet18.hybridize()
resnet18.forward(mx.nd.zeros(batch_shape))
resnet18.export('resnet18_v2')
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet18_v2', 0)

In our first section of code we import the modules needed to run MXNet, and to time our benchmark runs. We then download a pretrained version of Resnet18, hybridize it, and load it symbolically. It's important to note that the experimental version of TensorRT integration will only work with the symbolic MXNet API. If you're using Gluon, you must hybridize your computation graph and export it as a symbol before running inference. This may be addressed in future releases of MXNet, but in general if you’re concerned about getting the best inference performance possible from your models, it's a good practice to hybridize.

MXNet Performance

# Create sample input
input = mx.nd.zeros(batch_shape)

# Execute with MXNet
os.environ['MXNET_USE_TENSORRT'] = '0'
executor = sym.simple_bind(ctx=mx.gpu(0), data=batch_shape, grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)

# Warmup
print('Warming up MXNet')
for i in range(0, 10):
    y_gen = executor.forward(is_train=False, data=input)
    y_gen[0].wait_to_read()

# Timing
print('Starting MXNet timed run')
start = time.process_time()
for i in range(0, 10000):
    y_gen = executor.forward(is_train=False, data=input)
    y_gen[0].wait_to_read()
end = time.time()
print(time.process_time() - start)

For this experiment we are strictly interested in inference performance, so to simplify the benchmark we'll pass a tensor filled with zeros as an input. We then bind a symbol as usual, returning a normal MXNet executor, and we run forward on this executor in a loop. On a modern PC with a Titan V GPU the time taken is 33.73s.

While TensorRT integration remains experimental, we require users to set an environment variable to enable graph compilation. You can see that at the start of this test we explicitly disable TensorRT graph compilation support. Next, we will run the same predictions using TensorRT. This will require us to explicitly enable the MXNET_USE_TENSORRT environment variable, and we'll also use a slightly different API to bind our symbol.

MXNet with TensorRT Integration Performance

# Execute with TensorRT
print('Building TensorRT engine')
os.environ['MXNET_USE_TENSORRT'] = '1'
arg_params.update(aux_params)
all_params = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in arg_params.items()])
executor = mx.contrib.tensorrt.tensorrt_bind(sym, ctx=mx.gpu(0), all_params=all_params,
                                             data=batch_shape, grad_req='null', force_rebind=True)

Instead of calling simple_bind directly on our symbol to return an executor, we call an experimental API from the contrib module of MXNet. This call is meant to emulate the simple_bind call, and has many of the same arguments. One difference to note is that this call takes params in the form of a single merged dictionary to assist with a Tensor cleanup pass that we’ll describe below.

As TensorRT integration improves our goal is to gradually deprecate this tensorrt_bind call, and allow users to use TensorRT transparently (see the Subgraph API for more information). When this happens, the similarity between tensorrt_bind and simple_bind should make it easy to migrate your code.

#Warmup
print('Warming up TensorRT')
for i in range(0, 10):
    y_gen = executor.forward(is_train=False, data=input)
    y_gen[0].wait_to_read()

# Timing
print('Starting TensorRT timed run')
start = time.process_time()
for i in range(0, 10000):
    y_gen = executor.forward(is_train=False, data=input)
    y_gen[0].wait_to_read()
end = time.time()
print(time.process_time() - start)

We run timing with a warmup once more, and on the same machine, run in 18.99s. A 1.8x speed improvement! Speed improvements when using libraries like TensorRT can come a variety of optimizations, but in this case our speedups are coming from a technique known as operator fusion.

Operators and Subgraph Fusion

Behind the scenes a number of interesting things are happening to make these optimizations possible, and most revolve around subgraphs and operator fusion. As we can see in the images below, neural networks can be represented as computation graphs of operators (nodes in the graphs). Operators can perform a variety of functions, but most run simple mathematics and linear algebra on tensors. Often these operators run more efficiently if fused together into a large CUDA kernel that is executed on the GPU in a single call. What the MXNet TensorRT integration enables is the ability to scan the entire computation graph, identify interesting subgraphs and optimize them with TensorRT.

What this means is that when an MXNet computation graph is constructed, it will be parsed to determine if there are any sub-graphs that contain operator types that are supported by TensorRT. If MXNet determines that there are one (or many) compatible subgraphs during the graph-parse, it will extract these graphs and replace them with special TensorRT nodes (visible in the diagrams below). As the graph is executed whenever a TensorRT node is reached, the graph will make a library call to TensorRT which will run its own implementation of the subgraph, potentially with many operators fused together into a single CUDA kernel.

During this process MXNet will take care of passing along the input to the node and fetching the results. MXNet will also attempt to remove any duplicated weights (parameters) during the graph initialization to keep memory usage low. That is, if there are graph weights that are used only in the TensorRT sections of the graph, they will be removed from the MXNet set of parameters, and their memory will be freed.

Before

before

After

after

Future Work

As mentioned above, MXNet developers are excited about the possibilities of creating APIs that deal specifically with subgraphs. As this work matures it will bring many improvements for TensorRT users. We hope this will also be an opportunity for other acceleration libraries to integrate with MXNet.

Thanks

Thank you to NVIDIA for contributing this feature, and specifically thanks to Marek Kolodziej and Clement Fuji-Tsang. Thanks to Junyuan Xie and Jun Wu for the reviews and design feedback.

TODO (kellens)

I'd like to include a section where we make use of MXNet export to build a TensorRT engine. (Roshani let me know if this sounds like a good idea to you).

ONNX Graph Export

MXNet users using TensorRT 4.0 can build TRT engines from MXNet by: (1) exporting their MXNet graphs into ONNX format, (2) using NVIDIA tooling to compile the exported graphs from ONNX into a TensorRT engine, (3) compute predictions with the TensorRT engine using NVIDIA APIs.

Outline

  • Introduce TensorRT
  • Describe two paths to using TensorRT from MXNet
  • Show workflow diagrams of the two methods.
  • Code samples, highlighting the API changes.
  • Show performance difference between MXNet, TensortRT and MXNet w/ TRT.
  • Code samples highlighting how to reproduce perf gains.
  • Section listing requirements and linking to installation instructions.
  • Talk about cases where it makes sense to run both.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment