Skip to content

Commit

Permalink
Added strict=True
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jan 9, 2024
1 parent 0ee47c9 commit c7420bd
Show file tree
Hide file tree
Showing 41 changed files with 208 additions and 107 deletions.
10 changes: 5 additions & 5 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_ys(_final_state):
return final_state


class AbstractAdjoint(eqx.Module):
class AbstractAdjoint(eqx.Module, strict=True):
"""Abstract base class for all adjoint methods."""

@abc.abstractmethod
Expand Down Expand Up @@ -167,7 +167,7 @@ def _uncallable(*args, **kwargs):
assert False


class RecursiveCheckpointAdjoint(AbstractAdjoint):
class RecursiveCheckpointAdjoint(AbstractAdjoint, strict=True):
"""Backpropagate through [`diffrax.diffeqsolve`][] by differentiating the numerical
solution directly. This is sometimes known as "discretise-then-optimise", or
described as "backpropagation through the solver".
Expand Down Expand Up @@ -318,7 +318,7 @@ def loop(
"""


class DirectAdjoint(AbstractAdjoint):
class DirectAdjoint(AbstractAdjoint, strict=True):
"""A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that
`DirectAdjoint`:
Expand Down Expand Up @@ -434,7 +434,7 @@ def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
return frozenset(iter_x)


class ImplicitAdjoint(AbstractAdjoint):
class ImplicitAdjoint(AbstractAdjoint, strict=True):
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
This is used when solving towards a steady state, typically using
Expand Down Expand Up @@ -705,7 +705,7 @@ def __get(__aug):
return a_y1, a_diff_args1, a_diff_terms1


class BacksolveAdjoint(AbstractAdjoint):
class BacksolveAdjoint(AbstractAdjoint, strict=True):
"""Backpropagate through [`diffrax.diffeqsolve`][] by solving the continuous
adjoint equations backwards-in-time. This is also sometimes known as
"optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .._path import AbstractPath


class AbstractBrownianPath(AbstractPath):
class AbstractBrownianPath(AbstractPath, strict=True):
"""Abstract base class for all Brownian paths."""

levy_area: AbstractVar[LevyArea]
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .base import AbstractBrownianPath


class UnsafeBrownianPath(AbstractBrownianPath):
class UnsafeBrownianPath(AbstractBrownianPath, strict=True):
"""Brownian simulation that is only suitable for certain cases.
This is a very quick way to simulate Brownian motion, but can only be used when all
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
_Spline: TypeAlias = Literal["sqrt", "quad", "zero"]


class _State(eqx.Module):
class _State(eqx.Module, strict=True):
level: IntScalarLike # level of the tree
s: RealScalarLike # starting time of the interval
w_s_u_su: FloatTriple # W_s, W_u, W_{s,u}
Expand Down Expand Up @@ -109,7 +109,7 @@ def _split_interval(
return x_s, x_u, x_su


class VirtualBrownianTree(AbstractBrownianPath):
class VirtualBrownianTree(AbstractBrownianPath, strict=True):
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.
Can be initialised with `levy_area` set to `""`, or `"space-time"`.
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
LevyArea: TypeAlias = Literal["", "space-time"]


class LevyVal(eqx.Module):
class LevyVal(eqx.Module, strict=True):
dt: PyTree
W: PyTree
H: Optional[PyTree]
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._step_size_controller import AbstractAdaptiveStepSizeController


class AbstractDiscreteTerminatingEvent(eqx.Module):
class AbstractDiscreteTerminatingEvent(eqx.Module, strict=True):
"""Evaluated at the end of each integration step. If true then the solve is stopped
at that time.
"""
Expand All @@ -30,7 +30,7 @@ def __call__(self, state, **kwargs) -> BoolScalarLike:
"""


class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent):
class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent, strict=True):
"""Terminates the solve if its condition is ever active."""

