From e3a3bfa39a2706dd7ae67091feca3c1f88a20ba7 Mon Sep 17 00:00:00 2001 From: LutzGross Date: Wed, 6 Nov 2024 14:15:47 +0800 Subject: [PATCH] In line serach the initial alpha can be restricted to avoid overflow in the forward model if one has this information. It saves un-neccessary alpha back tracking in particular in the initial step in BFGS. --- escript/py_src/costfunctions.py | 18 +++++++++++++++++- escript/py_src/minimizer.py | 11 +++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/escript/py_src/costfunctions.py b/escript/py_src/costfunctions.py index 88c478c76f..feb7c2a6fc 100644 --- a/escript/py_src/costfunctions.py +++ b/escript/py_src/costfunctions.py @@ -81,7 +81,6 @@ def getStatistics(self): out="Number of cost function evaluations: %d\n" % self.Value_calls out+="Number of gradient evaluations: %d\n" % self.Gradient_calls out+="Number of inverse Hessian evaluations: %d\n" % self.InverseHessianApproximation_calls - out+="Number of gradient evaluations: %d\n" % self.Gradient_calls out+="Number of inner product evaluations: %d\n" % self.DualProduct_calls out+="Number of argument evaluations: %d\n" % self.Arguments_calls out+="Number of norm evaluations: %d" % self.Norm_calls @@ -264,6 +263,23 @@ def getGradient(self, m, *args): """ raise NotImplementedError + def getSqueezeFactor(self, m, p, *args): + """ + The new solution is calculated as m+a*p with a>0. This function allows to provide an upper bound + for a to make sure that m+a*p is valid typically to avoid overflow when the cost function is evaluated. + the solver will take action to make sure that the value of a is not too small. + + :param m: a solution approximation + :type m: m-type + :param p: an increment to the solution + :type m: m-type + :param args: pre-calculated values for ``m`` from `getArgumentsAndCount()` + :rtype: positive ``float`` or None + + :note: Overwrite this method to implement a cost function. + """ + return None + def getInverseHessianApproximation(self, r, m, *args, initializeHessian = True): """ returns an approximate evaluation *p* of the inverse of the Hessian diff --git a/escript/py_src/minimizer.py b/escript/py_src/minimizer.py index 4ac3e3a03c..027e778758 100644 --- a/escript/py_src/minimizer.py +++ b/escript/py_src/minimizer.py @@ -992,8 +992,15 @@ def run(self, m): else: alphaMin = self._relAlphaMin self.getLineSearch().setOptions(alphaMin=alphaMin) - alpha = max(alpha, alphaMin*1.10) - self.logger.debug("Starting line search with alphaMin, alpha = %g, %g" % (alphaMin, alpha)) + alpha = max(alpha, alphaMin * 1.10) + alpha_s = self.getCostFunction().getSqueezeFactor(m, p) + if not alpha_s is None: + assert alpha_s >0 + self.logger.debug("Safe alpha value given as %g." % (alpha_s,)) + alpha_s*=0.5 + if alpha_s > alphaMin: + alpha = min(alpha, alpha_s) + self.logger.info("Starting line search with alpha = %g."% (alpha, )) try: phi_new = self.getLineSearch().run(phi, alpha) alpha = phi_new.alpha