Skip to content

Instantly share code, notes, and snippets.

@samskalicky
Created August 19, 2020 08:08
Show Gist options
  • Save samskalicky/750fc456fe838325c0701e30ac5cc3c8 to your computer and use it in GitHub Desktop.
Save samskalicky/750fc456fe838325c0701e30ac5cc3c8 to your computer and use it in GitHub Desktop.
MXNet Ops
// put in src/operator and build MXNet
#include "operator_common.h"
extern "C" int listOps() {
// get op registry
::dmlc::Registry<::nnvm::Op>* reg = ::dmlc::Registry<::nnvm::Op>::Get();
// get list of registered op names
std::vector<std::string> ops = reg->ListAllNames();
// create inverse map of Op to name (to find aliases)
std::map<const ::nnvm::Op*,std::vector<std::string> > op_map;
for(auto &name : ops) {
const ::nnvm::Op* op = reg->Find(name);
if(op_map.count(op) > 0) {
if(name.compare(op->name) != 0)
op_map[op].push_back(name);
} else {
op_map[op]={};
if(name.compare(op->name) != 0)
op_map[op].push_back(name);
}
}
// print out the op mapping
for(auto &kv : op_map) {
std::cout << kv.first->name << ", ";
for(auto &n : kv.second)
std::cout << n << ", ";
std::cout << std::endl;
}
return 0;
}
int n = listOps();
import mxnet as mx
mx.base._LIB.listOps()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment