-
-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature request] differential algebraic equations #62
Comments
In principle, absolutely -- this would be great. Realistically this is something I'm unlikely to add myself in the near future, however. But if this is important to you then I'd be very happy to work with you to put together a PR adding support for this kind of thing. At least for semi-explicit DAEs, most of the support needed from Diffrax should already be present. In fact I think it should be possible to make things worth through the existing Untested and hastily thrown together, a first approach might look something like this. We'll solve the semi-explicit DAE
by first solving
with a single step of an existing solver, and then solving
via a nonlinear solver. (If the For the end user, the code would be used like so. def vector_field(t, y, z__args):
z, args = z__args
...
return dy_dt
def constraint(t, y, z__args):
z, args = z__args
...
return value_that_should_be_zero
term = ConstrainedTerm(ODETerm(vector_field), constraint)
solver = SemiExplicitConstrainedSolver(Kvaerno5())
diffeqsolve(term, solver, ...) # as normal And finally the implementation is as follows. # First just wrap together a term and a constraint
#
# We arrange it so that the `z` component of the DAE is passed through `args` to the
# user-specified vector field and constraint.
class ConstrainedTerm(AbstractTerm):
term: AbstractTerm
constraint: Callable[[Scalar, PyTree, Tuple[PyTree, PyTree]], PyTree]
def vf(self, t, y, args):
y, z = y
return self.term.vf(t, y, (z, args))
def contr(self, t0, t1):
return self.term.contr(t0, t1)
def prod(self, vf, control):
return self.term.prod(vf, control)
def vf_prod(self, t, y, args, control):
y, z = y
return self.term.vf_prod(t, y, (z, args), control)
def constr(self, t, y, args):
y, z = y
return self.constraint(t, y, (z, args))
def _implicit_relation(z1, nonlinear_solve_args):
constraint_fn, t1, y1, args = nonlinear_solve_args
return constraint_fn(t1, (y1, z1), args)
# AbstractWrappedSolver gives us access to self.solver (an ODE/SDE/etc. solver)
# AbstractImplicitSolver gives us access to self.nonlinear_solver
class SemiExplicitConstrainedSolver(AbstractWrappedSolver, AbstractImplicitSolver):
term_structure: jax.tree_structure(0)
interpolation_cls = LocalLinearInterpolation
def order(self, terms):
return self.solver.order(terms)
def strong_order(self, terms):
return self.solver.strong_order(terms)
def error_order(self, terms):
return self.solver.error_order(terms)
def init(self, terms, t0, t1, y0, args):
return self.solver.init(terms, t0, t1, y0, args)
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
y0, z0 = y0
y1, y_error, _, solver_state, result = self.solver.step(terms, t0, t1, y0, args, solver_state, made_jump)
jac = self.nonlinear_solver.jac(_implicit_relation, (terms.constr, t1, y1, args))
nonlinear_sol = self.nonlinear_solver(_implicit_relation, (terms.constr, t1, y1, args), jac)
z1 = nonlinear_sol.root
z_error = jax.tree_map(jnp.zeros_like, z1)
dense_info = dict(y0=(y0, z0), y1=(y1, z1))
result = jnp.maximum(result, nonlinear_sol.result)
return (y1, z1), (y_error, z_error), dense_info, solver_state, result Various comments on this implementation:
|
Thanks for your answer! I'll try to digest it first. |
@patrick-kidger Thanks for the great library. I can make this PR if you'd like to help. I'm in need of this functionality. |
Sure, I'd be happy to see what you come up with. I'm imagining an API looking something like diffeqsolve(
...
constraints=Constraints(
constraints=<pytree of constraint functions>,
z0=<value for extra state>,
... # any other options, e.g. choice of nonlinear solver for the projection step
) FWIW we're currently implementing delay diffeqs over in #169, and I imagine this will have some overlap. (E.g. binding the extra |
Hi, I'm trying to define a new Solver Wrapper just like it is shown in the example above for class CustomSolver(AbstractWrappedSolver, AbstractImplicitSolver):
tree_structure = jax.tree_util.tree_structure(0)
interpolation_cls = ThirdOrderHermitePolynomialInterpolation.from_k # Like the one used for Kvaerno5
def order() ...
def strong_order() ..
def error_order() ...
def init() ...
def step() ...
def func() ... However, I get the following error: ---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File ~/sbml2gpu/venv/lib/python3.10/site-packages/equinox/_better_abstract.py:239, in dataclass.<locals>.make_dataclass(cls)
238 try:
--> 239 annotations = cls.__dict__["__annotations__"]
240 except KeyError:
KeyError: '__annotations__'
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[12], line 1
----> 1 class ExplicitDAESolver(AbstractWrappedSolver, AbstractImplicitSolver):
2 # Use the interpolation scheme from the Kvaerno5 solver
3 # since we will use that and no other solvers.
4 term_structure = jax.tree_util.tree_structure(0)
5 interpolation_cls = ThirdOrderHermitePolynomialInterpolation.from_k
File ~/sbml2gpu/venv/lib/python3.10/site-packages/equinox/_module.py:107, in _ModuleMeta.__new__(mcs, name, bases, dict_)
105 if _init:
106 init_doc = cls.__init__.__doc__
--> 107 cls = dataclass(eq=False, repr=False, frozen=True, init=_init)(
108 cls # pyright: ignore
109 )
110 if _init:
111 cls.__init__.__doc__ = init_doc # pyright: ignore
File ~/sbml2gpu/venv/lib/python3.10/site-packages/equinox/_better_abstract.py:241, in dataclass.<locals>.make_dataclass(cls)
239 annotations = cls.__dict__["__annotations__"]
240 except KeyError:
--> 241 cls = dataclasses.dataclass(**kwargs)(cls)
242 else:
243 new_annotations = dict(annotations)
File /usr/lib/python3.10/dataclasses.py:1176, in dataclass.<locals>.wrap(cls)
1175 def wrap(cls):
-> 1176 return _process_class(cls, init, repr, eq, order, unsafe_hash,
1177 frozen, match_args, kw_only, slots)
File /usr/lib/python3.10/dataclasses.py:1025, in _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots)
1020 if init:
1021 # Does this class have a post-init function?
1022 has_post_init = hasattr(cls, _POST_INIT_NAME)
1024 _set_new_attribute(cls, '__init__',
-> 1025 _init_fn(all_init_fields,
1026 std_init_fields,
1027 kw_only_init_fields,
1028 frozen,
1029 has_post_init,
1030 # The name to use for the "self"
1031 # param in __init__. Use "self"
1032 # if possible.
1033 '__dataclass_self__' if 'self' in fields
1034 else 'self',
1035 globals,
1036 slots,
1037 ))
1039 # Get the fields as a list, and include only real fields. This is
1040 # used in all of the following methods.
1041 field_list = [f for f in fields.values() if f._field_type is _FIELD]
File /usr/lib/python3.10/dataclasses.py:546, in _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init, self_name, globals, slots)
544 seen_default = True
545 elif seen_default:
--> 546 raise TypeError(f'non-default argument {f.name!r} '
547 'follows default argument')
549 locals = {f'_type_{f.name}': f.type for f in fields}
550 locals.update({
551 'MISSING': MISSING,
552 '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
553 })
TypeError: non-default argument 'solver' follows default argument If I remove the parent Any suggestions? |
Hi again, I have found this repository is also useful to me. One thing that would be ideal to have is the differential algebraic equations (DAE) solver (http://www.scholarpedia.org/article/Differential-algebraic_equations), at least the semi-explicit form. Is there a plan to add this in diffrax?
The text was updated successfully, but these errors were encountered: