Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma committed Nov 27, 2024
1 parent 2741269 commit 9b3f1bf
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 9 deletions.
16 changes: 15 additions & 1 deletion desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa
A = A_augmented[:, :-1]
b = np.atleast_1d(A_augmented[:, -1].squeeze())

A_nondegenerate = A.copy()

# will store the global index of the unfixed rows, idx
indices_row = np.arange(A.shape[0])
indices_idx = np.arange(A.shape[1])
Expand Down Expand Up @@ -244,7 +246,19 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa
"or be due to floating point error.",
)

return xp, A, b, Z, D, unfixed_idx, project, recover
return (
xp,
A,
b,
Z,
D,
unfixed_idx,
project,
recover,
A_inv,
A_nondegenerate,
row_idx_to_delete,
)


class _Project(IOAble):
Expand Down
36 changes: 35 additions & 1 deletion desc/optimize/_constraint_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def build(self, use_jit=None, verbose=1):
self._unfixed_idx,
self._project,
self._recover,
self._Ainv,
self._A_full_nondegenerate,
self._degenerate_idx, # maybe we need those for b_new
) = factorize_linear_constraints(
self._objective,
self._constraint,
Expand Down Expand Up @@ -164,6 +167,28 @@ def unpack_state(self, x, per_objective=True):
x = self.recover(x)
return self._objective.unpack_state(x, per_objective)

def update_constraint_target(self, eq_new):
"""Update the target of the constraint."""
for con in self._constraint.objectives:
if hasattr(con, "update_target"):
con.update_target(eq_new)

x0 = jnp.zeros(self._constraint.dim_x)
b_new = -self._constraint.compute_scaled_error(x0)
b_new = np.delete(b_new, self._degenerate_idx)
xp_new = jnp.zeros_like(self._xp)
fixed_idx = np.setdiff1d(np.arange(self._xp.size), self._unfixed_idx)
xp_new[fixed_idx] = b_new[fixed_idx]
xp_new[self._unfixed_idx] = self._Ainv @ (
b_new - self._A_full_nondegenerate[:, fixed_idx] @ xp_new[fixed_idx]
)
from desc.objectives.utils import _Project, _Recover

self._project = _Project(self._Z, self._D, xp_new, self._unfixed_idx)
self._recover = _Recover(
self._Z, self._D, xp_new, self._unfixed_idx, self._objective.dim_x
)

def compute_unscaled(self, x_reduced, constants=None):
"""Compute the unscaled form of the objective function.
Expand Down Expand Up @@ -533,7 +558,7 @@ def _set_eq_state_vector(self):
self._args.remove(arg)
linear_constraint = ObjectiveFunction(self._linear_constraints)
linear_constraint.build()
_, _, _, self._Z, self._D, self._unfixed_idx, _, _ = (
(_, _, _, self._Z, self._D, self._unfixed_idx, *_) = (
factorize_linear_constraints(self._constraint, linear_constraint)
)

Expand Down Expand Up @@ -592,6 +617,12 @@ def build(self, use_jit=None, verbose=1): # noqa: C901
for constraint in self._linear_constraints:
constraint.build(use_jit=use_jit, verbose=verbose)

self._eq_solve_objective = LinearConstraintProjection(
self._constraint,
ObjectiveFunction(self._linear_constraints),
)
self._eq_solve_objective.build()

errorif(
self._constraint.things != [eq],
ValueError,
Expand Down Expand Up @@ -759,6 +790,9 @@ def _update_equilibrium(self, x, store=False):
x_dict = x_list[self._eq_idx]
x_dict_old = x_list_old[self._eq_idx]
deltas = {str(key): x_dict[key] - x_dict_old[key] for key in x_dict}
# Add some logic to perturb and solve to take single
# LinearConstraintProjection!
self._eq_solve_objective.update_constraint_target(self._eq)
self._eq = self._eq.perturb(
objective=self._constraint,
constraints=self._linear_constraints,
Expand Down
6 changes: 4 additions & 2 deletions desc/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def perturb( # noqa: C901
if verbose > 0:
print("Factorizing linear constraints")
timer.start("linear constraint factorize")
xp, _, _, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
xp, _, _, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints(
objective, constraint
)
timer.stop("linear constraint factorize")
Expand Down Expand Up @@ -750,7 +750,9 @@ def optimal_perturb( # noqa: C901
con.update_target(eq_new)
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)
_, _, _, _, _, _, _, recover = factorize_linear_constraints(objective_f, constraint)
_, _, _, _, _, _, _, recover, *_ = factorize_linear_constraints(
objective_f, constraint
)

# update other attributes
dx_reduced = dx1_reduced + dx2_reduced
Expand Down
2 changes: 1 addition & 1 deletion desc/vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def load(
constraints = maybe_add_self_consistency(eq, constraints)
objective = ObjectiveFunction(constraints)
objective.build(verbose=0)
_, _, _, _, _, _, project, recover = factorize_linear_constraints(
_, _, _, _, _, _, project, recover, *_ = factorize_linear_constraints(
objective, objective
)
args = objective.unpack_state(recover(project(objective.x(eq))), False)[0]
Expand Down
8 changes: 4 additions & 4 deletions tests/test_linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def test_correct_indexing_passed_modes():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints(
objective, constraint
)

Expand Down Expand Up @@ -508,7 +508,7 @@ def test_correct_indexing_passed_modes_and_passed_target():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints(
objective, constraint
)

Expand Down Expand Up @@ -568,7 +568,7 @@ def test_correct_indexing_passed_modes_axis():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints(
objective, constraint
)

Expand Down Expand Up @@ -697,7 +697,7 @@ def test_correct_indexing_passed_modes_and_passed_target_axis():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints(
objective, constraint
)

Expand Down

0 comments on commit 9b3f1bf

Please sign in to comment.