-
-
Notifications
You must be signed in to change notification settings - Fork 553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JaxSolver fails when using GPU support with no input parameters #3423
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #3423 +/- ##
========================================
Coverage 99.58% 99.58%
========================================
Files 256 256
Lines 20048 20048
========================================
Hits 19965 19965
Misses 83 83
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks! Let's hold merging this until the wheel publishing is fixed.
Thanks for this @jsbrittain, I tested this on Windows locally and I can confirm that #3371 is fixed through the changes in this PR. I didn't know it would be as easy as this but I get it because of my general inexperience with Jax. Could you please close that issue too here? I have opened a new issue about bumping the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wheel and rc situation is fixed now. The CHANGELOG line should be moved to the unreleased section and everything should be good to go.
The failing doctests here should not be a worry. I would suggest setting that configuration value to |
@agriyakhetarpal All done and passing. Note that some tests still appear to be quite fragile, with the Example notebooks in particular requiring three attempts to pass in this case (despite no code changes between runs). |
Co-authored-by: Saransh Chopra <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @jsbrittain! Sorry for all the release mess 😬
@brosaplanella: I missed talking about it in the meeting—should this PR be included in the release given that this is a bug fix? I think #3443 ought to be taken a look at too otherwise the GPU support mentioned in our CHANGELOG for v23.9rc0 would not be available for users on some platforms (e.g., macOS with Metal requires I would be happy to write a PR for that this week and incorporate @jsbrittain's suggestions on it into consideration, please let me know if that would be needed |
If we can include it that would be great. Tagging @Saransh-cpp so he is aware and can provide some input. |
I'll merge this and add it in the rc1 release. |
JaxSolver fails when using GPU support with no input parameters
* Bump to v23.9rc0 * Merge pull request #3412 from agriyakhetarpal/drop-i686-manylinux2014-support Drop support for i686 manylinux * Merge pull request #3413 from Saransh-cpp/improve-release-workflow Improve release workflow, add a note, bump version manually * Merge pull request #3436 from Saransh-cpp/fortnightly-wheels Build wheels on the 1st and 15th of every month * Merge pull request #3445 from pybamm-team/issue-3428-rename-exchange #3428 exchange-current density error * Merge pull request #3449 from pybamm-team/i3431-windows-wheels Fix failing windows wheel builds * Merge pull request #3456 from abillscmu/issue-3224-initial_soc make initial soc work with half cell models * Merge pull request #3467 from abillscmu/bugfix/initial_soh * Merge pull request #3423 from jsbrittain/jax_gpu JaxSolver fails when using GPU support with no input parameters * Fix changelog * Merge pull request #3475 from arjxn-py/fix-default-imports Resolve default imports for optional dependencies * Bump - `v23.9rc1` * Fix date in CHANGELOG * Bump version to v23.9 * Fix docs about Jax solver compatibility with Python versions (#3702) * Ensure correct Python versions for Jax solver compatibility * Simplify array of Python versions Co-authored-by: Eric G. Kratz <[email protected]> * Use different conjunction Co-authored-by: Eric G. Kratz <[email protected]> --------- Co-authored-by: Eric G. Kratz <[email protected]> * Merge pull request #3706 from agriyakhetarpal/fix-pybamm-install-odes Make `pybamm_install_odes` a bit more robust * #3690 fix issue with skipped steps (#3708) * #3690 fix issue with skipped steps * #3690 changelog * #3690 add test * #3611 use actual cell volume for average total heating (#3707) * #3611 use actual cell volume for average total heating * #3611 changelog * #3611 account for number of electrode pairs * #3611 update variable names * Improve the release workflow (#3737) * Try fixing the release workflow * Turn off safety * Fix CHANGELOG * Add OS * Use regex for better matches * Update instructions, add safety checks * checkout to the version branch for the final release * Bump to v24.1rc1 * #3630 fix interpolant shape error (#3761) * #3630 fix interpolant shape error * #3630 changelog * Bump to v24.1rc2 * Bump to v24.1 * Fix doctests failures in scheduled tests (#3784) Closes #3781 * Resolve broken `scikits.odes` installation on self-hosted M-series runner (#3785) * Try fixing M-series runner tests This is being done by adding SuiteSparse and SUNDIALS installations which might have been missing on the runner, which broke `scikits.odes`. * Don't use Homebrew SUNDIALS, use LD_LIBRARY_PATH * Don't use Homebrew to install SUNDIALS * Force remove pip cache for `scikits.odes` --------- Co-authored-by: Eric G. Kratz <[email protected]> * add temperature dependence to MSMR model * changelog * fix tests * fix example * rob comments * update notebook --------- Co-authored-by: Ferran Brosa Planella <[email protected]> Co-authored-by: Saransh Chopra <[email protected]> Co-authored-by: Martin Robinson <[email protected]> Co-authored-by: Agriya Khetarpal <[email protected]> Co-authored-by: Eric G. Kratz <[email protected]> Co-authored-by: Robert Timms <[email protected]> Co-authored-by: Saransh-cpp <[email protected]>
* Bump to v23.9rc0 * Merge pull request pybamm-team#3412 from agriyakhetarpal/drop-i686-manylinux2014-support Drop support for i686 manylinux * Merge pull request pybamm-team#3413 from Saransh-cpp/improve-release-workflow Improve release workflow, add a note, bump version manually * Merge pull request pybamm-team#3436 from Saransh-cpp/fortnightly-wheels Build wheels on the 1st and 15th of every month * Merge pull request pybamm-team#3445 from pybamm-team/issue-3428-rename-exchange pybamm-team#3428 exchange-current density error * Merge pull request pybamm-team#3449 from pybamm-team/i3431-windows-wheels Fix failing windows wheel builds * Merge pull request pybamm-team#3456 from abillscmu/issue-3224-initial_soc make initial soc work with half cell models * Merge pull request pybamm-team#3467 from abillscmu/bugfix/initial_soh * Merge pull request pybamm-team#3423 from jsbrittain/jax_gpu JaxSolver fails when using GPU support with no input parameters * Fix changelog * Merge pull request pybamm-team#3475 from arjxn-py/fix-default-imports Resolve default imports for optional dependencies * Bump - `v23.9rc1` * Fix date in CHANGELOG * Bump version to v23.9 * Fix docs about Jax solver compatibility with Python versions (pybamm-team#3702) * Ensure correct Python versions for Jax solver compatibility * Simplify array of Python versions Co-authored-by: Eric G. Kratz <[email protected]> * Use different conjunction Co-authored-by: Eric G. Kratz <[email protected]> --------- Co-authored-by: Eric G. Kratz <[email protected]> * Merge pull request pybamm-team#3706 from agriyakhetarpal/fix-pybamm-install-odes Make `pybamm_install_odes` a bit more robust * pybamm-team#3690 fix issue with skipped steps (pybamm-team#3708) * pybamm-team#3690 fix issue with skipped steps * pybamm-team#3690 changelog * pybamm-team#3690 add test * pybamm-team#3611 use actual cell volume for average total heating (pybamm-team#3707) * pybamm-team#3611 use actual cell volume for average total heating * pybamm-team#3611 changelog * pybamm-team#3611 account for number of electrode pairs * pybamm-team#3611 update variable names * Improve the release workflow (pybamm-team#3737) * Try fixing the release workflow * Turn off safety * Fix CHANGELOG * Add OS * Use regex for better matches * Update instructions, add safety checks * checkout to the version branch for the final release * Bump to v24.1rc1 * pybamm-team#3630 fix interpolant shape error (pybamm-team#3761) * pybamm-team#3630 fix interpolant shape error * pybamm-team#3630 changelog * Bump to v24.1rc2 * Bump to v24.1 * Fix doctests failures in scheduled tests (pybamm-team#3784) Closes pybamm-team#3781 * Resolve broken `scikits.odes` installation on self-hosted M-series runner (pybamm-team#3785) * Try fixing M-series runner tests This is being done by adding SuiteSparse and SUNDIALS installations which might have been missing on the runner, which broke `scikits.odes`. * Don't use Homebrew SUNDIALS, use LD_LIBRARY_PATH * Don't use Homebrew to install SUNDIALS * Force remove pip cache for `scikits.odes` --------- Co-authored-by: Eric G. Kratz <[email protected]> * add temperature dependence to MSMR model * changelog * fix tests * fix example * rob comments * update notebook --------- Co-authored-by: Ferran Brosa Planella <[email protected]> Co-authored-by: Saransh Chopra <[email protected]> Co-authored-by: Martin Robinson <[email protected]> Co-authored-by: Agriya Khetarpal <[email protected]> Co-authored-by: Eric G. Kratz <[email protected]> Co-authored-by: Robert Timms <[email protected]> Co-authored-by: Saransh-cpp <[email protected]>
Description
When specified with input parameters, the JaxSolver will parallelise the solve across parameter sets either using asyncio (cpu) or jax-vmap (gpu). When specified without input parameters the cpu pathway continues to solve the model correctly; however, when the solver is called without an input parameter list in a gpu-enabled jax environment then the solver fails. This is because the vmap function requires at least one input argument to contain a non-empty array.
A fix is to ensure that input parameter sets of length 0 or 1 are directed towards the cpu pathway, since this parallelisation occurs mostly over parameter sets, rather than within solves.
Fixes #3422
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.
Key checklist:
$ pre-commit run
(or$ nox -s pre-commit
) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)$ python run-tests.py --all
(or$ nox -s tests
)$ python run-tests.py --doctest
(or$ nox -s doctests
)You can run integration tests, unit tests, and doctests together at once, using
$ python run-tests.py --quick
(or$ nox -s quick
).Further checks:
** Existing tests should cover this scenario, but depend upon GPU runners which are work-in-progress. **