Skip to content

Instantly share code, notes, and snippets.

@wolterlw
Created September 18, 2019 04:33
Show Gist options
  • Save wolterlw/767a467d6f75b3533cd49482ba25ae7e to your computer and use it in GitHub Desktop.
Save wolterlw/767a467d6f75b3533cd49482ba25ae7e to your computer and use it in GitHub Desktop.
Changes made to the tensorflow source to enable custom operations
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index 4be3226938..6bc39a7194 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -37,6 +37,18 @@ NEON_FLAGS_IF_APPLICABLE = select({
],
})
+cc_library(
+ name = "common",
+ srcs = [],
+ hdrs = ["common.h"],
+ copts = tflite_copts(),
+ deps = [
+ ":cpu_check",
+ ":types",
+ "@gemmlowp//:fixedpoint",
+ ],
+)
+
cc_library(
name = "types",
srcs = [],
diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD
index 767a9fc476..c568ed4ccb 100644
--- a/tensorflow/lite/python/interpreter_wrapper/BUILD
+++ b/tensorflow/lite/python/interpreter_wrapper/BUILD
@@ -16,7 +16,10 @@ cc_library(
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/memory",
- ],
+ "//tensorflow/lite/python/custom_ops:max_pool_argmax",
+ "//tensorflow/lite/python/custom_ops:max_unpooling",
+ "//tensorflow/lite/python/custom_ops:transpose_conv_bias",
+ ],
)
tf_py_wrap_cc(
diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index d14af439ec..bb9a3393d0 100644
--- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -22,6 +22,11 @@ limitations under the License.
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/python/custom_ops/max_pool_argmax.h"
+#include "tensorflow/lite/python/custom_ops/max_unpooling.h"
+#include "tensorflow/lite/python/custom_ops/transpose_conv_bias.h"
+
+
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@@ -198,6 +203,13 @@ InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
}
auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
+ resolver->AddCustom("MaxPoolingWithArgmax2D",
+ tflite_operations::RegisterMaxPoolingWithArgmax2D());
+ resolver->AddCustom("MaxUnpooling2D",
+ tflite_operations::RegisterMaxUnpooling2D());
+ resolver->AddCustom("Convolution2DTransposeBias",
+ tflite_operations::RegisterConvolution2DTransposeBias());
+
auto interpreter = CreateInterpreter(model.get(), *resolver);
if (!interpreter) {
*error_msg = error_reporter->message();
# CUSTOM BUILD file in the custom_ops directory
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
package(
default_visibility = ["//visibility:public"],
)
cc_library(
name = "max_pool_argmax",
srcs = ["max_pool_argmax.cc"],
hdrs = ["max_pool_argmax.h"],
deps = [
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:padding",
"//tensorflow/lite/kernels/internal:common",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:tensor_utils",
],
)
cc_library(
name = "max_unpooling",
srcs = ["max_unpooling.cc"],
hdrs = ["max_unpooling.h"],
deps = [
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:padding",
"//tensorflow/lite/kernels/internal:common",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:tensor_utils",
],
)
cc_library(
name = "transpose_conv_bias",
srcs = ["transpose_conv_bias.cc"],
hdrs = ["transpose_conv_bias.h"],
deps = [
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:padding",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:tensor_utils",
"//tensorflow/lite/kernels/internal:types",
],
)
@LeeYongchao
Copy link

LeeYongchao commented Dec 13, 2019

in the tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc

tflite_operations::RegisterMaxUnpooling2D()

should be mediapipe::tflite_operations::RegisterMaxUnpooling2D()

thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment