Skip to content

Commit

Permalink
Docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Dec 8, 2024
1 parent 5a96fcd commit eb30cc7
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions firedrake/adjoint/ensemble_reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class EnsembleReducedFunctional(ReducedFunctional):
operation is employed to sum the functionals and their gradients over an ensemble
communicator.
If gather_functional is present, then all the values of J are communicated to all ensemble ranks, and passed in a list to gather_functional, which is a reduced functional that expects a list of that size of the relevant types.
If gather_functional is present, then all the values of J are communicated to all ensemble
ranks, and passed in a list to gather_functional, which is a reduced functional that expects
a list of that size of the relevant types.
Parameters
----------
Expand All @@ -45,6 +47,24 @@ class EnsembleReducedFunctional(ReducedFunctional):
``Ensemble.ensemble comm``.
gather_functional : An instance of the :class:`pyadjoint.ReducedFunctional`.
that takes in all of the Js.
derivative_components : list of int
The indices of the controls that the derivative should be computed with respect to.
If present, it overwrites ``derivative_cb_pre`` and ``derivative_cb_post``.
scale : float
A scaling factor applied to the functional and its gradient(with respect to the control).
tape : pyadjoint.Tape
A tape object that the reduced functional will use to evaluate the functional and
its gradient (or gradients).
eval_cb_pre : callable
Callback function before evaluating the functional. Input is a list of Controls.
derivative_cb_pre : callable
Callback function before evaluating derivatives. Input is a list of derivatives.
Should return a list of Controls (usually the same list as the input) to be passed
to :func:`pyadjoint.compute_gradient`.
derivative_cb_post : callable
Callback function after evaluating derivatives. Inputs are the functional, the derivative,
and the controls. All of them are the checkpointed versions. Should return a list of
derivatives (usually the same list as the input)to be returned from ``self.derivative``.
See Also
Expand All @@ -59,22 +79,17 @@ class EnsembleReducedFunctional(ReducedFunctional):
<https://www.firedrakeproject.org/parallelism.html#id8>`_.
"""
def __init__(self, J, control, ensemble, scatter_control=True,
gather_functional=None,
derivative_components=None,
scale=1.0, tape=None,
eval_cb_pre=lambda *args: None,
gather_functional=None, derivative_components=None,
scale=1.0, tape=None, eval_cb_pre=lambda *args: None,
eval_cb_post=lambda *args: None,
derivative_cb_pre=lambda controls: controls,
derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components,
hessian_cb_pre=lambda *args: None,
hessian_cb_post=lambda *args: None):
derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components
):
super(EnsembleReducedFunctional, self).__init__(
J, control, derivative_components=derivative_components,
scale=scale, tape=tape, eval_cb_pre=eval_cb_pre,
eval_cb_post=eval_cb_post, derivative_cb_pre=derivative_cb_pre,
derivative_cb_post=derivative_cb_post,
hessian_cb_pre=hessian_cb_pre,
hessian_cb_post=hessian_cb_post)
derivative_cb_post=derivative_cb_post)
self.ensemble = ensemble
self.scatter_control = scatter_control
self.gather_functional = gather_functional
Expand Down

0 comments on commit eb30cc7

Please sign in to comment.