Last active
May 20, 2021 06:09
-
-
Save dkurt/03f2d86db7ee5b42f0db8513f75be9e2 to your computer and use it in GitHub Desktop.
OpenVINO nGraph custom layer in runtime
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cmake_minimum_required(VERSION 3.4.3) | |
project(ngraph_custom_layer CXX) | |
find_package(InferenceEngine REQUIRED) | |
find_package(ngraph REQUIRED) | |
find_package(TBB REQUIRED tbb tbbmalloc) | |
include_directories( | |
${OpenCV_INCLUDE_DIRS} | |
include | |
) | |
file(GLOB SOURCES main.cpp) | |
add_executable(${CMAKE_PROJECT_NAME} ${SOURCES} ${HEADERS}) | |
target_compile_features(${CMAKE_PROJECT_NAME} PRIVATE cxx_range_for) | |
target_link_libraries(${CMAKE_PROJECT_NAME} | |
${InferenceEngine_LIBRARIES} | |
${NGRAPH_LIBRARIES} | |
${TBB_IMPORTED_TARGETS} | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <ngraph/ngraph.hpp> | |
#include <inference_engine.hpp> | |
// nGraph Op definition for custom layer | |
class MyLayerOp : public ngraph::op::Op { | |
public: | |
static constexpr ngraph::NodeTypeInfo type_info{"MyLayer", 0}; | |
const ngraph::NodeTypeInfo& get_type_info() const override { return type_info; } | |
MyLayerOp() = default; | |
MyLayerOp(const ngraph::Output<ngraph::Node>& input) : Op({input}) { | |
constructor_validate_and_infer_types(); | |
} | |
void validate_and_infer_types() override { | |
auto inpShape = get_input_partial_shape(0); | |
auto outShape = inpShape; | |
set_output_type(0, get_input_element_type(0), outShape); | |
} | |
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override { | |
return std::make_shared<MyLayerOp>(new_args.at(0)); | |
} | |
bool visit_attributes(ngraph::AttributeVisitor& visitor) override { | |
return true; | |
} | |
}; | |
constexpr ngraph::NodeTypeInfo MyLayerOp::type_info; | |
// Custom layer implementation | |
class MyLayerImpl : public InferenceEngine::ILayerExecImpl | |
{ | |
public: | |
explicit MyLayerImpl(const std::shared_ptr<ngraph::Node>& node) | |
{ | |
inpShape = node->get_input_shape(0); | |
outShape = node->get_output_shape(0); | |
} | |
~MyLayerImpl() | |
{ | |
// nothing | |
} | |
InferenceEngine::StatusCode init(InferenceEngine::LayerConfig& config, | |
InferenceEngine::ResponseDesc *resp) noexcept | |
{ | |
return InferenceEngine::StatusCode::OK; | |
} | |
virtual InferenceEngine::StatusCode | |
getSupportedConfigurations(std::vector<InferenceEngine::LayerConfig>& conf, | |
InferenceEngine::ResponseDesc* resp) noexcept | |
{ | |
std::vector<InferenceEngine::DataConfig> inDataConfig; | |
std::vector<InferenceEngine::DataConfig> outDataConfig; | |
InferenceEngine::SizeVector order(inpShape.size()); | |
std::iota(order.begin(), order.end(), 0); | |
// Allow any offset before data | |
size_t offset((std::numeric_limits<size_t>::max)()); | |
// Input shape | |
InferenceEngine::DataConfig inpConf; | |
inpConf.desc = InferenceEngine::TensorDesc(InferenceEngine::Precision::FP32, inpShape, {inpShape, order, offset}); | |
inDataConfig.push_back(inpConf); | |
// Output shape | |
InferenceEngine::DataConfig outConf; | |
outConf.desc = InferenceEngine::TensorDesc(InferenceEngine::Precision::FP32, outShape, {outShape, order, offset}); | |
outDataConfig.push_back(outConf); | |
InferenceEngine::LayerConfig layerConfig; | |
layerConfig.inConfs = inDataConfig; | |
layerConfig.outConfs = outDataConfig; | |
conf.push_back(layerConfig); | |
return InferenceEngine::StatusCode::OK; | |
} | |
virtual InferenceEngine::StatusCode execute(std::vector<InferenceEngine::Blob::Ptr>& inputs, | |
std::vector<InferenceEngine::Blob::Ptr>& outputs, | |
InferenceEngine::ResponseDesc *resp) noexcept | |
{ | |
const float* inp = inputs[0]->cbuffer().as<float*>(); | |
float* out = outputs[0]->buffer().as<float*>(); | |
for (size_t i = 0; i < inputs[0]->size(); ++i) { | |
out[i] = inp[i] + 1.0f; | |
} | |
return InferenceEngine::OK; | |
} | |
private: | |
ngraph::Shape inpShape; | |
ngraph::Shape outShape; | |
}; | |
class InfEngineNgraphExtension : public InferenceEngine::IExtension | |
{ | |
public: | |
void Unload() noexcept override {} | |
void Release() noexcept override { delete this; } | |
void GetVersion(const InferenceEngine::Version*&) const noexcept override {} | |
std::vector<std::string> getImplTypes(const std::shared_ptr<ngraph::Node>& node) override { | |
return {"CPU"}; | |
} | |
InferenceEngine::ILayerImpl::Ptr getImplementation(const std::shared_ptr<ngraph::Node>& node, | |
const std::string& implType) override { | |
if (std::dynamic_pointer_cast<MyLayerOp>(node) && implType == "CPU") { | |
return std::make_shared<MyLayerImpl>(node); | |
} | |
return nullptr; | |
} | |
}; | |
// Example function which replaces a subgraph with a custom node | |
void replaceSubgraph(std::shared_ptr<ngraph::Function> graph, | |
const std::vector<std::shared_ptr<ngraph::Node> >& nodes) { | |
std::shared_ptr<ngraph::Node> fusedNode; | |
// Find an input node and unconnect | |
for (auto& node : graph->get_ordered_ops()) { | |
if (node->get_name() == nodes[0]->get_name()) { | |
auto inp = node->get_input_node_shared_ptr(0); | |
node->clear_control_dependencies(); | |
node->clear_control_dependents(); | |
// Create a fused node | |
fusedNode = std::make_shared<MyLayerOp>(inp); | |
} | |
// Connect new node to consumers | |
if (node->get_name() == nodes[nodes.size() - 1]->get_name()) { | |
for (auto& consumer : node->output(0).get_target_inputs()) { | |
consumer.replace_source_output(fusedNode->output(0)); | |
} | |
} | |
} | |
} | |
int main(int argc, char** argv) { | |
std::vector<size_t> inpShape{1, 3, 10, 11}; | |
std::vector<size_t> outShape{1, 3, 10, 11}; | |
// Build network topology | |
std::shared_ptr<ngraph::Node> input = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::Shape(inpShape)); | |
std::shared_ptr<ngraph::Node> relu1 = std::make_shared<ngraph::op::Relu>(input); | |
std::shared_ptr<ngraph::Node> pool = std::make_shared<ngraph::op::v1::MaxPool>(relu1, | |
ngraph::Strides({1, 1}), // strides | |
ngraph::Shape({0, 0}), // pads_begin | |
ngraph::Shape({0, 0}), // pads_end | |
ngraph::Shape({1, 1}), // kernel | |
ngraph::op::RoundingType::FLOOR, | |
ngraph::op::PadType::VALID); | |
std::shared_ptr<ngraph::Node> relu = std::make_shared<ngraph::op::Relu>(pool); | |
auto ngraph_function = std::make_shared<ngraph::Function>( | |
relu, ngraph::ParameterVector{std::dynamic_pointer_cast<ngraph::op::Parameter>(input)}); | |
// Subgraph replacement | |
replaceSubgraph(ngraph_function, {relu1, pool}); | |
// Load network | |
InferenceEngine::Core ie; | |
ie.AddExtension(std::make_shared<InfEngineNgraphExtension>(), "CPU"); | |
InferenceEngine::CNNNetwork net = InferenceEngine::CNNNetwork(ngraph_function); | |
net.serialize("model.xml", "model.bin"); | |
InferenceEngine::ExecutableNetwork execNet = ie.LoadNetwork(net, "CPU"); | |
InferenceEngine::InferRequest infRequest = execNet.CreateInferRequest(); | |
// Run inference | |
std::vector<float> inpData(inpShape[0] * inpShape[1] * inpShape[2] * inpShape[3], 0); | |
std::vector<float> outData(outShape[0] * outShape[1] * outShape[2] * outShape[3], 0); | |
InferenceEngine::BlobMap inputBlobs, outputBlobs; | |
inputBlobs[net.getInputsInfo().begin()->first] = InferenceEngine::make_shared_blob<float>({ | |
InferenceEngine::Precision::FP32, | |
inpShape, | |
InferenceEngine::Layout::ANY}, inpData.data()); | |
outputBlobs[net.getOutputsInfo().begin()->first] = InferenceEngine::make_shared_blob<float>({ | |
InferenceEngine::Precision::FP32, | |
outShape, | |
InferenceEngine::Layout::ANY}, outData.data()); | |
infRequest.SetInput(inputBlobs); | |
infRequest.SetOutput(outputBlobs); | |
infRequest.Infer(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment