-
-
Notifications
You must be signed in to change notification settings - Fork 553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bump jax
and jaxlib
versions
#3443
Comments
Some more resources, collected from the Jax CHANGELOG (https://jax.readthedocs.io/en/latest/changelog.html) and related places on the internet:
Therefore it would be suitable to handle just CPU support in The final dependency resolution, with all of the above information, should look like this: extras_require={
"jax": [
# Linux aarch64 wheels for images on Docker Hub, Windows CPU wheels for 3.9–3.11
"jax==0.4.18", python_version>'3.9','<3.12'
"jaxlib==0.4.18", python_version>'3.9','<3.12'
# Last version to support Python 3.8, supports Windows too
"jax==0.4.14", python_version=='3.8'
],
... which should cover Windows (3.8–3.11), Linux (amd64 on 3.8–3.11, aarch64 on 3.9–3.11: there should be a way to check for the system platform for the latter), and macOS (3.8–3.11). |
1. Add support for Python 3.11 on aarch64 containers 2. Keep Python 3.8 support on older version 3. Add Python 3.9–3.11 support on newer version (same as the one for point 1) 4. Add support for CPU-only Windows installation 5. Pin all versions so as to not break anything.
I shall work on this issue over the next week; I assigned myself this. It looks like there is a lot of repetition in the installation instructions so I'll do a rewrite of them—which I will hopefully keep to be minor—in the same PR. I am considering moving the section on "Optional solvers" to a separate page. Then, there can be just two pages (excluding the Docker instructions): one page for the user installation and another for the source installation (instead of having separate pages for GNU/Linux + macOS and for Windows). I will also improve the content and structure in the pages where things do not seem to be clear for new users and add some more entries to the FAQ based on common issues we have noted recently, would love suggestions! (cc: @arjxn-py @Saransh-cpp). The goal should be to not let users get overwhelmed with all these different pages in such a sprawling installation guide – which is something that I admit to being when I was new to the repository last year. |
tested with `--upgrade` and `--upgrade-strategy eager` plus `--no-cache-dir`
1. Add support for Python 3.11 on aarch64 containers 2. Keep Python 3.8 support on older version 3. Add Python 3.9–3.11 support on newer version (same as the one for point 1) 4. Add support for CPU-only Windows installation 5. Pin all versions so as to not break anything.
tested with `--upgrade` and `--upgrade-strategy eager` plus `--no-cache-dir`
Currently, we have pinned
jaxlib==0.4.8
andjax=0.4.7
in the extras dependencies, but we could look into bumping the versions for a multitude of reasons:3.8
, support Python3.12
? #3390). More details belowjax
with CUDA 11.8 and 12.1 with CuDNN 8.6 and 8.9 on https://whls.blob.core.windows.net/unstable/index.html and the instructions for their usage are at https://github.com/cloudhan/jax-windows-builder. On M-series macOS, GPU support is available through the Metal plugin https://developer.apple.com/metal/jax/ which was also noted in testing runner with GPU support #3274 (comment). Note: both of these are experimental at this time and potentially unstable, but can be documented as such in the installation instructions. The good things are that both of them are in sync and require the same version ofjaxlib
, i.e.,0.4.11
, and that both of them have now been tested and all GPU-related functionalities in the unit tests seem to be working as intended.(For more details, please see the extended discussion in the
#infrastructure
channel on Slack)Therefore, the minimum version to update to could be
jaxlib=0.4.11
andjax==0.4.11
with the internal version requirement being relaxed a bit in #3121. However, it is also easy to target multiple versions with conditional dependencies so as to accommodate most if not all of these requirements. Jax itself does this and this is a reference setup.py file, for example. Considering that GPU support should be kept standalone, it can be given its own extra dependency set, something like[jax-gpu]
, or just[gpu]
, considering that only the Jax solver has it.A configuration in the
setup.py
file could look like this:Otherwise, given the tricky requirements for pre-compiling CUDA and CuDNN, a
nox
session or a helper script can be provided that can be used to fetch the correct versions withpip
based on the platform.The text was updated successfully, but these errors were encountered: