Skip to content

Commit

Permalink
add jitable to surface compute
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici committed Jul 2, 2024
1 parent 8edeecf commit 3a95922
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
16 changes: 14 additions & 2 deletions desc/geometry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,13 @@ def compute(
if params is None:
params = get_params(names, obj=self)
if transforms is None:
transforms = get_transforms(names, obj=self, grid=grid, **kwargs)
transforms = get_transforms(
names,
obj=self,
grid=grid,
jitable=kwargs.pop("jitable", False),
**kwargs,
)
if data is None:
data = {}
profiles = {}
Expand Down Expand Up @@ -490,7 +496,13 @@ def compute(
self,
dep0d,
params=params,
transforms=get_transforms(dep0d, obj=self, grid=grid0d, **kwargs),
transforms=get_transforms(
dep0d,
obj=self,
grid=grid0d,
jitable=kwargs.pop("jitable", False),
**kwargs,
),
profiles={},
data=None,
**kwargs,
Expand Down
1 change: 1 addition & 0 deletions desc/magnetic_fields/_current_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ def _compute_magnetic_field_from_CurrentPotentialField(
basis="xyz",
params=params,
transforms=transforms,
jitable=True,
)
else:
data = compute_fun(

Check warning on line 660 in desc/magnetic_fields/_current_potential.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_current_potential.py#L660

Added line #L660 was not covered by tests
Expand Down
31 changes: 31 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from desc.equilibrium import Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.grid import LinearGrid
from desc.io import load
from desc.magnetic_fields import FourierCurrentPotentialField
from desc.objectives import (
AspectRatio,
Energy,
Expand All @@ -36,6 +38,7 @@
MeanCurvature,
ObjectiveFunction,
PlasmaVesselDistance,
QuadraticFlux,
QuasisymmetryTripleProduct,
Volume,
get_fixed_boundary_constraints,
Expand Down Expand Up @@ -1318,3 +1321,31 @@ def test_LinearConstraint_jacobian():
np.testing.assert_allclose(vjp_unscaled, vjp1, rtol=1e-12, atol=1e-12)
np.testing.assert_allclose(vjp_unscaled, vjp2, rtol=1e-12, atol=1e-12)
np.testing.assert_allclose(vjp_unscaled, vjp3, rtol=1e-12, atol=1e-12)


@pytest.mark.unit
def test_quad_flux_with_surface_current_field():
"""Test that QuadraticFlux does not throw an error when field has transforms."""
# this happens because in QuadraticFlux.compute, field.compute_magnetic_field
# is called. If the field needs transforms to evaluate, then these transforms
# will be created on the fly if they are not provided, resulting in an error
# unless jitable=True is passed
eq = load("./tests/inputs/vacuum_circular_tokamak.h5")
field = FourierCurrentPotentialField.from_surface(
eq.surface, Phi_mn=[1, 0], modes_Phi=[[0, 0], [1, 1]], M_Phi=1, N_Phi=1
)
obj = ObjectiveFunction(
QuadraticFlux(
eq=eq,
field=field,
vacuum=True,
eval_grid=LinearGrid(M=2, N=2, sym=True),
field_grid=LinearGrid(M=2, N=2),
),
)
constraints = FixParameters(field, {"I": True, "G": True})
opt = Optimizer("lsq-exact")
# this should run without an error
(field_modular_opt,), result = opt.optimize(
field, objective=obj, constraints=constraints, maxiter=1, copy=True
)

0 comments on commit 3a95922

Please sign in to comment.