-
-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate also jax-cuda-plugin and jax-cuda-pjrt in cuda builds and bump CUDA used at built time to 12.6 and add fixes for CUDA #288
Conversation
The |
Hi! This is the friendly automated conda-forge-linting service. I just wanted to let you know that I linted all conda-recipes in your PR ( |
@conda-forge-admin, please rerender |
…nda-forge-pinning 2024.11.18.19.00.37
I started a built of a |
I tried to build several hours ago but got the following error
|
Indeed the same for me:
The full log: log-jaxlib-cuda.txt . Probably somehow some headers try to use the internal cudnn. |
Probably we need to add |
The cudnn fix worked fine, now the new error is:
|
The only occurrence of a similar problem are in conda-forge/bazel-feedstock#188 (comment), but then the affected user reports that the problem was solved, without saying what is the corresponding change (see https://xkcd.com/979/, but in this case the user is myself :D ). |
Actually, now that I think of this, probably I did a patched that then was rebased together to clean the PR. Probably the related patch is something like https://github.com/conda-forge/bazel-feedstock/blob/764ac0bb362224f0e8deb53b1a6a3f441b6ead7d/recipe/patches/0002-Build-with-native-dependencies.patch#L179-L189 . |
The linker command seems contain some absl libraries, but not all the one required:
|
After a bit of an hack (passing the missing linker flags all as part of an unrelated absl target that I know as linked) the compilation end successfully, but the produced jaxlib crashes at runtime:
|
The backtrace is:
|
Related to this part of code: https://github.com/openxla/xla/blob/7fd2196f3f21f67bd1bbde9adfe819117454acb3/xla/pjrt/c/pjrt_c_api_gpu.cc#L25-L30 . |
xref: abseil/abseil-cpp#1656 . |
Indeed this issue seems to describe exactly the issue. In a nutshell, apparently two parts of the code call Probably this does not happen on the PyPI packages, as there the Possible solutions: Use static abseil (at least for
|
recipe/build.sh
Outdated
@@ -78,7 +78,7 @@ build --verbose_failures | |||
build --toolchain_resolution_debug | |||
build --define=PREFIX=${PREFIX} | |||
build --define=PROTOBUF_INCLUDE_PATH=${PREFIX}/include | |||
build --local_cpu_resources=${CPU_COUNT} | |||
build --local_cpu_resources=120 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, this was not supposed to be committed, my bad.
That's a workaround I would be happy with for now. I would expect that the packages will always be imported one after another, |
…nda-forge-pinning 2024.11.22.09.17.35
Cleaned up a bit and implemented the suggestion. @traversaro Can you check whether this fixes your problem? |
Hi! This is the friendly automated conda-forge-linting service. I just wanted to let you know that I linted all conda-recipes in your PR ( I do have some suggestions for making it better though... For recipe/meta.yaml:
This message was generated by GitHub Actions workflow run https://github.com/conda-forge/conda-forge-webservices/actions/runs/12233205650. Examine the logs at this URL for more detail. |
|
Yes, that segfault. The related part of the code is https://github.com/openxla/xla/blob/626f1d2aadd2bb6d2217ffdcf6dba3933cffa183/xla/stream_executor/cuda/cuda_blas.cc#L188-L208 . I need to understand how to investigate better but my guess is that the following is happening: somehow the cuda cuBLAS is not found/not initialized (and this is the real problem), while if cuBLAS is installed in the system it is correctly found/initialized. Then, an error message would be printed, but using the log results in a segfault. |
Ok, for now I just inspected the code, but I think I am understanding what is going on (no, that was the wrong system). However, the CUDA xla plugins calls cuBLAS via a trampoline, and the trampoline is quite picky on the version of cuBLAS installed, trying explicitly to load the exact version used to build:
|
Ok, I noticed that also |
Ok, I am not sure but apparently setting |
Indeed, before segfaulting it is possible to see by setting
Interestingly, for some reason when jax is installed via pip,
|
A bit of update on my weekends findings/theories, as I am not sure when I will come back to this: I think the issue on the fact that conda's build of jaxlib looks for This depends on the value of the
Note that github search filters results in |
When building locally, before the build directory is cleaned, find -name 'cuda_config.h'
./build_artifacts/jaxlib_1731957330152/_build_env/share/bazel/a6978338d085f2f71b32b3d4d50f2908/external/local_config_cuda/cuda/cuda/cuda_config.h and yes, only #ifndef CUDA_CUDA_CONFIG_H_
#define CUDA_CUDA_CONFIG_H_
#define TF_CUDA_VERSION "12.0"
#define TF_CUDART_VERSION "12"
#define TF_CUPTI_VERSION ""
#define TF_CUBLAS_VERSION ""
#define TF_CUSOLVER_VERSION ""
#define TF_CURAND_VERSION ""
#define TF_CUFFT_VERSION ""
#define TF_CUSPARSE_VERSION ""
#define TF_CUDNN_VERSION ""
#define TF_CUDA_TOOLKIT_PATH ""
#define TF_CUDA_COMPUTE_CAPABILITIES 60, 70, 75, 80, 86, 89, 90, 90
#endif // CUDA_CUDA_CONFIG_H_ (this file is generated before the CUDA version is bumped to 12.6) In the same directory, config = {"cuda_version": "12.0", "cudnn_version": "", "cuda_compute_capabilities": ["sm_60", "sm_70", "sm_75", "sm_80", "sm_86", "sm_89", "sm_9 0", "compute_90"], "cpu_compiler": ""} |
It reads the version from and the detected version is from the filenames. I check the cublas directory
and I don't find
|
Oh, it's in |
This is set by jaxlib-feedstock/recipe/build.sh Line 41 in 266fbcd
|
Good catch! Either we need to fix/patch that part somehow, or add |
Here, the jaxlib-feedstock/recipe/build.sh Line 30 in 266fbcd
Note that cuDNN has the same issue - only the header file is copied jaxlib-feedstock/recipe/build.sh Lines 39 to 40 in 266fbcd
jaxlib-feedstock/recipe/build.sh Line 42 in 266fbcd
|
I’ve been following this issue from afar. Just wanted to say I appreciate the work that you guys are doing! Thank you for the time and effort put in! |
That sounds like the easiest solution by far? 🤔 |
Definitely! I pushed a commit in 8c12eaf that adds those as dependencies, and also patches XLA to be able to find the cuda libraries as installed by conda cuda packages without the need to set the env variable |
@njzjz I was not able to look into your findings, so if you have any alternative solution feel free to propose, thanks! |
It looks good to me! I think this PR can be merged and other improvements can be done in the future PRs. |
Thanks a lot for the merge @xhochy ! |
Fix #285 and conda-forge/jax-feedstock#162 .
Recap of changes:
libcublas-dev
,libcusolver-dev
,libcurand-dev
,cuda-cupti-dev
,libcufft-dev
,libcusparse-dev
as run dependencies to workaround dynamic loading logic ofxla
that looks for libraries without.12
suffixexport XLA_FLAGS=--xla_gpu_cuda_data_dir=$CONDA_PREFIX
(backport of cuda_root_path: Find cuda libraries when installed with conda packages openxla/xla#20288)Checklist
0
(if the version changed)conda-smithy
(Use the phrase@conda-forge-admin, please rerender
in a comment in this PR for automated rerendering)