cond_fn: Callable[..., BoolScalarLike]
Expand All @@ -50,7 +50,7 @@ def __call__(self, state, **kwargs):
"""


class SteadyStateEvent(AbstractDiscreteTerminatingEvent):
class SteadyStateEvent(AbstractDiscreteTerminatingEvent, strict=True):
"""Terminates the solve once it reaches a steady state."""

rtol: Optional[float] = None
Expand Down
8 changes: 4 additions & 4 deletions diffrax/_global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ._path import AbstractPath


class AbstractGlobalInterpolation(AbstractPath):
class AbstractGlobalInterpolation(AbstractPath, strict=True):
ts: AbstractVar[Real[Array, " times"]]
ts_size: AbstractVar[IntScalarLike]

Expand Down Expand Up @@ -52,7 +52,7 @@ def t1(self):
return self.ts[-1]


class LinearInterpolation(AbstractGlobalInterpolation):
class LinearInterpolation(AbstractGlobalInterpolation, strict=True):
"""Linearly interpolates some data `ys` over the interval $[t_0, t_1]$ with knots
at `ts`.
Expand Down Expand Up @@ -178,7 +178,7 @@ def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
"""


class CubicInterpolation(AbstractGlobalInterpolation):
class CubicInterpolation(AbstractGlobalInterpolation, strict=True):
"""Piecewise cubic spline interpolation over the interval $[t_0, t_1]$."""

ts: Real[Array, " times"]
Expand Down Expand Up @@ -302,7 +302,7 @@ def derivative(
"""


class DenseInterpolation(AbstractGlobalInterpolation):
class DenseInterpolation(AbstractGlobalInterpolation, strict=True):
ts: Real[Array, " times"]
# DenseInterpolations typically get `ts` and `infos` that are way longer than they
# need to be, and padded with `nan`s. This means the normal way of measuring how
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@
from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm


class SaveState(eqx.Module):
class SaveState(eqx.Module, strict=True):
saveat_ts_index: IntScalarLike
ts: eqxi.MaybeBuffer[Real[Array, " times"]]
ys: PyTree[eqxi.MaybeBuffer[Inexact[Array, "times ..."]]]
save_index: IntScalarLike


class State(eqx.Module):
class State(eqx.Module, strict=True):
# Evolving state during the solve
y: PyTree[Array]
tprev: FloatScalarLike
Expand Down
11 changes: 7 additions & 4 deletions diffrax/_local_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, TYPE_CHECKING

import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
Expand All @@ -17,11 +18,11 @@
from ._path import AbstractPath


class AbstractLocalInterpolation(AbstractPath):
class AbstractLocalInterpolation(AbstractPath, strict=True):
pass


class LocalLinearInterpolation(AbstractLocalInterpolation):
class LocalLinearInterpolation(AbstractLocalInterpolation, strict=True):
t0: RealScalarLike
t1: RealScalarLike
y0: Y
Expand All @@ -39,7 +40,7 @@ def evaluate(
return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω


class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation):
class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation, strict=True):
t0: RealScalarLike
t1: RealScalarLike
coeffs: PyTree[Shaped[Array, "4 ?*dims"], "Y"]
Expand Down Expand Up @@ -83,7 +84,9 @@ def _eval(_coeffs):
return jtu.tree_map(_eval, self.coeffs)


class FourthOrderPolynomialInterpolation(AbstractLocalInterpolation):
class FourthOrderPolynomialInterpolation(
AbstractLocalInterpolation, strict=eqx.StrictConfig(allow_abstract_name=True)
):
t0: RealScalarLike
t1: RealScalarLike
coeffs: PyTree[Shaped[Array, "5 ?*y"], "Y"]
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ._custom_types import RealScalarLike


