Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add easier to read print option #1189

Merged
merged 55 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
ee53746
add easier to digest print option
YigitElma Aug 13, 2024
89184d6
add description for values
YigitElma Aug 13, 2024
bdfbcc6
fix out of index issue for test_all_optimizers, make print_value_fmt …
YigitElma Aug 14, 2024
dbbb420
Merge branch 'master' into yge/print
YigitElma Aug 14, 2024
889bc34
add back * to objectives that have custom print_values
YigitElma Aug 14, 2024
04630c3
update print_value method of freeb objectives
YigitElma Aug 14, 2024
511a74f
remove old print, change print_value arg names, clean up
YigitElma Aug 14, 2024
76feaa2
update tests for new print
YigitElma Aug 14, 2024
e89914d
fix test, second print didn't have precompute transforms, add verbose=0
YigitElma Aug 14, 2024
5989ee1
rerun basic_equilibruim tutorial
YigitElma Aug 14, 2024
d5087d9
rerun basic_optimization tutorial, needed to resave a file to resolve…
YigitElma Aug 14, 2024
c7c1423
rerun bootstrap_current tutorial
YigitElma Aug 14, 2024
2445f89
rerun coil_stage_two_optimization tutorial
YigitElma Aug 14, 2024
f24536d
rerun conrinuation_step_by_step tutorial
YigitElma Aug 14, 2024
04a6cde
rerun free_boundary_equilibrium tutorial
YigitElma Aug 14, 2024
ee1fcb6
rerun nae_constraints tutorial
YigitElma Aug 14, 2024
3b2d0cc
rerun omnigeneity tutorial
YigitElma Aug 14, 2024
9a7da3c
rerun use_outputs tutorial
YigitElma Aug 14, 2024
ada9a83
Merge branch 'master' into yge/print
dpanici Aug 14, 2024
a9676d3
Merge branch 'yge/print' of github.com:PlasmaControl/DESC into yge/print
YigitElma Aug 14, 2024
77b184c
remove unwanted files came from cluster
YigitElma Aug 14, 2024
2e79d84
Merge branch 'master' into yge/print
dpanici Aug 15, 2024
c5681da
fix kernel issue, reduce 3d plot resolution for file size
YigitElma Aug 15, 2024
95ffe87
Merge branch 'yge/print' of github.com:PlasmaControl/DESC into yge/print
YigitElma Aug 15, 2024
e4f0fa3
try updating kernels again
YigitElma Aug 15, 2024
07dc49d
extend test to increase coverage
YigitElma Aug 15, 2024
b172d56
add space between units and end value
YigitElma Aug 15, 2024
3531741
delete input file generated by the tutorial
YigitElma Aug 15, 2024
e54e44d
fix the typo in boundary error
YigitElma Aug 15, 2024
afa906d
rerun free boundary tutorial
YigitElma Aug 16, 2024
2ebf016
Ensure notebooks are black formatted
f0uriest Aug 16, 2024
ca8e6a7
Merge branch 'master' into yge/print
f0uriest Aug 16, 2024
cc55816
add divider
YigitElma Aug 18, 2024
d9962fe
make all _print_value_fmt and _units consistent for axis objectives
YigitElma Aug 18, 2024
03ed085
move value part out of the _print_value_fmt to align every value
YigitElma Aug 18, 2024
58bc31d
fix formatting
YigitElma Aug 18, 2024
ff76800
Merge branch 'master' into yge/print
YigitElma Aug 18, 2024
b0a422a
add : for sum of squares
YigitElma Aug 18, 2024
80f50a8
fix the tests with new formatting
YigitElma Aug 18, 2024
880228b
add test for print str width
YigitElma Aug 18, 2024
334c93f
print total initialization time
YigitElma Aug 18, 2024
91d82b5
address PR reviews
YigitElma Aug 18, 2024
43559f0
re run notebooks on cluster
YigitElma Aug 18, 2024
d918f63
re run tutorials
YigitElma Aug 18, 2024
12b540b
update kernel
YigitElma Aug 18, 2024
a814013
Merge branch 'master' into yge/print
f0uriest Aug 19, 2024
594a30f
Merge branch 'master' into yge/print
dpanici Aug 20, 2024
8b5b787
Merge branch 'master' into yge/print
YigitElma Aug 21, 2024
bccf5fd
remove new lines
YigitElma Aug 21, 2024
7bd3df7
Merge branch 'master' into yge/print
YigitElma Aug 21, 2024
9ebb660
Merge branch 'master' into yge/print
YigitElma Aug 22, 2024
9f9e25d
Merge branch 'master' into yge/print
YigitElma Aug 22, 2024
4e04966
Merge remote-tracking branch 'origin' into yge/print
YigitElma Aug 22, 2024
8a87c2b
re run coil stage two tutorial
YigitElma Aug 22, 2024
4e4627e
Merge branch 'master' into yge/print
YigitElma Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 79 additions & 14 deletions desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,15 @@
f = jnp.sum(self.compute_scaled_error(x, constants=constants) ** 2) / 2
return f

def print_value(self, x, constants=None):
def print_value(self, x, x0=None, constants=None):
"""Print the value(s) of the objective.

Parameters
----------
x : ndarray
State vector.
x0 : ndarray, optional
Initial state vector before optimization.
constants : list
Constant parameters passed to sub-objectives.

Expand All @@ -351,12 +353,28 @@
constants = self.constants
if self.compiled and self._compile_mode in {"scalar", "all"}:
f = self.compute_scalar(x, constants=constants)
if x0 is not None:
f0 = self.compute_scalar(x0, constants=constants)

Check warning on line 357 in desc/objectives/objective_funs.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/objective_funs.py#L356-L357

Added lines #L356 - L357 were not covered by tests
else:
f = jnp.sum(self.compute_scaled_error(x, constants=constants) ** 2) / 2
print("Total (sum of squares): {:10.3e}, ".format(f))
if x0 is not None:
f0 = (
jnp.sum(self.compute_scaled_error(x0, constants=constants) ** 2) / 2
)
if x0 is not None:
print("Total (sum of squares): {:10.3e} --> {:10.3e}, ".format(f0, f))
else:
print("Total (sum of squares): {:10.3e}, ".format(f))

Check warning on line 367 in desc/objectives/objective_funs.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/objective_funs.py#L367

Added line #L367 was not covered by tests
params = self.unpack_state(x)
for par, obj, const in zip(params, self.objectives, constants):
obj.print_value(*par, constants=const)
if x0 is not None:
params0 = self.unpack_state(x0)
for par, par0, obj, const in zip(
params, params0, self.objectives, constants
):
obj.print_value(par, par0, constants=const)
else:
for par, obj, const in zip(params, self.objectives, constants):
obj.print_value(*par, constants=const)

Check warning on line 377 in desc/objectives/objective_funs.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/objective_funs.py#L376-L377

Added lines #L376 - L377 were not covered by tests
return None

def unpack_state(self, x, per_objective=True):
Expand Down Expand Up @@ -1075,18 +1093,37 @@
def print_value(self, *args, **kwargs):
"""Print the value of the objective."""
# compute_unscaled is jitted so better to use than than bare compute
f = self.compute_unscaled(*args, **kwargs)
if len(args) == 2:
arg, arg0 = args
f = self.compute_unscaled(*arg, **kwargs)
f0 = self.compute_unscaled(*arg0, **kwargs)
self._print_value_fmt = self._print_value_fmt + " --> {:10.3e}"
else:
arg0 = None
f = self.compute_unscaled(*args, **kwargs)
f0 = None

Check warning on line 1104 in desc/objectives/objective_funs.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/objective_funs.py#L1102-L1104

Added lines #L1102 - L1104 were not covered by tests

if self.linear:
# probably a Fixed* thing, just need to know norm
f = jnp.linalg.norm(self._shift(f))
print(self._print_value_fmt.format(f) + self._units)
f0 = jnp.linalg.norm(self._shift(f0)) if f0 is not None else f

print(self._print_value_fmt.format(f0, f) + self._units)

elif self.scalar:
# dont need min/max/mean of a scalar
print(self._print_value_fmt.format(f.squeeze()) + self._units)
fs = f.squeeze()
f0s = f0.squeeze() if f0 is not None else fs
print(self._print_value_fmt.format(f0s, fs) + self._units)
if self._normalize and self._units != "(dimensionless)":
fs_norm = self._scale(self._shift(f)).squeeze()
f0s_norm = (
self._scale(self._shift(f0)).squeeze()
if f0 is not None
else fs_norm
)
print(
self._print_value_fmt.format(self._scale(self._shift(f)).squeeze())
self._print_value_fmt.format(f0s_norm, fs_norm)
+ "(normalized error)"
)

Expand All @@ -1106,42 +1143,70 @@
fmin = jnp.min(f)
fmean = jnp.mean(f * w) / jnp.mean(w)

if arg0 is not None:
f0 = jnp.abs(f0) if abserr else f0
f0max = jnp.max(f0)
f0min = jnp.min(f0)
f0mean = jnp.mean(f0 * w) / jnp.mean(w)
else:
f0 = f
f0max = fmax
f0min = fmin
f0mean = fmean

Check warning on line 1155 in desc/objectives/objective_funs.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/objective_funs.py#L1152-L1155

Added lines #L1152 - L1155 were not covered by tests

print(
"Maximum "
+ ("absolute " if abserr else "")
+ self._print_value_fmt.format(fmax)
+ self._print_value_fmt.format(f0max, fmax)
+ self._units
)
print(
"Minimum "
+ ("absolute " if abserr else "")
+ self._print_value_fmt.format(fmin)
+ self._print_value_fmt.format(f0min, fmin)
+ self._units
)
print(
"Average "
+ ("absolute " if abserr else "")
+ self._print_value_fmt.format(fmean)
+ self._print_value_fmt.format(f0mean, fmean)
+ self._units
)

if self._normalize and self._units != "(dimensionless)":
if arg0 is not None:
fmax_norm = fmax / jnp.mean(self.normalization)
fmin_norm = fmin / jnp.mean(self.normalization)
fmean_norm = fmean / jnp.mean(self.normalization)

f0max_norm = f0max / jnp.mean(self.normalization)
f0min_norm = f0min / jnp.mean(self.normalization)
f0mean_norm = f0mean / jnp.mean(self.normalization)
else:
f0max_norm = fmax / jnp.mean(self.normalization)
f0min_norm = fmin / jnp.mean(self.normalization)
f0mean_norm = fmean / jnp.mean(self.normalization)

Check warning on line 1188 in desc/objectives/objective_funs.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/objective_funs.py#L1186-L1188

Added lines #L1186 - L1188 were not covered by tests

fmax_norm = jnp.inf
fmin_norm = jnp.inf
fmean_norm = jnp.inf

Check warning on line 1192 in desc/objectives/objective_funs.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/objective_funs.py#L1190-L1192

Added lines #L1190 - L1192 were not covered by tests

print(
"Maximum "
+ ("absolute " if abserr else "")
+ self._print_value_fmt.format(fmax / jnp.mean(self.normalization))
+ self._print_value_fmt.format(f0max_norm, fmax_norm)
+ "(normalized)"
)
print(
"Minimum "
+ ("absolute " if abserr else "")
+ self._print_value_fmt.format(fmin / jnp.mean(self.normalization))
+ self._print_value_fmt.format(f0min_norm, fmin_norm)
+ "(normalized)"
)
print(
"Average "
+ ("absolute " if abserr else "")
+ self._print_value_fmt.format(fmean / jnp.mean(self.normalization))
+ self._print_value_fmt.format(f0mean_norm, fmean_norm)
+ "(normalized)"
)

Expand Down
60 changes: 38 additions & 22 deletions desc/optimize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,29 +349,45 @@
things[ind].params_dict = params

if verbose > 0:
print("Start of solver")
# need to check index of things bc things0 contains copies of
# things, so they are not the same exact Python objects
objective.print_value(
objective.x(*[things0[things.index(t)] for t in objective.things])
)
for con in constraints:
arg_inds_for_this_con = [
things.index(t) for t in things if t in con.things
]
args_for_this_con = [things0[ind] for ind in arg_inds_for_this_con]
con.print_value(*con.xs(*args_for_this_con))

print("End of solver")
objective.print_value(
objective.x(*[things[things.index(t)] for t in objective.things])
state_0 = [things0[things.index(t)] for t in objective.things]
state = [things[things.index(t)] for t in objective.things]

print_pretty = (
YigitElma marked this conversation as resolved.
Show resolved Hide resolved
True # if True print before and end values next to each other
)
for con in constraints:
arg_inds_for_this_con = [
things.index(t) for t in things if t in con.things
]
args_for_this_con = [things[ind] for ind in arg_inds_for_this_con]
con.print_value(*con.xs(*args_for_this_con))

if print_pretty:
print("Start of solver --> End of solver")
objective.print_value(objective.x(*state), objective.x(*state_0))
for con in constraints:
arg_inds_for_this_con = [
things.index(t) for t in things if t in con.things
]
args_for_this_con = [things[ind] for ind in arg_inds_for_this_con]
args0_for_this_con = [things0[ind] for ind in arg_inds_for_this_con]
con.print_value(
con.xs(*args_for_this_con), con.xs(*args0_for_this_con)
)
else:
print("Start of solver")

Check warning on line 372 in desc/optimize/optimizer.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/optimizer.py#L372

Added line #L372 was not covered by tests
# need to check index of things bc things0 contains copies of
# things, so they are not the same exact Python objects
objective.print_value(objective.x(*state_0))
for con in constraints:
arg_inds_for_this_con = [

Check warning on line 377 in desc/optimize/optimizer.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/optimizer.py#L375-L377

Added lines #L375 - L377 were not covered by tests
things.index(t) for t in things if t in con.things
]
args_for_this_con = [things0[ind] for ind in arg_inds_for_this_con]
con.print_value(*con.xs(*args_for_this_con))

Check warning on line 381 in desc/optimize/optimizer.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/optimizer.py#L380-L381

Added lines #L380 - L381 were not covered by tests

print("End of solver")
objective.print_value(objective.x(*state))
for con in constraints:
arg_inds_for_this_con = [

Check warning on line 386 in desc/optimize/optimizer.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/optimizer.py#L383-L386

Added lines #L383 - L386 were not covered by tests
things.index(t) for t in things if t in con.things
]
args_for_this_con = [things[ind] for ind in arg_inds_for_this_con]
con.print_value(*con.xs(*args_for_this_con))

Check warning on line 390 in desc/optimize/optimizer.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/optimizer.py#L389-L390

Added lines #L389 - L390 were not covered by tests

if copy:
# need to swap things and things0, since things should be unchanged
Expand Down
Loading