Skip to content

Commit

Permalink
Add AssembleCompilationProvider routine
Browse files Browse the repository at this point in the history
`AssembleCompilationProvider` checks the availability of all the different PTX compilation methods and takes the user's preferences (DebugOptions) into account to create the best suitable PTX CompilationProvider.

Unfortunately the logic is rather convoluted since it mimics the current behaviour in NVPTXCompiler. More cleanup can be done at a later point in small steps.

PiperOrigin-RevId: 699071889
  • Loading branch information
beckerhe authored and Google-ML-Automation committed Dec 2, 2024
1 parent e4473de commit e2f8c1b
Show file tree
Hide file tree
Showing 6 changed files with 570 additions and 2 deletions.
64 changes: 64 additions & 0 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1761,3 +1761,67 @@ xla_cc_test(
"@tsl//tsl/platform:test",
],
)

cc_library(
name = "assemble_compilation_provider",
srcs = ["assemble_compilation_provider.cc"],
hdrs = ["assemble_compilation_provider.h"],
tags = [
"cuda-only",
"gpu",
],
deps = [
":compilation_provider",
":composite_compilation_provider",
":defer_relocatable_compilation_compilation_provider",
":driver_compilation_provider",
":nvjitlink_compilation_provider",
":nvjitlink_known_issues",
":nvjitlink_support",
":nvptxcompiler_compilation_provider",
":ptx_compiler_support",
":subprocess_compilation",
":subprocess_compilation_provider",
"//xla:xla_proto_cc",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "assemble_compilation_provider_test",
srcs = ["assemble_compilation_provider_test.cc"],
data = [
":nvlink",
":ptxas",
],
tags = [
"cuda-only",
"gpu",
"requires-gpu-nvidia",
],
deps = [
":assemble_compilation_provider",
":compilation_provider",
":cuda_platform",
":nvjitlink_support",
":ptx_compiler_support",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:cuda_root_path",
"@tsl//tsl/platform:path",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
285 changes: 285 additions & 0 deletions xla/stream_executor/cuda/assemble_compilation_provider.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/assemble_compilation_provider.h"

#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "xla/stream_executor/cuda/compilation_provider.h"
#include "xla/stream_executor/cuda/composite_compilation_provider.h"
#include "xla/stream_executor/cuda/defer_relocatable_compilation_compilation_provider.h"
#include "xla/stream_executor/cuda/driver_compilation_provider.h"
#include "xla/stream_executor/cuda/nvjitlink_compilation_provider.h"
#include "xla/stream_executor/cuda/nvjitlink_known_issues.h"
#include "xla/stream_executor/cuda/nvjitlink_support.h"
#include "xla/stream_executor/cuda/nvptxcompiler_compilation_provider.h"
#include "xla/stream_executor/cuda/ptx_compiler_support.h"
#include "xla/stream_executor/cuda/subprocess_compilation.h"
#include "xla/stream_executor/cuda/subprocess_compilation_provider.h"
#include "xla/stream_executor/semantic_version.h"
#include "tsl/platform/errors.h"

namespace stream_executor::cuda {
namespace {

// Returns true if NvJitLink is supported and should be used.
absl::Status HasNvJitLinkSupport(const xla::DebugOptions& debug_options) {
if (!IsLibNvJitLinkSupported()) {
return absl::UnavailableError(
"LibNvJitLink is not supported (disabled during compilation).");
}

if (debug_options.xla_gpu_libnvjitlink_mode() ==
xla::DebugOptions::LIB_NV_JIT_LINK_MODE_DISABLED) {
return absl::UnavailableError(
"LibNvJitLink is disabled (explicitly disabled via flag).");
}

if (debug_options.xla_gpu_libnvjitlink_mode() ==
xla::DebugOptions::LIB_NV_JIT_LINK_MODE_ENABLED) {
VLOG(4) << "Considering NvJitLink since it was explicitly enabled.";
return absl::OkStatus();
}

if (LoadedNvJitLinkHasKnownIssues()) {
return absl::UnavailableError(
"LibNvJitLink is disabled since the loaded library version has known "
"issues.");
}

VLOG(4)
<< "Considering NvJitLink since the loaded library version has no known "
"issues.";
return absl::OkStatus();
}

// Returns true if LibNvPtxCompiler is supported and should be used.
absl::Status HasNvptxcompilerSupport(const xla::DebugOptions& debug_options) {
if (!IsLibNvPtxCompilerSupported()) {
return absl::UnavailableError(
"LibNvPtxCompiler is not supported (disabled during compilation).");
}

if (!debug_options.xla_gpu_enable_libnvptxcompiler()) {
return absl::UnavailableError(
"LibNvPtxCompiler is disabled (explicitly disabled via flag).");
}

VLOG(4) << "Considering NvPtxCompiler since it was supported and enabled.";
return absl::OkStatus();
}

// Returns an error if the user-set flags are not compatible with each other and
// the build of XLA.
absl::Status CheckIncompatibleFlagSettings(
const xla::DebugOptions& debug_options) {
if (debug_options.xla_gpu_libnvjitlink_mode() ==
xla::DebugOptions::LIB_NV_JIT_LINK_MODE_ENABLED &&
!IsLibNvJitLinkSupported()) {
return absl::UnavailableError("LibNvJitLink is not supported.");
}

if (debug_options.xla_gpu_enable_libnvptxcompiler() &&
!IsLibNvPtxCompilerSupported()) {
return absl::UnavailableError("LibNvPtxCompiler is not supported.");
}

return absl::OkStatus();
}

// Calls `GetToolVersion` on the given path if it's OK. Otherwise returns the
// error status.
absl::StatusOr<SemanticVersion> GetToolVersionIfToolAvailable(
const absl::StatusOr<std::string>& path) {
if (!path.ok()) {
return path.status();
}

return GetToolVersion(path.value());
}

// Returns the given non-OK status or the value as a string.
template <typename T>
std::string ToDebugString(const absl::StatusOr<T>& status_or) {
if (status_or.ok()) {
return absl::StrCat(status_or.value());
}
return std::string{status_or.status().message()};
}

} // namespace

absl::StatusOr<std::unique_ptr<CompilationProvider>>
AssembleCompilationProvider(const xla::DebugOptions& debug_options) {
// TODO(b/381059098): Simplify this logic

TF_RETURN_IF_ERROR(CheckIncompatibleFlagSettings(debug_options));

std::string decision_log;
const auto append_to_decision_log = [&](std::string_view decision) {
VLOG(4) << decision;
absl::StrAppend(&decision_log, " - ", decision, "\n");
};

const absl::Status has_nvjitlink = HasNvJitLinkSupport(debug_options);
append_to_decision_log(
absl::StrCat("Has NvJitLink support: ", has_nvjitlink.message()));

const absl::Status has_nvptxcompiler = HasNvptxcompilerSupport(debug_options);
append_to_decision_log(
absl::StrCat("Has NvPtxCompiler support: ", has_nvptxcompiler.message()));

const bool parallel_compilation_support_is_desired =
debug_options.xla_gpu_enable_llvm_module_compilation_parallelism();
append_to_decision_log(
absl::StrCat("Parallel compilation support is desired: ",
parallel_compilation_support_is_desired));

if (has_nvjitlink.ok() && has_nvptxcompiler.ok()) {
// If both libraries are supported, we will use them together. This setup
// supports parallel compilation and we have the most control over the
// versions being used.
VLOG(3) << "Using libnvptxcompiler for compilation and libnvjitlink for "
"linking.";
std::vector<std::unique_ptr<CompilationProvider>> providers;
providers.push_back(std::make_unique<NvptxcompilerCompilationProvider>());
providers.push_back(std::make_unique<NvJitLinkCompilationProvider>());
return CompositeCompilationProvider::Create(std::move(providers));
}

if (has_nvjitlink.ok() && !has_nvptxcompiler.ok()) {
// If we only have libnvjitlink, we use it for both compilation and
// linking. To support parallel compilation we defer compilation into
// relocatable modules to the linking step by using the
// DeferRelocatableCompilationCompilationProvider.
VLOG(3) << "Using libnvjitlink for compilation and linking.";
return DeferRelocatableCompilationCompilationProvider::Create(
std::make_unique<NvJitLinkCompilationProvider>());
}

if (has_nvptxcompiler.ok() && !parallel_compilation_support_is_desired) {
// If we only have libnvptxcompiler, but don't need parallel compilation, we
// can just use the library on its own - no linking required.
VLOG(3) << "Using only libnvptxcompiler for compilation - no parallel "
"compilation support needed.";
return std::make_unique<NvptxcompilerCompilationProvider>();
}

absl::StatusOr<std::string> ptxas_path =
FindPtxAsExecutable(debug_options.xla_gpu_cuda_data_dir());
absl::StatusOr<SemanticVersion> ptxas_version =
GetToolVersionIfToolAvailable(ptxas_path);

absl::StatusOr<std::string> nvlink_path =
FindNvlinkExecutable(debug_options.xla_gpu_cuda_data_dir());
absl::StatusOr<SemanticVersion> nvlink_version =
GetToolVersionIfToolAvailable(nvlink_path);

append_to_decision_log(
absl::StrCat("ptxas_path: ", ToDebugString(ptxas_path)));
append_to_decision_log(
absl::StrCat("ptxas_version: ", ToDebugString(ptxas_version)));
append_to_decision_log(
absl::StrCat("nvlink_path: ", ToDebugString(nvlink_path)));
append_to_decision_log(
absl::StrCat("nvlink_version: ", ToDebugString(nvlink_version)));

const bool has_subprocess_compilation_support =
ptxas_path.ok() && nvlink_path.ok();

if (has_subprocess_compilation_support) {
VLOG(3) << "Using ptxas(path=" << ptxas_path.value()
<< ", version=" << ptxas_version.value() << ") and "
<< "nvlink(path=" << nvlink_path.value()
<< ", version=" << nvlink_version.value()
<< ") for compilation and linking.";
return std::make_unique<SubprocessCompilationProvider>(ptxas_path.value(),
nvlink_path.value());
}

const bool has_driver_compilation_support =
debug_options.xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found();
append_to_decision_log(absl::StrCat("Driver compilation is enabled: ",
has_driver_compilation_support));

if (parallel_compilation_support_is_desired && has_nvptxcompiler.ok() &&
has_driver_compilation_support) {
// It's possible to use libnvptxcompiler for compilation and the driver for
// linking. This setup supports parallel compilation but is less desired
// because we don't control the driver version. A too old driver might lead
// to linking errors.
VLOG(3) << "Using libnvptxcompiler for compilation and the driver for "
"linking.";
std::vector<std::unique_ptr<CompilationProvider>> providers;
providers.push_back(std::make_unique<NvptxcompilerCompilationProvider>());
providers.push_back(std::make_unique<DriverCompilationProvider>());
return CompositeCompilationProvider::Create(std::move(providers));
}

if (ptxas_path.ok() && has_driver_compilation_support) {
// It's possible to use ptxas for compilation and the driver for linking.
// This setup supports parallel compilation but is less desired because we
// don't control the driver version. A too old driver might lead to linking
// errors.
VLOG(3) << "Using libnvptxcompiler for compilation and the driver for "
"linking.";
std::vector<std::unique_ptr<CompilationProvider>> providers;
auto ptxas_provider = std::make_unique<SubprocessCompilationProvider>(
ptxas_path.value(), std::string{});
providers.push_back(std::move(ptxas_provider));
providers.push_back(std::make_unique<DriverCompilationProvider>());
return CompositeCompilationProvider::Create(std::move(providers));
}

// Passed this point we won't be able to support parallel compilation, so we
// error out if it was requested.
if (parallel_compilation_support_is_desired) {
return absl::UnavailableError(
absl::StrCat("Parallel compilation was requested, but no available "
"compilation provider supports it. Details: \n",
decision_log));
}

if (ptxas_path.ok()) {
VLOG(3) << "Using ptxas(path=" << ptxas_path.value()
<< ", version=" << ptxas_version.value()
<< ") for compilation. nvlink is not available.";
return std::make_unique<SubprocessCompilationProvider>(ptxas_path.value(),
std::string{});
}

if (has_driver_compilation_support) {
VLOG(3) << "Using the driver for compilation.";
return std::make_unique<DriverCompilationProvider>();
}

return absl::UnavailableError(absl::StrCat(
"No PTX compilation provider is available. Neither ptxas/nvlink nor "
"nvjtlink is available. As a fallback you can enable JIT compilation "
"in the CUDA driver via the flag "
"`--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found`. Details: \n",
decision_log));
}

} // namespace stream_executor::cuda
Loading

0 comments on commit e2f8c1b

Please sign in to comment.