Skip to content

Instantly share code, notes, and snippets.

@cswiercz
Last active July 6, 2017 15:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cswiercz/b1a72f938ca122ce6afb6ac3248007f9 to your computer and use it in GitHub Desktop.
Save cswiercz/b1a72f938ca122ce6afb6ac3248007f9 to your computer and use it in GitHub Desktop.
Introduction to Creating a New MXNet Operator

WIP - There is still much to learn.

Creating a New MXNet Operator

################################################################################ Let's begin by creating a simple, near-minimal working example. I'll start with the code and then walk through each component. We will create an operator, add_one() which will take a vector as input, add one to each element, and return the result.

First, the meat of the code:

// add_one.h

namespace mxnet {
namespace op {

template<typename xpu>
void AddOneCompute(
    const nnvm::NodeAttrs& attrs,
    const OpContext& ctx,
    const std::vector<TBlob>& inputs,
    const std::vector<OpReqType>& req,
    const std::vector<TBlob>& outputs) {
  using namespace mxnet_op;

  Stream<xpu> *s = ctx.get_stream<xpu>();
  Tensor<xpu, 1, real_t> in_vec = inputs[0].get<xpu, 1, real_t>(s);
  Tensor<xpu, 1, real_t> out_vec = outputs[0].get<xpu, 1, real_t>(s);
  for (index_t i=0; i<out_vec.size(0); ++i)
    out_vec[i] = in_vec[i] + 1;
}

Second, the operator registration:

// add_one_op.cc

#include "./lda.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(add_one)
.describe("Adds one to every element of the input")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1,1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1,1>)
.set_attr<FCompute>("FCompute<cpu>", AddOneCompute<cpu>)
.add_argument("data", "NDArray", "Source input");

Inplace Operations

################################################################################ To save on memory one can enable in-place operations. The in-place-edness occurs during a symbolic operation of the form x = f(x).

To enable the in-place option include the following directive in your operator registration: this configuration in particular maps the first input (index 0) to the first output (index 0)

.set_attr<nnvm::FInplaceOption>("FInplaceOption",
                                [](const NodeAttrs& attrs) {
                                  return std::vector<std::pair<int, int> >{{0, 0}};
                                })
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment