Releases: patrick-kidger/diffrax
Diffrax v0.6.1
Features
-
Compatibility with JAX 0.4.36.
-
New solvers! Added stochastic Runge--Kutta methods for solving the underdamped Langevin equation. We now have:
diffrax.AbstractFosterLangevinSRK
diffrax.ALIGN
diffrax.QUICSORT
diffrax.ShOULD
and these are used with the corresponding
diffrax.UnderdampedLangevinDriftTerm
diffrax.UnderdampedLangevinDiffusionTerm
huge thanks to @andyElking for carefully implementing all of these, which was a huge technical task. (#453 and 2000 new lines of code!) See the Underdamped Langevin Diffusion example for more on how to use these.
Bugfixes
- If
t0 == t1
and we haveSaveAt(ts=...)
then we now correctly outputlen(ts)
copies ofy0
. (Thanks @dkweiss31! #488, #494) - When using
diffrax.VirtualBrownianTree
on the GPU then floating point fluctuations would sometimes produce evaluations outside of the valid[t0, t1]
region, which would raise a spurious runtime error. This is now fixed. (Thanks @mattlevine22! jax-ml/jax#24807, #524, #526) - Complex fixes in SDEs (Thanks @Randl! #454)
- Improvements to errors, warnings, and some typo fixes (Thanks @lockwo @ddrous! #468, #478, #495, #530)
New Contributors
Full Changelog: v0.6.0...v0.6.1
Diffrax v0.6.0
Features
-
Continuous events! It is now possible to specify a condition at which point the differential equation should halt. For example, here's one finding the time at which a dropped ball hits the ground:
import diffrax import jax.numpy as jnp import optimistix as optx def vector_field(t, y, args): _, v = y return jnp.array([v, -9.81]) def cond_fn(t, y, args, **kwargs): x, _ = y return x term = diffrax.ODETerm(vector_field) solver = diffrax.Tsit5() t0 = 0 t1 = jnp.inf dt0 = 0.1 y0 = jnp.array([10.0, 0.0]) root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) event = diffrax.Event(cond_fn, root_finder) sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) print(f"Event time: {sol.ts[0]}") # Event time: 1.42... print(f"Velocity at event time: {sol.ys[0, 1]}") # Velocity at event time: -14.00...
When
cond_fn
hits zero, the solve stops. Once the event changes sign, then we use Optimistix to do a root find to locate the exact time at which the equation should terminate. Event handling is also fully differentiable.Getting this in was a huge amount of work from @cholberg -- thank you! -- and has been one of our longest-requested features for a while, so I'm really happy to have this in.
(We previously only had 'discrete events', which just terminated at the end of a step, and did not do a root find.)
See the events page in the documentation for more.
-
Simulation of space-time-time Lévy area. This is a higher-order statistic of Brownian motion, used in some advanced SDE solvers. We don't have any such solvers yet, but watch this space... ;)
This was a hugely impressive technical effort from @andyElking. Check out our arXiv paper on the topic, which discusses the technical nitty-gritty of how these statistics can be simulated in an efficient manner.
-
ControlTerm
now supports returning a Lineax linear operator. For example, here's how to easily create a diagonal diffusion term:def vector_field(t, y, args): # y is a JAX array of shape (2,) y1, y2 = y diagonal = jnp.array([y2, y1]) return lineax.DiagonalLinearOperator(diagonal) # corresponds to the matrix [[y2, 0], [0, y1]] diffusion_term = ControlTerm(vector_field, ...)
This should make it much easier to express SDEs with particular structure to their diffusion matrices.
This is particularly good for efficiency reasons: the operator-specified
.mv
(matrix-vector product) method is used, which typically provides a more efficient implementation than that given by filling in some zeros and using a dense matrix-vector product.Thank you to @lockwo for implementing this one!
See the documentation on
ControlTerm
for more.
Deprecations
Two APIs have now been deprecated.
Both of these APIs now have compatibility layers, so existing code should continue to work. However, they will now emit deprecation warnings, and users are encouraged to upgrade. These APIs may be removed at a later date.
-
diffeqsolve(..., discrete_terminating_event=...)
, along with the corresponding classesAbstractDiscreteTerminatingEvent
+DiscreteTerminatingEvent
+SteadyStateEvent
. These have been superseded bydiffeqsolve(..., event=Event(...))
. -
WeaklyDiagonalControlTerm
has been superseded by the new behaviour forControlTerm
, and its interaction with Lineax, as discussed above.
Other
- Now working around an upstream bug introduced in JAX 0.4.29+, so we should be compatible with modern JAX releases.
- No longer emitting warnings coming from JAX deprecating a few old APIs. (We've migrated to the new ones.)
Full Changelog: v0.5.1...v0.6.0
Diffrax v0.5.1
New research paper
One of the new features of this release is the simulation of space-time Lévy area, over arbitrary intervals, deterministically with respect to a PRNG key. This is a required component for adaptive step-size and higher-order SDE solvers in particular.
Well, this turned out to be an unsolved research question! And so with huge credit to @andyElking for diligently figuring out all the details -- we now have a new paper on arXiv discussing all the details! So far as we know this sets a new state-of-the-art for numerical Brownian simulation, and is what now powers all of the numerical SDE solving inside Diffrax.
If you're interested in numerical methods for SDEs, then check out the arxiv paper here.
New features for SDEs
-
Added a suite of Stochastic Runge--Kutta methods! These are higher-order solvers for SDEs, in particular when the noise has a particular form (additive, commutative, ...). A huge thank-you to @andyElking for implementing all of these:
-
GeneralShARK
: recommended when the drift is expensive to evaluate; -
SEA
: recommended when the noise is additive and wanting a solve that is cheap-and-low-accuracy; -
ShARK
: recommended default choice when the noise is additive; -
SlowRK
: recommended for commutative noise; -
SPaRK
: recommended when performing adaptive time stepping; -
SRA1
: alternative toShARK
(this is a now-classical SRK method).
-
-
Added support for simulating space-time Lévy Area to
VirtualBrownianTree
andUnsafeBrownianPath
. This is the bit discussed in the "new research paper" section above! The main thing here is the ability to sample from random variables like space-time Lévy area, which is a doubly-indexed integral of Brownian motion over time:$H_{s,t} = \frac{1}{t-s} \int_s^t ((W_r - W_s) - \frac{r-s}{t-s} (W_t - W_s)) dr$ .
New features for all differential equations
- Added
TextProgressMeter
andTqdmProgressMeter
, which can be used to track how far through a differential equation solve things have progressed. (Thanks @abocquet! #357, #398) - Added support for using adaptive step size controllers on TPUs (Thanks @stefanocortinovis! #366, #369)
- All
AbstractPath
s are nowtyping.Generic
s parameterised by their return type; allAbstractTerm
s are nowtyping.Generic
s parameterised by their vector field and control. (Thanks @tttc3! #359, #364)
Other
- Improved documentation for
PIDController
(Thanks @ParticularlyPythonicBS! #371, #372) - Now have a
py.typed
file to declare that we're static-type-checking compatible. (Thanks @lockwo! #408) - Bugfix for
CubicInterpolation
when there is nothing to interpolate. (Thanks @allen-adastra! #360) - Compatibility with future versions of JAX by removing the now-deprecated
jax.config
import. (Thanks @jakevdp! #377)
New Contributors
- @allen-adastra made their first contribution in #360
- @stefanocortinovis made their first contribution in #369
- @jakevdp made their first contribution in #377
- @ParticularlyPythonicBS made their first contribution in #372
Full Changelog: v0.5.0...v0.5.1
Diffrax v0.5.0
This is a fun release. :)
Diffrax was the very first project I ever released for the JAX ecosystem. Since then, many new libraries have grown up around it -- most notably jaxtyping, Lineax, and Optimistix.
All of these other libraries actually got their start because I wanted to use them for some purpose in Diffrax!
And with this release... we are now finally doing that. Diffrax now depends on jaxtyping for its type annotations, Lineax for linear solves, and Optimistix for root-finding!
That makes this release mostly just a huge internal refactor, so it shouldn't affect you (as a downstream user) very much at all.
Features
- Added
diffrax.VeryChord
, which is a chord-type quasi-Newton method typically used as part of an implicit solver. (This is the most common root-finding method used in implicit differential equation solvers.) - Added
diffrax.with_stepsize_controller_tols
, which can be used to mark that a root-finder should inherit its tolerances from thestepsize_controller
. For example, this is used as:This tolerance-inheritance is the default for all implicit solvers.root_finder = diffrax.with_stepsize_controller_tols(diffrax.VeryChord)() solver = diffrax.Kvaerno5(root_finder=root_finder) diffrax.diffeqsolve(..., solver=solver, ...)
(Previously this tolerance-inheritance business was done by passingrtol/atol=None
to the nonlinear solver -- and again was the default. However now that Optimistix owns the nonlinear solvers, it's up to Diffrax to handle tolerance-inheritance in a slightly different way.) - Added the arguments
diffrax.ImplicitAdjoint(linear_solver=..., tags=...)
. Implicit backpropagation can now be done using any choice of Lineax solver. - Now static-type-checking compatible. No more having your IDE yell at you for incorrect types.
- Diffrax should now be compatible with
JAX_NUMPY_DTYPE_PROMOTION=strict
andJAX_NUMPY_RANK_PROMOTION=raise
. (These are JAX flags that can be used to disable dtype promotion and broadcasting, to help write more reliable code.) diffrax.{ControlTerm, WeaklDiagonalControlTerm}
now support using a callable as theircontrol
, in which case it is treated as theevaluate
of anAbstractPath
over[-inf, inf]
.- Experimental support for complex numbers in explicit solvers. This may still go wrong, so please report bugs / send fixing PRs as you encounter them.
Breaking changes
diffrax.{AbstractNonlinearSolver, NewtonNonlinearSolver, NonlinearSolution}
have been removed in favour of using Optimistix. If you were using these explicitly, e.g.Kvaerno5(nonlinear_solver=NewtonNonlinearSolver(...))
, then the equivalent behaviour is now given byKvaerno5(root_finder=VeryChord(...))
. You can also use any other Optimistix root-finder too.- The
result
of a solve is now an Enumeration rather than a plain integer. For example, this means that you should write something likejnp.where(sol.result == diffrax.RESULTS.successful, ...)
, notjnp.where(sol.result == 0, ...)
. - A great many modules have been renamed from
foo.py
to_foo.py
to explicitly indicate that they're private. Make sure to access features via the public API. - Removed the
AbstractStepSizeController.wrap_solver
method.
Bugfixes
- Crash fix when using an implicit solver together with
DirectAdjoint
. - Crash fix when using
dt0=None, stepsize_controller=diffrax.PIDController(...)
with SDEs. - Crash fix when using
adjoint=BacksolveAdjoint(...)
withVirtualBrownianTree
withjax.disable_jit
on the TPU backend.
New Contributors
- @VIVelev made their first contribution in #298
- @rdaems made their first contribution in #311
- @packquickly made their first contribution in #325
Full Changelog: v0.4.1...v0.5.0
Diffrax v0.4.1
Minor release to fix two bugs, and to introduce a performance improvement.
- Fixed using implicit solvers with closed-over variables. (#258, #284)
- Fix bug that introduced incompatibility with Equinox v0.10.11. (#282, #283, patrick-kidger/equinox#438)
- Optimise StepTo to match performance of a naive scan. (#274, #276)
New Contributors
- @thibmonsel made their first contribution in #283
Full Changelog: v0.4.0...v0.4.1
Diffrax v0.4.0
Features
- Highlight: added IMEX solvers! These solve the "easy" part of the diffeq using an explicit solver, and the "hard" part using an implicit solver. We now have:
diffrax.KenCarp3
diffrax.KenCarp4
diffrax.KenCarp5
diffrax.Sil3
Each of these should be called with e.g.diffeqsolve(terms=MultiTerm(explicit_term, implicit_term), solver=diffrax.KenCarp4(), ...)
diffrax.ImplicitEuler
now supports adaptive time stepping, by using an embedded Heun method. (#251)
Backward incompatibilities
scan_stages
, e.g.Tsit5(scan_stages=True)
, no longer exists. All Runge--Kutta solvers now scan-over-stages by default.- If you were using
scan_stages
, then you should simply delete this argument. - If you were using the interactive API together with forward-mode autodiff then you should pass
scan_kind="bounded"
to the solver, e.g.Tsit5(scan_kind="bounded")
.
- If you were using
Bugfixes
- Fixed
AbstractTerm.vf_prod
being ignored, so that naiveprod(vf(...), control)
calls were being used instead of optimised vf-prod routines, where available. (#239) - Implicit solvers now use the correct stage value predictors. This should help the implicit solvers converge faster, so that overall runtime is decreased. This should mean that they occasionally take a different number of steps than before -- usually fewer.
Performance
- Overall compilation should be faster. (Due to patrick-kidger/equinox#353)
- Initial step size selection should now compile faster. (#257)
- Fixed dense output consuming far too much memory. (#252)
- Backsolve adjoint should now be much more efficient (due to the
vf_prod
bugfix).
Full Changelog: v0.3.1...v0.4.0
Diffrax v0.3.1
See the previous v0.3.0 release notes for the most significant recent changes.
This hotfix
Hotfix for the previous release breaking backprop through SaveAt(dense=True)
.
Full Changelog: v0.3.0...v0.3.1
Diffrax v0.3.0
Highlights
This release is primarily a performance improvement: the default adjoint method now uses an asymptotically more efficient checkpointing implementation.
New features
- Added
diffrax.citation
for automatically generating BibTeX references of the numerical methods being used. diffrax.SaveAt
can now save different selections of outputs at different times, usingdiffrax.SubSaveAt
.diffrax.SaveAt
now supports afn
argument for controlling what to save, e.g. only statistics of the solution. (#113, #221, thanks @joglekara in #220!)- Can now use
SaveAt(dense=True)
in the edge case whent0 == t1
.
Performance improvements
- The default adjoint method
RecursiveCheckpointAdjoint
now uses a dramatically improved implementation for reverse-mode autodifferentiate while loops. This should be asymptotically faster, and generally produce both runtime and compiletime speed-ups.- The previous implementation is available as
DirectAdjoint
. This is still useful in a handful of less-common cases, such as using forward-mode autodifferentiation. (Once JAX gets bounded while loops as native operations then this will be tidied up further.)
- The previous implementation is available as
Backward-incompatible changes
- Removed
NoAdjoint
. It existed as a performance improvement when not using autodifferentiation, butRecursiveCheckpointAdjoint
(the default) has now incorporated this performance improvement automatically. - Removed
ConstantStepSize(compile_steps=...)
andStepTo(compile_steps=...)
, as these are now unnecessarily when using the newRecursiveCheckpointAdjoint
. - Removed the undocumented
Fehlberg2
solver. (It's just not useful compared toHeun
/Midpoint
/Ralston
.) AbstractSolver.term_structure
should now be e.g.(ODETerm, AbstractTerm)
rather thanjtu.tree_structure((ODETerm, AbstractTerm))
, i.e. it now encodes the term type as well.- Dropped support for Python 3.7.
Fixes
- Fixed an upstream change in JAX that was breaking
UnsafeBrownianPath
andVirtualBrownianTree
(#225). - The sum of Runge--Kutta stages now happens in
HIGHEST
precision, which should improve numerical stability on some accelerators.
Examples
- The documentation now has an introductory "coupled ODEs" example.
- The documentation now has an advanced "nonlinear heat PDE" example.
- The "symbolic regression" example has been updated to use sympy2jax.
New Contributors
- @slishak made their first contribution in #198
- @RehMoritz made their first contribution in #215
- @joglekara made their first contribution in #220
Full Changelog: v0.2.2...v0.3.0
Diffrax v0.2.2
Performance improvements
Fixes
- Many documentation improvements.
- Fixed several warnings about
jax.{tree_map,tree_leaves,...}
being moved tojax.tree_util.{tree_map,tree_leaves,...}
. (Thanks @jacobusmmsmit!) - Fixed the step size controller choking if the error is ever NaN. (#143, #152)
- Fixed some crashes due to JAX-internal changes (If you've ever seen it throw an error about not knowing how to rewrite
closed_call_p
, it's this one.) - Fixed an obscure edge-case NaN on the backward pass, if you were using an implicit solver with an adaptive step size controller, got a rejected step due to the implicit solve failing to converge, and happened to also be backpropagating wrt the
controller_state
.
Other
- Added a new Kalman filter example (#159) (Thanks @SimiPixel!)
- Brownian motion classes accept pytrees for shape and dtype arguments (#183) (Thanks @ciupakabra!)
- The main change is an internal refactor: a lot of functionality has moved
diffrax.misc -> equinox.internal
.
New Contributors
- @jacobusmmsmit made their first contribution in #149
- @SimiPixel made their first contribution in #159
- @ciupakabra made their first contribution in #183
Full Changelog: v0.2.1...v0.2.2
Diffrax v0.2.1
Autogenerated release notes as follows:
What's Changed
- Made
is_okay
,is_successful
,is_event
public by @patrick-kidger in #134 - Fix implicit adjoints assuming array-valued state by @patrick-kidger in #136
- Replace jax tree manipulation method that are being deprecated with jax.tree_util equivalents by @mahdi-shafiei in #138
- bump version by @patrick-kidger in #141
New Contributors
- @mahdi-shafiei made their first contribution in #138
Full Changelog: v0.2.0...v0.2.1