Skip to content

Commit

Permalink
Work around JAX issue in 0.4.29
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jun 14, 2024
1 parent 5e57351 commit aca5a16
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def _check(term_cls, term, term_contr_kwargs, yi):
raise ValueError

contr = ft.partial(term.contr, **term_contr_kwargs)
control_type = jax.eval_shape(contr, 0.0, 0.0)
# Work around https://github.com/google/jax/issues/21825
control_type = eqx.filter_eval_shape(contr, 0.0, 0.0)
control_type_compatible = eqx.filter_eval_shape(
better_isinstance, control_type, control_type_expected
)
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def is_vf_expensive(
],
args: Args,
) -> bool:
control_struct = jax.eval_shape(self.contr, t0, t1)
control_struct = eqx.filter_eval_shape(self.contr, t0, t1)
if sum(c.size for c in jtu.tree_leaves(control_struct)) in (0, 1):
return False
else:
Expand Down

0 comments on commit aca5a16

Please sign in to comment.