-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ensure correct data types in getter methods #1030
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1030 +/- ##
=======================================
Coverage 94.85% 94.85%
=======================================
Files 87 87
Lines 21724 21724
=======================================
Hits 20607 20607
Misses 1117 1117
|
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +0.38 +/- 7.27 | +1.93e-03 +/- 3.67e-02 | 5.07e-01 +/- 3.5e-02 | 5.05e-01 +/- 9.6e-03 |
test_build_transform_fft_midres | -0.19 +/- 1.25 | -1.12e-03 +/- 7.35e-03 | 5.89e-01 +/- 5.5e-03 | 5.90e-01 +/- 4.8e-03 |
test_build_transform_fft_highres | +0.09 +/- 3.14 | +9.25e-04 +/- 3.07e-02 | 9.79e-01 +/- 2.9e-02 | 9.78e-01 +/- 9.9e-03 |
test_equilibrium_init_lowres | +0.26 +/- 0.51 | +9.59e-03 +/- 1.87e-02 | 3.65e+00 +/- 1.5e-02 | 3.64e+00 +/- 1.2e-02 |
test_equilibrium_init_medres | -0.22 +/- 0.84 | -9.24e-03 +/- 3.45e-02 | 4.11e+00 +/- 2.1e-02 | 4.12e+00 +/- 2.7e-02 |
test_equilibrium_init_highres | +0.15 +/- 0.55 | +8.04e-03 +/- 3.00e-02 | 5.52e+00 +/- 2.1e-02 | 5.51e+00 +/- 2.1e-02 |
test_objective_compile_dshape_current | -1.19 +/- 0.76 | -4.56e-02 +/- 2.90e-02 | 3.79e+00 +/- 2.1e-02 | 3.84e+00 +/- 2.1e-02 |
test_objective_compile_atf | -0.67 +/- 2.88 | -5.52e-02 +/- 2.36e-01 | 8.14e+00 +/- 2.1e-01 | 8.20e+00 +/- 1.0e-01 |
test_objective_compute_dshape_current | +9.15 +/- 6.18 | +1.07e-04 +/- 7.21e-05 | 1.27e-03 +/- 5.5e-05 | 1.17e-03 +/- 4.6e-05 |
test_objective_compute_atf | +3.86 +/- 4.94 | +1.56e-04 +/- 1.99e-04 | 4.19e-03 +/- 1.6e-04 | 4.03e-03 +/- 1.1e-04 |
test_objective_jac_dshape_current | -0.07 +/- 9.25 | -2.47e-05 +/- 3.48e-03 | 3.75e-02 +/- 3.1e-03 | 3.76e-02 +/- 1.6e-03 |
test_objective_jac_atf | +0.11 +/- 4.02 | +2.07e-03 +/- 7.46e-02 | 1.86e+00 +/- 5.4e-02 | 1.85e+00 +/- 5.2e-02 |
test_perturb_1 | +0.37 +/- 0.68 | +4.76e-02 +/- 8.80e-02 | 1.30e+01 +/- 1.6e-02 | 1.30e+01 +/- 8.6e-02 |
test_perturb_2 | +0.25 +/- 0.61 | +4.45e-02 +/- 1.08e-01 | 1.79e+01 +/- 9.6e-02 | 1.79e+01 +/- 4.9e-02 |
test_proximal_jac_atf | -0.00 +/- 0.94 | -3.27e-05 +/- 6.88e-02 | 7.30e+00 +/- 4.3e-02 | 7.30e+00 +/- 5.3e-02 |
test_proximal_freeb_compute | -1.55 +/- 0.94 | -2.77e-03 +/- 1.68e-03 | 1.76e-01 +/- 1.2e-03 | 1.79e-01 +/- 1.1e-03 |
test_proximal_freeb_jac | -0.19 +/- 0.99 | -1.43e-02 +/- 7.32e-02 | 7.35e+00 +/- 6.1e-02 | 7.37e+00 +/- 4.1e-02 |
test_solve_fixed_iter | -0.35 +/- 7.55 | -5.21e-02 +/- 1.11e+00 | 1.47e+01 +/- 8.6e-01 | 1.48e+01 +/- 7.1e-01 | |
Post code where error occurs on master to better diagnose this? |
Do we know why the bool ends up being in the pytree in the first place? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change setters not getters
See latest comment in the Issue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While not directly related to the underlying issue for this PR, a lot of the data type errors were coming from the _CoilObjective.build
method. The pytree stuff we had in here was really clunky, like creating a MixedCoilSet
that contained _Grid
s instead of _Coil
s. I tried to simplify the logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've looked at it a little bit but haven't gotten the chance to understand the changes yet
desc/objectives/_coils.py
Outdated
self._grid = tree_unflatten(coil_structure, flattened_grid) | ||
assert isinstance(self._grid, list) | ||
|
||
self._dim_f = np.sum( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't this just be sum(grid.num_nodes for grid in self._grid)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do basically the same for quad_weights
below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I needed to update quad_weights
to match the new logic like dim_f
. Thanks for catching that.
Before, the grids were mapped to the same tree structure as the input coil
. That was nice because it always made sure the structures matched (there was a grid for every coil), but it was also awkward because we were creating a MixedCoilSet
that had Grid
elements instead of Coils
.
My changes now avoid that mapping, and the format for specifying the grids has changed a bit too. Say we have a MixedCoilSet
that contains two CoilSet
s, the first with 2 coils and the second with 3 coils. Before you would give the grid for each individual Coil
like this:
grid = [[grid1, grid2], [grid3, grid4, grid5]]
And now you only give the grids for each CoilSet
like this:
grid = [grid1, grid2]
This is better, because all of the coils in a CoilSet
must have the same grid. Before, we were requiring the user to specify all of these individual grids and then throwing away all of them except for the first one. The downside of my new method is this extra logic to account for all of the coils within each coil set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still necessary or can we go back to the more general pytree calls for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reverted some of my changes to _CoilObjective
but I still needed to make some major revisions. It turns out that the master
branch is not functional with general nested trees (we can't optimize a MixedCoilSet
that contains other MixedCoilSet
s). That is another bug technically different from the original intent of this PR, but I think fixed it here anyways. I also updated the tests to cover more cases like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to the setting stuff is generally fine, ie stuff like
self._sym = bool(sym)
But I'd prefer to get rid of the changes to the getting methods, ie
def sym(self):
return bool(self._sym)
It's basically redundant, but in the off chance that they're inconsistent it can lead to some weird bugs (and major annoyance) if self.sym
and self._sym
have different types for some reason.
(also all of the IO and JAX stuff only cares about the private _sym
attribute. the sym
property is only accessed by the user.)
See also my comment above about changing the default _unjitable
util to treat scalar bools and ints as constants.
My bigger point though is that we should try to avoid assigning static or dynamic attributes by checking data type whenever possible, since its error prone and may fail in weird edge cases like we've seen. The better way I think is to declare the attribute itself as static/dynamic, so it will be handled correctly regardless of the shape and dtype.
desc/objectives/_coils.py
Outdated
grid = [grid] * self._num_coils | ||
if isinstance(grid, list): | ||
grid = tree_leaves(grid, is_leaf=lambda g: isinstance(g, _Grid)) | ||
assert len(grid) == len(coils) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might want to make this a ValueError
with a more descriptive error message?
Or maybe put a separate check after if isinstance(grid, list):
since that's the only case where it could cause problems
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also what if the user passes in different grids for each coil in a regular CoilSet
? I think in that case the code as is now will only use the first grid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re your second question: yes the code will only use the first grid. That is how the existing code also behaves, so this PR is not changing that. My changes earlier this week corrected this, but then I reverted them because my solution didn't work for nested coilsets. We could try to fix this in the future.
np.any([grid.num_rho > 1 or grid.num_theta > 1 for grid in self._grid]), | ||
ValueError, | ||
"Only use toroidal resolution for coil grids.", | ||
lambda c, g: get_transforms(self._data_keys, obj=c, grid=g), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor point: This will get more transforms than needed, since the ones for CoilSet
are later pruned to just the unique ones. Would be nice if we can avoid that redundant calculation but not sure how feasible it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I've resolved this in the latest commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly fine, just see comment about giving a more informative error message when grids/coils don't match.
Resolves #1029