diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 3d85ec69..fb4f20ec 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -107,7 +107,7 @@ def get_ys(_final_state): return final_state -class AbstractAdjoint(eqx.Module): +class AbstractAdjoint(eqx.Module, strict=True): """Abstract base class for all adjoint methods.""" @abc.abstractmethod @@ -167,7 +167,7 @@ def _uncallable(*args, **kwargs): assert False -class RecursiveCheckpointAdjoint(AbstractAdjoint): +class RecursiveCheckpointAdjoint(AbstractAdjoint, strict=True): """Backpropagate through [`diffrax.diffeqsolve`][] by differentiating the numerical solution directly. This is sometimes known as "discretise-then-optimise", or described as "backpropagation through the solver". @@ -318,7 +318,7 @@ def loop( """ -class DirectAdjoint(AbstractAdjoint): +class DirectAdjoint(AbstractAdjoint, strict=True): """A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that `DirectAdjoint`: @@ -434,7 +434,7 @@ def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]: return frozenset(iter_x) -class ImplicitAdjoint(AbstractAdjoint): +class ImplicitAdjoint(AbstractAdjoint, strict=True): r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem). This is used when solving towards a steady state, typically using @@ -705,7 +705,7 @@ def __get(__aug): return a_y1, a_diff_args1, a_diff_terms1 -class BacksolveAdjoint(AbstractAdjoint): +class BacksolveAdjoint(AbstractAdjoint, strict=True): """Backpropagate through [`diffrax.diffeqsolve`][] by solving the continuous adjoint equations backwards-in-time. This is also sometimes known as "optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint diff --git a/diffrax/_brownian/base.py b/diffrax/_brownian/base.py index 96a07253..0833fdaa 100644 --- a/diffrax/_brownian/base.py +++ b/diffrax/_brownian/base.py @@ -8,7 +8,7 @@ from .._path import AbstractPath -class AbstractBrownianPath(AbstractPath): +class AbstractBrownianPath(AbstractPath, strict=True): """Abstract base class for all Brownian paths.""" levy_area: AbstractVar[LevyArea] diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index c88e9ee2..e773124c 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -19,7 +19,7 @@ from .base import AbstractBrownianPath -class UnsafeBrownianPath(AbstractBrownianPath): +class UnsafeBrownianPath(AbstractBrownianPath, strict=True): """Brownian simulation that is only suitable for certain cases. This is a very quick way to simulate Brownian motion, but can only be used when all diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index bf972526..6a4ebce0 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -60,7 +60,7 @@ _Spline: TypeAlias = Literal["sqrt", "quad", "zero"] -class _State(eqx.Module): +class _State(eqx.Module, strict=True): level: IntScalarLike # level of the tree s: RealScalarLike # starting time of the interval w_s_u_su: FloatTriple # W_s, W_u, W_{s,u} @@ -109,7 +109,7 @@ def _split_interval( return x_s, x_u, x_su -class VirtualBrownianTree(AbstractBrownianPath): +class VirtualBrownianTree(AbstractBrownianPath, strict=True): """Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`. Can be initialised with `levy_area` set to `""`, or `"space-time"`. diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index ffccf64a..373995b7 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -56,7 +56,7 @@ LevyArea: TypeAlias = Literal["", "space-time"] -class LevyVal(eqx.Module): +class LevyVal(eqx.Module, strict=True): dt: PyTree W: PyTree H: Optional[PyTree] diff --git a/diffrax/_event.py b/diffrax/_event.py index a596d8ed..b4b42fe3 100644 --- a/diffrax/_event.py +++ b/diffrax/_event.py @@ -10,7 +10,7 @@ from ._step_size_controller import AbstractAdaptiveStepSizeController -class AbstractDiscreteTerminatingEvent(eqx.Module): +class AbstractDiscreteTerminatingEvent(eqx.Module, strict=True): """Evaluated at the end of each integration step. If true then the solve is stopped at that time. """ @@ -30,7 +30,7 @@ def __call__(self, state, **kwargs) -> BoolScalarLike: """ -class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent): +class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent, strict=True): """Terminates the solve if its condition is ever active.""" cond_fn: Callable[..., BoolScalarLike] @@ -50,7 +50,7 @@ def __call__(self, state, **kwargs): """ -class SteadyStateEvent(AbstractDiscreteTerminatingEvent): +class SteadyStateEvent(AbstractDiscreteTerminatingEvent, strict=True): """Terminates the solve once it reaches a steady state.""" rtol: Optional[float] = None diff --git a/diffrax/_global_interpolation.py b/diffrax/_global_interpolation.py index 9b0f2fe7..a80fce48 100644 --- a/diffrax/_global_interpolation.py +++ b/diffrax/_global_interpolation.py @@ -23,7 +23,7 @@ from ._path import AbstractPath -class AbstractGlobalInterpolation(AbstractPath): +class AbstractGlobalInterpolation(AbstractPath, strict=True): ts: AbstractVar[Real[Array, " times"]] ts_size: AbstractVar[IntScalarLike] @@ -52,7 +52,7 @@ def t1(self): return self.ts[-1] -class LinearInterpolation(AbstractGlobalInterpolation): +class LinearInterpolation(AbstractGlobalInterpolation, strict=True): """Linearly interpolates some data `ys` over the interval $[t_0, t_1]$ with knots at `ts`. @@ -178,7 +178,7 @@ def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]: """ -class CubicInterpolation(AbstractGlobalInterpolation): +class CubicInterpolation(AbstractGlobalInterpolation, strict=True): """Piecewise cubic spline interpolation over the interval $[t_0, t_1]$.""" ts: Real[Array, " times"] @@ -302,7 +302,7 @@ def derivative( """ -class DenseInterpolation(AbstractGlobalInterpolation): +class DenseInterpolation(AbstractGlobalInterpolation, strict=True): ts: Real[Array, " times"] # DenseInterpolations typically get `ts` and `infos` that are way longer than they # need to be, and padded with `nan`s. This means the normal way of measuring how diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index b6501e41..d9109b3f 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -47,14 +47,14 @@ from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm -class SaveState(eqx.Module): +class SaveState(eqx.Module, strict=True): saveat_ts_index: IntScalarLike ts: eqxi.MaybeBuffer[Real[Array, " times"]] ys: PyTree[eqxi.MaybeBuffer[Inexact[Array, "times ..."]]] save_index: IntScalarLike -class State(eqx.Module): +class State(eqx.Module, strict=True): # Evolving state during the solve y: PyTree[Array] tprev: FloatScalarLike diff --git a/diffrax/_local_interpolation.py b/diffrax/_local_interpolation.py index 1969cd87..6bba4b3f 100644 --- a/diffrax/_local_interpolation.py +++ b/diffrax/_local_interpolation.py @@ -1,5 +1,6 @@ from typing import Optional, TYPE_CHECKING +import equinox as eqx import jax.numpy as jnp import jax.tree_util as jtu import numpy as np @@ -17,11 +18,11 @@ from ._path import AbstractPath -class AbstractLocalInterpolation(AbstractPath): +class AbstractLocalInterpolation(AbstractPath, strict=True): pass -class LocalLinearInterpolation(AbstractLocalInterpolation): +class LocalLinearInterpolation(AbstractLocalInterpolation, strict=True): t0: RealScalarLike t1: RealScalarLike y0: Y @@ -39,7 +40,7 @@ def evaluate( return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω -class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation): +class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation, strict=True): t0: RealScalarLike t1: RealScalarLike coeffs: PyTree[Shaped[Array, "4 ?*dims"], "Y"] @@ -83,7 +84,9 @@ def _eval(_coeffs): return jtu.tree_map(_eval, self.coeffs) -class FourthOrderPolynomialInterpolation(AbstractLocalInterpolation): +class FourthOrderPolynomialInterpolation( + AbstractLocalInterpolation, strict=eqx.StrictConfig(allow_abstract_name=True) +): t0: RealScalarLike t1: RealScalarLike coeffs: PyTree[Shaped[Array, "5 ?*y"], "Y"] diff --git a/diffrax/_path.py b/diffrax/_path.py index c7b90a3b..92f78945 100644 --- a/diffrax/_path.py +++ b/diffrax/_path.py @@ -15,7 +15,7 @@ from ._custom_types import RealScalarLike -class AbstractPath(eqx.Module): +class AbstractPath(eqx.Module, strict=True): """Abstract base class for all paths. Every path has a start point `t0` and an end point `t1`. In between these values diff --git a/diffrax/_root_finder/_verychord.py b/diffrax/_root_finder/_verychord.py index 4bce16a5..899e3063 100644 --- a/diffrax/_root_finder/_verychord.py +++ b/diffrax/_root_finder/_verychord.py @@ -30,7 +30,7 @@ def _converged(factor: Scalar, tol: float) -> Bool[Array, ""]: return (factor > 0) & (factor < tol) -class _VeryChordState(eqx.Module): +class _VeryChordState(eqx.Module, strict=True): linear_state: tuple[lx.AbstractLinearOperator, PyTree[Any]] diff: Y diffsize: Scalar @@ -39,7 +39,7 @@ class _VeryChordState(eqx.Module): step: Scalar -class _NoAux(eqx.Module): +class _NoAux(eqx.Module, strict=True): fn: Callable def __call__(self, y, args): @@ -48,7 +48,7 @@ def __call__(self, y, args): return out -class VeryChord(optx.AbstractRootFinder): +class VeryChord(optx.AbstractRootFinder, strict=True): """The Chord method of root finding. As `optimistix.Chord`, except that in Runge--Kutta methods, the linearisation point diff --git a/diffrax/_saveat.py b/diffrax/_saveat.py index 6ee373de..c8cb3044 100644 --- a/diffrax/_saveat.py +++ b/diffrax/_saveat.py @@ -21,7 +21,7 @@ def _convert_ts( return jnp.asarray(ts) -class SubSaveAt(eqx.Module): +class SubSaveAt(eqx.Module, strict=True): """Used for finer-grained control over what is saved. A PyTree of these should be passed to `SaveAt(subs=...)`. @@ -53,7 +53,7 @@ def __check_init__(self): """ -class SaveAt(eqx.Module): +class SaveAt(eqx.Module, strict=True): """Determines what to save as output from the differential equation solve. Instances of this class should be passed as the `saveat` argument of diff --git a/diffrax/_solution.py b/diffrax/_solution.py index b1f7e322..965e61c0 100644 --- a/diffrax/_solution.py +++ b/diffrax/_solution.py @@ -1,5 +1,6 @@ from typing import Any, Optional +import equinox as eqx import jax import optimistix as optx from jaxtyping import Array, Bool, PyTree, Real, Shaped @@ -55,7 +56,7 @@ def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS: return RESULTS.where(pred, old_result, out_result) -class Solution(AbstractPath): +class Solution(AbstractPath, strict=eqx.StrictConfig(allow_method_override=True)): """The solution to a differential equation. **Attributes:** diff --git a/diffrax/_solver/base.py b/diffrax/_solver/base.py index 3b04bfa4..4cdced19 100644 --- a/diffrax/_solver/base.py +++ b/diffrax/_solver/base.py @@ -49,7 +49,7 @@ def __instancecheck__(cls, obj): _set_metaclass = dict(metaclass=_MetaAbstractSolver) -class AbstractSolver(eqx.Module, Generic[_SolverState], **_set_metaclass): +class AbstractSolver(eqx.Module, Generic[_SolverState], strict=True, **_set_metaclass): """Abstract base class for all differential equation solvers. Subclasses should have a class-level attribute `terms`, specifying the PyTree @@ -179,7 +179,7 @@ def func( """ -class AbstractImplicitSolver(AbstractSolver[_SolverState]): +class AbstractImplicitSolver(AbstractSolver[_SolverState], strict=True): """Indicates that this is an implicit differential equation solver, and as such that it should take a root finder as an argument. """ @@ -188,25 +188,25 @@ class AbstractImplicitSolver(AbstractSolver[_SolverState]): root_find_max_steps: AbstractVar[int] -class AbstractItoSolver(AbstractSolver[_SolverState]): +class AbstractItoSolver(AbstractSolver[_SolverState], strict=True): """Indicates that when used as an SDE solver that this solver will converge to the Itô solution. """ -class AbstractStratonovichSolver(AbstractSolver[_SolverState]): +class AbstractStratonovichSolver(AbstractSolver[_SolverState], strict=True): """Indicates that when used as an SDE solver that this solver will converge to the Stratonovich solution. """ -class AbstractAdaptiveSolver(AbstractSolver[_SolverState]): +class AbstractAdaptiveSolver(AbstractSolver[_SolverState], strict=True): """Indicates that this solver provides error estimates, and that as such it may be used with an adaptive step size controller. """ -class AbstractWrappedSolver(AbstractSolver[_SolverState]): +class AbstractWrappedSolver(AbstractSolver[_SolverState], strict=True): """Wraps another solver "transparently", in the sense that all `isinstance` checks will be forwarded on to the wrapped solver, e.g. when testing whether the solver is implicit/adaptive/SDE-compatible/etc. @@ -219,7 +219,9 @@ class if that is not desired behaviour.) class HalfSolver( - AbstractAdaptiveSolver[_SolverState], AbstractWrappedSolver[_SolverState] + AbstractAdaptiveSolver[_SolverState], + AbstractWrappedSolver[_SolverState], + strict=eqx.StrictConfig(allow_method_override=True), ): """Wraps another solver, trading cost in order to provide error estimates. (That is, it means the solver can be used with an adaptive step size controller, diff --git a/diffrax/_solver/bosh3.py b/diffrax/_solver/bosh3.py index 2659b207..af766fb0 100644 --- a/diffrax/_solver/bosh3.py +++ b/diffrax/_solver/bosh3.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation @@ -19,7 +20,7 @@ ) -class Bosh3(AbstractERK): +class Bosh3(AbstractERK, strict=eqx.StrictConfig(allow_method_override=True)): """Bogacki--Shampine's 3/2 method. 3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for @@ -29,6 +30,8 @@ class Bosh3(AbstractERK): Also sometimes known as "Ralston's third order method". """ + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None + tableau: ClassVar[ButcherTableau] = _bosh3_tableau interpolation_cls: ClassVar[ Callable[..., ThirdOrderHermitePolynomialInterpolation] diff --git a/diffrax/_solver/dopri5.py b/diffrax/_solver/dopri5.py index e6ebff3b..a2fac03e 100644 --- a/diffrax/_solver/dopri5.py +++ b/diffrax/_solver/dopri5.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np from .._local_interpolation import FourthOrderPolynomialInterpolation @@ -32,7 +33,7 @@ ) -class _Dopri5Interpolation(FourthOrderPolynomialInterpolation): +class _Dopri5Interpolation(FourthOrderPolynomialInterpolation, strict=True): c_mid: ClassVar[np.ndarray] = np.array( [ 6025192743 / 30085553152 / 2, @@ -50,7 +51,7 @@ class _Dopri5Interpolation(FourthOrderPolynomialInterpolation): # https://www.sciencedirect.com/science/article/pii/0898122196001411 # ("An Efficient Runge--Kutta (4, 5) pair", Bogacki and Shampine 1996) # Which they claim is slightly more efficient than the one we have here. -class Dopri5(AbstractERK): +class Dopri5(AbstractERK, strict=eqx.StrictConfig(allow_method_override=True)): r"""Dormand-Prince's 5/4 method. 5th order Runge--Kutta method. Has an embedded 4th order method for adaptive step @@ -89,6 +90,8 @@ class Dopri5(AbstractERK): ``` """ + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None + tableau: ClassVar[ButcherTableau] = _dopri5_tableau interpolation_cls: ClassVar[ Callable[..., _Dopri5Interpolation] diff --git a/diffrax/_solver/dopri8.py b/diffrax/_solver/dopri8.py index ab9ef64a..71fca652 100644 --- a/diffrax/_solver/dopri8.py +++ b/diffrax/_solver/dopri8.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar, Optional +from typing import ClassVar, Literal, Optional, Union +import equinox as eqx import jax import jax.numpy as jnp import numpy as np @@ -188,7 +189,7 @@ _vmap_polyval = jax.vmap(jnp.polyval, in_axes=(0, None)) -class _Dopri8Interpolation(AbstractLocalInterpolation): +class _Dopri8Interpolation(AbstractLocalInterpolation, strict=True): t0: RealScalarLike t1: RealScalarLike y0: Y @@ -301,7 +302,7 @@ def evaluate( return (self.y0**ω + vector_tree_dot(coeffs, self.k) ** ω).ω -class Dopri8(AbstractERK): +class Dopri8(AbstractERK, strict=eqx.StrictConfig(allow_method_override=True)): """Dormand--Prince's 8/7 method. 8th order Runge--Kutta method. Has an embedded 7th order method for adaptive step @@ -337,6 +338,8 @@ class Dopri8(AbstractERK): ``` """ + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None + tableau: ClassVar[ButcherTableau] = _dopri8_tableau interpolation_cls: ClassVar[ Callable[..., _Dopri8Interpolation] diff --git a/diffrax/_solver/euler.py b/diffrax/_solver/euler.py index c38642e9..adf07867 100644 --- a/diffrax/_solver/euler.py +++ b/diffrax/_solver/euler.py @@ -2,6 +2,7 @@ from typing import ClassVar from typing_extensions import TypeAlias +import equinox as eqx from equinox.internal import ω from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y @@ -15,7 +16,7 @@ _SolverState: TypeAlias = None -class Euler(AbstractItoSolver): +class Euler(AbstractItoSolver, strict=eqx.StrictConfig(allow_method_override=True)): """Euler's method. 1st order explicit Runge--Kutta method. Does not support adaptive step sizing. Uses diff --git a/diffrax/_solver/euler_heun.py b/diffrax/_solver/euler_heun.py index 88b99776..a62a9666 100644 --- a/diffrax/_solver/euler_heun.py +++ b/diffrax/_solver/euler_heun.py @@ -2,6 +2,7 @@ from typing import ClassVar from typing_extensions import TypeAlias +import equinox as eqx from equinox.internal import ω from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y @@ -15,7 +16,9 @@ _SolverState: TypeAlias = None -class EulerHeun(AbstractStratonovichSolver): +class EulerHeun( + AbstractStratonovichSolver, strict=eqx.StrictConfig(allow_method_override=True) +): """Euler-Heun method. Uses a 1st order local linear interpolation scheme for dense/ts output. diff --git a/diffrax/_solver/heun.py b/diffrax/_solver/heun.py index 3b2efac8..62804a70 100644 --- a/diffrax/_solver/heun.py +++ b/diffrax/_solver/heun.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation @@ -16,7 +17,11 @@ ) -class Heun(AbstractERK, AbstractStratonovichSolver): +class Heun( + AbstractERK, + AbstractStratonovichSolver, + strict=eqx.StrictConfig(allow_method_override=True), +): """Heun's method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive @@ -32,6 +37,8 @@ class Heun(AbstractERK, AbstractStratonovichSolver): When used to solve SDEs, converges to the Stratonovich solution. """ + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None + tableau: ClassVar[ButcherTableau] = _heun_tableau interpolation_cls: ClassVar[ Callable[..., ThirdOrderHermitePolynomialInterpolation] diff --git a/diffrax/_solver/implicit_euler.py b/diffrax/_solver/implicit_euler.py index eb3bdb00..20a60613 100644 --- a/diffrax/_solver/implicit_euler.py +++ b/diffrax/_solver/implicit_euler.py @@ -2,6 +2,7 @@ from typing import ClassVar from typing_extensions import TypeAlias +import equinox as eqx import optimistix as optx from equinox.internal import ω @@ -23,7 +24,11 @@ def _implicit_relation(z1, nonlinear_solve_args): return diff -class ImplicitEuler(AbstractImplicitSolver, AbstractAdaptiveSolver): +class ImplicitEuler( + AbstractImplicitSolver, + AbstractAdaptiveSolver, + strict=eqx.StrictConfig(allow_method_override=True), +): r"""Implicit Euler method. A-B-L stable 1st order SDIRK method. Has an embedded 2nd order Heun method for diff --git a/diffrax/_solver/kencarp3.py b/diffrax/_solver/kencarp3.py index d24546fd..f6abad1f 100644 --- a/diffrax/_solver/kencarp3.py +++ b/diffrax/_solver/kencarp3.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar, Optional, TYPE_CHECKING +from typing import ClassVar, Literal, Optional, TYPE_CHECKING, Union +import equinox as eqx import jax import jax.numpy as jnp import numpy as np @@ -89,7 +90,7 @@ ) -class KenCarpInterpolation(AbstractLocalInterpolation): +class AbstractKenCarpInterpolation(AbstractLocalInterpolation, strict=True): t0: RealScalarLike t1: RealScalarLike y0: Y @@ -120,7 +121,7 @@ def evaluate( return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω -class _KenCarp3Interpolation(KenCarpInterpolation): +class _KenCarp3Interpolation(AbstractKenCarpInterpolation, strict=True): coeffs = np.array( [ [-215264564351 / 13552729205753, 4655552711362 / 22874653954995], @@ -131,7 +132,11 @@ class _KenCarp3Interpolation(KenCarpInterpolation): ) -class KenCarp3(AbstractRungeKutta, AbstractImplicitSolver): +class KenCarp3( + AbstractRungeKutta, + AbstractImplicitSolver, + strict=eqx.StrictConfig(allow_method_override=True), +): """Kennedy--Carpenter's 3/2 IMEX method. 3rd order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly @@ -165,6 +170,7 @@ class KenCarp3(AbstractRungeKutta, AbstractImplicitSolver): Callable[..., _KenCarp3Interpolation] ] = _KenCarp3Interpolation + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kencarp4.py b/diffrax/_solver/kencarp4.py index e50ce2b2..b99b0b08 100644 --- a/diffrax/_solver/kencarp4.py +++ b/diffrax/_solver/kencarp4.py @@ -1,12 +1,13 @@ from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np import optimistix as optx from .._root_finder import VeryChord, with_stepsize_controller_tols from .base import AbstractImplicitSolver -from .kencarp3 import KenCarpInterpolation +from .kencarp3 import AbstractKenCarpInterpolation from .runge_kutta import ( AbstractRungeKutta, ButcherTableau, @@ -102,7 +103,7 @@ ) -class _KenCarp4Interpolation(KenCarpInterpolation): +class _KenCarp4Interpolation(AbstractKenCarpInterpolation, strict=True): coeffs = np.array( [ [ @@ -135,7 +136,11 @@ class _KenCarp4Interpolation(KenCarpInterpolation): ) -class KenCarp4(AbstractRungeKutta, AbstractImplicitSolver): +class KenCarp4( + AbstractRungeKutta, + AbstractImplicitSolver, + strict=eqx.StrictConfig(allow_method_override=True), +): """Kennedy--Carpenter's 4/3 IMEX method. 4th order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly @@ -169,6 +174,7 @@ class KenCarp4(AbstractRungeKutta, AbstractImplicitSolver): Callable[..., _KenCarp4Interpolation] ] = _KenCarp4Interpolation + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kencarp5.py b/diffrax/_solver/kencarp5.py index 5cfc81d6..085dd1b1 100644 --- a/diffrax/_solver/kencarp5.py +++ b/diffrax/_solver/kencarp5.py @@ -1,12 +1,13 @@ from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np import optimistix as optx from .._root_finder import VeryChord, with_stepsize_controller_tols from .base import AbstractImplicitSolver -from .kencarp3 import KenCarpInterpolation +from .kencarp3 import AbstractKenCarpInterpolation from .runge_kutta import ( AbstractRungeKutta, ButcherTableau, @@ -163,7 +164,7 @@ ) -class _KenCarp5Interpolation(KenCarpInterpolation): +class _KenCarp5Interpolation(AbstractKenCarpInterpolation, strict=True): coeffs = np.array( [ [ @@ -202,7 +203,11 @@ class _KenCarp5Interpolation(KenCarpInterpolation): ) -class KenCarp5(AbstractRungeKutta, AbstractImplicitSolver): +class KenCarp5( + AbstractRungeKutta, + AbstractImplicitSolver, + strict=eqx.StrictConfig(allow_method_override=True), +): """Kennedy--Carpenter's 5/4 IMEX method. 5th order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly @@ -236,6 +241,7 @@ class KenCarp5(AbstractRungeKutta, AbstractImplicitSolver): Callable[..., _KenCarp5Interpolation] ] = _KenCarp5Interpolation + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kvaerno3.py b/diffrax/_solver/kvaerno3.py index a91fecc7..5d443f92 100644 --- a/diffrax/_solver/kvaerno3.py +++ b/diffrax/_solver/kvaerno3.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np import optimistix as optx @@ -41,7 +42,7 @@ ) -class Kvaerno3(AbstractESDIRK): +class Kvaerno3(AbstractESDIRK, strict=eqx.StrictConfig(allow_method_override=True)): r"""Kvaerno's 3/2 method. A-L stable stiffly accurate 3rd order ESDIRK method. Has an embedded 2nd order @@ -70,6 +71,7 @@ class Kvaerno3(AbstractESDIRK): Callable[..., ThirdOrderHermitePolynomialInterpolation] ] = ThirdOrderHermitePolynomialInterpolation.from_k + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kvaerno4.py b/diffrax/_solver/kvaerno4.py index 8139acd3..230da20b 100644 --- a/diffrax/_solver/kvaerno4.py +++ b/diffrax/_solver/kvaerno4.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np import optimistix as optx @@ -79,7 +80,7 @@ def poly(*args): ) -class Kvaerno4(AbstractESDIRK): +class Kvaerno4(AbstractESDIRK, strict=eqx.StrictConfig(allow_method_override=True)): r"""Kvaerno's 4/3 method. A-L stable stiffly accurate 4th order ESDIRK method. Has an embedded 3rd order @@ -111,6 +112,7 @@ class Kvaerno4(AbstractESDIRK): Callable[..., ThirdOrderHermitePolynomialInterpolation] ] = ThirdOrderHermitePolynomialInterpolation.from_k + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kvaerno5.py b/diffrax/_solver/kvaerno5.py index 9af32a94..9e6f3307 100644 --- a/diffrax/_solver/kvaerno5.py +++ b/diffrax/_solver/kvaerno5.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np import optimistix as optx @@ -85,7 +86,7 @@ ) -class Kvaerno5(AbstractESDIRK): +class Kvaerno5(AbstractESDIRK, strict=eqx.StrictConfig(allow_method_override=True)): r"""Kvaerno's 5/4 method. A-L stable stiffly accurate 5th order ESDIRK method. Has an embedded 4th order @@ -117,6 +118,7 @@ class Kvaerno5(AbstractESDIRK): Callable[..., ThirdOrderHermitePolynomialInterpolation] ] = ThirdOrderHermitePolynomialInterpolation.from_k + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/leapfrog_midpoint.py b/diffrax/_solver/leapfrog_midpoint.py index 00ba11da..8923e37e 100644 --- a/diffrax/_solver/leapfrog_midpoint.py +++ b/diffrax/_solver/leapfrog_midpoint.py @@ -2,6 +2,7 @@ from typing import ClassVar from typing_extensions import TypeAlias +import equinox as eqx from equinox.internal import ω from jaxtyping import PyTree @@ -17,7 +18,9 @@ # TODO: support arbitrary linear multistep methods -class LeapfrogMidpoint(AbstractSolver): +class LeapfrogMidpoint( + AbstractSolver, strict=eqx.StrictConfig(allow_method_override=True) +): r"""Leapfrog/midpoint method. 2nd order linear multistep method. Uses 1st order local linear interpolation for diff --git a/diffrax/_solver/midpoint.py b/diffrax/_solver/midpoint.py index 184377f5..06b2ec5a 100644 --- a/diffrax/_solver/midpoint.py +++ b/diffrax/_solver/midpoint.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation @@ -16,7 +17,11 @@ ) -class Midpoint(AbstractERK, AbstractStratonovichSolver): +class Midpoint( + AbstractERK, + AbstractStratonovichSolver, + strict=eqx.StrictConfig(allow_method_override=True), +): """Midpoint method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive @@ -28,6 +33,8 @@ class Midpoint(AbstractERK, AbstractStratonovichSolver): When used to solve SDEs, converges to the Stratonovich solution. """ + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None + tableau: ClassVar[ButcherTableau] = _midpoint_tableau interpolation_cls: ClassVar[ Callable[..., ThirdOrderHermitePolynomialInterpolation] diff --git a/diffrax/_solver/milstein.py b/diffrax/_solver/milstein.py index 5eae5e5b..f7dbd289 100644 --- a/diffrax/_solver/milstein.py +++ b/diffrax/_solver/milstein.py @@ -2,6 +2,7 @@ from typing import ClassVar from typing_extensions import TypeAlias +import equinox as eqx import jax import jax.numpy as jnp import jax.tree_util as jtu @@ -27,7 +28,9 @@ # -class StratonovichMilstein(AbstractStratonovichSolver): +class StratonovichMilstein( + AbstractStratonovichSolver, strict=eqx.StrictConfig(allow_method_override=True) +): r"""Milstein's method; Stratonovich version. Used to solve SDEs, and converges to the Stratonovich solution. Uses local linear @@ -101,7 +104,9 @@ def func( return drift.vf(t0, y0, args), diffusion.vf(t0, y0, args) -class ItoMilstein(AbstractItoSolver): +class ItoMilstein( + AbstractItoSolver, strict=eqx.StrictConfig(allow_method_override=True) +): r"""Milstein's method; Itô version. Used to solve SDEs, and converges to the Itô solution. Uses local linear diff --git a/diffrax/_solver/ralston.py b/diffrax/_solver/ralston.py index d8a9c6bf..4a51d1ba 100644 --- a/diffrax/_solver/ralston.py +++ b/diffrax/_solver/ralston.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation @@ -25,7 +26,11 @@ ) -class Ralston(AbstractERK, AbstractStratonovichSolver): +class Ralston( + AbstractERK, + AbstractStratonovichSolver, + strict=eqx.StrictConfig(allow_method_override=True), +): """Ralston's method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive @@ -34,6 +39,8 @@ class Ralston(AbstractERK, AbstractStratonovichSolver): When used to solve SDEs, converges to the Stratonovich solution. """ + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None + tableau: ClassVar[ButcherTableau] = _ralston_tableau interpolation_cls: ClassVar[ Callable[..., ThirdOrderHermitePolynomialInterpolation] diff --git a/diffrax/_solver/reversible_heun.py b/diffrax/_solver/reversible_heun.py index 0f0a9fe9..01a8e27d 100644 --- a/diffrax/_solver/reversible_heun.py +++ b/diffrax/_solver/reversible_heun.py @@ -2,6 +2,7 @@ from typing import ClassVar from typing_extensions import TypeAlias +import equinox as eqx import jax.lax as lax from equinox.internal import ω from jaxtyping import PyTree @@ -16,7 +17,11 @@ _SolverState: TypeAlias = tuple[PyTree, PyTree] -class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): +class ReversibleHeun( + AbstractAdaptiveSolver, + AbstractStratonovichSolver, + strict=eqx.StrictConfig(allow_method_override=True), +): """Reversible Heun method. Algebraically reversible 2nd order method. Has an embedded 1st order method for diff --git a/diffrax/_solver/runge_kutta.py b/diffrax/_solver/runge_kutta.py index e8c28546..f7c81fa9 100644 --- a/diffrax/_solver/runge_kutta.py +++ b/diffrax/_solver/runge_kutta.py @@ -24,6 +24,7 @@ import lineax.internal as lxi import numpy as np import optimistix as optx +from equinox import AbstractVar if TYPE_CHECKING: @@ -197,7 +198,7 @@ def __post_init__(self): """ -class MultiButcherTableau(eqx.Module): +class MultiButcherTableau(eqx.Module, strict=True): """Wraps multiple [`diffrax.ButcherTableau`][]s together. Used in some multi-tableau solvers, like IMEX methods. @@ -342,7 +343,7 @@ def _assert_same_structure(x, y): return eqx.tree_equal(x, y) is True -class AbstractRungeKutta(AbstractAdaptiveSolver[_SolverState]): +class AbstractRungeKutta(AbstractAdaptiveSolver[_SolverState], strict=True): """Abstract base class for all Runge--Kutta solvers. (Other than fully-implicit Runge--Kutta methods, which have a different computational structure.) @@ -356,7 +357,7 @@ class AbstractRungeKutta(AbstractAdaptiveSolver[_SolverState]): instance of [`diffrax.CalculateJacobian`][]. """ - scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None + scan_kind: AbstractVar[Union[None, Literal["lax", "checkpointed", "bounded"]]] tableau: AbstractClassVar[Union[ButcherTableau, MultiButcherTableau]] calculate_jacobian: AbstractClassVar[CalculateJacobian] @@ -1201,7 +1202,7 @@ def _increment(tab_i, k_i): return y1, y_error, dense_info, new_solver_state, result -class AbstractERK(AbstractRungeKutta): +class AbstractERK(AbstractRungeKutta, strict=True): """Abstract base class for all Explicit Runge--Kutta solvers. Subclasses should include a class-level attribute `tableau`, an instance of @@ -1211,7 +1212,7 @@ class AbstractERK(AbstractRungeKutta): calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.never -class AbstractDIRK(AbstractRungeKutta, AbstractImplicitSolver): +class AbstractDIRK(AbstractRungeKutta, AbstractImplicitSolver, strict=True): """Abstract base class for all Diagonal Implicit Runge--Kutta solvers. Subclasses should include a class-level attribute `tableau`, an instance of @@ -1221,7 +1222,7 @@ class AbstractDIRK(AbstractRungeKutta, AbstractImplicitSolver): calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.every_stage -class AbstractSDIRK(AbstractDIRK): +class AbstractSDIRK(AbstractDIRK, strict=True): """Abstract base class for all Singular Diagonal Implict Runge--Kutta solvers. Subclasses should include a class-level attribute `tableau`, an instance of @@ -1239,7 +1240,7 @@ def __init_subclass__(cls, **kwargs): calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.first_stage -class AbstractESDIRK(AbstractDIRK): +class AbstractESDIRK(AbstractDIRK, strict=True): """Abstract base class for all Explicit Singular Diagonal Implicit Runge--Kutta solvers. diff --git a/diffrax/_solver/semi_implicit_euler.py b/diffrax/_solver/semi_implicit_euler.py index 00b9e1db..e43fbadc 100644 --- a/diffrax/_solver/semi_implicit_euler.py +++ b/diffrax/_solver/semi_implicit_euler.py @@ -2,6 +2,7 @@ from typing import ClassVar from typing_extensions import TypeAlias +import equinox as eqx from equinox.internal import ω from jaxtyping import ArrayLike, Float, PyTree @@ -19,7 +20,9 @@ Yb: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] -class SemiImplicitEuler(AbstractSolver): +class SemiImplicitEuler( + AbstractSolver, strict=eqx.StrictConfig(allow_method_override=True) +): """Semi-implicit Euler's method. Symplectic method. Does not support adaptive step sizing. Uses 1st order local diff --git a/diffrax/_solver/sil3.py b/diffrax/_solver/sil3.py index 7bec9d4b..63bef37e 100644 --- a/diffrax/_solver/sil3.py +++ b/diffrax/_solver/sil3.py @@ -1,5 +1,6 @@ -from typing import ClassVar +from typing import ClassVar, Literal, Union +import equinox as eqx import numpy as np import optimistix as optx from equinox.internal import ω @@ -48,7 +49,11 @@ ) -class Sil3(AbstractRungeKutta, AbstractImplicitSolver): +class Sil3( + AbstractRungeKutta, + AbstractImplicitSolver, + strict=eqx.StrictConfig(allow_method_override=True), +): """Whitaker--Kar's fast-slow IMEX method. 3rd order in the explicit (ERK) term; 2nd order in the implicit (EDIRK) term. Uses @@ -88,6 +93,7 @@ def interpolation_cls(t0, t1, y0, y1, k): t0=t0, t1=t1, y0=y0, y1=y1, k0=k0, k1=k1 ) + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/tsit5.py b/diffrax/_solver/tsit5.py index ebd96539..809b0577 100644 --- a/diffrax/_solver/tsit5.py +++ b/diffrax/_solver/tsit5.py @@ -1,6 +1,7 @@ from collections.abc import Callable -from typing import ClassVar, Optional +from typing import ClassVar, Literal, Optional, Union +import equinox as eqx import jax.numpy as jnp import numpy as np from equinox.internal import ω @@ -98,7 +99,7 @@ ) -class _Tsit5Interpolation(AbstractLocalInterpolation): +class _Tsit5Interpolation(AbstractLocalInterpolation, strict=True): t0: RealScalarLike t1: RealScalarLike y0: Y @@ -153,7 +154,7 @@ def evaluate( ).ω -class Tsit5(AbstractERK): +class Tsit5(AbstractERK, strict=eqx.StrictConfig(allow_method_override=True)): r"""Tsitouras' 5/4 method. 5th order explicit Runge--Kutta method. Has an embedded 4th order method for @@ -177,6 +178,8 @@ class Tsit5(AbstractERK): ``` """ + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None + tableau: ClassVar[ButcherTableau] = _tsit5_tableau interpolation_cls: ClassVar[ Callable[..., _Tsit5Interpolation] diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/adaptive.py index 4b3d432a..8a1d45eb 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/adaptive.py @@ -89,7 +89,7 @@ def intermediate(carry): class AbstractAdaptiveStepSizeController( - AbstractStepSizeController[_ControllerState, _Dt0] + AbstractStepSizeController[_ControllerState, _Dt0], strict=True ): """Indicates an adaptive step size controller. @@ -152,7 +152,7 @@ def __repr__(self): # TODO: we don't currently offer a limiter, or a variant accept/reject scheme, as given # in Soderlind and Wang 2006. class PIDController( - AbstractAdaptiveStepSizeController[_PidState, Optional[RealScalarLike]] + AbstractAdaptiveStepSizeController[_PidState, Optional[RealScalarLike]], strict=True ): r"""Adapts the step size to produce a solution accurate to a given tolerance. The tolerance is calculated as `atol + rtol * y` for the evolving solution `y`. diff --git a/diffrax/_step_size_controller/base.py b/diffrax/_step_size_controller/base.py index 625bd6fb..56139383 100644 --- a/diffrax/_step_size_controller/base.py +++ b/diffrax/_step_size_controller/base.py @@ -14,7 +14,9 @@ _Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike]) -class AbstractStepSizeController(eqx.Module, Generic[_ControllerState, _Dt0]): +class AbstractStepSizeController( + eqx.Module, Generic[_ControllerState, _Dt0], strict=True +): """Abstract base class for all step size controllers.""" @abc.abstractmethod diff --git a/diffrax/_step_size_controller/constant.py b/diffrax/_step_size_controller/constant.py index ce8ad425..4b542a86 100644 --- a/diffrax/_step_size_controller/constant.py +++ b/diffrax/_step_size_controller/constant.py @@ -12,7 +12,9 @@ from .base import AbstractStepSizeController -class ConstantStepSize(AbstractStepSizeController[RealScalarLike, RealScalarLike]): +class ConstantStepSize( + AbstractStepSizeController[RealScalarLike, RealScalarLike], strict=True +): """Use a constant step size, equal to the `dt0` argument of [`diffrax.diffeqsolve`][]. """ @@ -61,7 +63,7 @@ def adapt_step_size( ) -class StepTo(AbstractStepSizeController[IntScalarLike, None]): +class StepTo(AbstractStepSizeController[IntScalarLike, None], strict=True): """Make steps to just prespecified times.""" ts: Real[Array, " times"] = eqx.field(converter=jnp.asarray) diff --git a/diffrax/_term.py b/diffrax/_term.py index 277089b0..875d173f 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -16,7 +16,7 @@ from ._path import AbstractPath -class AbstractTerm(eqx.Module): +class AbstractTerm(eqx.Module, strict=True): r"""Abstract base class for all terms. Let $y$ solve some differential equation with vector field $f$ and control $x$. @@ -155,7 +155,7 @@ def is_vf_expensive( return False -class ODETerm(AbstractTerm): +class ODETerm(AbstractTerm, strict=True): r"""A term representing $f(t, y(t), args) \mathrm{d}t$. That is to say, the term appearing on the right hand side of an ODE, in which the control is time. @@ -219,7 +219,7 @@ def _mul(v): """ -class _CallableToPath(AbstractPath): +class _CallableToPath(AbstractPath, strict=True): fn: Callable @property @@ -250,7 +250,7 @@ def _prod(vf, control): return jnp.tensordot(vf, control, axes=jnp.ndim(control)) -class _ControlTerm(AbstractTerm): +class _AbstractControlTerm(AbstractTerm, strict=True): vector_field: Callable[[RealScalarLike, Y, Args], VF] control: Union[AbstractPath, Callable] = eqx.field(converter=_callable_to_path) @@ -273,7 +273,7 @@ def to_ode(self) -> ODETerm: return ODETerm(vector_field=vector_field) -_ControlTerm.__init__.__doc__ = """**Arguments:** +_AbstractControlTerm.__init__.__doc__ = """**Arguments:** - `vector_field`: A callable representing the vector field. This callable takes three arguments `(t, y, args)`. `t` is a scalar representing the integration time. `y` is @@ -285,7 +285,7 @@ def to_ode(self) -> ODETerm: """ -class ControlTerm(_ControlTerm): +class ControlTerm(_AbstractControlTerm, strict=True): r"""A term representing the general case of $f(t, y(t), args) \mathrm{d}x(t)$, in which the vector field - control interaction is a matrix-vector product. @@ -327,7 +327,7 @@ def prod(self, vf: VF, control: Control) -> Y: return jtu.tree_map(_prod, vf, control) -class WeaklyDiagonalControlTerm(_ControlTerm): +class WeaklyDiagonalControlTerm(_AbstractControlTerm, strict=True): r"""A term representing the case of $f(t, y(t), args) \mathrm{d}x(t)$, in which the vector field - control interaction is a matrix-vector product, and the matrix is square and diagonal. In this case we may represent the matrix as a vector @@ -353,8 +353,8 @@ def prod(self, vf: VF, control: Control) -> Y: return jtu.tree_map(operator.mul, vf, control) -class _ControlToODE(eqx.Module): - control_term: _ControlTerm +class _ControlToODE(eqx.Module, strict=True): + control_term: _AbstractControlTerm def __call__(self, t: RealScalarLike, y: Y, args: Args) -> Y: control = self.control_term.control.derivative(t) @@ -368,7 +368,9 @@ def _sum(*x): _Terms = TypeVar("_Terms", bound=tuple[AbstractTerm, ...]) -class MultiTerm(AbstractTerm, Generic[_Terms]): +class MultiTerm( + AbstractTerm, Generic[_Terms], strict=eqx.StrictConfig(allow_method_override=True) +): r"""Accumulates multiple terms into a single term. Consider the SDE @@ -436,7 +438,7 @@ def is_vf_expensive( return any(term.is_vf_expensive(t0, t1, y, args) for term in self.terms) -class WrapTerm(AbstractTerm): +class WrapTerm(AbstractTerm, strict=eqx.StrictConfig(allow_method_override=True)): term: AbstractTerm direction: IntScalarLike @@ -468,7 +470,7 @@ def is_vf_expensive( return self.term.is_vf_expensive(_t0, _t1, y, args) -class AdjointTerm(AbstractTerm): +class AdjointTerm(AbstractTerm, strict=eqx.StrictConfig(allow_method_override=True)): term: AbstractTerm def is_vf_expensive( diff --git a/pyproject.toml b/pyproject.toml index 6ca52eb4..10bc7a26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/patrick-kidger/diffrax" } -dependencies = ["jax>=0.4.18", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.2", "lineax>=0.0.4", "optimistix>=0.0.6"] +dependencies = ["jax>=0.4.18", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.3", "lineax>=0.0.4", "optimistix>=0.0.7"] [build-system] requires = ["hatchling"]