Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JaxSolver fails when using GPU support with no input parameters #3423

Merged
merged 7 commits into from
Oct 31, 2023

Conversation

jsbrittain
Copy link
Contributor

Description

When specified with input parameters, the JaxSolver will parallelise the solve across parameter sets either using asyncio (cpu) or jax-vmap (gpu). When specified without input parameters the cpu pathway continues to solve the model correctly; however, when the solver is called without an input parameter list in a gpu-enabled jax environment then the solver fails. This is because the vmap function requires at least one input argument to contain a non-empty array.

A fix is to ensure that input parameter sets of length 0 or 1 are directed towards the cpu pathway, since this parallelisation occurs mostly over parameter sets, rather than within solves.

Fixes #3422

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)
  • Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All tests pass: $ python run-tests.py --all (or $ nox -s tests)
  • The documentation builds: $ python run-tests.py --doctest (or $ nox -s doctests)

You can run integration tests, unit tests, and doctests together at once, using $ python run-tests.py --quick (or $ nox -s quick).

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

** Existing tests should cover this scenario, but depend upon GPU runners which are work-in-progress. **

@codecov
Copy link

codecov bot commented Oct 9, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (0482121) 99.58% compared to head (6cc3940) 99.58%.
Report is 2 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #3423   +/-   ##
========================================
  Coverage    99.58%   99.58%           
========================================
  Files          256      256           
  Lines        20048    20048           
========================================
  Hits         19965    19965           
  Misses          83       83           
Files Coverage Δ
pybamm/solvers/jax_solver.py 90.69% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@brosaplanella brosaplanella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks! Let's hold merging this until the wheel publishing is fixed.

@agriyakhetarpal
Copy link
Member

agriyakhetarpal commented Oct 12, 2023

Thanks for this @jsbrittain, I tested this on Windows locally and I can confirm that #3371 is fixed through the changes in this PR. I didn't know it would be as easy as this but I get it because of my general inexperience with Jax. Could you please close that issue too here?

I have opened a new issue about bumping the jax and jaxlib versions so that GPU support can be targeted across platforms and for other reasons, see #3443

Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wheel and rc situation is fixed now. The CHANGELOG line should be moved to the unreleased section and everything should be good to go.

@agriyakhetarpal
Copy link
Member

The failing doctests here should not be a worry. I would suggest setting that configuration value to False, though.

@jsbrittain
Copy link
Contributor Author

@agriyakhetarpal All done and passing. Note that some tests still appear to be quite fragile, with the Example notebooks in particular requiring three attempts to pass in this case (despite no code changes between runs).

@agriyakhetarpal
Copy link
Member

The example notebooks issue should be this one: #3415, which came after the changes made in #3198.

Co-authored-by: Saransh Chopra <[email protected]>
Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @jsbrittain! Sorry for all the release mess 😬

@agriyakhetarpal
Copy link
Member

@brosaplanella: I missed talking about it in the meeting—should this PR be included in the release given that this is a bug fix? I think #3443 ought to be taken a look at too otherwise the GPU support mentioned in our CHANGELOG for v23.9rc0 would not be available for users on some platforms (e.g., macOS with Metal requires v0.4.11—we are still on v0.4.8).

I would be happy to write a PR for that this week and incorporate @jsbrittain's suggestions on it into consideration, please let me know if that would be needed

@brosaplanella
Copy link
Member

If we can include it that would be great. Tagging @Saransh-cpp so he is aware and can provide some input.

@Saransh-cpp
Copy link
Member

I'll merge this and add it in the rc1 release.

@Saransh-cpp Saransh-cpp merged commit 138cbf2 into pybamm-team:develop Oct 31, 2023
33 checks passed
Saransh-cpp added a commit that referenced this pull request Oct 31, 2023
JaxSolver fails when using GPU support with no input parameters
rtimms added a commit that referenced this pull request May 16, 2024
* Bump to v23.9rc0

* Merge pull request #3412 from agriyakhetarpal/drop-i686-manylinux2014-support

Drop support for i686 manylinux

* Merge pull request #3413 from Saransh-cpp/improve-release-workflow

Improve release workflow, add a note, bump version manually

* Merge pull request #3436 from Saransh-cpp/fortnightly-wheels

Build wheels on the 1st and 15th of every month

* Merge pull request #3445 from pybamm-team/issue-3428-rename-exchange

#3428 exchange-current density error

* Merge pull request #3449 from pybamm-team/i3431-windows-wheels

Fix failing windows wheel builds

* Merge pull request #3456 from abillscmu/issue-3224-initial_soc

make initial soc work with half cell models

* Merge pull request #3467 from abillscmu/bugfix/initial_soh

* Merge pull request #3423 from jsbrittain/jax_gpu

JaxSolver fails when using GPU support with no input parameters

* Fix changelog

* Merge pull request #3475 from arjxn-py/fix-default-imports

Resolve default imports for optional dependencies

* Bump - `v23.9rc1`

* Fix date in CHANGELOG

* Bump version to v23.9

