Skip to content

Commit

Permalink
Added strict=True
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jan 9, 2024
1 parent 0ee47c9 commit 775da59
Show file tree
Hide file tree
Showing 41 changed files with 190 additions and 149 deletions.
10 changes: 5 additions & 5 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_ys(_final_state):
return final_state


class AbstractAdjoint(eqx.Module):
class AbstractAdjoint(eqx.Module, strict=True):
"""Abstract base class for all adjoint methods."""

@abc.abstractmethod
Expand Down Expand Up @@ -167,7 +167,7 @@ def _uncallable(*args, **kwargs):
assert False


class RecursiveCheckpointAdjoint(AbstractAdjoint):
class RecursiveCheckpointAdjoint(AbstractAdjoint, strict=True):
"""Backpropagate through [`diffrax.diffeqsolve`][] by differentiating the numerical
solution directly. This is sometimes known as "discretise-then-optimise", or
described as "backpropagation through the solver".
Expand Down Expand Up @@ -318,7 +318,7 @@ def loop(
"""


class DirectAdjoint(AbstractAdjoint):
class DirectAdjoint(AbstractAdjoint, strict=True):
"""A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that
`DirectAdjoint`:
Expand Down Expand Up @@ -434,7 +434,7 @@ def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
return frozenset(iter_x)


class ImplicitAdjoint(AbstractAdjoint):
class ImplicitAdjoint(AbstractAdjoint, strict=True):
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
This is used when solving towards a steady state, typically using
Expand Down Expand Up @@ -705,7 +705,7 @@ def __get(__aug):
return a_y1, a_diff_args1, a_diff_terms1


class BacksolveAdjoint(AbstractAdjoint):
class BacksolveAdjoint(AbstractAdjoint, strict=True):
"""Backpropagate through [`diffrax.diffeqsolve`][] by solving the continuous
adjoint equations backwards-in-time. This is also sometimes known as
"optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .._path import AbstractPath


class AbstractBrownianPath(AbstractPath):
class AbstractBrownianPath(AbstractPath, strict=True):
"""Abstract base class for all Brownian paths."""

levy_area: AbstractVar[LevyArea]
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .base import AbstractBrownianPath


class UnsafeBrownianPath(AbstractBrownianPath):
class UnsafeBrownianPath(AbstractBrownianPath, strict=True):
"""Brownian simulation that is only suitable for certain cases.
This is a very quick way to simulate Brownian motion, but can only be used when all
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
_Spline: TypeAlias = Literal["sqrt", "quad", "zero"]


class _State(eqx.Module):
class _State(eqx.Module, strict=True):
level: IntScalarLike # level of the tree
s: RealScalarLike # starting time of the interval
w_s_u_su: FloatTriple # W_s, W_u, W_{s,u}
Expand Down Expand Up @@ -109,7 +109,7 @@ def _split_interval(
return x_s, x_u, x_su


class VirtualBrownianTree(AbstractBrownianPath):
class VirtualBrownianTree(AbstractBrownianPath, strict=True):
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.
Can be initialised with `levy_area` set to `""`, or `"space-time"`.
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
LevyArea: TypeAlias = Literal["", "space-time"]


class LevyVal(eqx.Module):
class LevyVal(eqx.Module, strict=True):
dt: PyTree
W: PyTree
H: Optional[PyTree]
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._step_size_controller import AbstractAdaptiveStepSizeController


class AbstractDiscreteTerminatingEvent(eqx.Module):
class AbstractDiscreteTerminatingEvent(eqx.Module, strict=True):
"""Evaluated at the end of each integration step. If true then the solve is stopped
at that time.
"""
Expand All @@ -30,7 +30,7 @@ def __call__(self, state, **kwargs) -> BoolScalarLike:
"""


class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent):
class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent, strict=True):
"""Terminates the solve if its condition is ever active."""

cond_fn: Callable[..., BoolScalarLike]
Expand All @@ -50,7 +50,7 @@ def __call__(self, state, **kwargs):
"""


class SteadyStateEvent(AbstractDiscreteTerminatingEvent):
class SteadyStateEvent(AbstractDiscreteTerminatingEvent, strict=True):
"""Terminates the solve once it reaches a steady state."""

rtol: Optional[float] = None
Expand Down
8 changes: 4 additions & 4 deletions diffrax/_global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ._path import AbstractPath


class AbstractGlobalInterpolation(AbstractPath):
class AbstractGlobalInterpolation(AbstractPath, strict=True):
ts: AbstractVar[Real[Array, " times"]]
ts_size: AbstractVar[IntScalarLike]

Expand Down Expand Up @@ -52,7 +52,7 @@ def t1(self):
return self.ts[-1]


class LinearInterpolation(AbstractGlobalInterpolation):
class LinearInterpolation(AbstractGlobalInterpolation, strict=True):
"""Linearly interpolates some data `ys` over the interval $[t_0, t_1]$ with knots
at `ts`.
Expand Down Expand Up @@ -178,7 +178,7 @@ def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
"""


class CubicInterpolation(AbstractGlobalInterpolation):
class CubicInterpolation(AbstractGlobalInterpolation, strict=True):
"""Piecewise cubic spline interpolation over the interval $[t_0, t_1]$."""

ts: Real[Array, " times"]
Expand Down Expand Up @@ -302,7 +302,7 @@ def derivative(
"""


class DenseInterpolation(AbstractGlobalInterpolation):
class DenseInterpolation(AbstractGlobalInterpolation, strict=True):
ts: Real[Array, " times"]
# DenseInterpolations typically get `ts` and `infos` that are way longer than they
# need to be, and padded with `nan`s. This means the normal way of measuring how
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@
from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm


class SaveState(eqx.Module):
class SaveState(eqx.Module, strict=True):
saveat_ts_index: IntScalarLike
ts: eqxi.MaybeBuffer[Real[Array, " times"]]
ys: PyTree[eqxi.MaybeBuffer[Inexact[Array, "times ..."]]]
save_index: IntScalarLike


class State(eqx.Module):
class State(eqx.Module, strict=True):
# Evolving state during the solve
y: PyTree[Array]
tprev: FloatScalarLike
Expand Down
99 changes: 53 additions & 46 deletions diffrax/_local_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Optional, TYPE_CHECKING

import jax.numpy as jnp
Expand All @@ -17,11 +18,11 @@
from ._path import AbstractPath


class AbstractLocalInterpolation(AbstractPath):
class AbstractLocalInterpolation(AbstractPath, strict=True):
pass


class LocalLinearInterpolation(AbstractLocalInterpolation):
class LocalLinearInterpolation(AbstractLocalInterpolation, strict=True):
t0: RealScalarLike
t1: RealScalarLike
y0: Y
Expand All @@ -39,7 +40,7 @@ def evaluate(
return (coeff * (self.y1**ω - self.y0**ω)).call(jnp.asarray).ω


class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation):
class ThirdOrderHermitePolynomialInterpolation(AbstractLocalInterpolation, strict=True):
t0: RealScalarLike
t1: RealScalarLike
coeffs: PyTree[Shaped[Array, "4 ?*dims"], "Y"]
Expand Down Expand Up @@ -83,46 +84,52 @@ def _eval(_coeffs):
return jtu.tree_map(_eval, self.coeffs)


class FourthOrderPolynomialInterpolation(AbstractLocalInterpolation):
t0: RealScalarLike
t1: RealScalarLike
coeffs: PyTree[Shaped[Array, "5 ?*y"], "Y"]

c_mid: AbstractVar[np.ndarray]

def __init__(
self,
*,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
y1: Y,
k: PyTree[Shaped[Array, "order ?*y"], "Y"],
):
def _calculate(_y0, _y1, _k):
_ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1)
_f0 = _k[0]
_f1 = _k[-1]
# TODO: rewrite as matrix-vector product?
_a = 2 * (_f1 - _f0) - 8 * (_y1 + _y0) + 16 * _ymid
_b = 5 * _f0 - 3 * _f1 + 18 * _y0 + 14 * _y1 - 32 * _ymid
_c = _f1 - 4 * _f0 - 11 * _y0 - 5 * _y1 + 16 * _ymid
return jnp.stack([_a, _b, _c, _f0, _y0])

self.t0 = t0
self.t1 = t1
self.coeffs = jtu.tree_map(_calculate, y0, y1, k)

def evaluate(
self, t0: RealScalarLike, t1: Optional[RealScalarLike] = None, left: bool = True
) -> PyTree[Array]:
del left
if t1 is not None:
return self.evaluate(t1) - self.evaluate(t0)

t = linear_rescale(self.t0, t0, self.t1)

def _eval(_coeffs):
return jnp.polyval(_coeffs, t)

return jtu.tree_map(_eval, self.coeffs)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Abstract")

class FourthOrderPolynomialInterpolation(AbstractLocalInterpolation, strict=True):
t0: RealScalarLike
t1: RealScalarLike
coeffs: PyTree[Shaped[Array, "5 ?*y"], "Y"]

c_mid: AbstractVar[np.ndarray]

def __init__(
self,
*,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
y1: Y,
k: PyTree[Shaped[Array, "order ?*y"], "Y"],
):
def _calculate(_y0, _y1, _k):
_ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1)
_f0 = _k[0]
_f1 = _k[-1]
# TODO: rewrite as matrix-vector product?
_a = 2 * (_f1 - _f0) - 8 * (_y1 + _y0) + 16 * _ymid
_b = 5 * _f0 - 3 * _f1 + 18 * _y0 + 14 * _y1 - 32 * _ymid
_c = _f1 - 4 * _f0 - 11 * _y0 - 5 * _y1 + 16 * _ymid
return jnp.stack([_a, _b, _c, _f0, _y0])

self.t0 = t0
self.t1 = t1
self.coeffs = jtu.tree_map(_calculate, y0, y1, k)

def evaluate(
self,
t0: RealScalarLike,
t1: Optional[RealScalarLike] = None,
left: bool = True,
) -> PyTree[Array]:
del left
if t1 is not None:
return self.evaluate(t1) - self.evaluate(t0)

t = linear_rescale(self.t0, t0, self.t1)

def _eval(_coeffs):
return jnp.polyval(_coeffs, t)

return jtu.tree_map(_eval, self.coeffs)
3 changes: 2 additions & 1 deletion diffrax/_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ._custom_types import RealScalarLike


class AbstractPath(eqx.Module):
class AbstractPath(eqx.Module, strict=True):
"""Abstract base class for all paths.
Every path has a start point `t0` and an end point `t1`. In between these values
Expand Down Expand Up @@ -77,6 +77,7 @@ def evaluate(
The increment of the path between `t0` and `t1`.
"""

@eqx.strict_default_method
def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
r"""Evaluate the derivative of the path. Essentially equivalent
to `jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),))` (and indeed this is its
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_root_finder/_verychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _converged(factor: Scalar, tol: float) -> Bool[Array, ""]:
return (factor > 0) & (factor < tol)


class _VeryChordState(eqx.Module):
class _VeryChordState(eqx.Module, strict=True):
linear_state: tuple[lx.AbstractLinearOperator, PyTree[Any]]
diff: Y
diffsize: Scalar
Expand All @@ -39,7 +39,7 @@ class _VeryChordState(eqx.Module):
step: Scalar


class _NoAux(eqx.Module):
class _NoAux(eqx.Module, strict=True):
fn: Callable

def __call__(self, y, args):
Expand All @@ -48,7 +48,7 @@ def __call__(self, y, args):
return out


class VeryChord(optx.AbstractRootFinder):
class VeryChord(optx.AbstractRootFinder, strict=True):
"""The Chord method of root finding.
As `optimistix.Chord`, except that in Runge--Kutta methods, the linearisation point
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_saveat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _convert_ts(
return jnp.asarray(ts)


class SubSaveAt(eqx.Module):
class SubSaveAt(eqx.Module, strict=True):
"""Used for finer-grained control over what is saved. A PyTree of these should be
passed to `SaveAt(subs=...)`.
Expand Down Expand Up @@ -53,7 +53,7 @@ def __check_init__(self):
"""


class SaveAt(eqx.Module):
class SaveAt(eqx.Module, strict=True):
"""Determines what to save as output from the differential equation solve.
Instances of this class should be passed as the `saveat` argument of
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS:
return RESULTS.where(pred, old_result, out_result)


class Solution(AbstractPath):
class Solution(AbstractPath, strict=True):
"""The solution to a differential equation.
**Attributes:**
Expand Down
Loading

0 comments on commit 775da59

Please sign in to comment.