class AbstractPath(eqx.Module):
class AbstractPath(eqx.Module, strict=True):
"""Abstract base class for all paths.
Every path has a start point `t0` and an end point `t1`. In between these values
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_root_finder/_verychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _converged(factor: Scalar, tol: float) -> Bool[Array, ""]:
return (factor > 0) & (factor < tol)


class _VeryChordState(eqx.Module):
class _VeryChordState(eqx.Module, strict=True):
linear_state: tuple[lx.AbstractLinearOperator, PyTree[Any]]
diff: Y
diffsize: Scalar
Expand All @@ -39,7 +39,7 @@ class _VeryChordState(eqx.Module):
step: Scalar


class _NoAux(eqx.Module):
class _NoAux(eqx.Module, strict=True):
fn: Callable

def __call__(self, y, args):
Expand All @@ -48,7 +48,7 @@ def __call__(self, y, args):
return out


class VeryChord(optx.AbstractRootFinder):
class VeryChord(optx.AbstractRootFinder, strict=True):
"""The Chord method of root finding.
As `optimistix.Chord`, except that in Runge--Kutta methods, the linearisation point
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_saveat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _convert_ts(
return jnp.asarray(ts)


class SubSaveAt(eqx.Module):
class SubSaveAt(eqx.Module, strict=True):
"""Used for finer-grained control over what is saved. A PyTree of these should be
passed to `SaveAt(subs=...)`.
Expand Down Expand Up @@ -53,7 +53,7 @@ def __check_init__(self):
"""


class SaveAt(eqx.Module):
class SaveAt(eqx.Module, strict=True):
"""Determines what to save as output from the differential equation solve.
Instances of this class should be passed as the `saveat` argument of
Expand Down
3 changes: 2 additions & 1 deletion diffrax/_solution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Optional

import equinox as eqx
import jax
import optimistix as optx
from jaxtyping import Array, Bool, PyTree, Real, Shaped
Expand Down Expand Up @@ -55,7 +56,7 @@ def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS:
return RESULTS.where(pred, old_result, out_result)


class Solution(AbstractPath):
class Solution(AbstractPath, strict=eqx.StrictConfig(allow_method_override=True)):
"""The solution to a differential equation.
**Attributes:**
Expand Down
16 changes: 9 additions & 7 deletions diffrax/_solver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __instancecheck__(cls, obj):
_set_metaclass = dict(metaclass=_MetaAbstractSolver)


class AbstractSolver(eqx.Module, Generic[_SolverState], **_set_metaclass):
class AbstractSolver(eqx.Module, Generic[_SolverState], strict=True, **_set_metaclass):
"""Abstract base class for all differential equation solvers.
Subclasses should have a class-level attribute `terms`, specifying the PyTree
Expand Down Expand Up @@ -179,7 +179,7 @@ def func(
"""


class AbstractImplicitSolver(AbstractSolver[_SolverState]):
class AbstractImplicitSolver(AbstractSolver[_SolverState], strict=True):
"""Indicates that this is an implicit differential equation solver, and as such
that it should take a root finder as an argument.
"""
Expand All @@ -188,25 +188,25 @@ class AbstractImplicitSolver(AbstractSolver[_SolverState]):
root_find_max_steps: AbstractVar[int]


class AbstractItoSolver(AbstractSolver[_SolverState]):
class AbstractItoSolver(AbstractSolver[_SolverState], strict=True):
"""Indicates that when used as an SDE solver that this solver will converge to the
Itô solution.
"""


class AbstractStratonovichSolver(AbstractSolver[_SolverState]):
class AbstractStratonovichSolver(AbstractSolver[_SolverState], strict=True):
"""Indicates that when used as an SDE solver that this solver will converge to the
Stratonovich solution.
"""


class AbstractAdaptiveSolver(AbstractSolver[_SolverState]):
class AbstractAdaptiveSolver(AbstractSolver[_SolverState], strict=True):
"""Indicates that this solver provides error estimates, and that as such it may be
used with an adaptive step size controller.
"""


class AbstractWrappedSolver(AbstractSolver[_SolverState]):
class AbstractWrappedSolver(AbstractSolver[_SolverState], strict=True):
"""Wraps another solver "transparently", in the sense that all `isinstance` checks
will be forwarded on to the wrapped solver, e.g. when testing whether the solver is
implicit/adaptive/SDE-compatible/etc.
Expand All @@ -219,7 +219,9 @@ class if that is not desired behaviour.)


class HalfSolver(
AbstractAdaptiveSolver[_SolverState], AbstractWrappedSolver[_SolverState]
AbstractAdaptiveSolver[_SolverState],
AbstractWrappedSolver[_SolverState],
strict=eqx.StrictConfig(allow_method_override=True),
):
"""Wraps another solver, trading cost in order to provide error estimates. (That
is, it means the solver can be used with an adaptive step size controller,
Expand Down
7 changes: 5 additions & 2 deletions diffrax/_solver/bosh3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Callable
from typing import ClassVar
from typing import ClassVar, Literal, Union

import equinox as eqx
import numpy as np

from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation
Expand All @@ -19,7 +20,7 @@
)


class Bosh3(AbstractERK):
class Bosh3(AbstractERK, strict=eqx.StrictConfig(allow_method_override=True)):
"""Bogacki--Shampine's 3/2 method.
3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for
Expand All @@ -29,6 +30,8 @@ class Bosh3(AbstractERK):
Also sometimes known as "Ralston's third order method".
"""

scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None

tableau: ClassVar[ButcherTableau] = _bosh3_tableau
interpolation_cls: ClassVar[
Callable[..., ThirdOrderHermitePolynomialInterpolation]
Expand Down
Loading

0 comments on commit c7420bd

Please sign in to comment.