diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 3d85ec69..fb4f20ec 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -107,7 +107,7 @@ def get_ys(_final_state): return final_state -class AbstractAdjoint(eqx.Module): +class AbstractAdjoint(eqx.Module, strict=True): """Abstract base class for all adjoint methods.""" @abc.abstractmethod @@ -167,7 +167,7 @@ def _uncallable(*args, **kwargs): assert False -class RecursiveCheckpointAdjoint(AbstractAdjoint): +class RecursiveCheckpointAdjoint(AbstractAdjoint, strict=True): """Backpropagate through [`diffrax.diffeqsolve`][] by differentiating the numerical solution directly. This is sometimes known as "discretise-then-optimise", or described as "backpropagation through the solver". @@ -318,7 +318,7 @@ def loop( """ -class DirectAdjoint(AbstractAdjoint): +class DirectAdjoint(AbstractAdjoint, strict=True): """A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that `DirectAdjoint`: @@ -434,7 +434,7 @@ def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]: return frozenset(iter_x) -class ImplicitAdjoint(AbstractAdjoint): +class ImplicitAdjoint(AbstractAdjoint, strict=True): r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem). This is used when solving towards a steady state, typically using @@ -705,7 +705,7 @@ def __get(__aug): return a_y1, a_diff_args1, a_diff_terms1 -class BacksolveAdjoint(AbstractAdjoint): +class BacksolveAdjoint(AbstractAdjoint, strict=True): """Backpropagate through [`diffrax.diffeqsolve`][] by solving the continuous adjoint equations backwards-in-time. This is also sometimes known as "optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint diff --git a/diffrax/_brownian/base.py b/diffrax/_brownian/base.py index 96a07253..0833fdaa 100644 --- a/diffrax/_brownian/base.py +++ b/diffrax/_brownian/base.py @@ -8,7 +8,7 @@ from .._path import AbstractPath -class AbstractBrownianPath(AbstractPath): +class AbstractBrownianPath(AbstractPath, strict=True): """Abstract base class for all Brownian paths.""" levy_area: AbstractVar[LevyArea] diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index c88e9ee2..e773124c 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -19,7 +19,7 @@ from .base import AbstractBrownianPath -class UnsafeBrownianPath(AbstractBrownianPath): +class UnsafeBrownianPath(AbstractBrownianPath, strict=True): """Brownian simulation that is only suitable for certain cases. This is a very quick way to simulate Brownian motion, but can only be used when all diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index bf972526..6a4ebce0 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -60,7 +60,7 @@ _Spline: TypeAlias = Literal["sqrt", "quad", "zero"] -class _State(eqx.Module): +class _State(eqx.Module, strict=True): level: IntScalarLike # level of the tree s: RealScalarLike # starting time of the interval w_s_u_su: FloatTriple # W_s, W_u, W_{s,u} @@ -109,7 +109,7 @@ def _split_interval( return x_s, x_u, x_su -class VirtualBrownianTree(AbstractBrownianPath): +class VirtualBrownianTree(AbstractBrownianPath, strict=True): """Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`. Can be initialised with `levy_area` set to `""`, or `"space-time"`. diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index ffccf64a..373995b7 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -56,7 +56,7 @@ LevyArea: TypeAlias = Literal["", "space-time"] -class LevyVal(eqx.Module): +class LevyVal(eqx.Module, strict=True): dt: PyTree W: PyTree H: Optional[PyTree] diff --git a/diffrax/_event.py b/diffrax/_event.py index a596d8ed..b4b42fe3 100644 --- a/diffrax/_event.py +++ b/diffrax/_event.py @@ -10,7 +10,7 @@ from ._step_size_controller import AbstractAdaptiveStepSizeController -class AbstractDiscreteTerminatingEvent(eqx.Module): +class AbstractDiscreteTerminatingEvent(eqx.Module, strict=True): """Evaluated at the end of each integration step. If true then the solve is stopped at that time. """ @@ -30,7 +30,7 @@ def __call__(self, state, **kwargs) -> BoolScalarLike: """ -class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent): +class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent, strict=True): """Terminates the solve if its condition is ever active.""" cond_fn: Callable[..., BoolScalarLike] @@ -50,7 +50,7 @@ def __call__(self, state, **kwargs): """ -class SteadyStateEvent(AbstractDiscreteTerminatingEvent): +class SteadyStateEvent(AbstractDiscreteTerminatingEvent, strict=True): """Terminates the solve once it reaches a steady state.""" rtol: Optional[float] = None diff --git a/diffrax/_global_interpolation.py b/diffrax/_global_interpolation.py index 9b0f2fe7..a80fce48 100644 --- a/diffrax/_global_interpolation.py +++ b/diffrax/_global_interpolation.py @@ -23,7 +23,7 @@ from ._path import AbstractPath -class AbstractGlobalInterpolation(AbstractPath): +class AbstractGlobalInterpolation(AbstractPath, strict=True): ts: AbstractVar[Real[Array, " times"]] ts_size: AbstractVar[IntScalarLike] @@ -52,7 +52,7 @@ def t1(self): return self.ts[-1] -class LinearInterpolation(AbstractGlobalInterpolation): +class LinearInterpolation(AbstractGlobalInterpolation, strict=True): """Linearly interpolates some data `ys` over the interval $[t_0, t_1]$ with knots at `ts`. @@ -178,7 +178,7 @@ def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]: """ -class CubicInterpolation(AbstractGlobalInterpolation): +class CubicInterpolation(AbstractGlobalInterpolation, strict=True): """Piecewise cubic spline interpolation over the interval $[t_0, t_1]$.""" ts: Real[Array, " times"] @@ -302,7 +302,7 @@ def derivative( """ -class DenseInterpolation(AbstractGlobalInterpolation): +class DenseInterpolation(AbstractGlobalInterpolation, strict=True): ts: Real[Array, " times"] # DenseInterpolations typically get `ts` and `infos` that are way longer than they # need to be, and padded with `nan`s. This means the normal way of measuring how diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index b6501e41..d9109b3f 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -47,14 +47,14 @@ from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm -class SaveState(eqx.Module): +class SaveState(eqx.Module, strict=True): saveat_ts_index: IntScalarLike ts: eqxi.MaybeBuffer[Real[Array, " times"]] ys: PyTree[eqxi.MaybeBuffer[Inexact[Array, "times ..."]]] save_index: IntScalarLike -class State(eqx.Module): +class State(eqx.Module, strict=True): # Evolving state during the solve y: PyTree[Array] tprev: FloatScalarLike diff --git a/diffrax/_local_interpolation.py b/diffrax/_local_interpolation.py index 1969cd87..6bba4b3f 100644 --- a/diffrax/_local_interpolation.py +++ b/diffrax/_local_interpolation.py @@ -1,5 +1,6 @@ from typing import Optional, TYPE_CHECKING +import equinox as eqx import jax.numpy as jnp import jax.tree_util as jtu import numpy as np @@ -17,11 +18,11 @@ from ._path import AbstractPath -class AbstractLocalInterpolation(AbstractPath): +class AbstractLocalInterpolation(AbstractPath, strict=True): pass -class LocalLinearInterpolation(AbstractLocalInterpolation): +class LocalLinearInterpolation(AbstractLocalInterpolation, strict=True): t0: RealScalarLike t1: RealScalarLike y0: Y @@ -39,7 +40,7 @@ def evaluate( return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω -class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation): +class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation, strict=True): t0: RealScalarLike t1: RealScalarLike coeffs: PyTree[Shaped[Array, "4 ?*dims"], "Y"] @@ -83,7 +84,9 @@ def _eval(_coeffs): return jtu.tree_map(_eval, self.coeffs) -class FourthOrderPolynomialInterpolation(AbstractLocalInterpolation): +class FourthOrderPolynomialInterpolation( + AbstractLocalInterpolation, strict=eqx.StrictConfig(allow_abstract_name=True) +): t0: RealScalarLike t1: RealScalarLike coeffs: PyTree[Shaped[Array, "5 ?*y"], "Y"] diff --git a/diffrax/_path.py b/diffrax/_path.py index c7b90a3b..92f78945 100644 --- a/diffrax/_path.py +++ b/diffrax/_path.py @@ -15,7 +15,7 @@ from ._custom_types import RealScalarLike -class AbstractPath(eqx.Module): +class AbstractPath(eqx.Module, strict=True): """Abstract base class for all paths. Every path has a start point `t0` and an end point `t1`. In between these values diff --git a/diffrax/_root_finder/_verychord.py b/diffrax/_root_finder/_verychord.py index 4bce16a5..899e3063 100644 --- a/diffrax/_root_finder/_verychord.py +++ b/diffrax/_root_finder/_verychord.py @@ -30,7 +30,7 @@ def _converged(factor: Scalar, tol: float) -> Bool[Array, ""]: return (factor > 0) & (factor < tol) -class _VeryChordState(eqx.Module): +class _VeryChordState(eqx.Module, strict=True): linear_state: tuple[lx.AbstractLinearOperator, PyTree[Any]] diff: Y diffsize: Scalar @@ -39,7 +39,7 @@ class _VeryChordState(eqx.Module): step: Scalar -class _NoAux(eqx.Module): +class _NoAux(eqx.Module, strict=True): fn: Callable def __call__(self, y, args): @@ -48,7 +48,7 @@ def __call__(self, y, args): return out -class VeryChord(optx.AbstractRootFinder): +class VeryChord(optx.AbstractRootFinder, strict=True): """The Chord method of root finding. As `optimistix.Chord`, except that in Runge--Kutta methods, the linearisation point diff --git a/diffrax/_saveat.py b/diffrax/_saveat.py index 6ee373de..c8cb3044 100644 --- a/diffrax/_saveat.py +++ b/diffrax/_saveat.py @@ -21,7 +21,7 @@ def _convert_ts( return jnp.asarray(ts) -class SubSaveAt(eqx.Module): +class SubSaveAt(eqx.Module, strict=True): """Used for finer-grained control over what is saved. A PyTree of these should be passed to `SaveAt(subs=...)`. @@ -53,7 +53,7 @@ def __check_init__(self): """ -class SaveAt(eqx.Module): +class SaveAt(eqx.Module, strict=True): """Determines what to save as output from the differential equation solve. Instances of this class should be passed as the `saveat` argument of diff --git a/diffrax/_solution.py b/diffrax/_solution.py index b1f7e322..965e61c0 100644 --- a/diffrax/_solution.py +++ b/diffrax/_solution.py @@ -1,5 +1,6 @@ from typing import Any, Optional +import equinox as eqx import jax import optimistix as optx from jaxtyping import Array, Bool, PyTree, Real, Shaped @@ -55,7 +56,7 @@ def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS: return RESULTS.where(pred, old_result, out_result) -class Solution(AbstractPath): +class Solution(AbstractPath, strict=eqx.StrictConfig(allow_method_override=True)): """The solution to a differential equation. **Attributes:** diff --git a/diffrax/_solver/base.py b/diffrax/_solver/base.py index 3b04bfa4..02ca3c8e 100644 --- a/diffrax/_solver/base.py +++ b/diffrax/_solver/base.py @@ -49,7 +49,7 @@ def __instancecheck__(cls, obj): _set_metaclass = dict(metaclass=_MetaAbstractSolver) -class AbstractSolver(eqx.Module, Generic[_SolverState], **_set_metaclass): +class AbstractSolver(eqx.Module, Generic[_SolverState], strict=True, **_set_metaclass): """Abstract base class for all differential equation solvers. Subclasses should have a class-level attribute `terms`, specifying the PyTree @@ -179,7 +179,7 @@ def func( """ -class AbstractImplicitSolver(AbstractSolver[_SolverState]): +class AbstractImplicitSolver(AbstractSolver[_SolverState], strict=True): """Indicates that this is an implicit differential equation solver, and as such that it should take a root finder as an argument. """ @@ -188,25 +188,25 @@ class AbstractImplicitSolver(AbstractSolver[_SolverState]): root_find_max_steps: AbstractVar[int] -class AbstractItoSolver(AbstractSolver[_SolverState]): +class AbstractItoSolver(AbstractSolver[_SolverState], strict=True): """Indicates that when used as an SDE solver that this solver will converge to the Itô solution. """ -class AbstractStratonovichSolver(AbstractSolver[_SolverState]): +class AbstractStratonovichSolver(AbstractSolver[_SolverState], strict=True): """Indicates that when used as an SDE solver that this solver will converge to the Stratonovich solution. """ -class AbstractAdaptiveSolver(AbstractSolver[_SolverState]): +class AbstractAdaptiveSolver(AbstractSolver[_SolverState], strict=True): """Indicates that this solver provides error estimates, and that as such it may be used with an adaptive step size controller. """ -class AbstractWrappedSolver(AbstractSolver[_SolverState]): +class AbstractWrappedSolver(AbstractSolver[_SolverState], strict=True): """Wraps another solver "transparently", in the sense that all `isinstance` checks will be forwarded on to the wrapped solver, e.g. when testing whether the solver is implicit/adaptive/SDE-compatible/etc. diff --git a/diffrax/_solver/dopri5.py b/diffrax/_solver/dopri5.py index e6ebff3b..9faf5b77 100644 --- a/diffrax/_solver/dopri5.py +++ b/diffrax/_solver/dopri5.py @@ -32,7 +32,7 @@ ) -class _Dopri5Interpolation(FourthOrderPolynomialInterpolation): +class _Dopri5Interpolation(FourthOrderPolynomialInterpolation, strict=True): c_mid: ClassVar[np.ndarray] = np.array( [ 6025192743 / 30085553152 / 2, diff --git a/diffrax/_solver/dopri8.py b/diffrax/_solver/dopri8.py index ab9ef64a..74e216ca 100644 --- a/diffrax/_solver/dopri8.py +++ b/diffrax/_solver/dopri8.py @@ -188,7 +188,7 @@ _vmap_polyval = jax.vmap(jnp.polyval, in_axes=(0, None)) -class _Dopri8Interpolation(AbstractLocalInterpolation): +class _Dopri8Interpolation(AbstractLocalInterpolation, strict=True): t0: RealScalarLike t1: RealScalarLike y0: Y diff --git a/diffrax/_solver/kencarp3.py b/diffrax/_solver/kencarp3.py index d24546fd..0205de5f 100644 --- a/diffrax/_solver/kencarp3.py +++ b/diffrax/_solver/kencarp3.py @@ -89,7 +89,7 @@ ) -class KenCarpInterpolation(AbstractLocalInterpolation): +class AbstractKenCarpInterpolation(AbstractLocalInterpolation, strict=True): t0: RealScalarLike t1: RealScalarLike y0: Y @@ -120,7 +120,7 @@ def evaluate( return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω -class _KenCarp3Interpolation(KenCarpInterpolation): +class _KenCarp3Interpolation(AbstractKenCarpInterpolation, strict=True): coeffs = np.array( [ [-215264564351 / 13552729205753, 4655552711362 / 22874653954995], diff --git a/diffrax/_solver/kencarp4.py b/diffrax/_solver/kencarp4.py index e50ce2b2..2e8ee5d1 100644 --- a/diffrax/_solver/kencarp4.py +++ b/diffrax/_solver/kencarp4.py @@ -6,7 +6,7 @@ from .._root_finder import VeryChord, with_stepsize_controller_tols from .base import AbstractImplicitSolver -from .kencarp3 import KenCarpInterpolation +from .kencarp3 import AbstractKenCarpInterpolation from .runge_kutta import ( AbstractRungeKutta, ButcherTableau, @@ -102,7 +102,7 @@ ) -class _KenCarp4Interpolation(KenCarpInterpolation): +class _KenCarp4Interpolation(AbstractKenCarpInterpolation, strict=True): coeffs = np.array( [ [ diff --git a/diffrax/_solver/kencarp5.py b/diffrax/_solver/kencarp5.py index 5cfc81d6..288f3b19 100644 --- a/diffrax/_solver/kencarp5.py +++ b/diffrax/_solver/kencarp5.py @@ -6,7 +6,7 @@ from .._root_finder import VeryChord, with_stepsize_controller_tols from .base import AbstractImplicitSolver -from .kencarp3 import KenCarpInterpolation +from .kencarp3 import AbstractKenCarpInterpolation from .runge_kutta import ( AbstractRungeKutta, ButcherTableau, @@ -163,7 +163,7 @@ ) -class _KenCarp5Interpolation(KenCarpInterpolation): +class _KenCarp5Interpolation(AbstractKenCarpInterpolation, strict=True): coeffs = np.array( [ [ diff --git a/diffrax/_solver/runge_kutta.py b/diffrax/_solver/runge_kutta.py index e8c28546..49e86616 100644 --- a/diffrax/_solver/runge_kutta.py +++ b/diffrax/_solver/runge_kutta.py @@ -197,7 +197,7 @@ def __post_init__(self): """ -class MultiButcherTableau(eqx.Module): +class MultiButcherTableau(eqx.Module, strict=True): """Wraps multiple [`diffrax.ButcherTableau`][]s together. Used in some multi-tableau solvers, like IMEX methods. @@ -342,7 +342,7 @@ def _assert_same_structure(x, y): return eqx.tree_equal(x, y) is True -class AbstractRungeKutta(AbstractAdaptiveSolver[_SolverState]): +class AbstractRungeKutta(AbstractAdaptiveSolver[_SolverState], strict=True): """Abstract base class for all Runge--Kutta solvers. (Other than fully-implicit Runge--Kutta methods, which have a different computational structure.) @@ -1201,7 +1201,7 @@ def _increment(tab_i, k_i): return y1, y_error, dense_info, new_solver_state, result -class AbstractERK(AbstractRungeKutta): +class AbstractERK(AbstractRungeKutta, strict=True): """Abstract base class for all Explicit Runge--Kutta solvers. Subclasses should include a class-level attribute `tableau`, an instance of @@ -1211,7 +1211,7 @@ class AbstractERK(AbstractRungeKutta): calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.never -class AbstractDIRK(AbstractRungeKutta, AbstractImplicitSolver): +class AbstractDIRK(AbstractRungeKutta, AbstractImplicitSolver, strict=True): """Abstract base class for all Diagonal Implicit Runge--Kutta solvers. Subclasses should include a class-level attribute `tableau`, an instance of @@ -1221,7 +1221,7 @@ class AbstractDIRK(AbstractRungeKutta, AbstractImplicitSolver): calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.every_stage -class AbstractSDIRK(AbstractDIRK): +class AbstractSDIRK(AbstractDIRK, strict=True): """Abstract base class for all Singular Diagonal Implict Runge--Kutta solvers. Subclasses should include a class-level attribute `tableau`, an instance of @@ -1239,7 +1239,7 @@ def __init_subclass__(cls, **kwargs): calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.first_stage -class AbstractESDIRK(AbstractDIRK): +class AbstractESDIRK(AbstractDIRK, strict=True): """Abstract base class for all Explicit Singular Diagonal Implicit Runge--Kutta solvers. diff --git a/diffrax/_solver/tsit5.py b/diffrax/_solver/tsit5.py index ebd96539..413a9afa 100644 --- a/diffrax/_solver/tsit5.py +++ b/diffrax/_solver/tsit5.py @@ -98,7 +98,7 @@ ) -class _Tsit5Interpolation(AbstractLocalInterpolation): +class _Tsit5Interpolation(AbstractLocalInterpolation, strict=True): t0: RealScalarLike t1: RealScalarLike y0: Y diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/adaptive.py index 4b3d432a..8a1d45eb 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/adaptive.py @@ -89,7 +89,7 @@ def intermediate(carry): class AbstractAdaptiveStepSizeController( - AbstractStepSizeController[_ControllerState, _Dt0] + AbstractStepSizeController[_ControllerState, _Dt0], strict=True ): """Indicates an adaptive step size controller. @@ -152,7 +152,7 @@ def __repr__(self): # TODO: we don't currently offer a limiter, or a variant accept/reject scheme, as given # in Soderlind and Wang 2006. class PIDController( - AbstractAdaptiveStepSizeController[_PidState, Optional[RealScalarLike]] + AbstractAdaptiveStepSizeController[_PidState, Optional[RealScalarLike]], strict=True ): r"""Adapts the step size to produce a solution accurate to a given tolerance. The tolerance is calculated as `atol + rtol * y` for the evolving solution `y`. diff --git a/diffrax/_step_size_controller/base.py b/diffrax/_step_size_controller/base.py index 625bd6fb..56139383 100644 --- a/diffrax/_step_size_controller/base.py +++ b/diffrax/_step_size_controller/base.py @@ -14,7 +14,9 @@ _Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike]) -class AbstractStepSizeController(eqx.Module, Generic[_ControllerState, _Dt0]): +class AbstractStepSizeController( + eqx.Module, Generic[_ControllerState, _Dt0], strict=True +): """Abstract base class for all step size controllers.""" @abc.abstractmethod diff --git a/diffrax/_step_size_controller/constant.py b/diffrax/_step_size_controller/constant.py index ce8ad425..4b542a86 100644 --- a/diffrax/_step_size_controller/constant.py +++ b/diffrax/_step_size_controller/constant.py @@ -12,7 +12,9 @@ from .base import AbstractStepSizeController -class ConstantStepSize(AbstractStepSizeController[RealScalarLike, RealScalarLike]): +class ConstantStepSize( + AbstractStepSizeController[RealScalarLike, RealScalarLike], strict=True +): """Use a constant step size, equal to the `dt0` argument of [`diffrax.diffeqsolve`][]. """ @@ -61,7 +63,7 @@ def adapt_step_size( ) -class StepTo(AbstractStepSizeController[IntScalarLike, None]): +class StepTo(AbstractStepSizeController[IntScalarLike, None], strict=True): """Make steps to just prespecified times.""" ts: Real[Array, " times"] = eqx.field(converter=jnp.asarray) diff --git a/diffrax/_term.py b/diffrax/_term.py index 277089b0..875d173f 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -16,7 +16,7 @@ from ._path import AbstractPath -class AbstractTerm(eqx.Module): +class AbstractTerm(eqx.Module, strict=True): r"""Abstract base class for all terms. Let $y$ solve some differential equation with vector field $f$ and control $x$. @@ -155,7 +155,7 @@ def is_vf_expensive( return False -class ODETerm(AbstractTerm): +class ODETerm(AbstractTerm, strict=True): r"""A term representing $f(t, y(t), args) \mathrm{d}t$. That is to say, the term appearing on the right hand side of an ODE, in which the control is time. @@ -219,7 +219,7 @@ def _mul(v): """ -class _CallableToPath(AbstractPath): +class _CallableToPath(AbstractPath, strict=True): fn: Callable @property @@ -250,7 +250,7 @@ def _prod(vf, control): return jnp.tensordot(vf, control, axes=jnp.ndim(control)) -class _ControlTerm(AbstractTerm): +class _AbstractControlTerm(AbstractTerm, strict=True): vector_field: Callable[[RealScalarLike, Y, Args], VF] control: Union[AbstractPath, Callable] = eqx.field(converter=_callable_to_path) @@ -273,7 +273,7 @@ def to_ode(self) -> ODETerm: return ODETerm(vector_field=vector_field) -_ControlTerm.__init__.__doc__ = """**Arguments:** +_AbstractControlTerm.__init__.__doc__ = """**Arguments:** - `vector_field`: A callable representing the vector field. This callable takes three arguments `(t, y, args)`. `t` is a scalar representing the integration time. `y` is @@ -285,7 +285,7 @@ def to_ode(self) -> ODETerm: """ -class ControlTerm(_ControlTerm): +class ControlTerm(_AbstractControlTerm, strict=True): r"""A term representing the general case of $f(t, y(t), args) \mathrm{d}x(t)$, in which the vector field - control interaction is a matrix-vector product. @@ -327,7 +327,7 @@ def prod(self, vf: VF, control: Control) -> Y: return jtu.tree_map(_prod, vf, control) -class WeaklyDiagonalControlTerm(_ControlTerm): +class WeaklyDiagonalControlTerm(_AbstractControlTerm, strict=True): r"""A term representing the case of $f(t, y(t), args) \mathrm{d}x(t)$, in which the vector field - control interaction is a matrix-vector product, and the matrix is square and diagonal. In this case we may represent the matrix as a vector @@ -353,8 +353,8 @@ def prod(self, vf: VF, control: Control) -> Y: return jtu.tree_map(operator.mul, vf, control) -class _ControlToODE(eqx.Module): - control_term: _ControlTerm +class _ControlToODE(eqx.Module, strict=True): + control_term: _AbstractControlTerm def __call__(self, t: RealScalarLike, y: Y, args: Args) -> Y: control = self.control_term.control.derivative(t) @@ -368,7 +368,9 @@ def _sum(*x): _Terms = TypeVar("_Terms", bound=tuple[AbstractTerm, ...]) -class MultiTerm(AbstractTerm, Generic[_Terms]): +class MultiTerm( + AbstractTerm, Generic[_Terms], strict=eqx.StrictConfig(allow_method_override=True) +): r"""Accumulates multiple terms into a single term. Consider the SDE @@ -436,7 +438,7 @@ def is_vf_expensive( return any(term.is_vf_expensive(t0, t1, y, args) for term in self.terms) -class WrapTerm(AbstractTerm): +class WrapTerm(AbstractTerm, strict=eqx.StrictConfig(allow_method_override=True)): term: AbstractTerm direction: IntScalarLike @@ -468,7 +470,7 @@ def is_vf_expensive( return self.term.is_vf_expensive(_t0, _t1, y, args) -class AdjointTerm(AbstractTerm): +class AdjointTerm(AbstractTerm, strict=eqx.StrictConfig(allow_method_override=True)): term: AbstractTerm def is_vf_expensive( diff --git a/pyproject.toml b/pyproject.toml index 6ca52eb4..10bc7a26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/patrick-kidger/diffrax" } -dependencies = ["jax>=0.4.18", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.2", "lineax>=0.0.4", "optimistix>=0.0.6"] +dependencies = ["jax>=0.4.18", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.3", "lineax>=0.0.4", "optimistix>=0.0.7"] [build-system] requires = ["hatchling"]