Skip to content

Commit

Permalink
Add all reduced functional arguments. (#3908)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci authored Dec 11, 2024
1 parent be411c5 commit a420cd0
Showing 1 changed file with 43 additions and 4 deletions.
47 changes: 43 additions & 4 deletions firedrake/adjoint/ensemble_reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class EnsembleReducedFunctional(ReducedFunctional):
operation is employed to sum the functionals and their gradients over an ensemble
communicator.
If gather_functional is present, then all the values of J are communicated to all ensemble ranks, and passed in a list to gather_functional, which is a reduced functional that expects a list of that size of the relevant types.
If gather_functional is present, then all the values of J are communicated to all ensemble
ranks, and passed in a list to gather_functional, which is a reduced functional that expects
a list of that size of the relevant types.
Parameters
----------
Expand All @@ -45,7 +47,33 @@ class EnsembleReducedFunctional(ReducedFunctional):
``Ensemble.ensemble comm``.
gather_functional : An instance of the :class:`pyadjoint.ReducedFunctional`.
that takes in all of the Js.
derivative_components : list of int
The indices of the controls that the derivative should be computed with respect to.
If present, it overwrites ``derivative_cb_pre`` and ``derivative_cb_post``.
scale : float
A scaling factor applied to the functional and its gradient(with respect to the control).
tape : pyadjoint.Tape
A tape object that the reduced functional will use to evaluate the functional and
its gradients (or derivatives).
eval_cb_pre : :func:
Callback function before evaluating the functional. Input is a list of Controls.
eval_cb_pos : :func:
Callback function after evaluating the functional. Inputs are the functional value
and a list of Controls.
derivative_cb_pre : :func:
Callback function before evaluating gradients (or derivatives). Input is a list of
gradients (or derivatives). Should return a list of Controls (usually the same list as
the input) to be passed to :func:`pyadjoint.compute_gradient`.
derivative_cb_post : :func:
Callback function after evaluating derivatives. Inputs are the functional, a list of
gradients (or derivatives), and controls. All of them are the checkpointed versions.
Should return a list of gradients (or derivatives) (usually the same list as the input)
to be returned from ``self.derivative``.
hessian_cb_pre : :func:
Callback function before evaluating the Hessian. Input is a list of Controls.
hessian_cb_post : :func:
Callback function after evaluating the Hessian. Inputs are the functional, a list of
Hessian, and controls.
See Also
--------
Expand All @@ -59,8 +87,19 @@ class EnsembleReducedFunctional(ReducedFunctional):
<https://www.firedrakeproject.org/parallelism.html#id8>`_.
"""
def __init__(self, J, control, ensemble, scatter_control=True,
gather_functional=None):
super(EnsembleReducedFunctional, self).__init__(J, control)
gather_functional=None, derivative_components=None,
scale=1.0, tape=None, eval_cb_pre=lambda *args: None,
eval_cb_post=lambda *args: None,
derivative_cb_pre=lambda controls: controls,
derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components,
hessian_cb_pre=lambda *args: None, hessian_cb_post=lambda *args: None):
super(EnsembleReducedFunctional, self).__init__(
J, control, derivative_components=derivative_components,
scale=scale, tape=tape, eval_cb_pre=eval_cb_pre,
eval_cb_post=eval_cb_post, derivative_cb_pre=derivative_cb_pre,
derivative_cb_post=derivative_cb_post,
hessian_cb_pre=hessian_cb_pre, hessian_cb_post=hessian_cb_post)

self.ensemble = ensemble
self.scatter_control = scatter_control
self.gather_functional = gather_functional
Expand Down

0 comments on commit a420cd0

Please sign in to comment.