From cab9e1e49668480b64191d28e56004223236bba9 Mon Sep 17 00:00:00 2001 From: Brady Planden Date: Fri, 29 Sep 2023 10:12:22 +0100 Subject: [PATCH] Add Jax metal support, Jax FP64 conditions --- pybamm/expression_tree/operations/evaluate_python.py | 4 +++- pybamm/solvers/jax_bdf_solver.py | 4 +++- pybamm/solvers/jax_solver.py | 6 +++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index ae17a333ec..1f44a69784 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -13,7 +13,9 @@ import jax from jax.config import config - config.update("jax_enable_x64", True) + platform = jax.lib.xla_bridge.get_backend().platform.casefold() + if platform != "metal": + config.update("jax_enable_x64", True) class JaxCooMatrix: diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index b69744dd08..2f334ed8ec 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -18,7 +18,9 @@ from jax.tree_util import tree_flatten, tree_map, tree_unflatten from jax.util import cache, safe_map, split_list - config.update("jax_enable_x64", True) + platform = jax.lib.xla_bridge.get_backend().platform.casefold() + if platform != "metal": + config.update("jax_enable_x64", True) MAX_ORDER = 5 NEWTON_MAXITER = 4 diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py index 8e7b1b5cc5..b1928fa82a 100644 --- a/pybamm/solvers/jax_solver.py +++ b/pybamm/solvers/jax_solver.py @@ -227,7 +227,11 @@ async def solve_model_async(inputs_v): return await asyncio.gather(*coro) y = asyncio.run(solve_model_for_inputs()) - elif platform.startswith("gpu") or platform.startswith("tpu"): + elif ( + platform.startswith("gpu") + or platform.startswith("tpu") + or platform.startswith("metal") + ): # gpu execution runs faster when parallelised with vmap # (see also comment below regarding single-program multiple-data # execution (SPMD) using pmap on multiple XLAs)