Skip to content

Commit

Permalink
Merge branch 'master' into dp/vmecio-asym
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici committed Aug 23, 2024
2 parents 9d91529 + bdc5de4 commit 08c314b
Show file tree
Hide file tree
Showing 11 changed files with 495 additions and 214 deletions.
81 changes: 67 additions & 14 deletions desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def is_meshgrid(self):
Let the tuple (r, p, t) ∈ R³ denote a radial, poloidal, and toroidal
coordinate value. The is_meshgrid flag denotes whether any coordinate
can be iterated over along the relevant axis of the reshaped grid:
nodes.reshape(num_radial, num_poloidal, num_toroidal, 3).
nodes.reshape((num_poloidal, num_radial, num_toroidal, 3), order="F").
"""
return self.__dict__.setdefault("_is_meshgrid", False)

Expand Down Expand Up @@ -598,6 +598,52 @@ def replace_at_axis(self, x, y, copy=False, **kwargs):
)
return x

def meshgrid_reshape(self, x, order):
"""Reshape data to match grid coordinates.
Given flattened data on a tensor product grid, reshape the data such that
the axes of the array correspond to coordinate values on the grid.
Parameters
----------
x : ndarray, shape(N,) or shape(N,3)
Data to reshape.
order : str
Desired order of axes for returned data. Should be a permutation of
``grid.coordinates``, eg ``order="rtz"`` has the first axis of the returned
data correspond to different rho coordinates, the second axis to different
theta, etc. ``order="trz"`` would have the first axis correspond to theta,
and so on.
Returns
-------
x : ndarray
Data reshaped to align with grid nodes.
"""
errorif(
not self.is_meshgrid,
ValueError,
"grid is not a tensor product grid, so meshgrid_reshape doesn't "
"make any sense",
)
errorif(
sorted(order) != sorted(self.coordinates),
ValueError,
f"order should be a permutation of {self.coordinates}, got {order}",
)
shape = (self.num_poloidal, self.num_rho, self.num_zeta)
vec = False
if x.ndim > 1:
vec = True
shape += (-1,)
x = x.reshape(shape, order="F")
x = jnp.moveaxis(x, 1, 0) # now shape rtz/raz etc
newax = tuple(self.coordinates.index(c) for c in order)
if vec:
newax += (3,)
x = jnp.transpose(x, newax)
return x


class Grid(_Grid):
"""Collocation grid with custom node placement.
Expand Down Expand Up @@ -632,7 +678,7 @@ class Grid(_Grid):
Let the tuple (r, p, t) ∈ R³ denote a radial, poloidal, and toroidal
coordinate value. The is_meshgrid flag denotes whether any coordinate
can be iterated over along the relevant axis of the reshaped grid:
nodes.reshape(num_radial, num_poloidal, num_toroidal, 3).
nodes.reshape((num_poloidal, num_radial, num_toroidal, 3), order="F").
jitable : bool
Whether to skip certain checks and conditionals that don't work under jit.
Allows grid to be created on the fly with custom nodes, but weights, symmetry
Expand Down Expand Up @@ -762,11 +808,16 @@ def create_meshgrid(
dc = _periodic_spacing(c, period[2])[1] * NFP
else:
da, db, dc = spacing

bb, aa, cc = jnp.meshgrid(b, a, c, indexing="ij")

nodes = jnp.column_stack(
list(map(jnp.ravel, jnp.meshgrid(a, b, c, indexing="ij")))
[aa.flatten(order="F"), bb.flatten(order="F"), cc.flatten(order="F")]
)
bb, aa, cc = jnp.meshgrid(db, da, dc, indexing="ij")

spacing = jnp.column_stack(
list(map(jnp.ravel, jnp.meshgrid(da, db, dc, indexing="ij")))
[aa.flatten(order="F"), bb.flatten(order="F"), cc.flatten(order="F")]
)
weights = (
spacing.prod(axis=1)
Expand All @@ -776,19 +827,18 @@ def create_meshgrid(
else None
)

unique_a_idx = jnp.arange(a.size) * b.size * c.size
unique_b_idx = jnp.arange(b.size) * c.size
unique_c_idx = jnp.arange(c.size)
inverse_a_idx = repeat(
unique_a_idx // (b.size * c.size),
b.size * c.size,
total_repeat_length=a.size * b.size * c.size,
unique_a_idx = jnp.arange(a.size) * b.size
unique_b_idx = jnp.arange(b.size)
unique_c_idx = jnp.arange(c.size) * a.size * b.size
inverse_a_idx = jnp.tile(
repeat(unique_a_idx // b.size, b.size, total_repeat_length=a.size * b.size),
c.size,
)
inverse_b_idx = jnp.tile(
repeat(unique_b_idx // c.size, c.size, total_repeat_length=b.size * c.size),
a.size,
unique_b_idx,
a.size * c.size,
)
inverse_c_idx = jnp.tile(unique_c_idx, a.size * b.size)
inverse_c_idx = repeat(unique_c_idx // (a.size * b.size), (a.size * b.size))
return Grid(
nodes=nodes,
spacing=spacing,
Expand Down Expand Up @@ -908,6 +958,7 @@ def __init__(
self._toroidal_endpoint = False
self._node_pattern = "linear"
self._coordinates = "rtz"
self._is_meshgrid = True
self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP)
self._nodes, self._spacing = self._create_nodes(
L=L,
Expand Down Expand Up @@ -1200,6 +1251,7 @@ def __init__(self, L, M, N, NFP=1):
self._sym = False
self._node_pattern = "quad"
self._coordinates = "rtz"
self._is_meshgrid = True
self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP)
self._nodes, self._spacing = self._create_nodes(L=L, M=M, N=N, NFP=NFP)
# symmetry is never enforced for Quadrature Grid
Expand Down Expand Up @@ -1341,6 +1393,7 @@ def __init__(self, L, M, N, NFP=1, sym=False, axis=False, node_pattern="jacobi")
self._sym = sym
self._node_pattern = node_pattern
self._coordinates = "rtz"
self._is_meshgrid = False
self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP)
self._nodes, self._spacing = self._create_nodes(
L=L, M=M, N=N, NFP=NFP, axis=axis, node_pattern=node_pattern
Expand Down
5 changes: 4 additions & 1 deletion desc/io/optimizable_io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functions and methods for saving and loading equilibria and other objects."""

import copy
import functools
import os
import pickle
import pydoc
Expand Down Expand Up @@ -86,7 +87,9 @@ def _unjittable(x):
return any([_unjittable(y) for y in x.values()])
if hasattr(x, "dtype") and np.ndim(x) == 0:
return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_)
return isinstance(x, (str, types.FunctionType, bool, int, np.int_))
return isinstance(
x, (str, types.FunctionType, functools.partial, bool, int, np.int_)
)


def _make_hashable(x):
Expand Down
11 changes: 7 additions & 4 deletions desc/objectives/linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def update_target(self, thing):
assert len(new_target) == len(self.target)
self.target = new_target
self._target_from_user = self.target # in case the Objective is re-built
if self._use_jit:
self.jit()
if not self._use_jit:
self._unjit()

def _parse_target_from_user(
self, target_from_user, default_target, default_bounds, idx
Expand Down Expand Up @@ -232,8 +232,8 @@ def update_target(self, thing):
"""
self.target = self.compute(thing.params_dict)
if self._use_jit:
self.jit()
if not self._use_jit:
self._unjit()


class BoundaryRSelfConsistency(_Objective):
Expand Down Expand Up @@ -3184,6 +3184,7 @@ class FixNearAxisR(_FixedObjective):
"""

_static_attrs = ["_nae_eq"]
_target_arg = "R_lmn"
_fixed = False # not "diagonal", since its fixing a sum
_units = "(m)"
Expand Down Expand Up @@ -3320,6 +3321,7 @@ class FixNearAxisZ(_FixedObjective):
"""

_static_attrs = ["_nae_eq"]
_target_arg = "Z_lmn"
_fixed = False # not "diagonal", since its fixing a sum
_units = "(m)"
Expand Down Expand Up @@ -3462,6 +3464,7 @@ class FixNearAxisLambda(_FixedObjective):
"""

_static_attrs = ["_nae_eq"]
_target_arg = "L_lmn"
_fixed = False # not "diagonal", since its fixing a sum
_units = "(dimensionless)"
Expand Down
Loading

0 comments on commit 08c314b

Please sign in to comment.