From d9212ed076e0d5c8803b781eb0fda2a5be2a5096 Mon Sep 17 00:00:00 2001 From: Mark Messner Date: Tue, 28 Nov 2023 14:03:46 -0600 Subject: [PATCH] Modifications to try to train KM models --- pyoptmat/chunktime.py | 35 ++++++++++++++++++++++++++++++++--- pyoptmat/models.py | 2 +- pyoptmat/ode.py | 2 +- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/pyoptmat/chunktime.py b/pyoptmat/chunktime.py index a84e0c1..bd1c161 100644 --- a/pyoptmat/chunktime.py +++ b/pyoptmat/chunktime.py @@ -48,9 +48,8 @@ def newton_raphson_chunk( i = 0 while (i < miter) and torch.any(nR > atol) and torch.any(nR / nR0 > rtol): - x -= solver.solve(J, R) - R, J = fn(x) - nR = torch.norm(R, dim=-1) + dx = solver.solve(J, R) + x, R, J, nR, alpha = chunk_linesearch(x, dx, fn, R) i += 1 if i == miter: @@ -60,6 +59,36 @@ def newton_raphson_chunk( return x +def chunk_linesearch(x, dx, fn, R0, sigma = 2.0, c=1e-3, miter = 10): + """ + Backtracking linesearch for the chunk NR algorithm. + + Terminates when the Armijo criteria is reached, or you exceed some maximum iterations + + Args: + x (torch.tensor): initial point + dx (torch.tensor): direction + R0 (torch.tensor): initial residual + + Keyword Args: + sigma (scalar): decrease factor, i.e. alpha /= sigma + c (scalar): stopping criteria + miter (scalar): maximum iterations + """ + alpha = 1.0 + nR0 = torch.norm(R0, dim = -1) + i = 0 + while True: + R, J = fn(x - dx * alpha) + nR = torch.max(torch.norm(R, dim = -1)**2.0) + crit = torch.max(nR0**2.0 + 2.0 * c * alpha * torch.einsum('...i,...i', R0, dx)) + i += 1 + if nR <= crit or i >= miter: + break + alpha /= sigma + return x - dx * alpha, R, J, torch.max(torch.norm(R, dim = -1)), alpha + + class BidiagonalOperator(torch.nn.Module): """ diff --git a/pyoptmat/models.py b/pyoptmat/models.py index 19d4d1d..a8d9e6a 100644 --- a/pyoptmat/models.py +++ b/pyoptmat/models.py @@ -545,7 +545,7 @@ def RJ(erate): return R[..., None], J[..., None, None] - erate, _ = solvers.newton_raphson(RJ, erate_guess) + erate, _ = solvers.newton_raphson(RJ, erate_guess, atol = 1.0e-2) yp = y.clone() yp[..., 0] = cs ydot, J, Je, _ = self.model(t, yp, erate[..., 0], cT) diff --git a/pyoptmat/ode.py b/pyoptmat/ode.py index cb7989f..0b0b946 100644 --- a/pyoptmat/ode.py +++ b/pyoptmat/ode.py @@ -293,7 +293,7 @@ def __init__( block_size=1, rtol=1.0e-6, atol=1.0e-8, - miter=100, + miter=200, linear_solve_method="direct", direct_solve_method="thomas", direct_solve_min_size=0,