Skip to content

Commit

Permalink
grad stopping criterion is now inspecting difference between costfact…
Browse files Browse the repository at this point in the history
…ion values of two steps.
  • Loading branch information
LutzGross committed Oct 31, 2024
1 parent 5f77380 commit a7d8272
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions escript/py_src/minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def getLineSearch(self):
"""
return self.__line_search

def setTolerance(self, m_tol=1e-4, grad_tol=1e-8):
def setTolerance(self, m_tol=1e-4, grad_tol=1e-4):
"""
Sets the tolerance for the stopping criterion. The minimizer stops
when an appropriate norm is less than `m_tol`.
Expand Down Expand Up @@ -814,7 +814,7 @@ def setOptions(self, **opts):
:default m_tol: 1e-4
:key grad_tol: tolerance for gradient relative to initial costfunction value for termination of iteration
:type grad_tol: `float`
:default grad_tol: 1e-8
:default grad_tol: 1e-4
:key truncation: sets the number of previous LBFGS iterations to keep
:type truncation : `int`
:default truncation: 30
Expand Down Expand Up @@ -899,8 +899,6 @@ def doCallback(self, **args):
:type Fm: ``float``
:key gradFm: gradient for `m`
:type gradFm: g-type (see ``CostFunction``)
:key norm_gradFm: (estimated) norm of gradient at `m`
:type norm_gradFm: ``float`` or None if iterCount==0
:param args_m: arguments for `m`
:type args_m: ``tuple``
:param failed: set if the step was unsuccessful.
Expand Down Expand Up @@ -959,7 +957,7 @@ def run(self, m):
H_scale = None
if self._restart < self._iterMax+2:
self._restart = self._iterMax * 3
self.logger.info("INFO: Restart is currently disabled.")
self.logger.info("MinimizerLBFGS: Restart is currently disabled.")
self._result = m
args_m = self.getCostFunction().getArgumentsAndCount(m)
grad_Fm = self.getCostFunction().getGradientAndCount(m, *args_m)
Expand All @@ -969,7 +967,7 @@ def run(self, m):
self.logger.info("Initialization completed.")

self.doCallback(iterCount=0, m=m, dm=None, Fm=Fm, grad_Fm=grad_Fm,
norm_m=norm_m, norm_gradFm=None, args_m=args_m, failed=False)
norm_m=norm_m, args_m=args_m, failed=False)

non_curable_break_down = False
converged = False
Expand Down Expand Up @@ -1023,31 +1021,27 @@ def run(self, m):
flag = norm_dm <= mtol_abs
if flag:
self.logger.info("F(m) = %g" % Fm_new)
self.logger.info("Solution has converged: dm=%g, m*m_tol=%g" % (norm_dm, mtol_abs))
self.logger.info("Solution has converged: |m-m_old|=%g < |m|*m_tol=%g" % (norm_dm, mtol_abs))
converged = True
break
else:
self.logger.info("Solution checked: dx=%g, x*m_tol=%g" % (norm_dm, mtol_abs))
self.logger.info("Solution checked: |m-m_old|=%g, |m|*m_tol=%g" % (norm_dm, mtol_abs))
# unfortunately there is more work to do!
if grad_Fm_new is None:
self.logger.debug("Calculating missing gradient.")
args_new = self.getCostFunction().getArgumentsAndCount(m_new)
grad_Fm_new = self.getCostFunction().getGradientAndCount(m_new, *args_new)

Ftol_abs = self._grad_tol * abs(max(abs(Fm), abs(Fm_new)))
gradNorm1 = abs(self.getCostFunction().getDualProductAndCount(m_new, grad_Fm_new))/norm_m_new
gradNorm2 = abs(self.getCostFunction().getDualProductAndCount(delta_m, grad_Fm_new))/norm_dm
gradNorm=max(gradNorm1, gradNorm2)
flag = gradNorm <= Ftol_abs
dFm = abs(Fm - Fm_new)
flag = dFm <= Ftol_abs
if flag:
converged = True
self.logger.info("F(m) = %g" % Fm_new)
self.logger.info("grad Fm = %g, %g" % (gradNorm1, gradNorm2))
self.logger.info("Gradient has converged: grad F=%g, grad_tol=%g" % (gradNorm, Ftol_abs))
self.logger.info("Gradient has converged: |F-Fold|=%g < g_tol*max(|F|,|Fold|)=%g" % (dFm, Ftol_abs))
break
else:
self.logger.info("grad Fm = %g, %g" % (gradNorm1, gradNorm2))
self.logger.info("Gradient checked: grad F=%g, grad_tol=%g" % (gradNorm, Ftol_abs))
self.logger.info("Gradient checked: |F-Fold|=%g, g_tol*max(|F|,|Fold|)=%g" % (dFm, Ftol_abs))

delta_g = grad_Fm_new - grad_Fm
rho = self.getCostFunction().getDualProductAndCount(delta_m, delta_g)
Expand All @@ -1066,7 +1060,7 @@ def run(self, m):
k += 1
iterCount += 1
self.doCallback(iterCount=iterCount, m=m, dm=delta_m, Fm=Fm, grad_Fm=grad_Fm,
norm_m=norm_m, norm_gradFm=gradNorm, args_m=args_m, failed=break_down)
norm_m=norm_m, args_m=args_m, failed=break_down)

# delete oldest vector pair
if k > self._truncation:
Expand Down

0 comments on commit a7d8272

Please sign in to comment.