From aca5a1643dae79eee77e51bd951d2d4a09dbd565 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:49:33 +0200 Subject: [PATCH] Work around JAX issue in 0.4.29 --- diffrax/_integrate.py | 3 ++- diffrax/_term.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 7c53e50c..6cb25140 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -144,7 +144,8 @@ def _check(term_cls, term, term_contr_kwargs, yi): raise ValueError contr = ft.partial(term.contr, **term_contr_kwargs) - control_type = jax.eval_shape(contr, 0.0, 0.0) + # Work around https://github.com/google/jax/issues/21825 + control_type = eqx.filter_eval_shape(contr, 0.0, 0.0) control_type_compatible = eqx.filter_eval_shape( better_isinstance, control_type, control_type_expected ) diff --git a/diffrax/_term.py b/diffrax/_term.py index e61ca751..4fe4e9bc 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -493,7 +493,7 @@ def is_vf_expensive( ], args: Args, ) -> bool: - control_struct = jax.eval_shape(self.contr, t0, t1) + control_struct = eqx.filter_eval_shape(self.contr, t0, t1) if sum(c.size for c in jtu.tree_leaves(control_struct)) in (0, 1): return False else: