Skip to content

Commit

Permalink
Merge branch 'master' into dp/vector-potential
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici committed Sep 3, 2024
2 parents 2b99cc8 + d2e9a2c commit 98053b3
Show file tree
Hide file tree
Showing 47 changed files with 3,892 additions and 295 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ New Features
- Add ability to save and load vector potential information from ``mgrid`` files.
- Changes ``ToroidalFlux`` objective to default using a 1D loop integral of the vector potential
to compute the toroidal flux when possible, as opposed to a 2D surface integral of the magnetic field dotted with ``n_zeta``.
- Allow specification of Nyquist spectrum maximum modenumbers when using ``VMECIO.save`` to save a DESC .h5 file as a VMEC-format wout file

Bug Fixes

- Fixes bugs that occur when saving asymmetric equilibria as wout files
- Fixes bug that occurs when using ``VMECIO.plot_vmec_comparison`` to compare to an asymmetric wout file



v0.12.1
Expand Down
8 changes: 4 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ Contribute
:target: https://desc-docs.readthedocs.io/en/latest/?badge=latest
:alt: Documentation

.. |UnitTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/unittest.yml/badge.svg
:target: https://github.com/PlasmaControl/DESC/actions/workflows/unittest.yml
.. |UnitTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/unit_tests.yml/badge.svg
:target: https://github.com/PlasmaControl/DESC/actions/workflows/unit_tests.yml
:alt: UnitTests

.. |RegressionTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/regression_test.yml/badge.svg
:target: https://github.com/PlasmaControl/DESC/actions/workflows/regression_test.yml
.. |RegressionTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/regression_tests.yml/badge.svg
:target: https://github.com/PlasmaControl/DESC/actions/workflows/regression_tests.yml
:alt: RegressionTests

.. |Codecov| image:: https://codecov.io/gh/PlasmaControl/DESC/branch/master/graph/badge.svg?token=5LDR4B1O7Z
Expand Down
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ comment: # this is a top-level key
require_changes: false # if true: only post the comment if coverage changes
require_base: true # [true :: must have a base report to post]
require_head: true # [true :: must have a head report to post]
after_n_builds: 10
after_n_builds: 14
coverage:
status:
patch:
Expand Down
119 changes: 77 additions & 42 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,23 @@
)

if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign?
jit = jax.jit
fori_loop = jax.lax.fori_loop
cond = jax.lax.cond
switch = jax.lax.switch
while_loop = jax.lax.while_loop
vmap = jax.vmap
bincount = jnp.bincount
repeat = jnp.repeat
take = jnp.take
scan = jax.lax.scan
from jax import custom_jvp
from jax import custom_jvp, jit, vmap

imap = jax.lax.map
from jax.experimental.ode import odeint
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.lax import cond, fori_loop, scan, switch, while_loop
from jax.nn import softmax as softargmax
from jax.numpy import bincount, flatnonzero, repeat, take
from jax.numpy.fft import irfft, rfft, rfft2
from jax.scipy.fft import dct, idct
from jax.scipy.linalg import (
block_diag,
cho_factor,
cho_solve,
eigh_tridiagonal,
qr,
solve_triangular,
)
from jax.scipy.special import gammaln, logsumexp
from jax.tree_util import (
register_pytree_node,
Expand All @@ -90,6 +94,10 @@
treedef_is_leaf,
)

trapezoid = (
jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid
)

def put(arr, inds, vals):
"""Functional interface for array "fancy indexing".
Expand Down Expand Up @@ -328,6 +336,8 @@ def root(
This routine may be used on over or under-determined systems, in which case it
will solve it in a least squares / least norm sense.
"""
from desc.compute.utils import safenorm

if fixup is None:
fixup = lambda x, *args: x
if jac is None:
Expand Down Expand Up @@ -392,7 +402,7 @@ def tangent_solve(g, y):
x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (jnp.linalg.norm(res), niter)
return x, (safenorm(res), niter)


