Skip to content

Commit

Permalink
Merge branch 'master' into rc/examples
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Nov 21, 2024
2 parents b3eac3f + 0cc9c65 commit db1ab5d
Show file tree
Hide file tree
Showing 136 changed files with 981,402 additions and 471 deletions.
1 change: 1 addition & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.9.2
- name: Benchmark with pytest-benchmark (PR)
if: env.has_changes == 'true'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/cache_dependencies.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.7.2
pip install matplotlib==3.9.2
- name: Cache Python environment
id: cache-env
Expand Down
35 changes: 35 additions & 0 deletions .github/workflows/changelog_update.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Check changelog updated

on:
pull_request:
branches:
- master
types: [opened, synchronize, labeled, unlabeled]

jobs:
check_changelog_updated:
runs-on: ubuntu-latest
steps:
- name: Filter changes
id: changes
uses: dorny/paths-filter@v3
with:
filters: |
has_changes:
- 'desc/**'
- 'requirements.txt'
- 'requirements_conda.yml'
- '.github/workflows/changelog_update.yml'
- name: Check for relevant changes
id: check_changes
run: echo "has_changes=${{ steps.changes.outputs.has_changes }}" >> $GITHUB_ENV

- uses: actions/checkout@v4

- if: ${{ !contains(github.event.pull_request.labels.*.name, 'skip_changelog') && env.has_changes == 'true'}}
uses: danieljimeneznz/[email protected]
with:
require-changes-to: |
CHANGELOG.md
token: ${{ secrets.GITHUB_TOKEN }}
1 change: 0 additions & 1 deletion .github/workflows/jax_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ jobs:
sed -i '1i\jax[cpu] == ${{ matrix.jax-version }}' ./requirements.txt
cat ./requirements.txt
pip install -r ./devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Verify dependencies
run: |
python --version
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/notebook_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.7.2
pip install matplotlib==3.9.2
- name: Test notebooks with pytest and nbmake
if: env.has_changes == 'true'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/regression_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.7.2
pip install matplotlib==3.9.2
- name: Set Swap Space
if: env.has_changes == 'true'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
source .venv-${{ matrix.combos.python_version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.7.2
pip install matplotlib==3.9.2
- name: Set Swap Space
if: env.has_changes == 'true'
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/weekly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.9.2
- name: Set Swap Space
uses: pierotofy/set-swap-space@master
with:
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ New Features
Bug Fixes

- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version
- Changes ``FixLambdaGauge`` constraint to now enforce zero flux surface average for lambda, instead of enforcing lambda(rho,0,0)=0 as it was incorrectly doing before.
- Fixes bug in ``softmin/softmax`` implementation.


v0.12.3
Expand Down
6 changes: 3 additions & 3 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
)

if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign?
if use_jax: # noqa: C901
from jax import custom_jvp, jit, vmap

imap = jax.lax.map
Expand All @@ -76,7 +76,7 @@
from jax.numpy.fft import irfft, rfft, rfft2
from jax.scipy.fft import dct, idct
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.scipy.special import gammaln, logsumexp
from jax.scipy.special import gammaln
from jax.tree_util import (
register_pytree_node,
tree_flatten,
Expand Down Expand Up @@ -445,7 +445,7 @@ def tangent_solve(g, y):
qr,
solve_triangular,
)
from scipy.special import gammaln, logsumexp # noqa: F401
from scipy.special import gammaln # noqa: F401
from scipy.special import softmax as softargmax # noqa: F401

trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz
Expand Down
9 changes: 5 additions & 4 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,13 +1098,13 @@ def evaluate(
if not len(modes):
return np.array([]).reshape((len(nodes), 0))

# TODO: avoid duplicate calculations when mixing derivatives
# TODO(#1243): avoid duplicate calculations when mixing derivatives
r, t, z = nodes.T
l, m, n = modes.T
lm = modes[:, :2]

if unique:
# TODO: can avoid this here by using grid.unique_idx etc
# TODO(#1243): can avoid this here by using grid.unique_idx etc
# and adding unique_modes attributes to basis
_, ridx, routidx = np.unique(
r, return_index=True, return_inverse=True, axis=0
Expand Down Expand Up @@ -1364,7 +1364,6 @@ def polyval_vec(p, x, prec=None):
def _polyval_exact(p, x, prec):
p = np.atleast_2d(p)
x = np.atleast_1d(x).flatten()
# TODO: possibly multithread this bit
mpmath.mp.dps = prec
y = np.array([np.asarray(mpmath.polyval(list(pi), x)) for pi in p])
return y.astype(float)
Expand Down Expand Up @@ -1440,7 +1439,9 @@ def zernike_radial_coeffs(l, m, exact=True):
# hence they are all integers. So, we can use exact arithmetic with integer
# division instead of floating point division.
# [1]https://en.wikipedia.org/wiki/Zernike_polynomials#Other_representations
coeffs[ii, s] = ((-1) ** ((ll - s) // 2) * factorial((ll + s) // 2)) // (
coeffs[ii, s] = (
int((-1) ** ((ll - s) // 2)) * factorial((ll + s) // 2)
) // (
factorial((ll - s) // 2)
* factorial((s + mm) // 2)
* factorial((s - mm) // 2)
Expand Down
1 change: 0 additions & 1 deletion desc/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def f_(carry, x):
return res_append


# TODO in_axes a la vmap?
def _scanmap(fun, scan_fun, argnums=0):
"""A helper function to wrap f with a scan_fun."""

Expand Down
37 changes: 28 additions & 9 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,14 @@ def compute_magnetic_vector_potential(

@classmethod
def linspaced_angular(
cls, coil, current=None, axis=[0, 0, 1], angle=2 * np.pi, n=10, endpoint=False
cls,
coil,
current=None,
axis=[0, 0, 1],
angle=2 * np.pi,
n=10,
endpoint=False,
check_intersection=True,
):
"""Create a CoilSet by repeating a coil at equal spacing around the torus.
Expand All @@ -1651,6 +1658,8 @@ def linspaced_angular(
Number of copies of original coil.
endpoint : bool
Whether to include a coil at final rotation angle. Default = False.
check_intersection : bool
whether to check the resulting coilsets for intersecting coils.
"""
assert isinstance(coil, _Coil) and not isinstance(coil, CoilSet)
Expand All @@ -1664,11 +1673,17 @@ def linspaced_angular(
coili.rotate(axis=axis, angle=phi[i])
coili.current = currents[i]
coils.append(coili)
return cls(*coils)
return cls(*coils, check_intersection=check_intersection)

@classmethod
def linspaced_linear(
cls, coil, current=None, displacement=[2, 0, 0], n=4, endpoint=False
cls,
coil,
current=None,
displacement=[2, 0, 0],
n=4,
endpoint=False,
check_intersection=True,
):
"""Create a CoilSet by repeating a coil at equal spacing in a straight line.
Expand All @@ -1685,6 +1700,8 @@ def linspaced_linear(
Number of copies of original coil.
endpoint : bool
Whether to include a coil at final displacement location. Default = False.
check_intersection : bool
whether to check the resulting coilsets for intersecting coils.
"""
assert isinstance(coil, _Coil) and not isinstance(coil, CoilSet)
Expand All @@ -1699,10 +1716,10 @@ def linspaced_linear(
coili.translate(a[i] * displacement)
coili.current = currents[i]
coils.append(coili)
return cls(*coils)
return cls(*coils, check_intersection=check_intersection)

@classmethod
def from_symmetry(cls, coils, NFP=1, sym=False):
def from_symmetry(cls, coils, NFP=1, sym=False, check_intersection=True):
"""Create a coil group by reflection and symmetry.
Given coils over one field period, repeat coils NFP times between
Expand All @@ -1721,6 +1738,8 @@ def from_symmetry(cls, coils, NFP=1, sym=False):
sym : bool (optional)
Whether to enforce stellarator symmetry.
If True, the coils will be duplicated 2*NFP times. Default = False.
check_intersection : bool
whether to check the resulting coilsets for intersecting coils.
Returns
-------
Expand Down Expand Up @@ -1763,7 +1782,7 @@ def from_symmetry(cls, coils, NFP=1, sym=False):
rotated_coils.rotate(axis=[0, 0, 1], angle=2 * jnp.pi * k / NFP)
coilset += rotated_coils

return cls(*coilset)
return cls(*coilset, check_intersection=check_intersection)

@classmethod
def from_makegrid_coilfile(cls, coil_file, method="cubic", check_intersection=True):
Expand Down Expand Up @@ -1901,8 +1920,8 @@ def save_in_makegrid_format(self, coilsFilename, NFP=None, grid=None):
if None, will default to the coil compute functions's
default grid
"""
# TODO: name each group based off of CoilSet name?
# TODO: have CoilGroup be automatically assigned based off of
# TODO(#1376): name each group based off of CoilSet name?
# TODO(#1376): have CoilGroup be automatically assigned based off of
# CoilSet if current coilset is a collection of coilsets?

NFP = 1 if NFP is None else NFP
Expand Down Expand Up @@ -2697,7 +2716,7 @@ def insert(self, i, new_item):
self._coils.insert(i, new_item)

@classmethod
def from_makegrid_coilfile( # noqa: C901 - FIXME: simplify this
def from_makegrid_coilfile( # noqa: C901
cls, coil_file, method="cubic", ignore_groups=False, check_intersection=True
):
"""Create a MixedCoilSet of SplineXYZCoils from a MAKEGRID coil txtfile.
Expand Down
6 changes: 3 additions & 3 deletions desc/compute/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _A_of_z(params, transforms, profiles, data, **kwargs):
data=["Z", "n_rho", "e_theta|r,p", "rho"],
parameterization=["desc.geometry.surface.FourierRZToroidalSurface"],
resolution_requirement="rt", # just need max(rho) near 1
# FIXME: Add source grid requirement once omega is nonzero.
# TODO(#568): Add source grid requirement once omega is nonzero.
)
def _A_of_z_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
# Denote any vector v = [vᴿ, v^ϕ, vᶻ] with a tuple of its contravariant components.
Expand All @@ -213,7 +213,7 @@ def _A_of_z_FourierRZToroidalSurface(params, transforms, profiles, data, **kwarg
line_integrals(
transforms["grid"],
data["Z"] * n[:, 2] * safenorm(data["e_theta|r,p"], axis=-1),
# FIXME: Works currently for omega = zero, but for nonzero omega
# TODO(#568): Works currently for omega = zero, but for nonzero omega
# we need to integrate over theta at constant phi.
# Should be simple once we have coordinate mapping and source grid
# logic from GitHub pull request #1024.
Expand Down Expand Up @@ -449,7 +449,7 @@ def _perimeter_of_z(params, transforms, profiles, data, **kwargs):
line_integrals(
transforms["grid"],
safenorm(data["e_theta|r,p"], axis=-1),
# FIXME: Works currently for omega = zero, but for nonzero omega
# TODO(#568): Works currently for omega = zero, but for nonzero omega
# we need to integrate over theta at constant phi.
# Should be simple once we have coordinate mapping and source grid
# logic from GitHub pull request #1024.
Expand Down
4 changes: 2 additions & 2 deletions desc/compute/_omnigenity.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def fitfun(x):
return data


# TODO: do math to change definition of nu so that we can just use B_zeta_mn here
# TODO (#568): do math to change definition of nu so that we can just use B_zeta_mn here
@register_compute_fun(
name="B_phi_mn",
label="B_{\\phi, m, n}",
Expand All @@ -63,7 +63,7 @@ def fitfun(x):
data=["B_phi|r,t"],
resolution_requirement="tz",
grid_requirement={"is_meshgrid": True},
aliases="B_zeta_mn", # TODO: remove when phi != zeta
aliases="B_zeta_mn", # TODO(#568): remove when phi != zeta
M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M",
N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N",
)
Expand Down
6 changes: 3 additions & 3 deletions desc/compute/_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,7 @@ def _iota_num_rrr(params, transforms, profiles, data, **kwargs):
- beta * data["sqrt(g)_rrr"],
data["sqrt(g)"],
),
# Todo: axis limit of beta_rrr
# TODO(#587): axis limit of beta_rrr
# Computed with four applications of l’Hôpital’s rule.
# Requires sqrt(g)_rrrr and fourth derivatives of basis vectors.
jnp.nan,
Expand Down Expand Up @@ -1656,7 +1656,7 @@ def _iota_den_rrr(params, transforms, profiles, data, **kwargs):
- gamma * data["sqrt(g)_rrr"],
data["sqrt(g)"],
),
# Todo: axis limit
# TODO(#587): axis limit
# Computed with four applications of l’Hôpital’s rule.
# Requires sqrt(g)_rrrr and fourth derivatives of basis vectors.
jnp.nan,
Expand Down Expand Up @@ -1713,7 +1713,7 @@ def _q(params, transforms, profiles, data, **kwargs):
return data


# TODO: add K(rho,theta,zeta)*grad(rho) term
# TODO (#1381): add K(rho,theta,zeta)*grad(rho) term
@register_compute_fun(
name="I",
label="I",
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _D_current(params, transforms, profiles, data, **kwargs):
/ data["|grad(psi)|"] ** 3
* dot(Xi, data["B"])
),
# Todo: implement equivalent of equation 4.3 in desc coordinates
# TODO(#671): implement equivalent of equation 4.3 in desc coordinates
jnp.nan,
)
)
Expand Down
4 changes: 2 additions & 2 deletions desc/compute/_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .data_index import register_compute_fun
from .geom_utils import rpz2xyz

# TODO: review when zeta no longer equals phi
# TODO(#568): review when zeta no longer equals phi


@register_compute_fun(
Expand All @@ -27,7 +27,7 @@
def _x_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
R = transforms["R"].transform(params["R_lmn"])
Z = transforms["Z"].transform(params["Z_lmn"])
# TODO: change when zeta no longer equals phi
# TODO(#568): change when zeta no longer equals phi
phi = transforms["grid"].nodes[:, 2]
coords = jnp.stack([R, phi, Z], axis=1)
# default basis for "x" is rpz, the conversion will be done
Expand Down
1 change: 0 additions & 1 deletion desc/continuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,6 @@ def solve_continuation( # noqa: C901
if len(deltas) > 0:
if verbose > 0:
print("Perturbing equilibrium")
# TODO: pass Jx if available
eqp = eqfam[ii - 1].copy()
objective_i = get_equilibrium_objective(
eq=eqp, mode=objective, jac_chunk_size=jac_chunk_size
Expand Down
Loading

0 comments on commit db1ab5d

Please sign in to comment.