Skip to content

Instantly share code, notes, and snippets.

@cswiercz
Last active August 3, 2017 17:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cswiercz/37ec48f44a713fb3676480a8706644ed to your computer and use it in GitHub Desktop.
Save cswiercz/37ec48f44a713fb3676480a8706644ed to your computer and use it in GitHub Desktop.
Understanding MXNet elemwise_sum

Table of Contents

  1. Introduction
  2. Registering and Operator
    1. Operator Registration Attributes
  3. Inferring Shapes
  4. Inferring Types

Introduction

This document is a sort of guide to creating an MXNet operator. However, here we will learn by exploring an example: elemwise_sum. (.h, .c, .cu)

The code in elemwise_sum.{h,cc,cu} defines and registers the add_n() operator, which computes the elementwise sum of an arbitrary number of input arguments of the same shape.

>>> import mxnet as mx
>>> a = mx.nd.array([1,2,3])
>>> b = mx.nd.array([4,5,6])
>>> c = mx.nd.array([7,8,9])
>>> x = mx.nd.add_n(a,b,c)
>>> x.asnumpy()
array([ 12.,  15.,  18.], dtype=float32)

Registering an Operator

Before jumping into the workhorse part of the code let's see how an operator is registered with MXNet. Once we understand this we'll see how all of the consituents of the implementation are linked together.

Registration is done using the macro NNVM_REGISTER_OP. The code for registering add_n can be found in elemwise_sum.c and elemwise_sum.cu. Since the operator registration on the GPU-side inherits most of its parameters from the CPU-side parameter we will focus on the contents of elemwise_sum.cc:

NNVM_REGISTER_OP(add_n)
.add_alias("ElementWiseSum")
.describe(R"doc(Adds all input arguments element-wise.
.. math::
   add\_n(a_1, a_2, ..., a_n) = a_1 + a_2 + ... + a_n
``add_n`` is potentially more efficient than calling ``add`` by `n` times.
)doc" ADD_FILELINE)
.set_attr_parser(ParamParser<ElementWiseSumParam>)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
    uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
    return ret;
  })
.set_attr<nnvm::FListInputNames>("FListInputNames",
  [](const NodeAttrs& attrs) {
    uint32_t num_args = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
    std::vector<std::string> ret;
    for (uint32_t i = 0; i < num_args; ++i) {
      ret.push_back(std::string("arg") + std::to_string(i));
    }
    return ret;
  })
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)
.set_attr<nnvm::FInplaceOption>(
    "FInplaceOption", [](const NodeAttrs& attrs) {
      return std::vector<std::pair<int, int> >{{0, 0}};
    })
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_add_n"})
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments");

There is a lot to parse here, but don't worry, we'll take each component one at a time.

Operator Registration Attributes

Let's look at each of the attributes in the above operator definition one by one. For reference, the source code for the macro and basic operator structure/attributes can be found in NNVM's op.h. Also see op_attr_types.h for a list of the different types of attributes. If you're already really comfortable with working with MXNet then go ahead and read these two source codes. My goal here is to take some of the documentation listed in these files and exand upon them.

  • NNVM_REGISTER_OP(add_n)

    This registers the name of the operator. When built, you can invoke the operator from the Python interface, for example, from mx.nd.add_n() or mx.sym.add_n(). At the C++ level, the operator would be invoked using Op::Get(add_n). For example, one could write the following to make the operator easier to use:

    using namespace mxnet:op; // TODO: is this the right one?
    const Op* add_n = Op::Get("add_n");
    // use add_n below using an OpKernel (to be discussed below)
  • .add_alias("ElementWiseSum")

    Register an alias for the operator, allowing you to invoke it at the C++ level by writing Op::Get(ElementWiseSum).

  • .describe(R"...")

    Easy enough: include a docstring. Parsed as reStructured text syntax for the purposes of generating online documentation. See here for a reST primer.

  • .set_attr_parser(ParamParser<dmlc::Parameter>)

    Allows you to customize the way the attributes are parsed in the definition and invocation of the operator. In this case, the function add_n() is a bit tricky since we want to allow an arbitrary number of input arguments.

    Let's look at ElementWiseSumParam more closely to see what's going on,

    struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> {
    int num_args;
    DMLC_DECLARE_PARAMETER(ElementWiseSumParam) {
      DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
          .describe("Number of inputs to be summed.");
      }
    };
    
    DMLC_REGISTER_PARAMETER(ElementWiseSumParam);

    Right off the bat we see that the struct makes use of the Curiously Recurring Template Pattern which should make you feel like you're a rockstar C++ programmer. (Basically, CRTP is a compile-time polymorphism technique.) Long story short, the parameter inherits from dmlc::Parameter which is a lightweight parameter management system for operators and objects.

    Inside this struct we declare the parameter(s) we want to manage in the construction of the operator, in this case num_args. (Again, the whole point of this is so that we can provide the function a variable number of arguments.) DMLC_DECLARE_PARAMETER() is a macro for augmenting a particular parameter. In this case, we want to provide a description and set a lower bound.

    Note that parameters need to be registered separately from the operator.

    This is all done so we can define the next two attributes...

  • .set_num_inputs(int OR std::function<uint32_t (const NodeAttrs& attr)>)

    Set the number of inputs to the operator. That's it! For some operators the number of inputs is fixed and known. In this case, all you need to do is provide a hard coded integer, here. (e.g. .set_num_inputs(2) for an operator with two arguments/inputs)

    But that's not really it because here we are already in the deep end and want to define the way get the number of inputs using our custom parameter parser. The code used in add_n() is repeated here for convenience,

    .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
      uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
      return ret;
    })

    As the prototype requests we provide a function that accepts a NodeAttrs struct, which in this case are of type ElementWiseSumParam, and returns an unsigned integer. The lambda function above gets the ElementWiseSumParam and returns its num_args which is set to the number of arguments passed to the operator. *(TODO: at what point is num_args set?)

  • set_attr<nnvm::FListInputNames>

    There is a class of templatized operator attributes. See here for a pre-defined list. The function .set_attr<T> (here be dragons at this link) accepts three arguments:

    • const string& attr_name - the name of the attribute
    • const T& value - the value to set this attribute to
    • int plevel - the priority level of this attribute. If the operator inherits from another operator this tells the compiler which definition of the attribute to use. The priority level is set to 10 by default.

    So that being said, add_n() defines an attribute called "FListInputNames" where the value is the same type as nnvm::FListInputNames: a function which accepts a const NodeAttrs&, which in this case is interpreted as an ElementWiseSumParam, and returns a vector of strings. We can see in the attribute definition below that the function simply extracts num_args from the ElementWiseSumParam and returns the vector ["arg0", "arg1", ..., "arg(num_args-1)"].

    The point of this is to enables automatic variable creation for missing arguments.

  • .set_attr<std::string>("key_var_num_args", "num_args")

    This not well documented but examining the source code it seems to be a hint to the docstring generator that this function accepts a variable number of arguments. Seems to only be needed in this kind of situation.

  • .set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)

    This is a key attribute to a new operator! Assigning FCompute tells MXNet which function to call when the operator is called. That is, add_n() is, more or less, the function ElementWiseSumCompute<>() but with some layers of pre-processing. Later in the document we'll talk about this function in more detail, but I'll show the FCompute function prototype anyway,

    using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
                                         const OpContext& ctx,
                                         const std::vector<TBlob>& inputs,
                                         const std::vector<OpReqType>& req,
                                         const std::vector<TBlob>& outputs)>;

    This doesn't look like what we would use an input to add_n(). Again, we'll get back to this later.

  • .set_attr<nnvm::FInplaceOption>

    Operators have the option of performing computations in-place. That is, you can optionally store the output of operator in the memory already occupied by one of the inputs. The prototype for FInplaceOption is,

    using FInplaceOption = std::function<
      std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;

    Basically, the value of FInplaceOption is a function mapping the attributes of this compute node, which in this case is of type ElementWiseSumParam, to a list of two-tuples. Each tuple {i,j} defines a map from input i to output j. That is, the memory location of input i is the same as the memory location of output j.

    For add_n() the FInplaceOption function always returns {{0,0}}, meaning that no matter how many arguments are passed we only map the first input to the first (and only) output. This makes sense since add_n() is a variable argument function and, by the definition of ElementWiseSumParam, will always have at least one input.

    The nice thing about this design is that we can store the result of our computation can be stored in the appropriate output pointer and MXNet will take care of the in-place'edness.

  • .set_attr<nnvm::FInferShape>

    See Inferring Shapes.

  • .set_attr<nnvm::FInferType>

    See Inferring Types.

Inferring Shapes

foo

Inferring Types

bar

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment