From 5fdfbf35d6bd2fa463284ee21c924a6ba92a35de Mon Sep 17 00:00:00 2001 From: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> Date: Sat, 18 Nov 2023 03:33:06 +0530 Subject: [PATCH] #3443 Add various versions for `jax` and `jaxlib` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Add support for Python 3.11 on aarch64 containers 2. Keep Python 3.8 support on older version 3. Add Python 3.9–3.11 support on newer version (same as the one for point 1) 4. Add support for CPU-only Windows installation 5. Pin all versions so as to not break anything. --- setup.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index cab39020a2..e55036059a 100644 --- a/setup.py +++ b/setup.py @@ -382,15 +382,26 @@ def compile_KLU(): "pandas": [ "pandas>=1.5.0", ], + # Note: jax and jaxlib must be pinned to a specific version + # to avoid upstream breaking changes. "jax": [ - "jax==0.4.8", - "jaxlib==0.4.7", + # 0.4.18 provides support for Jax on aarch64 containers + # via the PyBaMM images on Docker Hub which come with + # Python 3.11 installed. + # It also provides support for CPU-only Jax on Windows. + "jax==0.4.18; python_version >= '3.9'", + "jaxlib==0.4.18; python_version >= '3.9'", + # Jax 0.4.13 was the last version to support Python 3.8. + # Support for CPU-only Windows was added in 0.4.13, so + # this version supports Windows too. + "jax==0.4.13; python_version < '3.9'", + "jaxlib==0.4.13; python_version < '3.9'", ], "odes": ["scikits.odes"], "all": [ "autograd>=1.6.2", "scikit-fem>=8.1.0", - "pybamm[examples,plot,cite,latexify,bpx,tqdm,pandas]" + "pybamm[examples,plot,cite,latexify,bpx,tqdm,pandas]", ], }, entry_points={