Skip to content

Commit

Permalink
Modifications to try to train KM models
Browse files Browse the repository at this point in the history
  • Loading branch information
reverendbedford committed Nov 28, 2023
1 parent 52ad388 commit d9212ed
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
35 changes: 32 additions & 3 deletions pyoptmat/chunktime.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def newton_raphson_chunk(
i = 0

while (i < miter) and torch.any(nR > atol) and torch.any(nR / nR0 > rtol):
x -= solver.solve(J, R)
R, J = fn(x)
nR = torch.norm(R, dim=-1)
dx = solver.solve(J, R)
x, R, J, nR, alpha = chunk_linesearch(x, dx, fn, R)
i += 1

if i == miter:
Expand All @@ -60,6 +59,36 @@ def newton_raphson_chunk(

return x

def chunk_linesearch(x, dx, fn, R0, sigma = 2.0, c=1e-3, miter = 10):
"""
Backtracking linesearch for the chunk NR algorithm.
Terminates when the Armijo criteria is reached, or you exceed some maximum iterations
Args:
x (torch.tensor): initial point
dx (torch.tensor): direction
R0 (torch.tensor): initial residual
Keyword Args:
sigma (scalar): decrease factor, i.e. alpha /= sigma
c (scalar): stopping criteria
miter (scalar): maximum iterations
"""
alpha = 1.0
nR0 = torch.norm(R0, dim = -1)
i = 0
while True:
R, J = fn(x - dx * alpha)
nR = torch.max(torch.norm(R, dim = -1)**2.0)
crit = torch.max(nR0**2.0 + 2.0 * c * alpha * torch.einsum('...i,...i', R0, dx))
i += 1
if nR <= crit or i >= miter:
break
alpha /= sigma
return x - dx * alpha, R, J, torch.max(torch.norm(R, dim = -1)), alpha



class BidiagonalOperator(torch.nn.Module):
"""
Expand Down
2 changes: 1 addition & 1 deletion pyoptmat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def RJ(erate):

return R[..., None], J[..., None, None]

erate, _ = solvers.newton_raphson(RJ, erate_guess)
erate, _ = solvers.newton_raphson(RJ, erate_guess, atol = 1.0e-2)
yp = y.clone()
yp[..., 0] = cs
ydot, J, Je, _ = self.model(t, yp, erate[..., 0], cT)
Expand Down
2 changes: 1 addition & 1 deletion pyoptmat/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(
block_size=1,
rtol=1.0e-6,
atol=1.0e-8,
miter=100,
miter=200,
linear_solve_method="direct",
direct_solve_method="thomas",
direct_solve_min_size=0,
Expand Down

0 comments on commit d9212ed

Please sign in to comment.