# we can't really test the numpy backend stuff in automated testing, so we ignore it
Expand All @@ -401,15 +411,54 @@ def tangent_solve(g, y):
jit = lambda func, *args, **kwargs: func
execute_on_cpu = lambda func: func
import scipy.optimize
from numpy.fft import irfft, rfft, rfft2 # noqa: F401
from scipy.fft import dct, idct # noqa: F401
from scipy.integrate import odeint # noqa: F401
from scipy.linalg import ( # noqa: F401
block_diag,
cho_factor,
cho_solve,
eigh_tridiagonal,
qr,
solve_triangular,
)
from scipy.special import gammaln, logsumexp # noqa: F401
from scipy.special import softmax as softargmax # noqa: F401

trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz

def imap(f, xs, batch_size=None, in_axes=0, out_axes=0):
"""Generalizes jax.lax.map; uses numpy."""
if not isinstance(xs, np.ndarray):
raise NotImplementedError(
"Require numpy array input, or install jax to support pytrees."
)
xs = np.moveaxis(xs, source=in_axes, destination=0)
return np.stack([f(x) for x in xs], axis=out_axes)

def vmap(fun, in_axes=0, out_axes=0):
"""A numpy implementation of jax.lax.map whose API is a subset of jax.vmap.
Like Python's builtin map,
except inputs and outputs are in the form of stacked arrays,
and the returned object is a vectorized version of the input function.
Parameters
----------
fun: callable
Function (A -> B)
in_axes: int
Axis to map over.
out_axes: int
An integer indicating where the mapped axis should appear in the output.
Returns
-------
fun_vmap: callable
Vectorized version of fun.
"""
return lambda xs: imap(fun, xs, in_axes=in_axes, out_axes=out_axes)

def tree_stack(*args, **kwargs):
"""Stack pytree for numpy backend."""
Expand Down Expand Up @@ -592,32 +641,6 @@ def while_loop(cond_fun, body_fun, init_val):
val = body_fun(val)
return val

def vmap(fun, out_axes=0):
"""A numpy implementation of jax.lax.map whose API is a subset of jax.vmap.
Like Python's builtin map,
except inputs and outputs are in the form of stacked arrays,
and the returned object is a vectorized version of the input function.
Parameters
----------
fun: callable
Function (A -> B)
out_axes: int
An integer indicating where the mapped axis should appear in the output.
Returns
-------
fun_vmap: callable
Vectorized version of fun.
"""

def fun_vmap(fun_inputs):
return np.stack([fun(fun_input) for fun_input in fun_inputs], axis=out_axes)

return fun_vmap

def scan(f, init, xs, length=None, reverse=False, unroll=1):
"""Scan a function over leading array axes while carrying along state.
Expand Down Expand Up @@ -657,9 +680,14 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1):
ys.append(y)
return carry, np.stack(ys)

def bincount(x, weights=None, minlength=None, length=None):
"""Same as np.bincount but with a dummy parameter to match jnp.bincount API."""
return np.bincount(x, weights, minlength)
def bincount(x, weights=None, minlength=0, length=None):
"""A numpy implementation of jnp.bincount."""
x = np.clip(x, 0, None)
if length is None:
length = max(minlength, x.max() + 1)
else:
minlength = max(minlength, length)
return np.bincount(x, weights, minlength)[:length]

def repeat(a, repeats, axis=None, total_repeat_length=None):
"""A numpy implementation of jnp.repeat."""
Expand Down Expand Up @@ -778,6 +806,13 @@ def root(
out = scipy.optimize.root(fun, x0, args, jac=jac, tol=tol)
return out.x, out

def flatnonzero(a, size=None, fill_value=0):
"""A numpy implementation of jnp.flatnonzero."""
nz = np.flatnonzero(a)
if size is not None:
nz = np.pad(nz, (0, max(size - nz.size, 0)), constant_values=fill_value)
return nz

def take(
a,
indices,
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.special import roots_legendre

from ..backend import fori_loop, jnp
from ..integrals import surface_averages_map
from ..integrals.surface_integral import surface_averages_map
from .data_index import register_compute_fun


Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_equil.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from desc.backend import jnp

from ..integrals import surface_averages
from ..integrals.surface_integral import surface_averages
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import (
from ..integrals.surface_integral import (
surface_averages,
surface_integrals_map,
surface_max,
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import surface_averages
from ..integrals.surface_integral import surface_averages
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm

Expand Down
Loading

0 comments on commit 98053b3

Please sign in to comment.