Skip to content

Instantly share code, notes, and snippets.

@zer0n
Created February 9, 2019 05:24
Show Gist options
  • Save zer0n/c5057a1739a8f2b1c0421e41741ff7e3 to your computer and use it in GitHub Desktop.
Save zer0n/c5057a1739a8f2b1c0421e41741ff7e3 to your computer and use it in GitHub Desktop.
LEARNING HYBRID FRONTEND SYNTAX THROUGH EXAMPLE
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\nLearning Hybrid Frontend Syntax Through Example\n===============================================\n**Author:** `Nathan Inkawhich <https://github.com/inkawhich>`_\n\nThis document is meant to highlight the syntax of the Hybrid Frontend\nthrough a non-code intensive example. The Hybrid Frontend is one of the\nnew shiny features of Pytorch 1.0 and provides an avenue for developers\nto transition their models from **eager** to **graph** mode. PyTorch\nusers are very familiar with eager mode as it provides the ease-of-use\nand flexibility that we all enjoy as researchers. Caffe2 users are more\naquainted with graph mode which has the benefits of speed, optimization\nopportunities, and functionality in C++ runtime environments. The hybrid\nfrontend bridges the gap between the the two modes by allowing\nresearchers to develop and refine their models in eager mode (i.e.\nPyTorch), then gradually transition the proven model to graph mode for\nproduction, when speed and resouce consumption become critical.\n\nHybrid Frontend Information\n---------------------------\n\nThe process for transitioning a model to graph mode is as follows.\nFirst, the developer constructs, trains, and tests the model in eager\nmode. Then they incrementally **trace** and **script** each\nfunction/module of the model with the Just-In-Time (JIT) compiler, at\neach step verifying that the output is correct. Finally, when each of\nthe components of the top-level model have been traced and scripted, the\nmodel itself is traced. At which point the model has been transitioned\nto graph mode, and has a complete python-free representation. With this\nrepresentation, the model runtime can take advantage of high-performance\nCaffe2 operators and graph based optimizations.\n\nBefore we continue, it is important to understand the idea of tracing\nand scripting, and why they are separate. The goal of **trace** and\n**script** is the same, and that is to create a graph representation of\nthe operations taking place in a given function. The discrepency comes\nfrom the flexibility of eager mode that allows for **data-dependent\ncontrol flows** within the model architecture. When a function does NOT\nhave a data-dependent control flow, it may be *traced* with\n``torch.jit.trace``. However, when the function *has* a data-dependent\ncontrol flow it must be *scripted* with ``torch.jit.script``. We will\nleave the details of the interworkings of the hybrid frontend for\nanother document, but the code example below will show the syntax of how\nto trace and script different pure python functions and torch Modules.\nHopefully, you will find that using the hybrid frontend is non-intrusive\nas it mostly involves adding decorators to the existing function and\nclass definitions.\n\nMotivating Example\n------------------\n\nIn this example we will implement a strange math function that may be\nlogically broken up into four parts that do, and do not contain\ndata-dependent control flows. The purpose here is to show a non-code\nintensive example where the use of the JIT is highlighted. This example\nis a stand-in representation of a useful model, whose implementation has\nbeen divided into various pure python functions and modules.\n\nThe function we seek to implement, $Y(x)$, is defined for\n$x \\epsilon \\mathbb{N}$ as\n\n\\begin{align}z(x) = \\Biggl \\lfloor \\frac{\\sqrt{\\prod_{i=1}^{|2 x|}i}}{5} \\Biggr \\rfloor\\end{align}\n\n\\begin{align}Y(x) = \\begin{cases}\n \\frac{z(x)}{2} & \\text{if } z(x)\\%2 == 0, \\\\\n z(x) & \\text{otherwise}\n \\end{cases}\\end{align}\n\n\\begin{align}\\begin{array}{| r | r |} \\hline\n x &1 &2 &3 &4 &5 &6 &7 \\\\ \\hline\n Y(x) &0 &0 &-5 &20 &190 &-4377 &-59051 \\\\ \\hline\n \\end{array}\\end{align}\n\nAs mentioned, the computation is split into four parts. Part one is the\nsimple tensor calculation of $|2x|$, which can be traced. Part two\nis the iterative product calculation that represents a data dependent\ncontrol flow to be scripted (the number of loop iteration depends on the\ninput at runtime). Part three is a trace-able\n$\\lfloor \\sqrt{a/5} \\rfloor$ calculation. Finally, part 4 handles\nthe output cases depending on the value of $z(x)$ and must be\nscripted due to the data dependency. Now, let's see how this looks in\ncode.\n\nPart 1 - Tracing a pure python function\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nWe can implement part one as a pure python function as below. Notice, to\ntrace this function we call ``torch.jit.trace`` and pass in the function\nto be traced. Since the trace requires a dummy input of the expected\nruntime type and shape, we also include the ``torch.rand`` to generate a\nsingle valued torch tensor.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\n\ndef fn(x):\n return torch.abs(2*x)\n\n# This is how you define a traced function\n# Pass in both the function to be traced and an example input to ``torch.jit.trace``\ntraced_fn = torch.jit.trace(fn, torch.rand(()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Part 2 - Scripting a pure python function\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nWe can also implement part 2 as a pure python function where we\niteratively compute the product. Since the number of iterations depends\non the value of the input, we have a data dependent control flow, so the\nfunction must be scripted. We can script python functions simply with\nthe ``@torch.jit.script`` decorator.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# This is how you define a script function\n# Apply this decorator directly to the function\n@torch.jit.script\ndef script_fn(x):\n z = torch.ones([1], dtype=torch.int64)\n for i in range(int(x)):\n z = z * (i + 1)\n return z"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Part 3 - Tracing a nn.Module\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nNext, we will implement part 3 of the computation within the forward\nfunction of a ``torch.nn.Module``. This module may be traced, but rather\nthan adding a decorator here, we will handle the tracing where the\nModule is constructed. Thus, the class definition is not changed at all.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# This is a normal module that can be traced.\nclass TracedModule(torch.nn.Module):\n def forward(self, x):\n x = x.type(torch.float32)\n return torch.floor(torch.sqrt(x) / 5.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Part 4 - Scripting a nn.Module\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nIn the final part of the computation we have a ``torch.nn.Module`` that\nmust be scripted. To accomodate this, we inherit from\n``torch.jit.ScriptModule`` and add the ``@torch.jit.script_method``\ndecorator to the forward function.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# This is how you define a scripted module.\n# The module should inherit from ScriptModule and the forward should have the\n# script_method decorator applied to it.\nclass ScriptModule(torch.jit.ScriptModule):\n @torch.jit.script_method\n def forward(self, x):\n r = -x\n if int(torch.fmod(x, 2.0)) == 0.0:\n r = x / 2.0\n return r"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Top-Level Module\n~~~~~~~~~~~~~~~~\n\nNow we will put together the pieces of the computation via a top level\nmodule called ``Net``. In the constructor, we will instantiate the\n``TracedModule`` and ``ScriptModule`` as attributes. This must be done\nbecause we ultimately want to trace/script the top level module, and\nhaving the traced/scripted modules as attributes allows the Net to\ninherit the required submodules' parameters. Notice, this is where we\nactually trace the ``TracedModule`` by calling ``torch.jit.trace()`` and\nproviding the necessary dummy input. Also notice that the\n``ScriptModule`` is constructed as normal because we handled the\nscripting in the class definition.\n\nHere we can also print the graphs created for each individual part of\nthe computation. The printed graphs allows us to see how the JIT\nultimately interpreted the functions as graph computations.\n\nFinally, we define the ``forward`` function for the Net module where we\nrun the input data ``x`` through the four parts of the computation.\nThere is no strange syntax here and we call the traced and scripted\nmodules and functions as expected.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# This is a demonstration net that calls all of the different types of\n# methods and functions\nclass Net(torch.nn.Module):\n def __init__(self):\n super(Net, self).__init__()\n # Modules must be attributes on the Module because if you want to trace\n # or script this Module, we must be able to inherit the submodules'\n # params.\n self.traced_module = torch.jit.trace(TracedModule(), torch.rand(()))\n self.script_module = ScriptModule()\n\n print('traced_fn graph', traced_fn.graph)\n print('script_fn graph', script_fn.graph)\n print('TracedModule graph', self.traced_module.__getattr__('forward').graph)\n print('ScriptModule graph', self.script_module.__getattr__('forward').graph)\n\n def forward(self, x):\n # Call a traced function\n x = traced_fn(x)\n\n # Call a script function\n x = script_fn(x)\n\n # Call a traced submodule\n x = self.traced_module(x)\n\n # Call a scripted submodule\n x = self.script_module(x)\n\n return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Running the Model\n~~~~~~~~~~~~~~~~~\n\nAll that's left to do is construct the Net and compute the output\nthrough the forward function. Here, we use $x=5$ as the test input\nvalue and expect $Y(x)=190.$ Also, check out the graphs that were\nprinted during the construction of the Net.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Instantiate this net and run it\nn = Net()\nprint(n(torch.tensor([5]))) # 190."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tracing the Top-Level Model\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe last part of the example is to trace the top-level module, ``Net``.\nAs mentioned previously, since the traced/scripted modules are\nattributes of Net, we are able to trace ``Net`` as it inherits the\nparameters of the traced/scripted submodules. Note, the syntax for\ntracing Net is identical to the syntax for tracing ``TracedModule``.\nAlso, check out the graph that is created.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"n_traced = torch.jit.trace(n, torch.tensor([5]))\nprint(n_traced(torch.tensor([5])))\nprint('n_traced graph', n_traced.graph)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Hopefully, this document can serve as an introduction to the hybrid\nfrontend as well as a syntax reference guide for more experienced users.\nAlso, there are a few things to keep in mind when using the hybrid\nfrontend. There is a constraint that traced/scripted methods must be\nwritten in a restricted subset of python, as features like generators,\ndefs, and Python data structures are not supported. As a workaround, the\nscripting model *is* designed to work with both traced and non-traced\ncode which means you can call non-traced code from traced functions.\nHowever, such a model may not be exported to ONNX.\n\n\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment