Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX does not recognise my NVIDIA GPU when installed via conda #162

Open
1 task done
zairving opened this issue Oct 30, 2024 · 3 comments
Open
1 task done

JAX does not recognise my NVIDIA GPU when installed via conda #162

zairving opened this issue Oct 30, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@zairving
Copy link

zairving commented Oct 30, 2024

Solution to issue cannot be found in the documentation.

  • I checked the documentation.

Issue

I initially opened this issue on the JAX repo and they redirected me here. Below is my original post:

I previously had a working installation of JAX (installed via conda) that recognised my NVIDIA GPU without issue. However, I recently migrated to a new machine and now I cannot get JAX to recognise my GPU when I install via conda. I'm using Miniforge to manage my conda environments, as I did on my old machine, and I installed JAX according to the docs:

conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia

When I then try to import JAX and check my available devices using:

import jax

print(jax.devices())

I get the following output:

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]

tensorflow, however, does recognise my GPU, and so I tried to install using pip. I created a new environment and ran:

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

When I then ran my JAX code above I got the output:

[CudaDevice(id=0)]

Voilà, my GPU has been found!

It therefore appears that the conda section of the docs might need updating.

Installed packages

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
binutils_impl_linux-64    2.43                 h4bf12b8_2    conda-forge
binutils_linux-64         2.43                 h4852527_2    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
c-ares                    1.34.2               heb4867d_0    conda-forge
ca-certificates           2024.8.30            hbcca054_0    conda-forge
comm                      0.2.2              pyhd8ed1ab_0    conda-forge
cuda-cccl_linux-64        12.6.77                       0    nvidia
cuda-crt-dev_linux-64     12.6.77                       0    nvidia
cuda-crt-tools            12.6.77                       0    nvidia
cuda-cudart               12.6.77                       0    nvidia
cuda-cudart-dev           12.6.77                       0    nvidia
cuda-cudart-dev_linux-64  12.6.77                       0    nvidia
cuda-cudart-static        12.6.77                       0    nvidia
cuda-cudart-static_linux-64 12.6.77                       0    nvidia
cuda-cudart_linux-64      12.6.77                       0    nvidia
cuda-cupti                12.6.80              hbd13f7d_0    conda-forge
cuda-driver-dev_linux-64  12.6.77                       0    nvidia
cuda-nvcc                 12.6.77                       0    nvidia
cuda-nvcc-dev_linux-64    12.6.77                       0    nvidia
cuda-nvcc-impl            12.6.77                       0    nvidia
cuda-nvcc-tools           12.6.77                       0    nvidia
cuda-nvcc_linux-64        12.6.77                       0    nvidia
cuda-nvrtc                12.6.77              hbd13f7d_0    conda-forge
cuda-nvtx                 12.6.77              hbd13f7d_0    conda-forge
cuda-nvvm-dev_linux-64    12.6.77                       0    nvidia
cuda-nvvm-impl            12.6.77                       0    nvidia
cuda-nvvm-tools           12.6.77                       0    nvidia
cuda-version              12.6                          3    nvidia
cudnn                     9.3.0.75             h93bb076_0    conda-forge
debugpy                   1.8.7           py313h46c70d0_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
exceptiongroup            1.2.2              pyhd8ed1ab_0    conda-forge
executing                 2.1.0              pyhd8ed1ab_0    conda-forge
gcc_impl_linux-64         12.4.0               hb2e57f8_1    conda-forge
gcc_linux-64              12.4.0               h6b7512a_5    conda-forge
gxx_impl_linux-64         12.4.0               h613a52c_1    conda-forge
gxx_linux-64              12.4.0               h8489865_5    conda-forge
importlib-metadata        8.5.0              pyha770c72_0    conda-forge
ipykernel                 6.29.5             pyh3099207_0    conda-forge
ipython                   8.29.0             pyh707e725_0    conda-forge
jax                       0.4.34             pyhd8ed1ab_0    conda-forge
jaxlib                    0.4.34          cuda120py313h3b1fb80_200    conda-forge
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
jupyter_client            8.6.3              pyhd8ed1ab_0    conda-forge
jupyter_core              5.7.2              pyh31011fe_1    conda-forge
kernel-headers_linux-64   3.10.0              he073ed8_18    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
krb5                      1.21.3               h659f571_0    conda-forge
ld_impl_linux-64          2.43                 h712a8e2_2    conda-forge
libabseil                 20240722.0      cxx17_h5888daf_1    conda-forge
libblas                   3.9.0           25_linux64_openblas    conda-forge
libcblas                  3.9.0           25_linux64_openblas    conda-forge
libcublas                 12.6.3.3             hbd13f7d_1    conda-forge
libcufft                  11.3.0.4             hbd13f7d_0    conda-forge
libcurand                 10.3.7.77            hbd13f7d_0    conda-forge
libcusolver               11.7.1.2             hbd13f7d_0    conda-forge
libcusparse               12.5.4.2             hbd13f7d_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libexpat                  2.6.3                h5888daf_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc                    14.2.0               h77fa898_1    conda-forge
libgcc-devel_linux-64     12.4.0             ha4f9413_101    conda-forge
libgcc-ng                 14.2.0               h69a702a_1    conda-forge
libgfortran               14.2.0               h69a702a_1    conda-forge
libgfortran-ng            14.2.0               h69a702a_1    conda-forge
libgfortran5              14.2.0               hd5240d6_1    conda-forge
libgomp                   14.2.0               h77fa898_1    conda-forge
libgrpc                   1.65.5               hf5c653b_0    conda-forge
liblapack                 3.9.0           25_linux64_openblas    conda-forge
libmpdec                  4.0.0                h4bc722e_0    conda-forge
libnvjitlink              12.6.77              hbd13f7d_1    conda-forge
libopenblas               0.3.28          pthreads_h94d23a6_0    conda-forge
libprotobuf               5.27.5               h5b01275_2    conda-forge
libre2-11                 2024.07.02           hbbce691_1    conda-forge
libsanitizer              12.4.0               h46f95d5_1    conda-forge
libsodium                 1.0.20               h4ab18f5_0    conda-forge
libsqlite                 3.47.0               hadc24fc_1    conda-forge
libstdcxx                 14.2.0               hc0a3c3a_1    conda-forge
libstdcxx-devel_linux-64  12.4.0             ha4f9413_101    conda-forge
libstdcxx-ng              14.2.0               h4852527_1    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libzlib                   1.3.1                hb9d3cd8_2    conda-forge
matplotlib-inline         0.1.7              pyhd8ed1ab_0    conda-forge
ml_dtypes                 0.5.0           py313ha87cce1_0    conda-forge
nccl                      2.23.4.1             h52f6c39_1    conda-forge
ncurses                   6.5                  he02047a_1    conda-forge
nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
numpy                     2.1.2           py313h4bf6692_0    conda-forge
openssl                   3.3.2                hb9d3cd8_0    conda-forge
opt-einsum                3.4.0                hd8ed1ab_0    conda-forge
opt_einsum                3.4.0              pyhd8ed1ab_0    conda-forge
packaging                 24.1               pyhd8ed1ab_0    conda-forge
parso                     0.8.4              pyhd8ed1ab_0    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pip                       24.3.1             pyh145f28c_0    conda-forge
platformdirs              4.3.6              pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.48             pyha770c72_0    conda-forge
psutil                    6.1.0           py313h536fd9c_0    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.3              pyhd8ed1ab_0    conda-forge
pygments                  2.18.0             pyhd8ed1ab_0    conda-forge
python                    3.13.0          h9ebbce0_100_cp313    conda-forge
python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
python_abi                3.13                    5_cp313    conda-forge
pyzmq                     26.2.0          py313h8e95178_3    conda-forge
re2                       2024.07.02           h77b4e00_1    conda-forge
readline                  8.2                  h8228510_1    conda-forge
scipy                     1.14.1          py313h27c5614_1    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
sysroot_linux-64          2.17                h4a8ded7_18    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tornado                   6.4.1           py313h536fd9c_1    conda-forge
traitlets                 5.14.3             pyhd8ed1ab_0    conda-forge
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2024b                hc8b5060_0    conda-forge
wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zeromq                    4.3.5                h3b0a872_6    conda-forge
zipp                      3.20.2             pyhd8ed1ab_0    conda-forge

Environment info

     active environment : jax_ml7
    active env location : /home/zac/miniforge3/envs/jax_ml7
            shell level : 2
       user config file : /home/zac/.condarc
 populated config files : /home/zac/miniforge3/.condarc
          conda version : 24.9.0
    conda-build version : not installed
         python version : 3.12.7.final.0
                 solver : libmamba (default)
       virtual packages : __archspec=1=zen4
                          __conda=24.9.0=0
                          __cuda=12.4=0
                          __glibc=2.39=0
                          __linux=6.8.0=0
                          __unix=0=0
       base environment : /home/zac/miniforge3  (writable)
      conda av data dir : /home/zac/miniforge3/etc/conda
  conda av metadata url : None
           channel URLs : https://conda.anaconda.org/conda-forge/linux-64
                          https://conda.anaconda.org/conda-forge/noarch
          package cache : /home/zac/miniforge3/pkgs
                          /home/zac/.conda/pkgs
       envs directories : /home/zac/miniforge3/envs
                          /home/zac/.conda/envs
               platform : linux-64
             user-agent : conda/24.9.0 requests/2.32.3 CPython/3.12.7 Linux/6.8.0-47-generic ubuntu/24.04.1 glibc/2.39 solver/libmamba conda-libmamba-solver/24.9.0 libmambapy/1.5.9
                UID:GID : 1000:1000
             netrc file : None
           offline mode : False
@zairving zairving added the bug Something isn't working label Oct 30, 2024
@hanbin973
Copy link

I'm experiencing the same bug. Installing jax-cuda12-pjrt with pip resolves the issue. It ships a single shared library that connects jax to xla.

@traversaro
Copy link
Contributor

Thanks for reporting the issue, as @hanbin973 reported this is related to conda-forge/jaxlib-feedstock#285 . If you need a conda-only solution, the 0.4.31 builds of jax and jaxsim are not affected by this problem.

@traversaro
Copy link
Contributor

@zairving @hanbin973 this should have been fixed by conda-forge/jaxlib-feedstock#288 (jaxlib version 0.4.34=*_201).

Can you test if the problem is solved for you? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants