Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to newest versions of jax + jaxlib and add Windows support for JAX Solver #3550

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
9b73f8e
#3443 Add various versions for `jax` and `jaxlib`
agriyakhetarpal Nov 17, 2023
8a049d9
#3312 #3443 Build both arm64 + amd64 images for all solvers
agriyakhetarpal Nov 17, 2023
d5d22d2
Merge branch 'develop' into bump-jax-jaxlib-versions
agriyakhetarpal Nov 22, 2023
e2d849d
#3443 Install `jax` extras in the same command
agriyakhetarpal Nov 22, 2023
e17163f
#3443 Bump to latest version of `jax` and `jaxlib`
agriyakhetarpal Nov 22, 2023
05d1061
#3443 Add Windows support via nox
agriyakhetarpal Nov 22, 2023
30db13b
#3443 Install `[jax]` for the integration tests
agriyakhetarpal Nov 22, 2023
5fd45c6
#3443 Fix expression tree Jax evaluator test
agriyakhetarpal Nov 22, 2023
9103a10
Remove explainer comments about version constraints
agriyakhetarpal Nov 22, 2023
d99acee
Remove explainer comment about pinning
agriyakhetarpal Nov 22, 2023
96e059b
Merge branch 'develop' into bump-jax-jaxlib-versions
agriyakhetarpal Nov 25, 2023
c066c81
Bump `jax` and `jaxlib` versions again
agriyakhetarpal Nov 25, 2023
f229ab8
Add a condition to not install `jax` if < py3.9
agriyakhetarpal Nov 25, 2023
a47e78d
Add a CHANGELOG entry for `jax` and `jax` versions
agriyakhetarpal Nov 25, 2023
8301d26
Remove incorrect `jax` and `jaxlib` version pins
agriyakhetarpal Nov 25, 2023
3f422bd
Update changelog about breaking change for Jax solver
agriyakhetarpal Dec 7, 2023
ae9a637
#3443 Add minimal docs about Windows and Python support
agriyakhetarpal Dec 7, 2023
f41be98
Merge branch 'develop' into bump-jax-jaxlib-versions
agriyakhetarpal Dec 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 1 addition & 26 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ jobs:
echo "tag=all" >> "$GITHUB_OUTPUT"
fi

- name: Build and push Docker image to Docker Hub (no solvers)
if: matrix.build-args == 'No solvers'
- name: Build and push Docker image to Docker Hub (${{ matrix.build-args }})
uses: docker/build-push-action@v5
with:
context: .
Expand All @@ -58,29 +57,5 @@ jobs:
push: true
platforms: linux/amd64, linux/arm64

- name: Build and push Docker image to Docker Hub (with ODES and IDAKLU solvers)
if: matrix.build-args == 'ODES' || matrix.build-args == 'IDAKLU'
uses: docker/build-push-action@v5
with:
context: .
file: scripts/Dockerfile
tags: pybamm/pybamm:${{ steps.tags.outputs.tag }}
push: true
build-args: ${{ matrix.build-args }}=true
platforms: linux/amd64, linux/arm64

- name: Build and push Docker image to Docker Hub (with ALL and JAX solvers)
if: matrix.build-args == 'ALL' || matrix.build-args == 'JAX'
uses: docker/build-push-action@v5
with:
context: .
file: scripts/Dockerfile
tags: pybamm/pybamm:${{ steps.tags.outputs.tag }}
push: true
build-args: ${{ matrix.build-args }}=true
# exclude arm64 for JAX and ALL builds for now, see
# https://github.com/google/jax/issues/13608
platforms: linux/amd64

- name: List built image(s)
run: docker images
13 changes: 5 additions & 8 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ def run_coverage(session):
"""Run the coverage tests and generate an XML report."""
set_environment_variables(PYBAMM_ENV, session=session)
session.install("coverage", silent=False)
session.install("-e", ".[all]", silent=False)
session.install("-e", ".[all,jax]", silent=False)
if sys.platform != "win32":
session.install("-e", ".[odes]", silent=False)
session.install("-e", ".[jax]", silent=False)
session.run("coverage", "run", "--rcfile=.coveragerc", "run-tests.py", "--nosub")
session.run("coverage", "combine")
session.run("coverage", "xml")
Expand All @@ -73,7 +72,7 @@ def run_coverage(session):
def run_integration(session):
"""Run the integration tests."""
set_environment_variables(PYBAMM_ENV, session=session)
session.install("-e", ".[all]", silent=False)
session.install("-e", ".[all,jax]", silent=False)
if sys.platform == "linux":
session.install("-e", ".[odes]", silent=False)
session.run("python", "run-tests.py", "--integration")
Expand All @@ -90,10 +89,9 @@ def run_doctests(session):
def run_unit(session):
"""Run the unit tests."""
set_environment_variables(PYBAMM_ENV, session=session)
session.install("-e", ".[all]", silent=False)
session.install("-e", ".[all,jax]", silent=False)
if sys.platform == "linux":
session.install("-e", ".[odes]", silent=False)
session.install("-e", ".[jax]", silent=False)
session.run("python", "run-tests.py", "--unit")


Expand Down Expand Up @@ -130,17 +128,16 @@ def set_dev(session):
external=True,
)
else:
session.run(python, "-m", "pip", "install", "-e", ".[all,dev]", external=True)
session.run(python, "-m", "pip", "install", "-e", ".[all,dev,jax]", external=True)


@nox.session(name="tests")
def run_tests(session):
"""Run the unit tests and integration tests sequentially."""
set_environment_variables(PYBAMM_ENV, session=session)
session.install("-e", ".[all]", silent=False)
session.install("-e", ".[all,jax]", silent=False)
if sys.platform == "linux" or sys.platform == "darwin":
session.install("-e", ".[odes]", silent=False)
session.install("-e", ".[jax]", silent=False)
session.run("python", "run-tests.py", "--all")


Expand Down
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,17 @@ def compile_KLU():
"pandas>=1.5.0",
],
"jax": [
"jax==0.4.8",
"jaxlib==0.4.7",
"jax==0.4.20; python_version >= '3.9'",
"jaxlib==0.4.20; python_version >= '3.9'",
# The versions below can be removed once PyBaMM no longer supports python 3.8
"jax==0.4.13; python_version < '3.9'",
"jaxlib==0.4.13; python_version < '3.9'",
],
"odes": ["scikits.odes"],
"all": [
"autograd>=1.6.2",
"scikit-fem>=8.1.0",
"pybamm[examples,plot,cite,latexify,bpx,tqdm,pandas]"
"pybamm[examples,plot,cite,latexify,bpx,tqdm,pandas]",
],
},
entry_points={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def test_evaluator_jax(self):
expr = pybamm.exp(a * b)
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator(t=None, y=np.array([[2], [3]]))
self.assertEqual(result, np.exp(6))
np.testing.assert_array_almost_equal(result, np.exp(6), decimal=15)

# test a constant expression
expr = pybamm.Scalar(2) * pybamm.Scalar(3)
Expand Down
Loading