diff --git a/firedrake/adjoint/ensemble_reduced_functional.py b/firedrake/adjoint/ensemble_reduced_functional.py index a2f9ea7915..d0272d8e09 100644 --- a/firedrake/adjoint/ensemble_reduced_functional.py +++ b/firedrake/adjoint/ensemble_reduced_functional.py @@ -28,7 +28,9 @@ class EnsembleReducedFunctional(ReducedFunctional): operation is employed to sum the functionals and their gradients over an ensemble communicator. - If gather_functional is present, then all the values of J are communicated to all ensemble ranks, and passed in a list to gather_functional, which is a reduced functional that expects a list of that size of the relevant types. + If gather_functional is present, then all the values of J are communicated to all ensemble + ranks, and passed in a list to gather_functional, which is a reduced functional that expects + a list of that size of the relevant types. Parameters ---------- @@ -45,7 +47,33 @@ class EnsembleReducedFunctional(ReducedFunctional): ``Ensemble.ensemble comm``. gather_functional : An instance of the :class:`pyadjoint.ReducedFunctional`. that takes in all of the Js. - + derivative_components : list of int + The indices of the controls that the derivative should be computed with respect to. + If present, it overwrites ``derivative_cb_pre`` and ``derivative_cb_post``. + scale : float + A scaling factor applied to the functional and its gradient(with respect to the control). + tape : pyadjoint.Tape + A tape object that the reduced functional will use to evaluate the functional and + its gradients (or derivatives). + eval_cb_pre : :func: + Callback function before evaluating the functional. Input is a list of Controls. + eval_cb_pos : :func: + Callback function after evaluating the functional. Inputs are the functional value + and a list of Controls. + derivative_cb_pre : :func: + Callback function before evaluating gradients (or derivatives). Input is a list of + gradients (or derivatives). Should return a list of Controls (usually the same list as + the input) to be passed to :func:`pyadjoint.compute_gradient`. + derivative_cb_post : :func: + Callback function after evaluating derivatives. Inputs are the functional, a list of + gradients (or derivatives), and controls. All of them are the checkpointed versions. + Should return a list of gradients (or derivatives) (usually the same list as the input) + to be returned from ``self.derivative``. + hessian_cb_pre : :func: + Callback function before evaluating the Hessian. Input is a list of Controls. + hessian_cb_post : :func: + Callback function after evaluating the Hessian. Inputs are the functional, a list of + Hessian, and controls. See Also -------- @@ -59,8 +87,19 @@ class EnsembleReducedFunctional(ReducedFunctional): `_. """ def __init__(self, J, control, ensemble, scatter_control=True, - gather_functional=None): - super(EnsembleReducedFunctional, self).__init__(J, control) + gather_functional=None, derivative_components=None, + scale=1.0, tape=None, eval_cb_pre=lambda *args: None, + eval_cb_post=lambda *args: None, + derivative_cb_pre=lambda controls: controls, + derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components, + hessian_cb_pre=lambda *args: None, hessian_cb_post=lambda *args: None): + super(EnsembleReducedFunctional, self).__init__( + J, control, derivative_components=derivative_components, + scale=scale, tape=tape, eval_cb_pre=eval_cb_pre, + eval_cb_post=eval_cb_post, derivative_cb_pre=derivative_cb_pre, + derivative_cb_post=derivative_cb_post, + hessian_cb_pre=hessian_cb_pre, hessian_cb_post=hessian_cb_post) + self.ensemble = ensemble self.scatter_control = scatter_control self.gather_functional = gather_functional