Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear v3 #450

Closed
wants to merge 13 commits into from
3 changes: 1 addition & 2 deletions diffrax/_solver/srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ class AbstractSRK(AbstractSolver[_SolverState]):
r"""A general Stochastic Runge-Kutta method.

This accepts `terms` of the form
`MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))` or
`MultiTerm(ODETerm(drift), WeaklyDiagonalControlTerm(diffusion, brownian_motion))`.
`MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))`.
Depending on the solver, the Brownian motion might need to generate
different types of Lévy areas, specified by the `minimal_levy_area` attribute.

Expand Down
36 changes: 35 additions & 1 deletion diffrax/_term.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import abc
import operator
import warnings
from collections.abc import Callable
from typing import cast, Generic, Optional, TypeVar, Union

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax as lx
import numpy as np
from equinox.internal import ω
from jaxtyping import ArrayLike, PyTree, PyTreeDef
Expand Down Expand Up @@ -288,7 +290,8 @@ def to_ode(self) -> ODETerm:
- `vector_field`: A callable representing the vector field. This callable takes three
arguments `(t, y, args)`. `t` is a scalar representing the integration time. `y` is
the evolving state of the system. `args` are any static arguments as passed to
[`diffrax.diffeqsolve`][].
[`diffrax.diffeqsolve`][]. This `vector_field` can be a function that returns a
JAX array, or returns any [lineax `AbstractLinearOperator`](https://docs.kidger.site/lineax/api/operators/#lineax.AbstractLinearOperator).
- `control`: The control. Should either be (A) a [`diffrax.AbstractPath`][], in which
case its `evaluate(t0, t1)` method will be used to give the increment of the control
over a time interval `[t0, t1]`, or (B) a callable `(t0, t1) -> increment`, which
Expand All @@ -310,6 +313,26 @@ class ControlTerm(_AbstractControlTerm[_VF, _Control]):
A common special case is when `y0` and `control` are vector-valued, and
`vector_field` is matrix-valued.

To make a weakly diagonal control term, simply use your vector field
callable return a `lx.DiagonalLinearOperator`.

!!! info

Why "weakly" diagonal? Consider the matrix representation of the vector field,
as a square diagonal matrix. In general, the (i,i)-th element may depending
upon any of the values of `y`. It is only if the (i,i)-th element only depends
upon the i-th element of `y` that the vector field is said to be "diagonal",
without the "weak". (This stronger property is useful in some SDE solvers.)

!!! example

```python
control = UnsafeBrownianPath(shape=(2,), key=...)
vector_field = lambda t, y, args: lx.DiagonalLinearOperator(jnp.ones_like(y))
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(diffusion_term, ...)
```

!!! example

```python
Expand All @@ -335,6 +358,8 @@ class ControlTerm(_AbstractControlTerm[_VF, _Control]):
"""

def prod(self, vf: _VF, control: _Control) -> Y:
if isinstance(vf, lx.AbstractLinearOperator):
return vf.mv(control)
return jtu.tree_map(_prod, vf, control)


Expand All @@ -360,6 +385,15 @@ class WeaklyDiagonalControlTerm(_AbstractControlTerm[_VF, _Control]):
without the "weak". (This stronger property is useful in some SDE solvers.)
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"WeaklyDiagonalControlTerm is pending deprecation and may be removed "
"in future versions. Consider using the new alternative "
"ControlTerm(lx.DiagonalLinearOperator(...)).",
DeprecationWarning,
)
super().__init__(*args, **kwargs)

def prod(self, vf: _VF, control: _Control) -> Y:
with jax.numpy_dtype_promotion("standard"):
return jtu.tree_map(operator.mul, vf, control)
Expand Down
6 changes: 3 additions & 3 deletions docs/api/solvers/sde_solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochast

The type of solver chosen determines how the `terms` argument of `diffeqsolve` should be laid out.

Most solvers handle both ODEs and SDEs in the same way, and expect a single term. So for an ODE you would pass `terms=ODETerm(vector_field)`, and for an SDE you would pass `terms=MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))` or `terms=MultiTerm(ODETerm(drift), WeaklyDiagonalControlTerm(diffusion, brownian_motion))`. For example:
Most solvers handle both ODEs and SDEs in the same way, and expect a single term. So for an ODE you would pass `terms=ODETerm(vector_field)`, and for an SDE you would pass `terms=MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))`. For example:

