Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump jax and jaxlib versions #3443

Closed
agriyakhetarpal opened this issue Oct 12, 2023 · 2 comments · Fixed by #3550
Closed

Bump jax and jaxlib versions #3443

agriyakhetarpal opened this issue Oct 12, 2023 · 2 comments · Fixed by #3550
Assignees
Labels

Comments

@agriyakhetarpal
Copy link
Member

agriyakhetarpal commented Oct 12, 2023

Currently, we have pinned jaxlib==0.4.8 and jax=0.4.7 in the extras dependencies, but we could look into bumping the versions for a multitude of reasons:

  1. CPU-only native support for amd64 Windows is available now which will enable the use of the Jax solver on Windows (CPU-only) with Python 3.9–3.12. We can mark the dependency with a conditional in setup.py as is standard practice so that there are no conflicts based on platforms and Python versions, and continue to support a lower version for Python 3.8 until we drop support for it (Drop support for Python 3.8, support Python 3.12? #3390). More details below
  2. Availability of arm64/aarch64 wheels for Linux has been there since some minor versions already, this will be helpful for the current proposed solution for multi-platform Docker images in Multi-architecture Docker images and parallel builds #3430 and will make Jax importable in Docker on M-series macOS machines, the current versions fail due to the use of emulation which does not support AVX instructions.
  3. GPU support with Jax: this is available by compiling CUDA and CuDNN on WSL2 and installing the GPU-specific wheels from the Google indices or otherwise directly installing CUDA from the wheels. However, WSL2 requires administrator permissions to install on Windows machines. On Windows, there are some community-enabled wheels for jax with CUDA 11.8 and 12.1 with CuDNN 8.6 and 8.9 on https://whls.blob.core.windows.net/unstable/index.html and the instructions for their usage are at https://github.com/cloudhan/jax-windows-builder. On M-series macOS, GPU support is available through the Metal plugin https://developer.apple.com/metal/jax/ which was also noted in testing runner with GPU support #3274 (comment). Note: both of these are experimental at this time and potentially unstable, but can be documented as such in the installation instructions. The good things are that both of them are in sync and require the same version of jaxlib, i.e., 0.4.11, and that both of them have now been tested and all GPU-related functionalities in the unit tests seem to be working as intended.

(For more details, please see the extended discussion in the #infrastructure channel on Slack)

Therefore, the minimum version to update to could be jaxlib=0.4.11 and jax==0.4.11 with the internal version requirement being relaxed a bit in #3121. However, it is also easy to target multiple versions with conditional dependencies so as to accommodate most if not all of these requirements. Jax itself does this and this is a reference setup.py file, for example. Considering that GPU support should be kept standalone, it can be given its own extra dependency set, something like [jax-gpu], or just [gpu], considering that only the Jax solver has it.

A configuration in the setup.py file could look like this:

extras_require={
        # provides arm macOS wheels for Docker and native CPU-only support on Windows
        "jax": [
            "jax==0.4.18", python_version>'3.9','<3.12'
            "jaxlib==0.4.18",  python_version>'3.9','<3.12'
            "jax==0.4.1X" # some other minor version that includes support for
            #  Python 3.8 so that our installation does not break
        ],
      # provides GPU support for macOS
        "jax-gpu": [
            "jax==0.4.10",
            "jaxlib==0.4.11",
            "jax-metal==0.0.3"
        ],

Otherwise, given the tricky requirements for pre-compiling CUDA and CuDNN, a nox session or a helper script can be provided that can be used to fetch the correct versions with pip based on the platform.

@agriyakhetarpal
Copy link
Member Author

Some more resources, collected from the Jax CHANGELOG (https://jax.readthedocs.io/en/latest/changelog.html) and related places on the internet:

  1. Windows CPU-only wheels were first added in 0.4.13, unofficial ones existed for versions 0.4.90.4.11, Python 3.8–3.11
  2. Windows unofficial GPU wheels are still at 0.4.11, with Python 3.9–3.11 for CUDA 11.8 and Python 3.10–3.11 for CUDA 12.1
  3. Linux aarch64/arm64 wheels were added in 0.4.18, useful for our Docker images which have Python 3.9 by default (will soon be bumped to 3.11)
  4. 0.4.14 was the last version to support Python 3.8
  5. M-series macOS GPU support (Metal plugin) requires Python 3.9 and above with 0.4.11. This requires jax-metal which in-turn requires building jaxlib from source.

Therefore it would be suitable to handle just CPU support in setup.py and add instructions in the documentation about how to enable GPU support on platforms (with helper scripts and a pybamm_install_with_gpu entry point to go with them possibly in our next release?)

The final dependency resolution, with all of the above information, should look like this:

extras_require={
        "jax": [
            # Linux aarch64 wheels for images on Docker Hub, Windows CPU wheels for 3.9–3.11
            "jax==0.4.18", python_version>'3.9','<3.12' 
            "jaxlib==0.4.18",  python_version>'3.9','<3.12'
            # Last version to support Python 3.8, supports Windows too
            "jax==0.4.14", python_version=='3.8'
        ],
        ...

which should cover Windows (3.8–3.11), Linux (amd64 on 3.8–3.11, aarch64 on 3.9–3.11: there should be a way to check for the system platform for the latter), and macOS (3.8–3.11).

agriyakhetarpal added a commit to agriyakhetarpal/PyBaMM that referenced this issue Nov 17, 2023
1. Add support for Python 3.11 on aarch64 containers
2. Keep Python 3.8 support on older version
3. Add Python 3.9–3.11 support on newer version (same as the one for point 1)
4. Add support for CPU-only Windows installation
5. Pin all versions so as to not break anything.
agriyakhetarpal added a commit to agriyakhetarpal/PyBaMM that referenced this issue Nov 17, 2023
@agriyakhetarpal agriyakhetarpal self-assigned this Nov 17, 2023
@agriyakhetarpal
Copy link
Member Author

I shall work on this issue over the next week; I assigned myself this. It looks like there is a lot of repetition in the installation instructions so I'll do a rewrite of them—which I will hopefully keep to be minor—in the same PR. I am considering moving the section on "Optional solvers" to a separate page. Then, there can be just two pages (excluding the Docker instructions): one page for the user installation and another for the source installation (instead of having separate pages for GNU/Linux + macOS and for Windows).

I will also improve the content and structure in the pages where things do not seem to be clear for new users and add some more entries to the FAQ based on common issues we have noted recently, would love suggestions! (cc: @arjxn-py @Saransh-cpp). The goal should be to not let users get overwhelmed with all these different pages in such a sprawling installation guide – which is something that I admit to being when I was new to the repository last year.

agriyakhetarpal added a commit to agriyakhetarpal/PyBaMM that referenced this issue Nov 22, 2023
agriyakhetarpal added a commit to agriyakhetarpal/PyBaMM that referenced this issue Nov 22, 2023
tested with `--upgrade` and `--upgrade-strategy eager` plus `--no-cache-dir`
agriyakhetarpal added a commit to agriyakhetarpal/PyBaMM that referenced this issue Nov 22, 2023
agriyakhetarpal added a commit to agriyakhetarpal/PyBaMM that referenced this issue Nov 22, 2023
agriyakhetarpal added a commit to agriyakhetarpal/PyBaMM that referenced this issue Nov 22, 2023
agriyakhetarpal added a commit to agriyakhetarpal/PyBaMM that referenced this issue Dec 7, 2023
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this issue Aug 12, 2024
1. Add support for Python 3.11 on aarch64 containers
2. Keep Python 3.8 support on older version
3. Add Python 3.9–3.11 support on newer version (same as the one for point 1)
4. Add support for CPU-only Windows installation
5. Pin all versions so as to not break anything.
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this issue Aug 12, 2024
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this issue Aug 12, 2024
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this issue Aug 12, 2024
tested with `--upgrade` and `--upgrade-strategy eager` plus `--no-cache-dir`
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this issue Aug 12, 2024
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this issue Aug 12, 2024
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this issue Aug 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
1 participant