Skip to content

Releases: patrick-kidger/diffrax

Diffrax v0.6.1

09 Dec 11:22
Compare
Choose a tag to compare

Features

  • Compatibility with JAX 0.4.36.

  • New solvers! Added stochastic Runge--Kutta methods for solving the underdamped Langevin equation. We now have:

    • diffrax.AbstractFosterLangevinSRK
    • diffrax.ALIGN
    • diffrax.QUICSORT
    • diffrax.ShOULD

    and these are used with the corresponding

    • diffrax.UnderdampedLangevinDriftTerm
    • diffrax.UnderdampedLangevinDiffusionTerm

    huge thanks to @andyElking for carefully implementing all of these, which was a huge technical task. (#453 and 2000 new lines of code!) See the Underdamped Langevin Diffusion example for more on how to use these.

Bugfixes

  • If t0 == t1 and we have SaveAt(ts=...) then we now correctly output len(ts) copies of y0. (Thanks @dkweiss31! #488, #494)
  • When using diffrax.VirtualBrownianTree on the GPU then floating point fluctuations would sometimes produce evaluations outside of the valid [t0, t1] region, which would raise a spurious runtime error. This is now fixed. (Thanks @mattlevine22! jax-ml/jax#24807, #524, #526)
  • Complex fixes in SDEs (Thanks @Randl! #454)
  • Improvements to errors, warnings, and some typo fixes (Thanks @lockwo @ddrous! #468#478, #495, #530)

New Contributors

Full Changelog: v0.6.0...v0.6.1

Diffrax v0.6.0

01 Jul 09:16
Compare
Choose a tag to compare

Features

  • Continuous events! It is now possible to specify a condition at which point the differential equation should halt. For example, here's one finding the time at which a dropped ball hits the ground:

    import diffrax
    import jax.numpy as jnp
    import optimistix as optx
    
    def vector_field(t, y, args):
        _, v = y
        return jnp.array([v, -9.81])
    
    def cond_fn(t, y, args, **kwargs):
        x, _ = y
        return x
    
    term = diffrax.ODETerm(vector_field)
    solver = diffrax.Tsit5()
    t0 = 0
    t1 = jnp.inf
    dt0 = 0.1
    y0 = jnp.array([10.0, 0.0])
    root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
    event = diffrax.Event(cond_fn, root_finder)
    sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event)
    print(f"Event time: {sol.ts[0]}") # Event time: 1.42...
    print(f"Velocity at event time: {sol.ys[0, 1]}") # Velocity at event time: -14.00...

    When cond_fn hits zero, the solve stops. Once the event changes sign, then we use Optimistix to do a root find to locate the exact time at which the equation should terminate. Event handling is also fully differentiable.

    Getting this in was a huge amount of work from @cholberg -- thank you! -- and has been one of our longest-requested features for a while, so I'm really happy to have this in.

    (We previously only had 'discrete events', which just terminated at the end of a step, and did not do a root find.)

    See the events page in the documentation for more.

  • Simulation of space-time-time Lévy area. This is a higher-order statistic of Brownian motion, used in some advanced SDE solvers. We don't have any such solvers yet, but watch this space... ;)

    This was a hugely impressive technical effort from @andyElking. Check out our arXiv paper on the topic, which discusses the technical nitty-gritty of how these statistics can be simulated in an efficient manner.

  • ControlTerm now supports returning a Lineax linear operator. For example, here's how to easily create a diagonal diffusion term:

    def vector_field(t, y, args):
        # y is a JAX array of shape (2,)
        y1, y2 = y
        diagonal = jnp.array([y2, y1])
        return lineax.DiagonalLinearOperator(diagonal)  # corresponds to the matrix [[y2, 0], [0, y1]]
    
    diffusion_term = ControlTerm(vector_field, ...)

    This should make it much easier to express SDEs with particular structure to their diffusion matrices.

    This is particularly good for efficiency reasons: the operator-specified .mv (matrix-vector product) method is used, which typically provides a more efficient implementation than that given by filling in some zeros and using a dense matrix-vector product.

    Thank you to @lockwo for implementing this one!

    See the documentation on ControlTerm for more.

Deprecations

Two APIs have now been deprecated.

Both of these APIs now have compatibility layers, so existing code should continue to work. However, they will now emit deprecation warnings, and users are encouraged to upgrade. These APIs may be removed at a later date.

  • diffeqsolve(..., discrete_terminating_event=...), along with the corresponding classes AbstractDiscreteTerminatingEvent + DiscreteTerminatingEvent + SteadyStateEvent. These have been superseded by diffeqsolve(..., event=Event(...)).

  • WeaklyDiagonalControlTerm has been superseded by the new behaviour for ControlTerm, and its interaction with Lineax, as discussed above.

Other

  • Now working around an upstream bug introduced in JAX 0.4.29+, so we should be compatible with modern JAX releases.
  • No longer emitting warnings coming from JAX deprecating a few old APIs. (We've migrated to the new ones.)

Full Changelog: v0.5.1...v0.6.0

Diffrax v0.5.1

19 May 15:41
Compare
Choose a tag to compare

New research paper

One of the new features of this release is the simulation of space-time Lévy area, over arbitrary intervals, deterministically with respect to a PRNG key. This is a required component for adaptive step-size and higher-order SDE solvers in particular.

Well, this turned out to be an unsolved research question! And so with huge credit to @andyElking for diligently figuring out all the details -- we now have a new paper on arXiv discussing all the details! So far as we know this sets a new state-of-the-art for numerical Brownian simulation, and is what now powers all of the numerical SDE solving inside Diffrax.

If you're interested in numerical methods for SDEs, then check out the arxiv paper here.

New features for SDEs

  • Added a suite of Stochastic Runge--Kutta methods! These are higher-order solvers for SDEs, in particular when the noise has a particular form (additive, commutative, ...). A huge thank-you to @andyElking for implementing all of these:

    • GeneralShARK: recommended when the drift is expensive to evaluate;
    • SEA: recommended when the noise is additive and wanting a solve that is cheap-and-low-accuracy;
    • ShARK: recommended default choice when the noise is additive;
    • SlowRK: recommended for commutative noise;
    • SPaRK: recommended when performing adaptive time stepping;
    • SRA1: alternative to ShARK (this is a now-classical SRK method).
  • Added support for simulating space-time Lévy Area to VirtualBrownianTree and UnsafeBrownianPath. This is the bit discussed in the "new research paper" section above! The main thing here is the ability to sample from random variables like space-time Lévy area, which is a doubly-indexed integral of Brownian motion over time: $H_{s,t} = \frac{1}{t-s} \int_s^t ((W_r - W_s) - \frac{r-s}{t-s} (W_t - W_s)) dr$.

New features for all differential equations

  • Added TextProgressMeter and TqdmProgressMeter, which can be used to track how far through a differential equation solve things have progressed. (Thanks @abocquet! #357, #398)
  • Added support for using adaptive step size controllers on TPUs (Thanks @stefanocortinovis! #366, #369)
  • All AbstractPaths are now typing.Generics parameterised by their return type; all AbstractTerms are now typing.Generics parameterised by their vector field and control. (Thanks @tttc3! #359, #364)

Other

  • Improved documentation for PIDController (Thanks @ParticularlyPythonicBS! #371, #372)
  • Now have a py.typed file to declare that we're static-type-checking compatible. (Thanks @lockwo! #408)
  • Bugfix for CubicInterpolation when there is nothing to interpolate. (Thanks @allen-adastra! #360)
  • Compatibility with future versions of JAX by removing the now-deprecated jax.config import. (Thanks @jakevdp! #377)

New Contributors

Full Changelog: v0.5.0...v0.5.1

Diffrax v0.5.0

08 Jan 23:23
Compare
Choose a tag to compare

This is a fun release. :)

Diffrax was the very first project I ever released for the JAX ecosystem. Since then, many new libraries have grown up around it -- most notably jaxtyping, Lineax, and Optimistix.

All of these other libraries actually got their start because I wanted to use them for some purpose in Diffrax!

And with this release... we are now finally doing that. Diffrax now depends on jaxtyping for its type annotations, Lineax for linear solves, and Optimistix for root-finding!

That makes this release mostly just a huge internal refactor, so it shouldn't affect you (as a downstream user) very much at all.

Features

  • Added diffrax.VeryChord, which is a chord-type quasi-Newton method typically used as part of an implicit solver. (This is the most common root-finding method used in implicit differential equation solvers.)
  • Added diffrax.with_stepsize_controller_tols, which can be used to mark that a root-finder should inherit its tolerances from the stepsize_controller. For example, this is used as:
    root_finder = diffrax.with_stepsize_controller_tols(diffrax.VeryChord)()
    solver = diffrax.Kvaerno5(root_finder=root_finder)
    diffrax.diffeqsolve(..., solver=solver, ...)
    This tolerance-inheritance is the default for all implicit solvers.
    (Previously this tolerance-inheritance business was done by passing rtol/atol=None to the nonlinear solver -- and again was the default. However now that Optimistix owns the nonlinear solvers, it's up to Diffrax to handle tolerance-inheritance in a slightly different way.)
  • Added the arguments diffrax.ImplicitAdjoint(linear_solver=..., tags=...). Implicit backpropagation can now be done using any choice of Lineax solver.
  • Now static-type-checking compatible. No more having your IDE yell at you for incorrect types.
  • Diffrax should now be compatible with JAX_NUMPY_DTYPE_PROMOTION=strict and JAX_NUMPY_RANK_PROMOTION=raise. (These are JAX flags that can be used to disable dtype promotion and broadcasting, to help write more reliable code.)
  • diffrax.{ControlTerm, WeaklDiagonalControlTerm} now support using a callable as their control, in which case it is treated as the evaluate of an AbstractPath over [-inf, inf].
  • Experimental support for complex numbers in explicit solvers. This may still go wrong, so please report bugs / send fixing PRs as you encounter them.

Breaking changes

  • diffrax.{AbstractNonlinearSolver, NewtonNonlinearSolver, NonlinearSolution} have been removed in favour of using Optimistix. If you were using these explicitly, e.g. Kvaerno5(nonlinear_solver=NewtonNonlinearSolver(...)), then the equivalent behaviour is now given by Kvaerno5(root_finder=VeryChord(...)). You can also use any other Optimistix root-finder too.
  • The result of a solve is now an Enumeration rather than a plain integer. For example, this means that you should write something like jnp.where(sol.result == diffrax.RESULTS.successful, ...), not jnp.where(sol.result == 0, ...).
  • A great many modules have been renamed from foo.py to _foo.py to explicitly indicate that they're private. Make sure to access features via the public API.
  • Removed the AbstractStepSizeController.wrap_solver method.

Bugfixes

  • Crash fix when using an implicit solver together with DirectAdjoint.
  • Crash fix when using dt0=None, stepsize_controller=diffrax.PIDController(...) with SDEs.
  • Crash fix when using adjoint=BacksolveAdjoint(...) with VirtualBrownianTree with jax.disable_jit on the TPU backend.

New Contributors

Full Changelog: v0.4.1...v0.5.0

Diffrax v0.4.1

03 Aug 20:45
Compare
Choose a tag to compare

Minor release to fix two bugs, and to introduce a performance improvement.

New Contributors

Full Changelog: v0.4.0...v0.4.1

Diffrax v0.4.0

22 May 15:33
a5e160a
Compare
Choose a tag to compare

Features

  • Highlight: added IMEX solvers! These solve the "easy" part of the diffeq using an explicit solver, and the "hard" part using an implicit solver. We now have:
    • diffrax.KenCarp3
    • diffrax.KenCarp4
    • diffrax.KenCarp5
    • diffrax.Sil3
      Each of these should be called with e.g. diffeqsolve(terms=MultiTerm(explicit_term, implicit_term), solver=diffrax.KenCarp4(), ...)
  • diffrax.ImplicitEuler now supports adaptive time stepping, by using an embedded Heun method. (#251)

Backward incompatibilities

  • scan_stages, e.g. Tsit5(scan_stages=True), no longer exists. All Runge--Kutta solvers now scan-over-stages by default.
    • If you were using scan_stages, then you should simply delete this argument.
    • If you were using the interactive API together with forward-mode autodiff then you should pass scan_kind="bounded" to the solver, e.g. Tsit5(scan_kind="bounded").

Bugfixes

  • Fixed AbstractTerm.vf_prod being ignored, so that naive prod(vf(...), control) calls were being used instead of optimised vf-prod routines, where available. (#239)
  • Implicit solvers now use the correct stage value predictors. This should help the implicit solvers converge faster, so that overall runtime is decreased. This should mean that they occasionally take a different number of steps than before -- usually fewer.

Performance

  • Overall compilation should be faster. (Due to patrick-kidger/equinox#353)
  • Initial step size selection should now compile faster. (#257)
  • Fixed dense output consuming far too much memory. (#252)
  • Backsolve adjoint should now be much more efficient (due to the vf_prod bugfix).

Full Changelog: v0.3.1...v0.4.0

Diffrax v0.3.1

23 Feb 03:53
16b08d5
Compare
Choose a tag to compare

See the previous v0.3.0 release notes for the most significant recent changes.

This hotfix

Hotfix for the previous release breaking backprop through SaveAt(dense=True).

Full Changelog: v0.3.0...v0.3.1

Diffrax v0.3.0

21 Feb 03:30
9280c3a
Compare
Choose a tag to compare

Highlights

This release is primarily a performance improvement: the default adjoint method now uses an asymptotically more efficient checkpointing implementation.

New features

  • Added diffrax.citation for automatically generating BibTeX references of the numerical methods being used.
  • diffrax.SaveAt can now save different selections of outputs at different times, using diffrax.SubSaveAt.
  • diffrax.SaveAt now supports a fn argument for controlling what to save, e.g. only statistics of the solution. (#113, #221, thanks @joglekara in #220!)
  • Can now use SaveAt(dense=True) in the edge case when t0 == t1.

Performance improvements

  • The default adjoint method RecursiveCheckpointAdjoint now uses a dramatically improved implementation for reverse-mode autodifferentiate while loops. This should be asymptotically faster, and generally produce both runtime and compiletime speed-ups.
    • The previous implementation is available as DirectAdjoint. This is still useful in a handful of less-common cases, such as using forward-mode autodifferentiation. (Once JAX gets bounded while loops as native operations then this will be tidied up further.)

Backward-incompatible changes

  • Removed NoAdjoint. It existed as a performance improvement when not using autodifferentiation, but RecursiveCheckpointAdjoint (the default) has now incorporated this performance improvement automatically.
  • Removed ConstantStepSize(compile_steps=...) and StepTo(compile_steps=...), as these are now unnecessarily when using the new RecursiveCheckpointAdjoint.
  • Removed the undocumented Fehlberg2 solver. (It's just not useful compared to Heun/Midpoint/Ralston.)
  • AbstractSolver.term_structure should now be e.g. (ODETerm, AbstractTerm) rather than jtu.tree_structure((ODETerm, AbstractTerm)), i.e. it now encodes the term type as well.
  • Dropped support for Python 3.7.

Fixes

  • Fixed an upstream change in JAX that was breaking UnsafeBrownianPath and VirtualBrownianTree (#225).
  • The sum of Runge--Kutta stages now happens in HIGHEST precision, which should improve numerical stability on some accelerators.

Examples

  • The documentation now has an introductory "coupled ODEs" example.
  • The documentation now has an advanced "nonlinear heat PDE" example.
  • The "symbolic regression" example has been updated to use sympy2jax.

New Contributors

Full Changelog: v0.2.2...v0.3.0

Diffrax v0.2.2

15 Nov 18:39
ea1bdc9
Compare
Choose a tag to compare

Performance improvements

  • Now make fewer vector field traces in several cases (#172, #174)

Fixes

  • Many documentation improvements.
  • Fixed several warnings about jax.{tree_map,tree_leaves,...} being moved to jax.tree_util.{tree_map,tree_leaves,...}. (Thanks @jacobusmmsmit!)
  • Fixed the step size controller choking if the error is ever NaN. (#143, #152)
  • Fixed some crashes due to JAX-internal changes (If you've ever seen it throw an error about not knowing how to rewrite closed_call_p, it's this one.)
  • Fixed an obscure edge-case NaN on the backward pass, if you were using an implicit solver with an adaptive step size controller, got a rejected step due to the implicit solve failing to converge, and happened to also be backpropagating wrt the controller_state.

Other

  • Added a new Kalman filter example (#159) (Thanks @SimiPixel!)
  • Brownian motion classes accept pytrees for shape and dtype arguments (#183) (Thanks @ciupakabra!)
  • The main change is an internal refactor: a lot of functionality has moved diffrax.misc -> equinox.internal.

New Contributors

Full Changelog: v0.2.1...v0.2.2

Diffrax v0.2.1

03 Aug 22:59
7548c49
Compare
Choose a tag to compare

Autogenerated release notes as follows:

What's Changed

New Contributors

Full Changelog: v0.2.0...v0.2.1