```python
drift = lambda t, y, args: -y
Expand All @@ -18,7 +18,7 @@ See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochast

For any individual solver then this is documented below, and is also available programatically under `<solver>.term_structure`.

For advanced users, note that we typically accept any `AbstractTerm` for the diffusion, so it could be a custom one that implements more-efficient behaviour for the structure of your diffusion matrix. (Much like how [`diffrax.WeaklyDiagonalControlTerm`][] is more efficient than [`diffrax.ControlTerm`][] for diagonal diffusions.)
For advanced users, note that we typically accept any `AbstractTerm` for the diffusion, so it could be a custom one that implements more-efficient behaviour for the structure of your diffusion matrix.

---

Expand Down Expand Up @@ -52,7 +52,7 @@ These solvers can be used to solve SDEs just as well as they can be used to solv

!!! info "Term structure"

These solvers are SDE-specific. For these, `terms` must specifically be of the form `MultiTerm(ODETerm(...), SomeOtherTerm(...))` (Typically `SomeOTherTerm` will be a `ControlTerm` or `WeaklyDiagonalControlTerm`) representing the drift and diffusion specifically.
These solvers are SDE-specific. For these, `terms` must specifically be of the form `MultiTerm(ODETerm(...), SomeOtherTerm(...))` (Typically `SomeOTherTerm` will be a `ControlTerm` representing the drift and diffusion specifically.


::: diffrax.EulerHeun
Expand Down
8 changes: 1 addition & 7 deletions docs/api/terms.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Each solver is capable of handling certain classes of problems, as described by
---

!!! note
You can create your own terms if appropriate: e.g. if a diffusion matrix has some particular structure, and you want to use a specialised more efficient matrix-vector product algorithm in `prod`. For example this is what [`diffrax.WeaklyDiagonalControlTerm`][] does, as compared to just [`diffrax.ControlTerm`][].
You can create your own terms if appropriate: e.g. if a diffusion matrix has some particular structure, and you want to use a specialised more efficient matrix-vector product algorithm in `prod`.

::: diffrax.ODETerm
selection:
Expand All @@ -57,12 +57,6 @@ Each solver is capable of handling certain classes of problems, as described by
- __init__
- to_ode

::: diffrax.WeaklyDiagonalControlTerm
selection:
members:
- __init__
- to_ode

::: diffrax.MultiTerm
selection:
members:
Expand Down
2 changes: 1 addition & 1 deletion docs/usage/extending.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The main points of extension are as follows:
- **Custom controls** (e.g. **custom interpolation schemes** analogous to [`diffrax.CubicInterpolation`][]) should inherit from [`diffrax.AbstractPath`][].

- **Custom terms** should inherit from [`diffrax.AbstractTerm`][].
- For example, if the vector field - control interaction is a matrix-vector product, but the matrix is known to have special structure, then you may wish to create a custom term that can calculate this interaction more efficiently than is given by a full matrix-vector product. For example this is done with [`diffrax.WeaklyDiagonalControlTerm`][] as compared to [`diffrax.ControlTerm`][].
- For example, if the vector field - control interaction is a matrix-vector product, but the matrix is known to have special structure, then you may wish to create a custom term that can calculate this interaction more efficiently than is given by a full matrix-vector product. Given the large suite of linear operators [lineax](https://docs.kidger.site/lineax/) implements (which are fully supported by [`diffrax.ControlTerm`][]), this is likely rarely necessary.

In each case we recommend looking up existing solvers/etc. in Diffrax to understand how to implement them.

Expand Down
16 changes: 14 additions & 2 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import lineax as lx
import optax
import pytest
from jaxtyping import Array
Expand Down Expand Up @@ -329,7 +330,11 @@ def run(model):
run(mlp)


def test_sde_against(getkey):
@pytest.mark.parametrize(
"diffusion_fn",
["weak", "lineax"],
)
def test_sde_against(diffusion_fn, getkey):
def f(t, y, args):
k0, _ = args
return -k0 * y
Expand All @@ -338,14 +343,21 @@ def g(t, y, args):
_, k1 = args
return k1 * y

def g_lx(t, y, args):
_, k1 = args
return lx.DiagonalLinearOperator(k1 * y)

t0 = 0
t1 = 1
dt0 = 0.001
tol = 1e-5
shape = (2,)
bm = diffrax.VirtualBrownianTree(t0, t1, tol, shape, key=getkey())
drift = diffrax.ODETerm(f)
diffusion = diffrax.WeaklyDiagonalControlTerm(g, bm)
if diffusion_fn == "weak":
diffusion = diffrax.WeaklyDiagonalControlTerm(g, bm)
else:
diffusion = diffrax.ControlTerm(g_lx, bm)
terms = diffrax.MultiTerm(drift, diffusion)
solver = diffrax.Heun()

Expand Down
35 changes: 35 additions & 0 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import lineax as lx
import pytest
import scipy.stats
from diffrax import ControlTerm, MultiTerm, ODETerm
Expand Down Expand Up @@ -644,6 +645,9 @@ class TestSolver(diffrax.AbstractSolver):
"e": diffrax.MultiTerm[
tuple[diffrax.ODETerm, diffrax.AbstractTerm[Any, Float[Array, " 5"]]]
],
"f": diffrax.MultiTerm[
tuple[diffrax.ODETerm, diffrax.AbstractTerm[Any, Float[Array, " 5"]]]
],
}
interpolation_cls = diffrax.LocalLinearInterpolation

Expand Down Expand Up @@ -676,13 +680,21 @@ def func(self, terms, t0, y0, args):
lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(5)
),
),
"f": diffrax.MultiTerm(
ode_term,
diffrax.ControlTerm(
lambda t, y, args: lx.DiagonalLinearOperator(-y),
lambda t0, t1: jnp.array(t1 - t0).repeat(5),
),
),
}
compatible_y0 = {
"a": jnp.array(1.0),
"b": jnp.array(2.0),
"c": jnp.arange(3.0),
"d": jnp.arange(4.0),
"e": jnp.arange(5.0),
"f": jnp.arange(5.0),
}
diffrax.diffeqsolve(compatible_term, solver, 0.0, 1.0, 0.1, compatible_y0)

Expand All @@ -698,6 +710,13 @@ def func(self, terms, t0, y0, args):
lambda t0, t1: t1 - t0, # wrong control shape
),
),
"f": diffrax.MultiTerm(
ode_term,
diffrax.ControlTerm(
lambda t, y, args: lx.DiagonalLinearOperator(-y),
lambda t0, t1: jnp.array(t1 - t0).repeat(5),
),
),
}
incompatible_term2 = {
"a": ode_term,
Expand All @@ -710,6 +729,13 @@ def func(self, terms, t0, y0, args):
lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(3)
),
),
"f": diffrax.MultiTerm(
ode_term,
diffrax.ControlTerm(
lambda t, y, args: lx.DiagonalLinearOperator(-y),
lambda t0, t1: jnp.array(t1 - t0).repeat(5),
),
),
}
incompatible_term3 = {
"a": ode_term,
Expand All @@ -720,6 +746,13 @@ def func(self, terms, t0, y0, args):
"e": diffrax.WeaklyDiagonalControlTerm(
lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(3)
),
"f": diffrax.MultiTerm(
ode_term,
diffrax.ControlTerm(
lambda t, y, args: lx.DiagonalLinearOperator(-y),
lambda t0, t1: jnp.array(t1 - t0).repeat(5),
),
),
}

incompatible_y0_1 = {
Expand All @@ -728,13 +761,15 @@ def func(self, terms, t0, y0, args):
"c": jnp.arange(4.0), # of length 4, not 3
"d": jnp.arange(4.0),
"e": jnp.arange(5.0),
"f": jnp.arange(5.0),
}
incompatible_y0_2 = {
"a": jnp.array(1.0),
"b": jnp.array(2.0),
"c": jnp.arange(3.0),
# Missing "d" piece
"e": jnp.arange(5.0),
"f": jnp.arange(5.0),
}
incompatible_y0_3 = jnp.array(4.0) # Completely the wrong structure!
for term in (
Expand Down
54 changes: 50 additions & 4 deletions test/test_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import lineax as lx
import pytest
from diffrax import ControlTerm, MultiTerm, ODETerm, WeaklyDiagonalControlTerm

Expand Down Expand Up @@ -270,18 +271,63 @@ def _drift(t, y, args):
assert solution.ys.shape == (1, 3)


def _lineax_weakly_diagonal_noise_helper(solver, dtype):
w_shape = (3,)
args = (0.5, 1.2)

def _diffusion(t, y, args):
a, b = args
return lx.DiagonalLinearOperator(jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype))

def _drift(t, y, args):
a, b = args
return -a * y

y0 = jnp.ones(w_shape, dtype)

bm = diffrax.VirtualBrownianTree(
0.0, 1.0, 0.05, w_shape, jr.PRNGKey(0), diffrax.SpaceTimeLevyArea
)

terms = MultiTerm(ODETerm(_drift), ControlTerm(_diffusion, bm))
saveat = diffrax.SaveAt(t1=True)
solution = diffrax.diffeqsolve(
terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat
)
assert solution.ys is not None
assert solution.ys.shape == (1, 3)


@pytest.mark.parametrize("solver_ctr", _solvers())
@pytest.mark.parametrize(
"dtype",
(jnp.float64, jnp.complex128),
)
def test_weakly_diagonal_noise(solver_ctr, dtype):
_weakly_diagonal_noise_helper(solver_ctr(), dtype)
@pytest.mark.parametrize(
"weak_type",
("old", "lineax"),
)
def test_weakly_diagonal_noise(solver_ctr, dtype, weak_type):
if weak_type == "old":
_weakly_diagonal_noise_helper(solver_ctr(), dtype)
elif weak_type == "lineax":
_lineax_weakly_diagonal_noise_helper(solver_ctr(), dtype)
else:
raise ValueError("Invalid weak_type")


@pytest.mark.parametrize(
"dtype",
(jnp.float64, jnp.complex128),
)
def test_halfsolver_term_compatible(dtype):
_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)
@pytest.mark.parametrize(
"weak_type",
("old", "lineax"),
)
def test_halfsolver_term_compatible(dtype, weak_type):
if weak_type == "old":
_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)
elif weak_type == "lineax":
_lineax_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)
else:
raise ValueError("Invalid weak_type")
Loading
Loading