Skip to content

Instantly share code, notes, and snippets.

@lindahua
Created October 3, 2016 13:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lindahua/33110dfbcb1542f3742474aaf46b61af to your computer and use it in GitHub Desktop.
Save lindahua/33110dfbcb1542f3742474aaf46b61af to your computer and use it in GitHub Desktop.
Simplified way to register operations (proof of concept)
// A proof-of-concept demonstration of Operation definition & registration
#include <functional>
#include <iostream>
#include <string>
#include <vector>
#include <utility>
#include <tuple>
#include <cassert>
#include <cmath>
#include <unordered_map>
struct Attr {
std::string name;
double value;
};
using attr_list_t = std::vector<Attr>;
using calc_func_t = std::function<double(size_t, const double*)>;
// Operation that actually performs the computation
class Op final {
private:
calc_func_t fwd_;
public:
Op(calc_func_t f) : fwd_(f) {}
double forward(size_t n, const double* args) const {
return fwd_(n, args);
}
};
using op_creator_t = std::function<Op(const attr_list_t&)>;
// Proto that specifies useful meta information
class OpProto final {
private:
std::string name_;
size_t ninputs_ = 0;
op_creator_t opcreator_;
public:
OpProto(const std::string& name)
: name_(name) {}
const std::string& name() const {
return name_;
}
OpProto& set_ninputs(size_t n) {
ninputs_ = n;
return *this;
}
size_t ninputs() const {
return ninputs_;
}
OpProto& set_opcreator(op_creator_t cf) {
opcreator_ = cf;
return *this;
}
Op createOp(const attr_list_t& attrs) const {
return opcreator_(attrs);
}
};
// Registration facilities
static std::unordered_map<std::string, OpProto> registry;
inline OpProto& registerOp(const std::string& name) {
registry.emplace(name, OpProto(name));
return registry.at(name);
}
inline const OpProto& getOp(const std::string& name) {
return registry.at(name);
}
// main
int main() {
// register operations
registerOp("add")
.set_ninputs(2)
.set_opcreator([](const attr_list_t&){
return Op([](size_t n, const double *args){
assert(n == 2);
return args[0] + args[1];
});
});
registerOp("mul")
.set_ninputs(2)
.set_opcreator([](const attr_list_t&){
return Op([](size_t n, const double *args){
assert(n == 2);
return args[0] * args[1];
});
});
registerOp("pow")
.set_ninputs(1)
.set_opcreator([](const attr_list_t& attrs){
double p = 1;
// extract attributes
for (const auto& a: attrs) {
if (a.name == "p") { p = a.value; }
}
// create operator accordingly
return Op([p](size_t n, const double *args){
assert(n == 1);
return std::pow(args[0], p);
});
});
// set a sequence of computations to be done
using item_t = std::tuple<OpProto, attr_list_t, std::vector<double>>;
std::vector<item_t> items {
item_t{getOp("add"), {}, {1.0, 2.0}},
item_t{getOp("mul"), {}, {2.0, 3.0}},
item_t{getOp("pow"), {{"p", 2.0}}, {5.0}},
};
// run the items
for (const auto& item: items) {
const OpProto& proto = std::get<0>(item);
const attr_list_t& attrs = std::get<1>(item);
const std::vector<double>& args = std::get<2>(item);
Op op = proto.createOp(attrs);
double r = op.forward(args.size(), args.data());
std::cout << proto.name() << " --> " << r << std::endl;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment