Skip to content

Commit

Permalink
Store model times in result when path collection enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
chrhansk committed Jul 12, 2024
1 parent 6f222bb commit 82ce740
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 26 deletions.
23 changes: 9 additions & 14 deletions pygradflow/integration/integration_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def perform_integration(
assert (path[-1][:, -1] == curr_z).all()
assert (path[-1][:, -1] == ivp_result.y[:, 0]).all()
path.append(ivp_result.y[:, 1:])
self.path_times.append(ivp_result.t[1:])
path[-1][:, -1] = next_z

dist = np.linalg.norm(
Expand Down Expand Up @@ -364,6 +365,7 @@ def solve(self, x0: Optional[np.ndarray] = None, y0: Optional[np.ndarray] = None
initial_iterate = self.transform.initial_iterate

self.path = [initial_iterate.z[:, None]]
self.path_times = [np.array([0.0])]

print_problem_stats(problem, initial_iterate)

Expand Down Expand Up @@ -476,18 +478,6 @@ def solve(self, x0: Optional[np.ndarray] = None, y0: Optional[np.ndarray] = None
(curr_x, curr_y) = self.flow.split_states(curr_z)
iterate = Iterate(problem, params, curr_x, curr_y)

result_props = dict()

if params.collect_path:
complete_path = np.hstack(self.path)
self.path = None

num_vars = problem.num_vars

result_props["path"] = complete_path
result_props["primal_path"] = complete_path[:num_vars, :]
result_props["dual_path"] = complete_path[num_vars:, :]

x = iterate.x
y = iterate.y
d = iterate.bounds_dual
Expand All @@ -506,7 +496,8 @@ def solve(self, x0: Optional[np.ndarray] = None, y0: Optional[np.ndarray] = None

(x, y, d) = self.transform.restore_sol(x, y, d)

return SolverResult(
solver_result = SolverResult(
problem,
x,
y,
d,
Expand All @@ -515,5 +506,9 @@ def solve(self, x0: Optional[np.ndarray] = None, y0: Optional[np.ndarray] = None
num_accepted_steps=accepted_steps,
total_time=total_time,
dist_factor=dist_factor,
**result_props,
)

if params.collect_path:
solver_result._set_path(np.hstack(self.path), np.hstack(self.path_times))

return solver_result
45 changes: 44 additions & 1 deletion pygradflow/result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

from pygradflow.problem import Problem
from pygradflow.status import SolverStatus


Expand All @@ -11,6 +12,7 @@ class SolverResult:

def __init__(
self,
problem: Problem,
x: np.ndarray,
y: np.ndarray,
d: np.ndarray,
Expand All @@ -21,6 +23,7 @@ def __init__(
dist_factor: float,
**attrs
):
self.problem = problem
self._attrs = attrs

self._x = x
Expand All @@ -32,6 +35,38 @@ def __init__(
self.total_time = total_time
self.dist_factor = dist_factor

def _set_path(self, path, model_times):
self._attrs["path"] = path
self._attrs["model_times"] = model_times

num_vars = self.problem.num_vars
num_cons = self.problem.num_cons

assert model_times.ndim == 1
assert path.shape == (num_vars + num_cons, len(model_times))

self._attrs["primal_path"] = lambda: path[:num_vars]
self._attrs["dual_path"] = lambda: path[num_vars:]

def speed():
return np.linalg.norm(np.diff(self.path, axis=1), axis=0) / np.diff(
model_times
)

def primal_speed():
return np.linalg.norm(np.diff(self.primal_path, axis=1), axis=0) / np.diff(
model_times
)

def dual_speed():
return np.linalg.norm(np.diff(self.dual_path, axis=1), axis=0) / np.diff(
model_times
)

self._attrs["model_speed"] = speed
self._attrs["primal_model_speed"] = primal_speed
self._attrs["dual_model_speed"] = dual_speed

@property
def status(self) -> SolverStatus:
"""
Expand All @@ -41,7 +76,15 @@ def status(self) -> SolverStatus:

def __getattr__(self, name):
attrs = super().__getattribute__("_attrs")
return attrs.get(name, None)
val = attrs.get(name, None)

if val is None:
return val

if callable(val):
return val()

return val

def __setitem__(self, name, value):
self._attrs[name] = value
Expand Down
22 changes: 11 additions & 11 deletions pygradflow/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def solve(

if params.collect_path:
path: Optional[List[np.ndarray]] = [initial_iterate.z]
path_times = [0.0]
else:
path = None

Expand Down Expand Up @@ -331,6 +332,7 @@ def solve(

if path is not None:
path.append(next_iterate.z)
path_times.append(path_times[-1] + (1.0 / lamb))

iterate = next_iterate

Expand Down Expand Up @@ -370,16 +372,8 @@ def solve(

(x, y, d) = self.transform.restore_sol(x, y, d)

result_props = dict()

if path is not None:
complete_path: np.ndarray = np.vstack(path).T
num_vars = problem.num_vars
result_props["path"] = complete_path
result_props["primal_path"] = complete_path[:num_vars, :]
result_props["dual_path"] = complete_path[num_vars:, :]

return SolverResult(
result = SolverResult(
problem,
x,
y,
d,
Expand All @@ -391,5 +385,11 @@ def solve(
final_scaled_obj=iterate.obj,
final_stat_res=iterate.stat_res,
final_cons_violation=iterate.cons_violation,
**result_props,
)

if path is not None:
complete_path = np.vstack(path).T
model_times = np.hstack(path_times)
result._set_path(complete_path, model_times)

return result

0 comments on commit 82ce740

Please sign in to comment.