Skip to content

Instantly share code, notes, and snippets.

@cheshire
Created April 28, 2021 03:41
Show Gist options
  • Save cheshire/dbbb36b43c0f4e665423a8ef6f738e31 to your computer and use it in GitHub Desktop.
Save cheshire/dbbb36b43c0f4e665423a8ef6f738e31 to your computer and use it in GitHub Desktop.
diff --git a//tensorflow/compiler/jit/BUILD b//tensorflow/compiler/jit/BUILD
--- a//tensorflow/compiler/jit/BUILD
+++ b//tensorflow/compiler/jit/BUILD
@@ -151,6 +151,7 @@ cc_library(
":jit_compilation_passes",
":xla_device",
":xla_kernel_creator", # buildcleaner: keep
+ ":xla_device_no_jit_rewrite_registration",
"//third_party/absl/memory",
"//third_party/absl/strings",
"//third_party/tensorflow/compiler/jit/kernels:xla_ops",
diff --git a//tensorflow/compiler/jit/get_compiler_ir.cc b//tensorflow/compiler/jit/get_compiler_ir.cc
--- a//tensorflow/compiler/jit/get_compiler_ir.cc
+++ b//tensorflow/compiler/jit/get_compiler_ir.cc
@@ -97,7 +97,8 @@ xla::StatusOr<std::string> GetCompilerIr
TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaCompilationCache>(
rmgr->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache_write_into) {
- return BuildXlaCompilationCache(dev, platform_info, cache_write_into);
+ return BuildXlaCompilationCache(dev, flr, platform_info,
+ cache_write_into);
}));
core::ScopedUnref cache_ref(cache);
diff --git a//tensorflow/compiler/jit/kernels/xla_ops.cc b//tensorflow/compiler/jit/kernels/xla_ops.cc
--- a//tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b//tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -188,7 +188,8 @@ static Status CompileToLocalExecutable(
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache) {
- return BuildXlaCompilationCache(ctx->device(), platform_info, cache);
+ return BuildXlaCompilationCache(ctx->device(), ctx->function_library(),
+ platform_info, cache);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
diff --git a//tensorflow/compiler/jit/xla_compile_on_demand_op.cc b//tensorflow/compiler/jit/xla_compile_on_demand_op.cc
--- a//tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b//tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -119,8 +119,8 @@ Status XlaCompileOnDemandOp::Compile(
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", cache,
[&](XlaCompilationCache** write_into_cache) {
- return BuildXlaCompilationCache(ctx->device(), platform_info_,
- write_into_cache);
+ return BuildXlaCompilationCache(ctx->device(), ctx->function_library(),
+ platform_info_, write_into_cache);
}));
XlaCompiler::Options options = GenerateCompilerOptions(
diff --git a//tensorflow/compiler/jit/xla_gpu_device.cc b//tensorflow/compiler/jit/xla_gpu_device.cc
--- a//tensorflow/compiler/jit/xla_gpu_device.cc
+++ b//tensorflow/compiler/jit/xla_gpu_device.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/jit/kernels/xla_ops.h"
#include "third_party/tensorflow/compiler/jit/xla_device.h"
#include "third_party/tensorflow/compiler/jit/xla_device_ops.h"
+#include "third_party/tensorflow/compiler/jit/xla_platform_info.h"
#include "third_party/tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "third_party/tensorflow/core/common_runtime/device_factory.h"
#include "third_party/tensorflow/core/common_runtime/gpu/gpu_init.h"
@@ -32,30 +33,6 @@ limitations under the License.
namespace tensorflow {
-// Returns a set containing the device ids contained in visible_device_list or
-// nullopt if it is empty. It returns error in case of malformed configuration
-// string.
-static xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
- const string& visible_device_list) {
- std::set<int> gpu_ids;
- if (visible_device_list.empty()) {
- return {{absl::nullopt}};
- }
- const std::vector<string> visible_devices =
- absl::StrSplit(visible_device_list, ',');
- for (const string& platform_device_id_str : visible_devices) {
- int32 platform_device_id;
- if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) {
- return errors::InvalidArgument(
- "Could not parse entry in 'visible_device_list': '",
- platform_device_id_str,
- "'. visible_device_list = ", visible_device_list);
- }
- gpu_ids.insert(platform_device_id);
- }
- return {{gpu_ids}};
-}
-
class XlaGpuDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
diff --git a//tensorflow/compiler/jit/xla_platform_info.cc b//tensorflow/compiler/jit/xla_platform_info.cc
--- a//tensorflow/compiler/jit/xla_platform_info.cc
+++ b//tensorflow/compiler/jit/xla_platform_info.cc
@@ -19,7 +19,28 @@ limitations under the License.
namespace tensorflow {
-Status BuildXlaCompilationCache(DeviceBase* device,
+xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
+ absl::string_view visible_device_list) {
+ std::set<int> gpu_ids;
+ if (visible_device_list.empty()) {
+ return {{absl::nullopt}};
+ }
+ const std::vector<string> visible_devices =
+ absl::StrSplit(visible_device_list, ',');
+ for (const string& platform_device_id_str : visible_devices) {
+ int32 platform_device_id;
+ if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) {
+ return errors::InvalidArgument(
+ "Could not parse entry in 'visible_device_list': '",
+ platform_device_id_str,
+ "'. visible_device_list = ", visible_device_list);
+ }
+ gpu_ids.insert(platform_device_id);
+ }
+ return {{gpu_ids}};
+}
+
+Status BuildXlaCompilationCache(DeviceBase* device, FunctionLibraryRuntime* flr,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache) {
if (platform_info.xla_device_metadata()) {
@@ -60,6 +81,13 @@ Status BuildXlaCompilationCache(DeviceBa
client_options.set_platform(platform.ValueOrDie());
client_options.set_intra_op_parallelism_threads(
device->tensorflow_cpu_worker_threads()->num_threads);
+
+ string allowed_gpus =
+ flr->config_proto()->gpu_options().visible_device_list();
+ TF_ASSIGN_OR_RETURN(absl::optional<std::set<int>> gpu_ids,
+ ParseVisibleDeviceList(allowed_gpus));
+ client_options.set_allowed_devices(gpu_ids);
+
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
diff --git a//tensorflow/compiler/jit/xla_platform_info.h b//tensorflow/compiler/jit/xla_platform_info.h
--- a//tensorflow/compiler/jit/xla_platform_info.h
+++ b//tensorflow/compiler/jit/xla_platform_info.h
@@ -81,8 +81,14 @@ class XlaPlatformInfo {
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
};
+// Returns a set containing the device ids contained in visible_device_list or
+// nullopt if it is empty. It returns error in case of malformed configuration
+// string.
+StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
+ absl::string_view visible_device_list);
+
// Returns created XLA compilation cache.
-Status BuildXlaCompilationCache(DeviceBase* dev,
+Status BuildXlaCompilationCache(DeviceBase* dev, FunctionLibraryRuntime* flr,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment