Skip to content

Instantly share code, notes, and snippets.

@dneprDroid
Created June 15, 2018 07:38
Show Gist options
  • Save dneprDroid/17cc74a62f54ce66a4d56e124d98cf1d to your computer and use it in GitHub Desktop.
Save dneprDroid/17cc74a62f54ce66a4d56e124d98cf1d to your computer and use it in GitHub Desktop.
Tensorflow KernelRegistration.cpp
#include <tensorflow/core/framework/op.h>
#include <tensorflow/core/framework/op_kernel.h>
#include <tensorflow/core/framework/shape_inference.h>
using namespace tensorflow;
using namespace register_kernel;
struct KernelRegistration {
KernelRegistration(const KernelDef& d, StringPiece c,
kernel_factory::OpKernelRegistrar::Factory f)
: def(d), kernel_class_name(c.ToString()), factory(f) {}
const KernelDef def;
const string kernel_class_name;
const kernel_factory::OpKernelRegistrar::Factory factory;
};
typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;
static string _GenKey(StringPiece op_type, const DeviceType& device_type,
StringPiece label) {
return strings::StrCat(op_type, ":", device_type.type(), ":", label);
}
static KernelRegistry* GlobalKernelRegistryTyped() {
return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
}
typedef Eigen::ThreadPoolDevice CPUDevice;
void RegisterOp(string kernel_origin_name, bool replace_origin) {
auto factory = [](OpKernelConstruction* context) -> ::tensorflow::OpKernel* {
return new CustomOpKernel<CPUDevice, float, false>(context);
};
auto kernel_def = Name(kernel_origin_name.c_str()).Device(DEVICE_CPU).Build();
if (kernel_def->op() != "_no_register") {
const string key = _GenKey(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label());
KernelRegistration registration = {*kernel_def, kernel_origin_name, factory};
if (replace_origin)
GlobalKernelRegistryTyped()->erase(key);
GlobalKernelRegistryTyped()->insert(std::make_pair(key, registration));
}
delete kernel_def;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment