This document is a sort of guide to creating an MXNet operator. However, here we
will learn by exploring an example: elemwise_sum
.
(.h
,
.c
,
.cu
)
The code in elemwise_sum.{h,cc,cu}
defines and registers
the add_n()
operator, which computes the elementwise sum of an arbitrary number of input
arguments of the same shape.
>>> import mxnet as mx
>>> a = mx.nd.array([1,2,3])
>>> b = mx.nd.array([4,5,6])
>>> c = mx.nd.array([7,8,9])
>>> x = mx.nd.add_n(a,b,c)
>>> x.asnumpy()
array([ 12., 15., 18.], dtype=float32)
Before jumping into the workhorse part of the code let's see how an operator is registered with MXNet. Once we understand this we'll see how all of the consituents of the implementation are linked together.
Registration is done using the macro NNVM_REGISTER_OP
. The code for
registering add_n
can be found
in
elemwise_sum.c
and
elemwise_sum.cu
.
Since the operator registration on the GPU-side inherits most of its parameters
from the CPU-side parameter we will focus on the contents of elemwise_sum.cc
:
NNVM_REGISTER_OP(add_n)
.add_alias("ElementWiseSum")
.describe(R"doc(Adds all input arguments element-wise.
.. math::
add\_n(a_1, a_2, ..., a_n) = a_1 + a_2 + ... + a_n
``add_n`` is potentially more efficient than calling ``add`` by `n` times.
)doc" ADD_FILELINE)
.set_attr_parser(ParamParser<ElementWiseSumParam>)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
return ret;
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
uint32_t num_args = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args;
std::vector<std::string> ret;
for (uint32_t i = 0; i < num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)
.set_attr<nnvm::FInplaceOption>(
"FInplaceOption", [](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FGradient>("FGradient", CloneGradient{"_backward_add_n"})
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments");
There is a lot to parse here, but don't worry, we'll take each component one at a time.
Let's look at each of the attributes in the above operator definition one by
one. For reference, the source code for the macro and basic operator
structure/attributes can be found in
NNVM's op.h
.
Also
see
op_attr_types.h
for
a list of the different types of attributes. If you're already really
comfortable with working with MXNet then go ahead and read these two source
codes. My goal here is to take some of the documentation listed in these files
and exand upon them.
-
NNVM_REGISTER_OP(add_n)
This registers the name of the operator. When built, you can invoke the operator from the Python interface, for example, from
mx.nd.add_n()
ormx.sym.add_n()
. At the C++ level, the operator would be invoked usingOp::Get(add_n)
. For example, one could write the following to make the operator easier to use:using namespace mxnet:op; // TODO: is this the right one? const Op* add_n = Op::Get("add_n"); // use add_n below using an OpKernel (to be discussed below)
-
.add_alias("ElementWiseSum")
Register an alias for the operator, allowing you to invoke it at the C++ level by writing
Op::Get(ElementWiseSum)
. -
.describe(R"...")
Easy enough: include a docstring. Parsed as reStructured text syntax for the purposes of generating online documentation. See here for a reST primer.
-
.set_attr_parser(ParamParser<dmlc::Parameter>)
Allows you to customize the way the attributes are parsed in the definition and invocation of the operator. In this case, the function
add_n()
is a bit tricky since we want to allow an arbitrary number of input arguments.Let's look at
ElementWiseSumParam
more closely to see what's going on,struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> { int num_args; DMLC_DECLARE_PARAMETER(ElementWiseSumParam) { DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) .describe("Number of inputs to be summed."); } }; DMLC_REGISTER_PARAMETER(ElementWiseSumParam);
Right off the bat we see that the struct makes use of the Curiously Recurring Template Pattern which should make you feel like you're a rockstar C++ programmer. (Basically, CRTP is a compile-time polymorphism technique.) Long story short, the parameter inherits from
dmlc::Parameter
which is a lightweight parameter management system for operators and objects.Inside this struct we declare the parameter(s) we want to manage in the construction of the operator, in this case
num_args
. (Again, the whole point of this is so that we can provide the function a variable number of arguments.)DMLC_DECLARE_PARAMETER()
is a macro for augmenting a particular parameter. In this case, we want to provide a description and set a lower bound.Note that parameters need to be registered separately from the operator.
This is all done so we can define the next two attributes...
-
.set_num_inputs(int OR std::function<uint32_t (const NodeAttrs& attr)>)
Set the number of inputs to the operator. That's it! For some operators the number of inputs is fixed and known. In this case, all you need to do is provide a hard coded integer, here. (e.g.
.set_num_inputs(2)
for an operator with two arguments/inputs)But that's not really it because here we are already in the deep end and want to define the way get the number of inputs using our custom parameter parser. The code used in
add_n()
is repeated here for convenience,.set_num_inputs([](const nnvm::NodeAttrs& attrs) { uint32_t ret = dmlc::get<ElementWiseSumParam>(attrs.parsed).num_args; return ret; })
As the prototype requests we provide a function that accepts a
NodeAttrs
struct, which in this case are of typeElementWiseSumParam
, and returns an unsigned integer. The lambda function above gets theElementWiseSumParam
and returns itsnum_args
which is set to the number of arguments passed to the operator. *(TODO: at what point isnum_args
set?) -
set_attr<nnvm::FListInputNames>
There is a class of templatized operator attributes. See here for a pre-defined list. The function
.set_attr<T>
(here be dragons at this link) accepts three arguments:const string& attr_name
- the name of the attributeconst T& value
- the value to set this attribute toint plevel
- the priority level of this attribute. If the operator inherits from another operator this tells the compiler which definition of the attribute to use. The priority level is set to 10 by default.
So that being said,
add_n()
defines an attribute called"FListInputNames"
where the value is the same type asnnvm::FListInputNames
: a function which accepts aconst NodeAttrs&
, which in this case is interpreted as anElementWiseSumParam
, and returns a vector of strings. We can see in the attribute definition below that the function simply extractsnum_args
from theElementWiseSumParam
and returns the vector["arg0", "arg1", ..., "arg(num_args-1)"]
.The point of this is to enables automatic variable creation for missing arguments.
-
.set_attr<std::string>("key_var_num_args", "num_args")
This not well documented but examining the source code it seems to be a hint to the docstring generator that this function accepts a variable number of arguments. Seems to only be needed in this kind of situation.
-
.set_attr<FCompute>("FCompute<cpu>", ElementWiseSumCompute<cpu>)
This is a key attribute to a new operator! Assigning
FCompute
tells MXNet which function to call when the operator is called. That is,add_n()
is, more or less, the functionElementWiseSumCompute<>()
but with some layers of pre-processing. Later in the document we'll talk about this function in more detail, but I'll show theFCompute
function prototype anyway,using FCompute = std::function<void (const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req, const std::vector<TBlob>& outputs)>;
This doesn't look like what we would use an input to
add_n()
. Again, we'll get back to this later. -
.set_attr<nnvm::FInplaceOption>
Operators have the option of performing computations in-place. That is, you can optionally store the output of operator in the memory already occupied by one of the inputs. The prototype for
FInplaceOption
is,using FInplaceOption = std::function< std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
Basically, the value of
FInplaceOption
is a function mapping the attributes of this compute node, which in this case is of typeElementWiseSumParam
, to a list of two-tuples. Each tuple{i,j}
defines a map from inputi
to outputj
. That is, the memory location of inputi
is the same as the memory location of outputj
.For
add_n()
theFInplaceOption
function always returns{{0,0}}
, meaning that no matter how many arguments are passed we only map the first input to the first (and only) output. This makes sense sinceadd_n()
is a variable argument function and, by the definition ofElementWiseSumParam
, will always have at least one input.The nice thing about this design is that we can store the result of our computation can be stored in the appropriate output pointer and MXNet will take care of the in-place'edness.
-
.set_attr<nnvm::FInferShape>
See Inferring Shapes.
-
.set_attr<nnvm::FInferType>
See Inferring Types.
foo
bar