Skip to content

Commit

Permalink
Okay, almost done
Browse files Browse the repository at this point in the history
  • Loading branch information
reverendbedford committed Jan 13, 2024
1 parent c50b0ec commit e06af04
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
14 changes: 12 additions & 2 deletions pyoptmat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,11 @@ class ModelIntegrator(nn.Module):
"""

def __init__(self, model, *args, use_adjoint=True, bisect_first=False, **kwargs):
def __init__(self, model, *args, use_adjoint=True, bisect_first=False, throw_on_scalar_fail = False, **kwargs):
super().__init__(*args)
self.model = model
self.use_adjoint = use_adjoint
self.throw_on_scalar_fail = throw_on_scalar_fail
self.kwargs_for_integration = kwargs

if self.use_adjoint:
Expand Down Expand Up @@ -325,6 +326,7 @@ def solve_both(self, times, temperatures, idata, control):
temperatures,
control,
bisect_first=self.bisect_first,
throw_on_scalar_fail = self.throw_on_scalar_fail
)

return self.imethod(bmodel, init, times, **self.kwargs_for_integration)
Expand Down Expand Up @@ -405,6 +407,8 @@ def solve_stress(self, times, stresses, temperatures):
stress_rate_interpolator,
stress_interpolator,
temperature_interpolator,
bisect_first = self.bisect_first,
throw_on_scalar_fail = self.throw_on_scalar_fail
)

return self.imethod(smodel, init, times, **self.kwargs_for_integration)
Expand Down Expand Up @@ -442,6 +446,7 @@ def __init__(
temps,
control,
bisect_first=False,
throw_on_scalar_fail = False,
*args,
**kwargs
):
Expand Down Expand Up @@ -473,6 +478,7 @@ def __init__(
times[..., self.scontrol], temps[..., self.scontrol]
),
bisect_first=bisect_first,
throw_on_scalar_fail = throw_on_scalar_fail
)

def forward(self, t, y):
Expand Down Expand Up @@ -556,6 +562,7 @@ def __init__(
max_erate=1e3,
guess_erate=1.0e-3,
bisect_first=False,
throw_on_scalar_fail = False,
*args,
**kwargs
):
Expand All @@ -568,6 +575,7 @@ def __init__(
self.max_erate = max_erate
self.bisect_first = bisect_first
self.guess_erate = guess_erate
self.throw_on_scalar_fail = throw_on_scalar_fail

def forward(self, t, y):
"""
Expand Down Expand Up @@ -597,9 +605,11 @@ def RJ(erate):
RJ,
torch.ones_like(y[..., 0]) * self.min_erate,
torch.ones_like(y[..., 0]) * self.max_erate,
throw_on_fail = self.throw_on_scalar_fail
)
else:
erate = solvers.scalar_newton(RJ, torch.sign(csr) * self.guess_erate)
erate = solvers.scalar_newton(RJ, torch.sign(csr) * self.guess_erate,
throw_on_fail = self.throw_on_scalar_fail)

ydot, J, Je, _ = self.model(
t, torch.cat([cs.unsqueeze(-1), y[..., 1:]], dim=-1), erate, cT
Expand Down
8 changes: 5 additions & 3 deletions pyoptmat/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def scalar_bisection(fn, a, b, atol=1.0e-6, miter=100):
return c


def scalar_newton(fn, x0, atol=1.0e-6, miter=100):
def scalar_newton(fn, x0, atol=1.0e-6, miter=100, throw_on_fail = False):
"""
Solve logically scalar equations with Newton's method
Expand All @@ -68,14 +68,16 @@ def scalar_newton(fn, x0, atol=1.0e-6, miter=100):

R, J = fn(x)
else:
if throw_on_fail:
raise RuntimeError("Scalar solve failed")
warnings.warn(
"Scalar implicit solve did not succeed. Results may be inaccurate..."
)

return x


def scalar_bisection_newton(fn, a, b, atol=1.0e-6, miter=100, biter=10):
def scalar_bisection_newton(fn, a, b, atol=1.0e-6, miter=100, biter=10, throw_on_fail = False):
"""
Solve logically scalar equations by switching from bisection to Newton's method
Expand All @@ -90,7 +92,7 @@ def scalar_bisection_newton(fn, a, b, atol=1.0e-6, miter=100, biter=10):
miter (int): max number of iterations for Newton's method
"""
x = scalar_bisection(fn, a, b, atol=atol, miter=biter)
return scalar_newton(fn, x, atol=atol, miter=miter)
return scalar_newton(fn, x, atol=atol, miter=miter, throw_on_fail = throw_on_fail)


def newton_raphson_bt(
Expand Down

0 comments on commit e06af04

Please sign in to comment.