diff --git a/recipe/build.sh b/recipe/build.sh index 2cbe5e18..9844292e 100644 --- a/recipe/build.sh +++ b/recipe/build.sh @@ -16,16 +16,33 @@ export CXXFLAGS="${CXXFLAGS} -DNDEBUG" if [[ "${cuda_compiler_version:-None}" != "None" ]]; then if [[ ${cuda_compiler_version} == 11.8 ]]; then - export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_62,sm_70,sm_72,sm_75,sm_80,sm_86,sm_87,sm_89,sm_90,compute_90 - export TF_CUDA_PATHS="${CUDA_HOME},${PREFIX}" + export HERMETIC_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_50,sm_60,sm_62,sm_70,sm_72,sm_75,sm_80,sm_86,sm_87,sm_89,sm_90,compute_90 + export TF_CUDA_PATHS="${CUDA_HOME},${PREFIX}" elif [[ ${cuda_compiler_version} == 12* ]]; then - export TF_CUDA_COMPUTE_CAPABILITIES=sm_60,sm_70,sm_75,sm_80,sm_86,sm_89,sm_90,compute_90 + export HERMETIC_CUDA_COMPUTE_CAPABILITIES=sm_60,sm_70,sm_75,sm_80,sm_86,sm_89,sm_90,compute_90 export CUDA_HOME="${BUILD_PREFIX}/targets/x86_64-linux" export TF_CUDA_PATHS="${BUILD_PREFIX}/targets/x86_64-linux,${PREFIX}/targets/x86_64-linux" - # Needed for some nvcc binaries - export PATH=$PATH:${BUILD_PREFIX}/nvvm/bin - # XLA can only cope with a single cuda header include directory, merge both - rsync -a ${PREFIX}/targets/x86_64-linux/include/ ${BUILD_PREFIX}/targets/x86_64-linux/include/ + # Needed for some nvcc binaries + export PATH=$PATH:${BUILD_PREFIX}/nvvm/bin + # XLA can only cope with a single cuda header include directory, merge both + rsync -a ${PREFIX}/targets/x86_64-linux/include/ ${BUILD_PREFIX}/targets/x86_64-linux/include/ + + # Although XLA supports a non-hermetic build, it still tries to find headers in the hermetic locations. + # We do this in the BUILD_PREFIX to not have any impact on the resulting jaxlib package. + # Otherwise, these copied files would be included in the package. + rm -rf ${BUILD_PREFIX}/targets/x86_64-linux/include/third_party + mkdir -p ${BUILD_PREFIX}/targets/x86_64-linux/include/third_party/gpus/cuda/extras/CUPTI + cp -r ${PREFIX}/targets/x86_64-linux/include ${BUILD_PREFIX}/targets/x86_64-linux/include/third_party/gpus/cuda/ + cp -r ${PREFIX}/targets/x86_64-linux/include ${BUILD_PREFIX}/targets/x86_64-linux/include/third_party/gpus/cuda/extras/CUPTI/ + mkdir -p ${BUILD_PREFIX}/targets/x86_64-linux/include/third_party/gpus/cudnn + cp ${PREFIX}/include/cudnn.h ${BUILD_PREFIX}/targets/x86_64-linux/include/third_party/gpus/cudnn/ + export LOCAL_CUDA_PATH="${BUILD_PREFIX}/targets/x86_64-linux" + export LOCAL_CUDNN_PATH="${PREFIX}/targets/x86_64-linux" + export LOCAL_NCCL_PATH="${PREFIX}/targets/x86_64-linux" + mkdir -p ${BUILD_PREFIX}/targets/x86_64-linux/bin + ln -s $(which ptxas) ${BUILD_PREFIX}/targets/x86_64-linux/bin/ptxas + ln -s $(which nvlink) ${BUILD_PREFIX}/targets/x86_64-linux/bin/nvlink + ln -s $(which fatbinary) ${BUILD_PREFIX}/targets/x86_64-linux/bin/fatbinary else echo "unsupported cuda version." exit 1 @@ -41,9 +58,7 @@ if [[ "${cuda_compiler_version:-None}" != "None" ]]; then CUDA_ARGS="--enable_cuda \ --enable_nccl \ - --cuda_path=${TF_CUDA_PATHS} \ - --cudnn_path=${PREFIX} \ - --cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \ + --cuda_compute_capabilities=$HERMETIC_CUDA_COMPUTE_CAPABILITIES \ --cuda_version=$TF_CUDA_VERSION \ --cudnn_version=$TF_CUDNN_VERSION" fi @@ -88,6 +103,9 @@ if [[ "${target_platform}" == "osx-arm64" || "${target_platform}" != "${build_pl else EXTRA="${CUDA_ARGS:-}" fi +if [[ "${target_platform}" == linux-* ]]; then + EXTRA="${EXTRA} --nouse_clang" +fi ${PYTHON} build/build.py \ --target_cpu_features default \ --enable_mkl_dnn \ diff --git a/recipe/meta.yaml b/recipe/meta.yaml index 5c43b1cc..b927b56c 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -1,5 +1,5 @@ -{% set version = "0.4.31" %} -{% set build = 2 %} +{% set version = "0.4.32" %} +{% set build = 0 %} {% if cuda_compiler_version != "None" %} {% set build = build + 200 %} @@ -13,10 +13,13 @@ package: source: # only pull sources after upstream PyPI release... url: https://github.com/google/jax/archive/jaxlib-v{{ version }}.tar.gz - sha256: 022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94 + sha256: 3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2 patches: - patches/0001-Allow-for-custom-CUDA-build.patch - patches/0002-Consolidated-build-fixes-for-XLA.patch + - patches/0003-Simplify-logic-in-jaxlib-FFI_ASSIGN_OR_RETURN-macro-.patch + - patches/0004-Fix-XLA_FFIR_REGISTER-macros.patch + - patches/0005-Add-missing-typename.patch build: number: {{ build }} @@ -57,6 +60,7 @@ requirements: - cuda-cudart-dev # [(cuda_compiler_version or "").startswith("12")] - cuda-nvml-dev # [(cuda_compiler_version or "").startswith("12")] - cuda-nvtx-dev # [(cuda_compiler_version or "").startswith("12")] + - cuda-nvcc-tools # [(cuda_compiler_version or "").startswith("12")] - libcublas-dev # [(cuda_compiler_version or "").startswith("12")] - libcusolver-dev # [(cuda_compiler_version or "").startswith("12")] - libcurand-dev # [(cuda_compiler_version or "").startswith("12")] @@ -65,6 +69,7 @@ requirements: - python - pip - numpy + - setuptools - wheel - cuda-version {{ cuda_compiler_version }} # [cuda_compiler_version != "None"] # avoid not being able to pass `-C=--build-option=--python-tag=cp` due to @@ -82,9 +87,7 @@ requirements: - scipy >=1.9 - ml_dtypes >=0.2.0 - __cuda # [cuda_compiler_version != "None"] - - cuda-nvcc # [(cuda_compiler_version or "").startswith("12")] - # Workaround for https://github.com/conda-forge/cuda-cupti-feedstock/issues/14 - - cuda-cupti >=12.0.90,<13.0a0 # [(cuda_compiler_version or "").startswith("12")] + - cuda-nvcc-tools # [(cuda_compiler_version or "").startswith("12")] run_constrained: - jax >={{ version }} diff --git a/recipe/patches/0001-Allow-for-custom-CUDA-build.patch b/recipe/patches/0001-Allow-for-custom-CUDA-build.patch index 484c4c39..1c1b9f32 100644 --- a/recipe/patches/0001-Allow-for-custom-CUDA-build.patch +++ b/recipe/patches/0001-Allow-for-custom-CUDA-build.patch @@ -1,25 +1,25 @@ -From 1daa8cc30c7c2d70a71aa164d9ecb5923b34e0c0 Mon Sep 17 00:00:00 2001 +From 1ec53ea591323e47c8ce53ed9b0736e98784ff68 Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Sun, 8 Oct 2023 19:34:34 +0200 -Subject: [PATCH 1/2] Allow for custom CUDA build +Subject: [PATCH 1/5] Allow for custom CUDA build --- build/build.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/build/build.py b/build/build.py -index 2f6822281..0000dbbf8 100755 +index c3a8627..5bfbdcb 100755 --- a/build/build.py +++ b/build/build.py -@@ -277,6 +277,11 @@ def write_bazelrc(*, remote_build, - f.write("build --action_env TF_CUDA_PATHS=\"{tf_cuda_paths}\"\n" - .format(tf_cuda_paths=",".join(tf_cuda_paths))) - if cuda_version: -+ # set GCC_HOST_COMPILER_PATH for toolchain with conda-forge -+ f.write("build --action_env GCC_HOST_COMPILER_PATH=\"{gcc_host_compiler_path}\"\n" -+ .format(gcc_host_compiler_path=os.environ["GCC"])) -+ f.write("build --action_env GCC_HOST_COMPILER_PREFIX=\"{gcc_host_compiler_prefix}\"\n" +@@ -289,6 +289,11 @@ def write_bazelrc(*, remote_build, + f.write("build --config=nvcc_clang\n") + f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") + if cuda_version: ++ # set GCC_HOST_COMPILER_PATH for toolchain with conda-forge ++ f.write("build --action_env GCC_HOST_COMPILER_PATH=\"{gcc_host_compiler_path}\"\n" ++ .format(gcc_host_compiler_path=os.environ["GCC"])) ++ f.write("build --action_env GCC_HOST_COMPILER_PREFIX=\"{gcc_host_compiler_prefix}\"\n" + .format(gcc_host_compiler_prefix=os.path.dirname(os.environ["GCC"]))) - f.write("build --action_env TF_CUDA_VERSION=\"{cuda_version}\"\n" - .format(cuda_version=cuda_version)) - if cudnn_version: + f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n" + .format(cuda_version=cuda_version)) + if cudnn_version: diff --git a/recipe/patches/0002-Consolidated-build-fixes-for-XLA.patch b/recipe/patches/0002-Consolidated-build-fixes-for-XLA.patch index 5722a364..364f7697 100644 --- a/recipe/patches/0002-Consolidated-build-fixes-for-XLA.patch +++ b/recipe/patches/0002-Consolidated-build-fixes-for-XLA.patch @@ -1,7 +1,7 @@ -From 3967a662a3cb00e8144628ba021116ee59d74134 Mon Sep 17 00:00:00 2001 +From a7e732c129f51d16dd59b353e8c66ceae9b5529c Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Thu, 14 Dec 2023 17:06:15 +0100 -Subject: [PATCH 2/2] Consolidated build fixes for XLA +Subject: [PATCH 2/5] Consolidated build fixes for XLA jax vendors xla, but only populates the sources through bazel, so we cannot patch as usual through conda, but rather need to teach the bazel build file @@ -14,24 +14,75 @@ which is also where we're patching in the list of patches to apply to xla. Co-Authored-By: H. Vetinari --- + .../xla/0001-Omit-usage-of-StrFormat.patch | 43 ++++ ...pport-third-party-build-of-boringssl.patch | 51 ++++ third_party/xla/0002-Fix-abseil-headers.patch | 73 ++++++ .../xla/0003-Omit-usage-of-StrFormat.patch | 43 ++++ ...0004-Add-missing-bits-absl-systemlib.patch | 226 ++++++++++++++++++ third_party/xla/workspace.bzl | 6 + - 5 files changed, 399 insertions(+) + 6 files changed, 442 insertions(+) + create mode 100644 third_party/xla/0001-Omit-usage-of-StrFormat.patch create mode 100644 third_party/xla/0001-Support-third-party-build-of-boringssl.patch create mode 100644 third_party/xla/0002-Fix-abseil-headers.patch create mode 100644 third_party/xla/0003-Omit-usage-of-StrFormat.patch create mode 100644 third_party/xla/0004-Add-missing-bits-absl-systemlib.patch +diff --git a/third_party/xla/0001-Omit-usage-of-StrFormat.patch b/third_party/xla/0001-Omit-usage-of-StrFormat.patch +new file mode 100644 +index 0000000..d1b4765 +--- /dev/null ++++ b/third_party/xla/0001-Omit-usage-of-StrFormat.patch +@@ -0,0 +1,43 @@ ++From b7d3f685ea9f58f0054af0f34d0bc3ccac43fa5c Mon Sep 17 00:00:00 2001 ++From: "Uwe L. Korn" ++Date: Thu, 4 Jul 2024 10:36:03 +0200 ++Subject: [PATCH] Omit usage of StrFormat ++ ++--- ++ xla/stream_executor/gpu/gpu_executor.h | 9 ++++++--- ++ 1 file changed, 6 insertions(+), 3 deletions(-) ++ ++diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h ++index 3a5945e884..9bdc2acd47 100644 ++--- a/xla/stream_executor/gpu/gpu_executor.h +++++ b/xla/stream_executor/gpu/gpu_executor.h ++@@ -29,6 +29,7 @@ limitations under the License. ++ #include ++ #include ++ #include +++#include ++ #include ++ #include ++ ++@@ -38,7 +39,6 @@ limitations under the License. ++ #include "absl/numeric/int128.h" ++ #include "absl/status/status.h" ++ #include "absl/status/statusor.h" ++-#include "absl/strings/str_format.h" ++ #include "absl/synchronization/mutex.h" ++ #include "absl/types/span.h" ++ #include "xla/stream_executor/blas.h" ++@@ -187,8 +187,11 @@ class GpuExecutor : public StreamExecutor { ++ uint64_t size) override { ++ auto* buffer = GpuDriver::HostAllocate(context_, size); ++ if (buffer == nullptr && size > 0) { ++- return absl::InternalError( ++- absl::StrFormat("Failed to allocate HostMemory of size %d", size)); +++ std::ostringstream stringStream; +++ stringStream << "Failed to allocate HostMemory of size "; +++ stringStream << size; +++ std::string res = stringStream.str(); +++ return absl::InternalError(res); ++ } ++ return std::make_unique(buffer, size, this); ++ } diff --git a/third_party/xla/0001-Support-third-party-build-of-boringssl.patch b/third_party/xla/0001-Support-third-party-build-of-boringssl.patch new file mode 100644 -index 000000000..e24a45e1f +index 0000000..26a9904 --- /dev/null +++ b/third_party/xla/0001-Support-third-party-build-of-boringssl.patch @@ -0,0 +1,51 @@ -+From 876bfe566992d7829dc4fdb82de72ff2c622f015 Mon Sep 17 00:00:00 2001 ++From 9a5932bb8a363f777ce39ff75a52eda2bba9c21f Mon Sep 17 00:00:00 2001 +From: "Uwe L. Korn" +Date: Thu, 14 Dec 2023 15:04:51 +0100 +Subject: [PATCH 1/4] Support third-party build of boringssl @@ -70,10 +121,10 @@ index 000000000..e24a45e1f ++ ], ++) +diff --git a/workspace2.bzl b/workspace2.bzl -+index 5c9d465040..69dfa954b3 100644 ++index 1809702d8b..6fc538d3a2 100644 +--- a/workspace2.bzl ++++ b/workspace2.bzl -+@@ -67,7 +67,7 @@ def _tf_repositories(): ++@@ -69,7 +69,7 @@ def _tf_repositories(): + name = "boringssl", + sha256 = "9dc53f851107eaf87b391136d13b815df97ec8f76dadb487b58b2fc45e624d2c", + strip_prefix = "boringssl-c00d7ca810e93780bd0c8ee4eea28f4f2ea4bcdc", @@ -84,11 +135,11 @@ index 000000000..e24a45e1f + diff --git a/third_party/xla/0002-Fix-abseil-headers.patch b/third_party/xla/0002-Fix-abseil-headers.patch new file mode 100644 -index 000000000..7a58075e1 +index 0000000..96a6fec --- /dev/null +++ b/third_party/xla/0002-Fix-abseil-headers.patch @@ -0,0 +1,73 @@ -+From adc3749cd0a77a502c9ffd9c558dbee96c1fc0ab Mon Sep 17 00:00:00 2001 ++From 97ad75d0bf891be488fb223ac95ff6572b4ecd88 Mon Sep 17 00:00:00 2001 +From: "Uwe L. Korn" +Date: Thu, 23 May 2024 15:45:52 +0200 +Subject: [PATCH 2/4] Fix abseil headers @@ -101,10 +152,10 @@ index 000000000..7a58075e1 + 4 files changed, 10 insertions(+) + +diff --git a/third_party/tsl/tsl/platform/default/BUILD b/third_party/tsl/tsl/platform/default/BUILD -+index 01cf593888..ba5b5cc068 100644 ++index b3ce4301fb..9b72c2eb42 100644 +--- a/third_party/tsl/tsl/platform/default/BUILD ++++ b/third_party/tsl/tsl/platform/default/BUILD -+@@ -220,6 +220,8 @@ cc_library( ++@@ -225,6 +225,8 @@ cc_library( + deps = [ + "//tsl/platform:logging", + "@com_google_absl//absl/log:check", @@ -114,7 +165,7 @@ index 000000000..7a58075e1 + ) + +diff --git a/third_party/tsl/tsl/profiler/rpc/client/BUILD b/third_party/tsl/tsl/profiler/rpc/client/BUILD -+index 03f8c1deff..1f081a14d1 100644 ++index 4b8ece7403..a2772846b8 100644 +--- a/third_party/tsl/tsl/profiler/rpc/client/BUILD ++++ b/third_party/tsl/tsl/profiler/rpc/client/BUILD +@@ -101,6 +101,8 @@ cc_library( @@ -150,7 +201,7 @@ index 000000000..7a58075e1 + ], + alwayslink = True, +diff --git a/xla/tsl/distributed_runtime/rpc/BUILD b/xla/tsl/distributed_runtime/rpc/BUILD -+index 0f9a93eb1a..e5f11fa62c 100644 ++index 817c4dc5a4..d6f27deb5c 100644 +--- a/xla/tsl/distributed_runtime/rpc/BUILD ++++ b/xla/tsl/distributed_runtime/rpc/BUILD +@@ -37,6 +37,7 @@ cc_library( @@ -163,11 +214,11 @@ index 000000000..7a58075e1 + "@tsl//tsl/platform:status", diff --git a/third_party/xla/0003-Omit-usage-of-StrFormat.patch b/third_party/xla/0003-Omit-usage-of-StrFormat.patch new file mode 100644 -index 000000000..541c06f40 +index 0000000..67d2275 --- /dev/null +++ b/third_party/xla/0003-Omit-usage-of-StrFormat.patch @@ -0,0 +1,43 @@ -+From 8434fbb499a3c035c9b028f1500b01229ce04a4a Mon Sep 17 00:00:00 2001 ++From a360cd33b748c4f6b1ab00e386ac8031112c5b2f Mon Sep 17 00:00:00 2001 +From: "Uwe L. Korn" +Date: Thu, 4 Jul 2024 10:36:03 +0200 +Subject: [PATCH 3/4] Omit usage of StrFormat @@ -177,17 +228,17 @@ index 000000000..541c06f40 + 1 file changed, 6 insertions(+), 3 deletions(-) + +diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h -+index c19fa1ccee..c1565b864e 100644 ++index 8e9a8352e2..36d42493c6 100644 +--- a/xla/stream_executor/gpu/gpu_executor.h ++++ b/xla/stream_executor/gpu/gpu_executor.h -+@@ -28,6 +28,7 @@ limitations under the License. ++@@ -27,6 +27,7 @@ limitations under the License. + #include + #include + #include ++#include + #include + #include -+ ++ #include +@@ -37,7 +38,6 @@ limitations under the License. + #include "absl/numeric/int128.h" + #include "absl/status/status.h" @@ -196,7 +247,7 @@ index 000000000..541c06f40 + #include "absl/synchronization/mutex.h" + #include "absl/types/span.h" + #include "xla/stream_executor/blas.h" -+@@ -177,8 +177,11 @@ class GpuExecutor : public StreamExecutorCommon { ++@@ -166,8 +166,11 @@ class GpuExecutor : public StreamExecutorCommon { + uint64_t size) override { + auto* buffer = GpuDriver::HostAllocate(context_, size); + if (buffer == nullptr && size > 0) { @@ -212,16 +263,15 @@ index 000000000..541c06f40 + } diff --git a/third_party/xla/0004-Add-missing-bits-absl-systemlib.patch b/third_party/xla/0004-Add-missing-bits-absl-systemlib.patch new file mode 100644 -index 000000000..e151c23c8 +index 0000000..1941f79 --- /dev/null +++ b/third_party/xla/0004-Add-missing-bits-absl-systemlib.patch @@ -0,0 +1,226 @@ -+From f43652257c58896305d13c6dc9829c9f3f522a8f Mon Sep 17 00:00:00 2001 ++From fc6d67a2f5fce78eb91477fa4bca5c47b6fc31fd Mon Sep 17 00:00:00 2001 +From: "Uwe L. Korn" +Date: Thu, 4 Jul 2024 15:58:32 +0200 +Subject: [PATCH 4/4] Add missing bits absl systemlib + -+Co-Authored-By: H. Vetinari +--- + .../third_party/absl/system.absl.base.BUILD | 16 +++++ + .../third_party/absl/system.absl.crc.BUILD | 70 +++++++++++++++++++ @@ -425,14 +475,15 @@ index 000000000..e151c23c8 + name = "strings", + linkopts = ["-labsl_strings"], +diff --git a/third_party/tsl/third_party/absl/workspace.bzl b/third_party/tsl/third_party/absl/workspace.bzl -+index 06f75166ce..446dbc4081 100644 ++index 9565a82c33..e71aa16726 100644 +--- a/third_party/tsl/third_party/absl/workspace.bzl ++++ b/third_party/tsl/third_party/absl/workspace.bzl -+@@ -15,11 +15,13 @@ def repo(): ++@@ -14,12 +14,14 @@ def repo(): ++ SYS_DIRS = [ + "algorithm", + "base", -+ "cleanup", ++ "crc", ++ "cleanup", + "container", + "debugging", + "flags", @@ -443,7 +494,7 @@ index 000000000..e151c23c8 + "meta", + "numeric", diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl -index af52e7671..76fb83680 100644 +index 8f4accc..3b7afaf 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -30,6 +30,12 @@ def repo(): diff --git a/recipe/patches/0003-Simplify-logic-in-jaxlib-FFI_ASSIGN_OR_RETURN-macro-.patch b/recipe/patches/0003-Simplify-logic-in-jaxlib-FFI_ASSIGN_OR_RETURN-macro-.patch new file mode 100644 index 00000000..3d4e600c --- /dev/null +++ b/recipe/patches/0003-Simplify-logic-in-jaxlib-FFI_ASSIGN_OR_RETURN-macro-.patch @@ -0,0 +1,62 @@ +From 69f982e2ce5408b0d782545e346c9c100b916b8b Mon Sep 17 00:00:00 2001 +From: Dan Foreman-Mackey +Date: Tue, 17 Sep 2024 11:22:49 -0700 +Subject: [PATCH 3/5] Simplify logic in jaxlib FFI_ASSIGN_OR_RETURN macro, and + fix gcc build. + +In https://github.com/google/jax/issues/23687, it was reported that recent jaxlib changes introduced issues when building from source using gcc, instead of the clang build that we test. I'm not 100% sure why the previous macro didn't work, but in investigating I found a version that seems to work on both clang and gcc with simpler logic. + +PiperOrigin-RevId: 675641259 +--- + jaxlib/ffi_helpers.h | 38 +++++++++----------------------------- + 1 file changed, 9 insertions(+), 29 deletions(-) + +diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h +index fba57d1..4750502 100644 +--- a/jaxlib/ffi_helpers.h ++++ b/jaxlib/ffi_helpers.h +@@ -62,35 +62,15 @@ namespace jax { + FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) + + // All the macros below here are to handle the case in FFI_ASSIGN_OR_RETURN +-// where the LHS is wrapped in parentheses. +-#define FFI_ASSIGN_OR_RETURN_EAT(...) +-#define FFI_ASSIGN_OR_RETURN_REM(...) __VA_ARGS__ +-#define FFI_ASSIGN_OR_RETURN_EMPTY() +- +-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(...) \ +- FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER((__VA_ARGS__, 0, 1)) +-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER(args) \ +- FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I args +-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) is_empty +- +-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY(...) \ +- FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(__VA_ARGS__) +-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(...) \ +- FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(_, ##__VA_ARGS__) +- +-#define FFI_ASSIGN_OR_RETURN_IF_1(_Then, _Else) _Then +-#define FFI_ASSIGN_OR_RETURN_IF_0(_Then, _Else) _Else +-#define FFI_ASSIGN_OR_RETURN_IF(_Cond, _Then, _Else) \ +- FFI_ASSIGN_OR_RETURN_CONCAT_(FFI_ASSIGN_OR_RETURN_IF_, _Cond)(_Then, _Else) +- +-#define FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(...) \ +- FFI_ASSIGN_OR_RETURN_IS_EMPTY(FFI_ASSIGN_OR_RETURN_EAT __VA_ARGS__) +- +-#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(...) \ +- FFI_ASSIGN_OR_RETURN_IF(FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(__VA_ARGS__), \ +- FFI_ASSIGN_OR_RETURN_REM, \ +- FFI_ASSIGN_OR_RETURN_EMPTY()) \ +- __VA_ARGS__ ++// where the LHS is wrapped in parentheses. See a more detailed discussion at ++// https://stackoverflow.com/a/62984543 ++#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(X) \ ++ FFI_ASSIGN_OR_RETURN_ESCAPE(FFI_ASSIGN_OR_RETURN_EMPTY X) ++#define FFI_ASSIGN_OR_RETURN_EMPTY(...) FFI_ASSIGN_OR_RETURN_EMPTY __VA_ARGS__ ++#define FFI_ASSIGN_OR_RETURN_ESCAPE(...) \ ++ FFI_ASSIGN_OR_RETURN_ESCAPE_(__VA_ARGS__) ++#define FFI_ASSIGN_OR_RETURN_ESCAPE_(...) FFI_ASSIGN_OR_RETURN_##__VA_ARGS__ ++#define FFI_ASSIGN_OR_RETURN_FFI_ASSIGN_OR_RETURN_EMPTY + + template + inline absl::StatusOr MaybeCastNoOverflow( diff --git a/recipe/patches/0004-Fix-XLA_FFIR_REGISTER-macros.patch b/recipe/patches/0004-Fix-XLA_FFIR_REGISTER-macros.patch new file mode 100644 index 00000000..4c6e42e7 --- /dev/null +++ b/recipe/patches/0004-Fix-XLA_FFIR_REGISTER-macros.patch @@ -0,0 +1,147 @@ +From 39e78f89a5c0c92c5e8591aa458de268493b829d Mon Sep 17 00:00:00 2001 +From: "Uwe L. Korn" +Date: Tue, 8 Oct 2024 12:57:07 +0000 +Subject: [PATCH 4/5] Fix XLA_FFIR_REGISTER macros + +--- + ..._FFI_REGISTER_-macros-global-qualifi.patch | 118 ++++++++++++++++++ + third_party/xla/workspace.bzl | 1 + + 2 files changed, 119 insertions(+) + create mode 100644 third_party/xla/0005-PR-17477-Fix-XLA_FFI_REGISTER_-macros-global-qualifi.patch + +diff --git a/third_party/xla/0005-PR-17477-Fix-XLA_FFI_REGISTER_-macros-global-qualifi.patch b/third_party/xla/0005-PR-17477-Fix-XLA_FFI_REGISTER_-macros-global-qualifi.patch +new file mode 100644 +index 0000000..8fc06e0 +--- /dev/null ++++ b/third_party/xla/0005-PR-17477-Fix-XLA_FFI_REGISTER_-macros-global-qualifi.patch +@@ -0,0 +1,118 @@ ++From 6b73d321ad45ca86cba50a308f12215a6f96ee28 Mon Sep 17 00:00:00 2001 ++From: Alexander Pivovarov ++Date: Mon, 23 Sep 2024 00:35:34 -0700 ++Subject: [PATCH 5/5] PR #17477: Fix XLA_FFI_REGISTER_ macros - global ++ qualification of class name is invalid ++ ++Imported from GitHub PR https://github.com/openxla/xla/pull/17477 ++ ++Currently `bazel test //xla/ffi/api:ffi_test` fails with compilation error: ++```bash ++In file included from ./xla/ffi/api/ffi.h:48, ++ from xla/ffi/api/ffi_test.cc:16: ++./xla/ffi/api/api.h:1774:38: error: global qualification of class name is invalid before '{' token ++ 1774 | struct ::xla::ffi::AttrDecoding { \ ++ | ^ ++xla/ffi/api/ffi_test.cc:71:1: note: in expansion of macro 'XLA_FFI_REGISTER_ENUM_ATTR_DECODING' ++ 71 | XLA_FFI_REGISTER_ENUM_ATTR_DECODING(::xla::ffi::Int32BasedEnum); ++ | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ++``` ++ ++To solve "global qualification of class name is invalid" issue we can add `namespace xla::ffi { ` block to the macros and remove `::xla::ffi::` prefix in struct decls inside `XLA_FFI_REGISTER_*` macros ++ ++### Testing ++``` ++bazel test //xla/ffi/... ++ ++INFO: Build completed successfully, 47 total actions ++//xla/ffi:ffi_test PASSED in 0.5s ++//xla/ffi:call_frame_test PASSED in 0.1s ++//xla/ffi:execution_context_test PASSED in 0.1s ++//xla/ffi:execution_state_test PASSED in 0.1s ++//xla/ffi:type_id_registry_test PASSED in 0.1s ++//xla/ffi/api:ffi_test PASSED in 0.5s ++ ++Executed 6 out of 6 tests: 6 tests pass. ++``` ++ ++### Related links: ++- https://github.com/openxla/xla/pull/15747 ++- https://github.com/openxla/xla/commit/ef49d057bffd4b8ff14bda925d48ea7610aaa856 ++ ++Copybara import of the project: ++ ++-- ++fffa62b2d47feb915c0c6300b0af5540974911d4 by Alexander Pivovarov : ++ ++Fix XLA_FFI_REGISTER_ macros ++ ++Merging this change closes #17477 ++ ++COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/17477 from apivovarov:fix_XLA_FFI_REGISTER_macro fffa62b2d47feb915c0c6300b0af5540974911d4 ++PiperOrigin-RevId: 677666742 ++--- ++ xla/ffi/api/api.h | 20 ++++++++++++++------ ++ 1 file changed, 14 insertions(+), 6 deletions(-) ++ ++diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h ++index 8e3774f45c..914a1a8697 100644 ++--- a/xla/ffi/api/api.h +++++ b/xla/ffi/api/api.h ++@@ -1678,13 +1678,14 @@ auto DictionaryDecoder(Members... m) { ++ // binding specification inference from a callable signature. ++ // ++ #define XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(T, ...) \ +++ namespace xla::ffi { \ ++ template <> \ ++- struct ::xla::ffi::AttrsBinding { \ +++ struct AttrsBinding { \ ++ using Attrs = T; \ ++ }; \ ++ \ ++ template <> \ ++- struct ::xla::ffi::AttrDecoding { \ +++ struct AttrDecoding { \ ++ using Type = T; \ ++ static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ ++ DiagnosticEngine& diagnostic) { \ ++@@ -1699,13 +1700,17 @@ auto DictionaryDecoder(Members... m) { ++ reinterpret_cast(attr), \ ++ internal::StructMemberNames(__VA_ARGS__), diagnostic); \ ++ } \ ++- } +++ }; \ +++ } /* namespace xla::ffi */ \ +++ static_assert(std::is_class_v<::xla::ffi::AttrsBinding>); \ +++ static_assert(std::is_class_v<::xla::ffi::AttrDecoding>) ++ ++ // Registers decoding for a user-defined enum class type. Uses enums underlying ++ // type to decode the attribute as a scalar value and cast it to the enum type. ++ #define XLA_FFI_REGISTER_ENUM_ATTR_DECODING(T) \ +++ namespace xla::ffi { \ ++ template <> \ ++- struct ::xla::ffi::AttrDecoding { \ +++ struct AttrDecoding { \ ++ using Type = T; \ ++ using U = std::underlying_type_t; \ ++ static_assert(std::is_enum::value, "Expected enum class"); \ ++@@ -1718,7 +1723,8 @@ auto DictionaryDecoder(Members... m) { ++ } \ ++ \ ++ auto* scalar = reinterpret_cast(attr); \ ++- auto expected_dtype = internal::NativeTypeToCApiDataType(); \ +++ auto expected_dtype = \ +++ ::xla::ffi::internal::NativeTypeToCApiDataType(); \ ++ if (XLA_FFI_PREDICT_FALSE(scalar->dtype != expected_dtype)) { \ ++ return diagnostic.Emit("Wrong scalar data type: expected ") \ ++ << expected_dtype << " but got " << scalar->dtype; \ ++@@ -1727,7 +1733,9 @@ auto DictionaryDecoder(Members... m) { ++ auto underlying = *reinterpret_cast(scalar->value); \ ++ return static_cast(underlying); \ ++ } \ ++- }; +++ }; \ +++ } /* namespace xla::ffi */ \ +++ static_assert(std::is_class_v<::xla::ffi::AttrDecoding>) ++ ++ //===----------------------------------------------------------------------===// ++ // Helper macro for registering FFI implementations +diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl +index 3b7afaf..013a020 100644 +--- a/third_party/xla/workspace.bzl ++++ b/third_party/xla/workspace.bzl +@@ -35,6 +35,7 @@ def repo(): + "//third_party/xla:0002-Fix-abseil-headers.patch", + "//third_party/xla:0003-Omit-usage-of-StrFormat.patch", + "//third_party/xla:0004-Add-missing-bits-absl-systemlib.patch", ++ "//third_party/xla:0005-PR-17477-Fix-XLA_FFI_REGISTER_-macros-global-qualifi.patch", + ], + ) + diff --git a/recipe/patches/0005-Add-missing-typename.patch b/recipe/patches/0005-Add-missing-typename.patch new file mode 100644 index 00000000..6b3be9ba --- /dev/null +++ b/recipe/patches/0005-Add-missing-typename.patch @@ -0,0 +1,22 @@ +From 86454745f436411956850d19238bfd3b55aa2a7f Mon Sep 17 00:00:00 2001 +From: "Uwe L. Korn" +Date: Wed, 9 Oct 2024 09:39:46 +0000 +Subject: [PATCH 5/5] Add missing typename + +--- + jaxlib/gpu/solver_kernels_ffi.cc | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc +index 3c74b85..e839494 100644 +--- a/jaxlib/gpu/solver_kernels_ffi.cc ++++ b/jaxlib/gpu/solver_kernels_ffi.cc +@@ -618,7 +618,7 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); +- auto w_data = static_cast::Type*>(w->untyped_data()); ++ auto w_data = static_cast::Type*>(w->untyped_data()); + auto info_data = info->typed_data(); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( diff --git a/recipe/patches/c61e49cd4a6c58b3b9823a32fe1320d65c98c45d.patch b/recipe/patches/c61e49cd4a6c58b3b9823a32fe1320d65c98c45d.patch new file mode 100644 index 00000000..3fc88041 --- /dev/null +++ b/recipe/patches/c61e49cd4a6c58b3b9823a32fe1320d65c98c45d.patch @@ -0,0 +1,62 @@ +From c61e49cd4a6c58b3b9823a32fe1320d65c98c45d Mon Sep 17 00:00:00 2001 +From: Dan Foreman-Mackey +Date: Tue, 17 Sep 2024 11:22:49 -0700 +Subject: [PATCH] Simplify logic in jaxlib FFI_ASSIGN_OR_RETURN macro, and fix + gcc build. + +In https://github.com/google/jax/issues/23687, it was reported that recent jaxlib changes introduced issues when building from source using gcc, instead of the clang build that we test. I'm not 100% sure why the previous macro didn't work, but in investigating I found a version that seems to work on both clang and gcc with simpler logic. + +PiperOrigin-RevId: 675641259 +--- + jaxlib/ffi_helpers.h | 38 +++++++++----------------------------- + 1 file changed, 9 insertions(+), 29 deletions(-) + +diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h +index fba57d11b9f2..47505020f3b8 100644 +--- a/jaxlib/ffi_helpers.h ++++ b/jaxlib/ffi_helpers.h +@@ -62,35 +62,15 @@ namespace jax { + FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) + + // All the macros below here are to handle the case in FFI_ASSIGN_OR_RETURN +-// where the LHS is wrapped in parentheses. +-#define FFI_ASSIGN_OR_RETURN_EAT(...) +-#define FFI_ASSIGN_OR_RETURN_REM(...) __VA_ARGS__ +-#define FFI_ASSIGN_OR_RETURN_EMPTY() +- +-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(...) \ +- FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER((__VA_ARGS__, 0, 1)) +-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_HELPER(args) \ +- FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I args +-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) is_empty +- +-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY(...) \ +- FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(__VA_ARGS__) +-#define FFI_ASSIGN_OR_RETURN_IS_EMPTY_I(...) \ +- FFI_ASSIGN_OR_RETURN_IS_EMPTY_INNER(_, ##__VA_ARGS__) +- +-#define FFI_ASSIGN_OR_RETURN_IF_1(_Then, _Else) _Then +-#define FFI_ASSIGN_OR_RETURN_IF_0(_Then, _Else) _Else +-#define FFI_ASSIGN_OR_RETURN_IF(_Cond, _Then, _Else) \ +- FFI_ASSIGN_OR_RETURN_CONCAT_(FFI_ASSIGN_OR_RETURN_IF_, _Cond)(_Then, _Else) +- +-#define FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(...) \ +- FFI_ASSIGN_OR_RETURN_IS_EMPTY(FFI_ASSIGN_OR_RETURN_EAT __VA_ARGS__) +- +-#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(...) \ +- FFI_ASSIGN_OR_RETURN_IF(FFI_ASSIGN_OR_RETURN_IS_PARENTHESIZED(__VA_ARGS__), \ +- FFI_ASSIGN_OR_RETURN_REM, \ +- FFI_ASSIGN_OR_RETURN_EMPTY()) \ +- __VA_ARGS__ ++// where the LHS is wrapped in parentheses. See a more detailed discussion at ++// https://stackoverflow.com/a/62984543 ++#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(X) \ ++ FFI_ASSIGN_OR_RETURN_ESCAPE(FFI_ASSIGN_OR_RETURN_EMPTY X) ++#define FFI_ASSIGN_OR_RETURN_EMPTY(...) FFI_ASSIGN_OR_RETURN_EMPTY __VA_ARGS__ ++#define FFI_ASSIGN_OR_RETURN_ESCAPE(...) \ ++ FFI_ASSIGN_OR_RETURN_ESCAPE_(__VA_ARGS__) ++#define FFI_ASSIGN_OR_RETURN_ESCAPE_(...) FFI_ASSIGN_OR_RETURN_##__VA_ARGS__ ++#define FFI_ASSIGN_OR_RETURN_FFI_ASSIGN_OR_RETURN_EMPTY + + template + inline absl::StatusOr MaybeCastNoOverflow(