Skip to content

Commit

Permalink
More progress
Browse files Browse the repository at this point in the history
  • Loading branch information
reverendbedford committed Nov 30, 2023
1 parent d00999e commit 205ff6a
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def make(n, eta, s0, R, d, C, g, **kwargs):
scale_scale_priors,
eps,
include_noise=False,
use_cached_guess= True
).to(device)

# 4) Get the guide
Expand Down
2 changes: 1 addition & 1 deletion examples/structural-inference/tension/statistical/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def make(n, eta, s0, R, d, **kwargs):
model = optimize.HierarchicalStatisticalModel(
lambda *args, **kwargs: make(*args, block_size = time_chunk_size, direct_solve_method = linear_solve_method,
**kwargs),
names, loc_loc_priors, loc_scale_priors, scale_scale_priors, eps
names, loc_loc_priors, loc_scale_priors, scale_scale_priors, eps, use_cached_guess = False
).to(device)

# 4) Get the guide
Expand Down
1 change: 0 additions & 1 deletion pyoptmat/chunktime.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def newton_raphson_chunk(
x -= dx
R, J = fn(x)
nR = torch.norm(R, dim = -1)
print(i, torch.max(nR))
i += 1

if i == miter:
Expand Down
33 changes: 22 additions & 11 deletions pyoptmat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class ModelIntegrator(nn.Module):
"""

def __init__(self, model, *args, use_adjoint=True, **kwargs):
def __init__(self, model, *args, use_adjoint=True, bisect_first = False, **kwargs):
super().__init__(*args)
self.model = model
self.use_adjoint = use_adjoint
Expand All @@ -294,6 +294,8 @@ def __init__(self, model, *args, use_adjoint=True, **kwargs):
else:
self.imethod = ode.odeint

self.bisect_first = bisect_first

def solve_both(self, times, temperatures, idata, control):
"""
Solve for either strain or stress control at once
Expand Down Expand Up @@ -327,6 +329,7 @@ def solve_both(self, times, temperatures, idata, control):
base_interpolator,
temperature_interpolator,
control,
bisect_first = self.bisect_first
)

return self.imethod(bmodel, init, times, **self.kwargs_for_integration)
Expand Down Expand Up @@ -435,7 +438,7 @@ class BothBasedModel(nn.Module):
indices: split into strain and stress control
"""

def __init__(self, model, rate_fn, base_fn, T_fn, control, *args, **kwargs):
def __init__(self, model, rate_fn, base_fn, T_fn, control, bisect_first = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = model
self.rate_fn = rate_fn
Expand All @@ -445,7 +448,7 @@ def __init__(self, model, rate_fn, base_fn, T_fn, control, *args, **kwargs):

self.emodel = StrainBasedModel(self.model, self.rate_fn, self.T_fn)
self.smodel = StressBasedModel(
self.model, self.rate_fn, self.base_fn, self.T_fn
self.model, self.rate_fn, self.base_fn, self.T_fn, bisect_first = bisect_first
)

def forward(self, t, y):
Expand Down Expand Up @@ -514,12 +517,15 @@ class StressBasedModel(nn.Module):
T_fn: T(t)
"""

def __init__(self, model, srate_fn, stress_fn, T_fn, *args, **kwargs):
def __init__(self, model, srate_fn, stress_fn, T_fn, min_erate = -1e2, max_erate = 1e3, bisect_first = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = model
self.srate_fn = srate_fn
self.stress_fn = stress_fn
self.T_fn = T_fn
self.min_erate = min_erate
self.max_erate = max_erate
self.bisect_first = bisect_first

def forward(self, t, y):
"""
Expand All @@ -533,28 +539,33 @@ def forward(self, t, y):
cs = self.stress_fn(t)
cT = self.T_fn(t)

erate_guess = torch.zeros_like(y[..., 0])[..., None]

def RJ(erate):
yp = y.clone()
yp[..., 0] = cs
ydot, _, Je, _ = self.model(t, yp, erate[..., 0], cT)
ydot, _, Je, _ = self.model(t, yp, erate, cT)

R = ydot[..., 0] - csr
J = Je[..., 0]

return R[..., None], J[..., None, None]
return R, J

if self.bisect_first:
erate = solvers.scalar_bisection_newton(RJ,
torch.ones_like(y[...,0]) * self.min_erate,
torch.ones_like(y[...,0]) * self.max_erate)
else:
erate = solvers.scalar_newton(RJ,
torch.zeros_like(y[...,0]))

erate, _ = solvers.newton_raphson(RJ, erate_guess, atol = 1.0e-2)
yp = y.clone()
yp[..., 0] = cs
ydot, J, Je, _ = self.model(t, yp, erate[..., 0], cT)
ydot, J, Je, _ = self.model(t, yp, erate, cT)

# Rescale the jacobian
J[..., 0, :] = -J[..., 0, :] / Je[..., 0][..., None]
J[..., :, 0] = 0

# Insert the strain rate
ydot[..., 0] = erate[..., 0]
ydot[..., 0] = erate

return ydot, J
8 changes: 7 additions & 1 deletion pyoptmat/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ class FixedGridBlockSolver:
direct_solve_min_size (int): minimum PCR block size for the hybrid approach
adjoint_params: parameters to track for the adjoint backward pass
guess_type (string): strategy for initial guess, options are "zero" and "previous"
guess_history (torch.tensor): complete load history used for guess, overrides guess_type
throw_on_fail (bool): if true throw an exception if the implicit solve fails,
offset_step (int): use a special, smaller chunk size for the first step
"""
Expand All @@ -302,6 +303,7 @@ def __init__(
direct_solve_min_size=0,
adjoint_params=None,
guess_type="zero",
guess_history=None,
throw_on_fail=False,
offset_step = 0,
**kwargs,
Expand Down Expand Up @@ -348,6 +350,7 @@ def __init__(

# Initial guess for integration
self.guess_type = guess_type
self.guess_history = guess_history

# Throw exception on failed solve
self.throw_on_fail = throw_on_fail
Expand Down Expand Up @@ -380,6 +383,7 @@ def integrate(self, t, cache_adjoint=False):
result[0] = self.y0
incs = self._gen_increments(t)


for k1,k2 in zip(incs[:-1], incs[1:]):
result[k1 : k2] = self.block_update(
t[k1 : k2],
Expand Down Expand Up @@ -420,7 +424,9 @@ def _initial_guess(self, result, k, nchunk):
k (int): current time step
nchunk (int): current chunk size
"""
if self.guess_type == "zero":
if self.guess_history is not None:
guess = self.guess_history[k : k + nchunk] - result[k - 1]
elif self.guess_type == "zero":
guess = torch.zeros_like(result[k : k + nchunk])
elif self.guess_type == "previous":
if k - nchunk - 1 < 0:
Expand Down
47 changes: 40 additions & 7 deletions pyoptmat/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class DeterministicModel(Module):
ics (list(torch.tensor)): initial conditions to use for each parameter
"""

def __init__(self, maker, names, ics):
def __init__(self, maker, names, ics, use_cached_guess = False):
super().__init__()

self.maker = maker
Expand All @@ -136,6 +136,9 @@ def __init__(self, maker, names, ics):
for name, ic in zip(names, ics):
setattr(self, name, Parameter(ic))

self.use_cached_guess = use_cached_guess
self.cached_solution = None

def get_params(self):
"""
Return the parameters for input to the model
Expand All @@ -155,12 +158,18 @@ def forward(self, exp_data, exp_cycles, exp_types, exp_control):
exp_types (torch.tensor): experiment types, as integers
exp_control (torch.tensor): stress/strain control flag
"""
model = self.maker(*self.get_params())
if self.use_cached_guess:
model = self.maker(*self.get_params(), guess_history = self.cached_solution)
else:
model = self.maker(*self.get_params())

predictions = model.solve_both(
exp_data[0], exp_data[1], exp_data[2], exp_control
)

if self.use_cached_guess:
self.cached_solution = predictions.detach().clone()

return experiments.convert_results(predictions[:, :, 0], exp_cycles, exp_types)


Expand Down Expand Up @@ -189,7 +198,7 @@ class StatisticalModel(PyroModule):
entry i represents the noise in test type i
"""

def __init__(self, maker, names, locs, scales, eps, nan_num=False):
def __init__(self, maker, names, locs, scales, eps, nan_num=False, use_cached_guess = False):
super().__init__()

self.maker = maker
Expand All @@ -205,6 +214,9 @@ def __init__(self, maker, names, locs, scales, eps, nan_num=False):

self.nan_num = nan_num

self.use_cached_guess = use_cached_guess
self.cached_solution = None

def get_params(self):
"""
Return the sampled parameters for input to the model
Expand All @@ -229,10 +241,17 @@ def forward(self, exp_data, exp_cycles, exp_types, exp_control, exp_results=None
Keyword Args:
exp_results (torch.tensor): true results for conditioning
"""
model = self.maker(*self.get_params())
if self.use_cached_guess:
model = self.maker(*self.get_params(), guess_history = self.cached_solution)
else:
model = self.maker(*self.get_params())

predictions = model.solve_both(
exp_data[0], exp_data[1], exp_data[2], exp_control
)
if self.use_cached_guess:
self.cached_solution = predictions.detach().clone()

results = experiments.convert_results(
predictions[:, :, 0], exp_cycles, exp_types
)
Expand Down Expand Up @@ -318,6 +337,7 @@ def __init__(
param_suffix="_param",
include_noise=False,
weights=None,
use_cached_guess = False
):
super().__init__()

Expand Down Expand Up @@ -395,6 +415,9 @@ def __init__(
# This annoyance is required to make the adjoint solver work
self.extra_param_names = []

self.use_cached_guess = use_cached_guess
self.cached_solution = None

@property
def nparams(self):
"""
Expand Down Expand Up @@ -531,13 +554,23 @@ def forward(self, exp_data, exp_cycles, exp_types, exp_control, exp_results=None
scale=self._make_weight_tensor(exp_types)
):
# Sample the bottom level parameters
bmodel = self.maker(
*self.sample_bot(), extra_params=self.get_extra_params()
)
if self.use_cached_guess:
bmodel = self.maker(
*self.sample_bot(), extra_params=self.get_extra_params(),
guess_history = self.cached_solution
)
else:
bmodel = self.maker(
*self.sample_bot(), extra_params=self.get_extra_params()
)

# Generate the results
predictions = bmodel.solve_both(
exp_data[0], exp_data[1], exp_data[2], exp_control
)
if self.use_cached_guess:
self.cached_solution = predictions.detach().clone()

# Process the results
results = experiments.convert_results(
predictions[:, :, 0], exp_cycles, exp_types
Expand Down
79 changes: 79 additions & 0 deletions pyoptmat/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,85 @@
import warnings
import torch

def scalar_bisection(fn, a, b, atol = 1.0e-6, miter = 100):
"""
Solve logically scalar equations with bisection
Args:
fn (function): function returning scalar residual and jacobian
a (torch.tensor): lower bound
b (torch.tensor): upper bound
Keyword Args:
atol : absolute tolerance for convergence
miter (int): max number of iterations
"""
Ra, _ = fn(a)
Rb, _ = fn(b)

if not torch.all((torch.sign(Ra) + torch.sign(Rb)) == 0):
raise RuntimeError("Initial values do not bisect in bisection solver")

c = (a+b) / 2.0
Rc, _ = fn(c)

for i in range(miter):
if torch.all(torch.abs(Rc) < atol):
break

ac = torch.sign(Ra) == torch.sign(Rc)
bc = torch.sign(Rb) == torch.sign(Rc)
a[ac] = c[ac]
b[bc] = c[bc]

c = (a+b) / 2.0
Rc, _ = fn((a+b) / 2.0)

return c

def scalar_newton(fn, x0, atol = 1.0e-6, miter = 100):
"""
Solve logically scalar equations with Newton's method
Args:
fn (function): function returning scalar residual and jacobian
x0 (torch.tensor): initial guess
Keyword Args:
atol (float): absolute tolerance for convergence
miter (int): maximum number of iterations
"""
x = x0
R, J = fn(x)

for i in range(miter):
if torch.all(torch.abs(R) < atol):
break

x -= R / J

R, J = fn(x)
else:
warnings.warn("Scalar implicit solve did not succeed. Results may be inaccurate...")

return x

def scalar_bisection_newton(fn, a, b, atol = 1.0e-6, miter = 100, biter = 10):
"""
Solve logically scalar equations by switching from bisection to Newton's method
Args:
fn (function): function returning scalar residual and jacobian
a (torch.tensor): lower bound
b (torch.tensor): upper bound
Keyword Args:
atol : absolute tolerance for convergence
biter: initial number of bisection iterations
miter (int): max number of iterations for Newton's method
"""
x = scalar_bisection(fn, a, b, atol = atol, miter = biter)
return scalar_newton(fn, x, atol = atol, miter = miter)

def newton_raphson_bt(
fn, x0, linsolver="lu", rtol=1e-6, atol=1e-10, miter=100, max_bt=5
Expand Down

0 comments on commit 205ff6a

Please sign in to comment.