* Fix docs about Jax solver compatibility with Python versions (#3702)

* Ensure correct Python versions for Jax solver compatibility

* Simplify array of Python versions

Co-authored-by: Eric G. Kratz <[email protected]>

* Use different conjunction

Co-authored-by: Eric G. Kratz <[email protected]>

---------

Co-authored-by: Eric G. Kratz <[email protected]>

* Merge pull request #3706 from agriyakhetarpal/fix-pybamm-install-odes

Make `pybamm_install_odes` a bit more robust

* #3690 fix issue with skipped steps (#3708)

* #3690 fix issue with skipped steps

* #3690 changelog

* #3690 add test

* #3611 use actual cell volume for average total heating (#3707)

* #3611 use actual cell volume for average total heating

* #3611 changelog

* #3611 account for number of electrode pairs

* #3611 update variable names

* Improve the release workflow (#3737)

* Try fixing the release workflow

* Turn off safety

* Fix CHANGELOG

* Add OS

* Use regex for better matches

* Update instructions, add safety checks

* checkout to the version branch for the final release

* Bump to v24.1rc1

* #3630 fix interpolant shape error (#3761)

* #3630 fix interpolant shape error

* #3630 changelog

* Bump to v24.1rc2

* Bump to v24.1

* Fix doctests failures in scheduled tests (#3784)

Closes #3781

* Resolve broken `scikits.odes` installation on self-hosted M-series runner (#3785)

* Try fixing M-series runner tests

This is being done by adding SuiteSparse and SUNDIALS installations which might have been missing on the runner, which broke `scikits.odes`.

* Don't use Homebrew SUNDIALS, use LD_LIBRARY_PATH

* Don't use Homebrew to install SUNDIALS

* Force remove pip cache for `scikits.odes`

---------

Co-authored-by: Eric G. Kratz <[email protected]>

* add temperature dependence to MSMR model

* changelog

* fix tests

* fix example

* rob comments

* update notebook

---------

Co-authored-by: Ferran Brosa Planella <[email protected]>
Co-authored-by: Saransh Chopra <[email protected]>
Co-authored-by: Martin Robinson <[email protected]>
Co-authored-by: Agriya Khetarpal <[email protected]>
Co-authored-by: Eric G. Kratz <[email protected]>
Co-authored-by: Robert Timms <[email protected]>
Co-authored-by: Saransh-cpp <[email protected]>
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this pull request Aug 12, 2024
* Bump to v23.9rc0

* Merge pull request pybamm-team#3412 from agriyakhetarpal/drop-i686-manylinux2014-support

Drop support for i686 manylinux

* Merge pull request pybamm-team#3413 from Saransh-cpp/improve-release-workflow

Improve release workflow, add a note, bump version manually

* Merge pull request pybamm-team#3436 from Saransh-cpp/fortnightly-wheels

Build wheels on the 1st and 15th of every month

* Merge pull request pybamm-team#3445 from pybamm-team/issue-3428-rename-exchange

pybamm-team#3428 exchange-current density error

* Merge pull request pybamm-team#3449 from pybamm-team/i3431-windows-wheels

Fix failing windows wheel builds

* Merge pull request pybamm-team#3456 from abillscmu/issue-3224-initial_soc

make initial soc work with half cell models

* Merge pull request pybamm-team#3467 from abillscmu/bugfix/initial_soh

* Merge pull request pybamm-team#3423 from jsbrittain/jax_gpu

JaxSolver fails when using GPU support with no input parameters

* Fix changelog

* Merge pull request pybamm-team#3475 from arjxn-py/fix-default-imports

Resolve default imports for optional dependencies

* Bump - `v23.9rc1`

* Fix date in CHANGELOG

* Bump version to v23.9

* Fix docs about Jax solver compatibility with Python versions (pybamm-team#3702)

* Ensure correct Python versions for Jax solver compatibility

* Simplify array of Python versions

Co-authored-by: Eric G. Kratz <[email protected]>

* Use different conjunction

Co-authored-by: Eric G. Kratz <[email protected]>

---------

Co-authored-by: Eric G. Kratz <[email protected]>

* Merge pull request pybamm-team#3706 from agriyakhetarpal/fix-pybamm-install-odes

Make `pybamm_install_odes` a bit more robust

* pybamm-team#3690 fix issue with skipped steps (pybamm-team#3708)

* pybamm-team#3690 fix issue with skipped steps

* pybamm-team#3690 changelog

* pybamm-team#3690 add test

* pybamm-team#3611 use actual cell volume for average total heating (pybamm-team#3707)

* pybamm-team#3611 use actual cell volume for average total heating

* pybamm-team#3611 changelog

* pybamm-team#3611 account for number of electrode pairs

* pybamm-team#3611 update variable names

* Improve the release workflow (pybamm-team#3737)

* Try fixing the release workflow

* Turn off safety

* Fix CHANGELOG

* Add OS

* Use regex for better matches

* Update instructions, add safety checks

* checkout to the version branch for the final release

* Bump to v24.1rc1

* pybamm-team#3630 fix interpolant shape error (pybamm-team#3761)

* pybamm-team#3630 fix interpolant shape error

* pybamm-team#3630 changelog

* Bump to v24.1rc2

* Bump to v24.1

* Fix doctests failures in scheduled tests (pybamm-team#3784)

Closes pybamm-team#3781

* Resolve broken `scikits.odes` installation on self-hosted M-series runner (pybamm-team#3785)

* Try fixing M-series runner tests

This is being done by adding SuiteSparse and SUNDIALS installations which might have been missing on the runner, which broke `scikits.odes`.

* Don't use Homebrew SUNDIALS, use LD_LIBRARY_PATH

* Don't use Homebrew to install SUNDIALS

* Force remove pip cache for `scikits.odes`

---------

Co-authored-by: Eric G. Kratz <[email protected]>

* add temperature dependence to MSMR model

* changelog

* fix tests

* fix example

* rob comments

* update notebook

---------

Co-authored-by: Ferran Brosa Planella <[email protected]>
Co-authored-by: Saransh Chopra <[email protected]>
Co-authored-by: Martin Robinson <[email protected]>
Co-authored-by: Agriya Khetarpal <[email protected]>
Co-authored-by: Eric G. Kratz <[email protected]>
Co-authored-by: Robert Timms <[email protected]>
Co-authored-by: Saransh-cpp <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants