Skip to content

Instantly share code, notes, and snippets.

@sanjoy
Created February 1, 2020 06:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sanjoy/2f355116c7ca88bfcd8665ba988a1bbf to your computer and use it in GitHub Desktop.
Save sanjoy/2f355116c7ca88bfcd8665ba988a1bbf to your computer and use it in GitHub Desktop.
commit 28994e468464c6e5402234899f71883838f2b185
Author: Sanjoy Das <sanjoy@playingwithpointers.com>
Date: Fri Jan 31 22:11:28 2020 -0800
Fix bad_function_call
diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc
index 62171dbaa7..9231865e93 100644
--- a/tensorflow/core/kernels/ops_testutil.cc
+++ b/tensorflow/core/kernels/ops_testutil.cc
@@ -30,7 +30,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
- TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions());
+ TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), thread_pool_.get());
device_type_ = device_type;
#ifdef GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h
index 2b4a1a7cca..5eed5bd979 100644
--- a/tensorflow/core/kernels/ops_testutil.h
+++ b/tensorflow/core/kernels/ops_testutil.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -76,6 +77,8 @@ class OpsTestBase : public ::testing::Test {
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
CHECK(device) << "Could not create CPU device";
+ thread_pool_ = absl::make_unique<thread::ThreadPool>(Env::Default(), /*name=*/"default", /*num_threads=*/1);
+
device_ = device.get();
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
@@ -85,7 +88,7 @@ class OpsTestBase : public ::testing::Test {
OpRegistry::Global(), FunctionDefLibrary{});
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
- TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions());
+ TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), thread_pool_.get());
}
~OpsTestBase() override {
@@ -274,6 +277,7 @@ class OpsTestBase : public ::testing::Test {
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment