Created
February 1, 2020 06:23
-
-
Save sanjoy/2f355116c7ca88bfcd8665ba988a1bbf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit 28994e468464c6e5402234899f71883838f2b185 | |
Author: Sanjoy Das <sanjoy@playingwithpointers.com> | |
Date: Fri Jan 31 22:11:28 2020 -0800 | |
Fix bad_function_call | |
diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc | |
index 62171dbaa7..9231865e93 100644 | |
--- a/tensorflow/core/kernels/ops_testutil.cc | |
+++ b/tensorflow/core/kernels/ops_testutil.cc | |
@@ -30,7 +30,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type, | |
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device)); | |
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( | |
device_mgr_.get(), Env::Default(), /*config=*/nullptr, | |
- TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions()); | |
+ TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), thread_pool_.get()); | |
device_type_ = device_type; | |
#ifdef GOOGLE_CUDA | |
diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h | |
index 2b4a1a7cca..5eed5bd979 100644 | |
--- a/tensorflow/core/kernels/ops_testutil.h | |
+++ b/tensorflow/core/kernels/ops_testutil.h | |
@@ -23,6 +23,7 @@ limitations under the License. | |
#include "tensorflow/core/common_runtime/device_factory.h" | |
#include "tensorflow/core/common_runtime/device_mgr.h" | |
#include "tensorflow/core/common_runtime/process_function_library_runtime.h" | |
+#include "tensorflow/core/platform/threadpool.h" | |
#include "tensorflow/core/framework/allocator.h" | |
#include "tensorflow/core/framework/device_base.h" | |
#include "tensorflow/core/framework/graph.pb.h" | |
@@ -76,6 +77,8 @@ class OpsTestBase : public ::testing::Test { | |
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); | |
CHECK(device) << "Could not create CPU device"; | |
+ thread_pool_ = absl::make_unique<thread::ThreadPool>(Env::Default(), /*name=*/"default", /*num_threads=*/1); | |
+ | |
device_ = device.get(); | |
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device)); | |
@@ -85,7 +88,7 @@ class OpsTestBase : public ::testing::Test { | |
OpRegistry::Global(), FunctionDefLibrary{}); | |
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( | |
device_mgr_.get(), Env::Default(), /*config=*/nullptr, | |
- TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions()); | |
+ TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), thread_pool_.get()); | |
} | |
~OpsTestBase() override { | |
@@ -274,6 +277,7 @@ class OpsTestBase : public ::testing::Test { | |
std::unique_ptr<FunctionLibraryDefinition> flib_def_; | |
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; | |
+ std::unique_ptr<thread::ThreadPool> thread_pool_; | |
private: | |
TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment