From e06af04720fe010347c56e9562419ffaa1246207 Mon Sep 17 00:00:00 2001 From: Mark Messner Date: Fri, 12 Jan 2024 19:48:57 -0600 Subject: [PATCH] Okay, almost done --- pyoptmat/models.py | 14 ++++++++++++-- pyoptmat/solvers.py | 8 +++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pyoptmat/models.py b/pyoptmat/models.py index 13a0f93..c475741 100644 --- a/pyoptmat/models.py +++ b/pyoptmat/models.py @@ -283,10 +283,11 @@ class ModelIntegrator(nn.Module): """ - def __init__(self, model, *args, use_adjoint=True, bisect_first=False, **kwargs): + def __init__(self, model, *args, use_adjoint=True, bisect_first=False, throw_on_scalar_fail = False, **kwargs): super().__init__(*args) self.model = model self.use_adjoint = use_adjoint + self.throw_on_scalar_fail = throw_on_scalar_fail self.kwargs_for_integration = kwargs if self.use_adjoint: @@ -325,6 +326,7 @@ def solve_both(self, times, temperatures, idata, control): temperatures, control, bisect_first=self.bisect_first, + throw_on_scalar_fail = self.throw_on_scalar_fail ) return self.imethod(bmodel, init, times, **self.kwargs_for_integration) @@ -405,6 +407,8 @@ def solve_stress(self, times, stresses, temperatures): stress_rate_interpolator, stress_interpolator, temperature_interpolator, + bisect_first = self.bisect_first, + throw_on_scalar_fail = self.throw_on_scalar_fail ) return self.imethod(smodel, init, times, **self.kwargs_for_integration) @@ -442,6 +446,7 @@ def __init__( temps, control, bisect_first=False, + throw_on_scalar_fail = False, *args, **kwargs ): @@ -473,6 +478,7 @@ def __init__( times[..., self.scontrol], temps[..., self.scontrol] ), bisect_first=bisect_first, + throw_on_scalar_fail = throw_on_scalar_fail ) def forward(self, t, y): @@ -556,6 +562,7 @@ def __init__( max_erate=1e3, guess_erate=1.0e-3, bisect_first=False, + throw_on_scalar_fail = False, *args, **kwargs ): @@ -568,6 +575,7 @@ def __init__( self.max_erate = max_erate self.bisect_first = bisect_first self.guess_erate = guess_erate + self.throw_on_scalar_fail = throw_on_scalar_fail def forward(self, t, y): """ @@ -597,9 +605,11 @@ def RJ(erate): RJ, torch.ones_like(y[..., 0]) * self.min_erate, torch.ones_like(y[..., 0]) * self.max_erate, + throw_on_fail = self.throw_on_scalar_fail ) else: - erate = solvers.scalar_newton(RJ, torch.sign(csr) * self.guess_erate) + erate = solvers.scalar_newton(RJ, torch.sign(csr) * self.guess_erate, + throw_on_fail = self.throw_on_scalar_fail) ydot, J, Je, _ = self.model( t, torch.cat([cs.unsqueeze(-1), y[..., 1:]], dim=-1), erate, cT diff --git a/pyoptmat/solvers.py b/pyoptmat/solvers.py index 8082e2c..96a9600 100644 --- a/pyoptmat/solvers.py +++ b/pyoptmat/solvers.py @@ -45,7 +45,7 @@ def scalar_bisection(fn, a, b, atol=1.0e-6, miter=100): return c -def scalar_newton(fn, x0, atol=1.0e-6, miter=100): +def scalar_newton(fn, x0, atol=1.0e-6, miter=100, throw_on_fail = False): """ Solve logically scalar equations with Newton's method @@ -68,6 +68,8 @@ def scalar_newton(fn, x0, atol=1.0e-6, miter=100): R, J = fn(x) else: + if throw_on_fail: + raise RuntimeError("Scalar solve failed") warnings.warn( "Scalar implicit solve did not succeed. Results may be inaccurate..." ) @@ -75,7 +77,7 @@ def scalar_newton(fn, x0, atol=1.0e-6, miter=100): return x -def scalar_bisection_newton(fn, a, b, atol=1.0e-6, miter=100, biter=10): +def scalar_bisection_newton(fn, a, b, atol=1.0e-6, miter=100, biter=10, throw_on_fail = False): """ Solve logically scalar equations by switching from bisection to Newton's method @@ -90,7 +92,7 @@ def scalar_bisection_newton(fn, a, b, atol=1.0e-6, miter=100, biter=10): miter (int): max number of iterations for Newton's method """ x = scalar_bisection(fn, a, b, atol=atol, miter=biter) - return scalar_newton(fn, x, atol=atol, miter=miter) + return scalar_newton(fn, x, atol=atol, miter=miter, throw_on_fail = throw_on_fail) def newton_raphson_bt(