Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with recomputing quantities on incorrect grid #1006

Merged
merged 9 commits into from
May 2, 2024
36 changes: 29 additions & 7 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,19 +877,31 @@ def compute(

if calc0d and override_grid:
grid0d = QuadratureGrid(self.L_grid, self.M_grid, self.N_grid, self.NFP)
data0d_seed = {
key: data[key]
for key in data
if data_index[p][key]["coordinates"] == ""
}
data0d = compute_fun(
self,
dep0d,
params=params,
transforms=get_transforms(dep0d, obj=self, grid=grid0d, **kwargs),
profiles=get_profiles(dep0d, obj=self, grid=grid0d),
data=None,
# If a dependency of something is already computed, use it
# instead of recomputing it on a potentially bad grid.
data=data0d_seed,
**kwargs,
)
# these should all be 0d quantities so don't need to compress/expand
data0d = {key: val for key, val in data0d.items() if key in dep0d}
data.update(data0d)

data0d_seed = (
{key: data[key] for key in data if data_index[p][key]["coordinates"] == ""}
if ((calc1dr or calc1dz) and override_grid)
else {}
)
unalmis marked this conversation as resolved.
Show resolved Hide resolved
if calc1dr and override_grid:
grid1dr = LinearGrid(
rho=grid.nodes[grid.unique_rho_idx, 0],
Expand All @@ -898,15 +910,20 @@ def compute(
NFP=self.NFP,
sym=self.sym,
)
# TODO: Pass in data0d as a seed once there are 1d quantities that
# depend on 0d quantities in data_index.
data1dr_seed = {
key: grid1dr.copy_data_from_other(data[key], grid, surface_label="rho")
for key in data
if data_index[p][key]["coordinates"] == "r"
}
data1dr = compute_fun(
self,
dep1dr,
params=params,
transforms=get_transforms(dep1dr, obj=self, grid=grid1dr, **kwargs),
profiles=get_profiles(dep1dr, obj=self, grid=grid1dr),
data=None,
# If a dependency of something is already computed, use it
# instead of recomputing it on a potentially bad grid.
data=data1dr_seed | data0d_seed,
**kwargs,
)
# need to make this data broadcast with the data on the original grid
Expand All @@ -925,15 +942,20 @@ def compute(
NFP=grid.NFP, # ex: self.NFP>1 but grid.NFP=1 for plot_3d
sym=self.sym,
)
# TODO: Pass in data0d as a seed once there are 1d quantities that
# depend on 0d quantities in data_index.
data1dz_seed = {
key: grid1dz.copy_data_from_other(data[key], grid, surface_label="zeta")
for key in data
if data_index[p][key]["coordinates"] == "z"
}
data1dz = compute_fun(
self,
dep1dz,
params=params,
transforms=get_transforms(dep1dz, obj=self, grid=grid1dz, **kwargs),
profiles=get_profiles(dep1dz, obj=self, grid=grid1dz),
data=None,
# If a dependency of something is already computed, use it
# instead of recomputing it on a potentially bad grid.
data=data1dz_seed | data0d_seed,
**kwargs,
)
# need to make this data broadcast with the data on the original grid
Expand Down
Binary file modified tests/inputs/master_compute_data.pkl
Binary file not shown.
3 changes: 1 addition & 2 deletions tests/test_axis_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,7 @@ def test_limit_continuity(self):
"alpha_r": {"rtol": 1e-3},
}
zero_map = dict.fromkeys(zero_limits, {"desired_at_axis": 0})
# same as 'weaker_tolerance | zero_limit', but works on Python 3.8 (PEP 584)
kwargs = dict(weaker_tolerance, **zero_map)
kwargs = weaker_tolerance | zero_map
# fixed iota
eq = get("W7-X")
eq.change_resolution(4, 4, 4, 8, 8, 8)
Expand Down
Loading