diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 094f84cd781..2b365769adf 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -489,6 +489,10 @@ class DirectSessionFactory : public SessionFactory { ResourceMgr* gpu_shared_rmgr = nullptr; #if GOOGLE_CUDA + bool use_per_session_host_allocator = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("PER_SESSION_HOSTALLOC", + /*default_val=*/false, + &use_per_session_host_allocator)); if (use_multi_stream) { // Create shared resource for gpu devices gpu_shared_rmgr = new ResourceMgr("localhost"); @@ -496,7 +500,7 @@ class DirectSessionFactory : public SessionFactory { for (int i = 0; i < session_num; ++i) { dev_rmgr_map.device_rmgr_map[gpu_dev_prefix+std::to_string(base_index+i)] = gpu_shared_rmgr; - if (i > 0) { + if (use_per_session_host_allocator && i > 0) { dev_rmgr_map.device_rmgr_map[dev_prefix+"/device:CPU:"+std::to_string(i)] = shared_rmgr; dev_rmgr_map.device_rmgr_map[dev_prefix+"/device:cpu:"+std::to_string(i)] = shared_rmgr; dev_rmgr_map.device_rmgr_map["/device:CPU:"+std::to_string(i)] = shared_rmgr; @@ -571,8 +575,13 @@ class DirectSessionFactory : public SessionFactory { follower_options.config.add_per_session_devices( "/job:localhost/replica:0/task:0/device:GPU:" + std::to_string(base_index+i)); - follower_options.config.add_per_session_devices( - "/job:localhost/replica:0/task:0/device:CPU:"+std::to_string(i)); + if (use_per_session_host_allocator) { + follower_options.config.add_per_session_devices( + "/job:localhost/replica:0/task:0/device:CPU:"+std::to_string(i)); + } else { + follower_options.config.add_per_session_devices( + "/job:localhost/replica:0/task:0/device:CPU:0"); + } } #endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc index bb4c510253f..11421ad0999 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc @@ -199,8 +199,12 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory { int num_numa_nodes = options.config.experimental().use_numa_affinity() ? port::NUMANumNodes() : 1; + bool use_per_session_host_allocator = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("PER_SESSION_HOSTALLOC", + /*default_val=*/false, + &use_per_session_host_allocator)); int sess_num = 1; - if (dev_rmgr_map) { + if (use_per_session_host_allocator && dev_rmgr_map) { for (auto& item : dev_rmgr_map->device_rmgr_map) { int sess_idx = std::stoi(item.first.substr(item.first.rfind(":")+1)); if (sess_idx >= sess_num) { diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index f8faf15f7d6..222c7408652 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -737,11 +737,13 @@ Status GraphExecutionState::InitBaseGraph(std::unique_ptr&& new_graph) { break; } } - const auto& dname1 = session_options_->config.per_session_devices(1); - for (auto& d : device_set_->devices()) { - if (d->name() == dname1) { - devices.AddDevice(d); - break; + if (session_options_->config.per_session_devices_size() > 1) { + const auto& dname1 = session_options_->config.per_session_devices(1); + for (auto& d : device_set_->devices()) { + if (d->name() == dname1) { + devices.AddDevice(d); + break; + } } } }