Created
April 28, 2021 03:41
-
-
Save cheshire/dbbb36b43c0f4e665423a8ef6f738e31 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a//tensorflow/compiler/jit/BUILD b//tensorflow/compiler/jit/BUILD | |
--- a//tensorflow/compiler/jit/BUILD | |
+++ b//tensorflow/compiler/jit/BUILD | |
@@ -151,6 +151,7 @@ cc_library( | |
":jit_compilation_passes", | |
":xla_device", | |
":xla_kernel_creator", # buildcleaner: keep | |
+ ":xla_device_no_jit_rewrite_registration", | |
"//third_party/absl/memory", | |
"//third_party/absl/strings", | |
"//third_party/tensorflow/compiler/jit/kernels:xla_ops", | |
diff --git a//tensorflow/compiler/jit/get_compiler_ir.cc b//tensorflow/compiler/jit/get_compiler_ir.cc | |
--- a//tensorflow/compiler/jit/get_compiler_ir.cc | |
+++ b//tensorflow/compiler/jit/get_compiler_ir.cc | |
@@ -97,7 +97,8 @@ xla::StatusOr<std::string> GetCompilerIr | |
TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaCompilationCache>( | |
rmgr->default_container(), "xla_cache", &cache, | |
[&](XlaCompilationCache** cache_write_into) { | |
- return BuildXlaCompilationCache(dev, platform_info, cache_write_into); | |
+ return BuildXlaCompilationCache(dev, flr, platform_info, | |
+ cache_write_into); | |
})); | |
core::ScopedUnref cache_ref(cache); | |
diff --git a//tensorflow/compiler/jit/kernels/xla_ops.cc b//tensorflow/compiler/jit/kernels/xla_ops.cc | |
--- a//tensorflow/compiler/jit/kernels/xla_ops.cc | |
+++ b//tensorflow/compiler/jit/kernels/xla_ops.cc | |
@@ -188,7 +188,8 @@ static Status CompileToLocalExecutable( | |
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>( | |
rm->default_container(), "xla_cache", &cache, | |
[&](XlaCompilationCache** cache) { | |
- return BuildXlaCompilationCache(ctx->device(), platform_info, cache); | |
+ return BuildXlaCompilationCache(ctx->device(), ctx->function_library(), | |
+ platform_info, cache); | |
})); | |
// Hold the reference to the JIT during evaluation. (We could probably | |
// free it sooner because the ResourceMgr will retain a reference, but | |
diff --git a//tensorflow/compiler/jit/xla_compile_on_demand_op.cc b//tensorflow/compiler/jit/xla_compile_on_demand_op.cc | |
--- a//tensorflow/compiler/jit/xla_compile_on_demand_op.cc | |
+++ b//tensorflow/compiler/jit/xla_compile_on_demand_op.cc | |
@@ -119,8 +119,8 @@ Status XlaCompileOnDemandOp::Compile( | |
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>( | |
rm->default_container(), "xla_cache", cache, | |
[&](XlaCompilationCache** write_into_cache) { | |
- return BuildXlaCompilationCache(ctx->device(), platform_info_, | |
- write_into_cache); | |
+ return BuildXlaCompilationCache(ctx->device(), ctx->function_library(), | |
+ platform_info_, write_into_cache); | |
})); | |
XlaCompiler::Options options = GenerateCompilerOptions( | |
diff --git a//tensorflow/compiler/jit/xla_gpu_device.cc b//tensorflow/compiler/jit/xla_gpu_device.cc | |
--- a//tensorflow/compiler/jit/xla_gpu_device.cc | |
+++ b//tensorflow/compiler/jit/xla_gpu_device.cc | |
@@ -25,6 +25,7 @@ limitations under the License. | |
#include "third_party/tensorflow/compiler/jit/kernels/xla_ops.h" | |
#include "third_party/tensorflow/compiler/jit/xla_device.h" | |
#include "third_party/tensorflow/compiler/jit/xla_device_ops.h" | |
+#include "third_party/tensorflow/compiler/jit/xla_platform_info.h" | |
#include "third_party/tensorflow/compiler/tf2xla/xla_op_registry.h" | |
#include "third_party/tensorflow/core/common_runtime/device_factory.h" | |
#include "third_party/tensorflow/core/common_runtime/gpu/gpu_init.h" | |
@@ -32,30 +33,6 @@ limitations under the License. | |
namespace tensorflow { | |
-// Returns a set containing the device ids contained in visible_device_list or | |
-// nullopt if it is empty. It returns error in case of malformed configuration | |
-// string. | |
-static xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList( | |
- const string& visible_device_list) { | |
- std::set<int> gpu_ids; | |
- if (visible_device_list.empty()) { | |
- return {{absl::nullopt}}; | |
- } | |
- const std::vector<string> visible_devices = | |
- absl::StrSplit(visible_device_list, ','); | |
- for (const string& platform_device_id_str : visible_devices) { | |
- int32 platform_device_id; | |
- if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) { | |
- return errors::InvalidArgument( | |
- "Could not parse entry in 'visible_device_list': '", | |
- platform_device_id_str, | |
- "'. visible_device_list = ", visible_device_list); | |
- } | |
- gpu_ids.insert(platform_device_id); | |
- } | |
- return {{gpu_ids}}; | |
-} | |
- | |
class XlaGpuDeviceFactory : public DeviceFactory { | |
public: | |
Status ListPhysicalDevices(std::vector<string>* devices) override; | |
diff --git a//tensorflow/compiler/jit/xla_platform_info.cc b//tensorflow/compiler/jit/xla_platform_info.cc | |
--- a//tensorflow/compiler/jit/xla_platform_info.cc | |
+++ b//tensorflow/compiler/jit/xla_platform_info.cc | |
@@ -19,7 +19,28 @@ limitations under the License. | |
namespace tensorflow { | |
-Status BuildXlaCompilationCache(DeviceBase* device, | |
+xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList( | |
+ absl::string_view visible_device_list) { | |
+ std::set<int> gpu_ids; | |
+ if (visible_device_list.empty()) { | |
+ return {{absl::nullopt}}; | |
+ } | |
+ const std::vector<string> visible_devices = | |
+ absl::StrSplit(visible_device_list, ','); | |
+ for (const string& platform_device_id_str : visible_devices) { | |
+ int32 platform_device_id; | |
+ if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) { | |
+ return errors::InvalidArgument( | |
+ "Could not parse entry in 'visible_device_list': '", | |
+ platform_device_id_str, | |
+ "'. visible_device_list = ", visible_device_list); | |
+ } | |
+ gpu_ids.insert(platform_device_id); | |
+ } | |
+ return {{gpu_ids}}; | |
+} | |
+ | |
+Status BuildXlaCompilationCache(DeviceBase* device, FunctionLibraryRuntime* flr, | |
const XlaPlatformInfo& platform_info, | |
XlaCompilationCache** cache) { | |
if (platform_info.xla_device_metadata()) { | |
@@ -60,6 +81,13 @@ Status BuildXlaCompilationCache(DeviceBa | |
client_options.set_platform(platform.ValueOrDie()); | |
client_options.set_intra_op_parallelism_threads( | |
device->tensorflow_cpu_worker_threads()->num_threads); | |
+ | |
+ string allowed_gpus = | |
+ flr->config_proto()->gpu_options().visible_device_list(); | |
+ TF_ASSIGN_OR_RETURN(absl::optional<std::set<int>> gpu_ids, | |
+ ParseVisibleDeviceList(allowed_gpus)); | |
+ client_options.set_allowed_devices(gpu_ids); | |
+ | |
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); | |
if (!client.ok()) { | |
return client.status(); | |
diff --git a//tensorflow/compiler/jit/xla_platform_info.h b//tensorflow/compiler/jit/xla_platform_info.h | |
--- a//tensorflow/compiler/jit/xla_platform_info.h | |
+++ b//tensorflow/compiler/jit/xla_platform_info.h | |
@@ -81,8 +81,14 @@ class XlaPlatformInfo { | |
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); | |
}; | |
+// Returns a set containing the device ids contained in visible_device_list or | |
+// nullopt if it is empty. It returns error in case of malformed configuration | |
+// string. | |
+StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList( | |
+ absl::string_view visible_device_list); | |
+ | |
// Returns created XLA compilation cache. | |
-Status BuildXlaCompilationCache(DeviceBase* dev, | |
+Status BuildXlaCompilationCache(DeviceBase* dev, FunctionLibraryRuntime* flr, | |
const XlaPlatformInfo& platform_info, | |
XlaCompilationCache** cache); | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment