From 531df5f9780aa35ce49fca588c46d8b78dd5507a Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Mon, 2 Dec 2024 03:17:04 -0800 Subject: [PATCH] PR #19927: [cuda] Warn about ptxas versions before CUDA 12.6.3 Imported from GitHub PR https://github.com/openxla/xla/pull/19927 This PR adds version checks to determine whether the current setup is affected by nvbug 4826023. We already have a JAX PR (https://github.com/jax-ml/jax/pull/25091) that bumps its dependency on the relevant CUDA wheel; the present XLA PR is designed to get users with an existing installation to upgrade. CUDA 12.x < 12.6.3 on Hopper+ is known to be affected. The first CUDA 12.6.3 nvidia-cuda-nvcc wheel is patch number 85, hence we specifically check for `CC >= SM90 and 12.0.0 <= ptxas_version < 12.6.85`. If such a version is found to be present, we issue a warning prompting the user to upgrade to CUDA 12.6.3 or newer. Implementing the above-mentioned checks is complicated by the fact that XLA may compile PTX in three (four?) different ways ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L761-L778)): - nvJitLink (linkable library; [nvjitlink_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/nvjitlink_impl.cc#L154)) - nvPtxCompiler (another linkable library; [ptx_compiler_impl.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/ptx_compiler_impl.cc#L84)) - ptxas (spawn a PTX compiler binary as a subprocess; [subprocess_compilation.cc](https://github.com/openxla/xla/blob/846e02df32d53921950fdf240b9fa3ca53351821/xla/stream_executor/cuda/subprocess_compilation.cc#L263)) (As a bonus, `nvptx_compiler.cc` alludes to `--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found` possibly falling back to compiling ptx through the driver ([nvptx_compiler.cc](https://github.com/openxla/xla/blob/2f79665f7ea93b9b13d99eceb468dce313ab609e/xla/service/gpu/nvptx_compiler.cc#L795-L806)). As far as I can tell the flag currently doesn't do anything, though.) **Caveat:** We may show a spurious warning for some CUDA releases `>=12.6.3` as the `nvJitLink` only seems to expose major and minor versions, but not the patch number. By default at least JAX seems to use the subprocess_compilation route, which _is_ aware of the patch number and hence will show no such spurious warning. The warning is currently logged at the `ERROR` log level, since `WARNING` doesn't seem to be shown by default. --- Example: ``` # A JAX-Toolbox image affected $ docker run -it --gpus=all jax:jax-2024-11-25 $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" E1128 15:53:19.872235 2401322 ptx_compiler_helpers.cc:40] *** WARNING *** Invoking PTXAS with version 12.6.77, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x up to and including 12.6.2 miscompile certain edge cases around clamping. Please upgrade to CUDA 12.6.3 or newer. [0 0 0 0 3 6] $ pip install -U "nvidia-cuda-nvcc-cu12>=12.6.85" (...) $ python3 -c "import jax; import jax.numpy as jnp; A = jnp.arange(18).reshape(6, 3); m = jnp.arange(-3, 3); print(jax.jit(lambda _0, _1: _0.at[jnp.abs(_1), 0].get())(A, m))" [9 6 3 0 3 6] ``` --- On a general note: I'm not particularly happy with adding all this new code for version checks, but don't see any particularly better immediate solution. Note that similar checks are already spread across the three variants _and_ the dispatching code in `nvptx_compiler.cc`. However, all of these have slightly different semantics (warning vs ignoring versions) and only target a single variant. Copybara import of the project: -- d32e9b03e8c6c1afc957268f1eefec0e10c5df78 by Georg Stefan Schmid : [cuda] Warn about ptxas versions before CUDA 12.6.3 Merging this change closes #19927 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/19927 from gspschmid:gschmid/ptxax-version-warn d32e9b03e8c6c1afc957268f1eefec0e10c5df78 PiperOrigin-RevId: 701903542 --- xla/stream_executor/cuda/BUILD | 6 +++- xla/stream_executor/cuda/nvjitlink_impl.cc | 3 ++ .../cuda/ptx_compiler_helpers.cc | 28 +++++++++++++++++++ .../cuda/ptx_compiler_helpers.h | 7 +++++ xla/stream_executor/cuda/ptx_compiler_impl.cc | 4 +++ .../cuda/subprocess_compilation.cc | 3 ++ 6 files changed, 50 insertions(+), 1 deletion(-) diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index dace5221f9905..e059c81f99ff9 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -660,6 +660,9 @@ cc_library( srcs = ["ptx_compiler_helpers.cc"], hdrs = ["ptx_compiler_helpers.h"], deps = [ + "//xla/stream_executor:device_description", + "//xla/stream_executor:semantic_version", + "@com_google_absl//absl/base", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -667,7 +670,7 @@ cc_library( ], ) -cc_test( +xla_cc_test( name = "ptx_compiler_helpers_test", srcs = ["ptx_compiler_helpers_test.cc"], deps = [ @@ -717,6 +720,7 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:nvptxcompiler", "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/stream_executor/cuda/nvjitlink_impl.cc b/xla/stream_executor/cuda/nvjitlink_impl.cc index 2de6aff6b79d8..160f8bfcc50ef 100644 --- a/xla/stream_executor/cuda/nvjitlink_impl.cc +++ b/xla/stream_executor/cuda/nvjitlink_impl.cc @@ -131,6 +131,9 @@ absl::StatusOr> CompileAndLinkUsingLibNvJitLink( return std::vector(); } + TF_ASSIGN_OR_RETURN((auto [major, minor]), GetNvJitLinkVersion()); + WarnIfBadPtxasVersion("nvJitLink", cc, {major, minor, 0}); + std::vector cli_args; // On Hopper, default to sm_90a so that all instructions can be used. But // only sm_90 is forward compatible, so don't use sm_90a with newer hardware: diff --git a/xla/stream_executor/cuda/ptx_compiler_helpers.cc b/xla/stream_executor/cuda/ptx_compiler_helpers.cc index 8905b65913ecd..596fb58521a5a 100644 --- a/xla/stream_executor/cuda/ptx_compiler_helpers.cc +++ b/xla/stream_executor/cuda/ptx_compiler_helpers.cc @@ -56,4 +56,32 @@ absl::Status CreateErrorFromPTXASLog(std::string_view log, return absl::OkStatus(); } +// Warns if the ptxas version should be upgraded. +// Only prints the warning upon the first invocation. +void WarnIfBadPtxasVersion(std::string_view method, + const CudaComputeCapability& cc, + SemanticVersion compiler_version) { + static absl::once_flag run_once; + absl::call_once(run_once, [&] { + // nvbug 4826023: Occurs on Hopper+ in CUDA versions 12.x up to and + // including CUDA 12.6.2; the earliest ptxas release that corresponds to + // CUDA 12.6.3 is 12.6.85. + if (cc.major >= 9 && compiler_version >= SemanticVersion{12, 0, 0} && + compiler_version < SemanticVersion{12, 6, 85}) { + LOG(ERROR) + << "*** WARNING *** Invoking " << method << " with version " + << compiler_version + << ", which corresponds to a CUDA version <=12.6.2. CUDA versions " + "12.x.y up to and including 12.6.2 miscompile certain edge " + "cases around clamping.\nPlease upgrade to CUDA 12.6.3 or newer."; + if (method != "ptxas" && compiler_version.major() == 12 && + compiler_version.minor() == 6) { + LOG(ERROR) << "(Note that this warning may be shown spuriously for " + "CUDA 12.6.y, since " + << method << " does not report patch versions.)"; + } + } + }); +} + } // namespace stream_executor diff --git a/xla/stream_executor/cuda/ptx_compiler_helpers.h b/xla/stream_executor/cuda/ptx_compiler_helpers.h index af49a4f4a01df..d061eee6184fd 100644 --- a/xla/stream_executor/cuda/ptx_compiler_helpers.h +++ b/xla/stream_executor/cuda/ptx_compiler_helpers.h @@ -17,6 +17,8 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/semantic_version.h" namespace stream_executor { // Checks whether ptxas log contains errors related to register allocation. @@ -30,6 +32,11 @@ bool IsPtxRegisterAllocationError(std::string_view); absl::Status CreateErrorFromPTXASLog(std::string_view log, std::string_view architecture, bool cancel_if_reg_spill); + +// Warns if the ptxas version should be upgraded. +void WarnIfBadPtxasVersion(std::string_view method, + const CudaComputeCapability& cc, + SemanticVersion compiler_version); } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_HELPERS_H_ diff --git a/xla/stream_executor/cuda/ptx_compiler_impl.cc b/xla/stream_executor/cuda/ptx_compiler_impl.cc index 745d7e1bf09e2..e48d73ca1c729 100644 --- a/xla/stream_executor/cuda/ptx_compiler_impl.cc +++ b/xla/stream_executor/cuda/ptx_compiler_impl.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/semantic_version.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace stream_executor { @@ -84,6 +85,9 @@ static std::string_view ToString(nvPTXCompileResult status) { absl::StatusOr> CompileGpuAsmUsingLibNvPtxCompiler( const CudaComputeCapability& cc, const std::string& ptx_contents, GpuAsmOpts options, bool cancel_if_reg_spill) { + TF_ASSIGN_OR_RETURN(auto version, GetLibNvPtxCompilerVersion()); + WarnIfBadPtxasVersion("nvPTXCompiler", cc, version); + nvPTXCompilerHandle compiler_handle{}; RETURN_IF_NVPTXCOMPILER_ERROR(nvPTXCompilerCreate( &compiler_handle, ptx_contents.size(), ptx_contents.data())); diff --git a/xla/stream_executor/cuda/subprocess_compilation.cc b/xla/stream_executor/cuda/subprocess_compilation.cc index f299fd536ef05..b3885becc7b05 100644 --- a/xla/stream_executor/cuda/subprocess_compilation.cc +++ b/xla/stream_executor/cuda/subprocess_compilation.cc @@ -263,6 +263,9 @@ absl::StatusOr> CompileGpuAsmUsingPtxAs( absl::StatusOr> CompileGpuAsmUsingPtxAs( std::string_view ptxas_path, const CudaComputeCapability& cc, std::string_view ptx, GpuAsmOpts options, bool cancel_if_reg_spill) { + TF_ASSIGN_OR_RETURN(auto version, GetToolVersion(ptxas_path)); + WarnIfBadPtxasVersion("ptxas", cc, version); + // Write ptx into a temporary file. std::string ptx_path; auto env = tsl::Env::Default();