Skip to content

Commit

Permalink
FourierPlanarCurve basis option
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dudt committed May 13, 2024
1 parent d9f4a8d commit 5ba6004
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 32 deletions.
15 changes: 9 additions & 6 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,17 +458,19 @@ class FourierPlanarCoil(_Coil, FourierPlanarCurve):
Parameters
----------
current : float
current through the coil, in Amperes
Current through the coil, in Amperes.
center : array-like, shape(3,)
x,y,z coordinates of center of coil
Coordinates of center of curve, in system determined by basis.
normal : array-like, shape(3,)
x,y,z components of normal vector to planar surface
Components of normal vector to planar surface, in system determined by basis.
r_n : array-like
fourier coefficients for radius from center as function of polar angle
Fourier coefficients for radius from center as function of polar angle
modes : array-like
mode numbers associated with r_n
basis : {'xyz', 'rpz'}
Coordinate system for center and normal vectors. Default = 'xyz'.
name : str
name for this coil
Name for this coil.
Examples
--------
Expand Down Expand Up @@ -514,9 +516,10 @@ def __init__(
normal=[0, 1, 0],
r_n=2,
modes=None,
basis="xyz",
name="",
):
super().__init__(current, center, normal, r_n, modes, name)
super().__init__(current, center, normal, r_n, modes, basis, name)


class SplineXYZCoil(_Coil, SplineXYZCurve):
Expand Down
52 changes: 40 additions & 12 deletions desc/compute/_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,15 @@ def _Z_Curve(params, transforms, profiles, data, **kwargs):
data=["s"],
parameterization="desc.geometry.curve.FourierPlanarCurve",
basis="{'rpz', 'xyz'}: Basis for returned vectors, Default 'rpz'",
basis_in="{'rpz', 'xyz'}: Basis for input params vectors, Default 'xyz'",
)
def _x_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
if kwargs.get("basis_in", "xyz").lower() == "rpz":
center = rpz2xyz(params["center"])
normal = rpz2xyz_vec(params["normal"], phi=params["center"][1])

Check warning on line 184 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L183-L184

Added lines #L183 - L184 were not covered by tests
else:
center = params["center"]
normal = params["normal"]
# create planar curve at Z==0
r = transforms["r"].transform(params["r_n"], dz=0)
Z = jnp.zeros_like(r)
Expand All @@ -186,10 +193,10 @@ def _x_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.array([X, Y, Z]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"])))
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
coords = jnp.matmul(coords, A.T) + params["center"]
coords = jnp.matmul(coords, A.T) + center
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
if kwargs.get("basis", "rpz").lower() == "rpz":
coords = xyz2rpz(coords)
Expand All @@ -213,8 +220,15 @@ def _x_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
data=["s"],
parameterization="desc.geometry.curve.FourierPlanarCurve",
basis="{'rpz', 'xyz'}: Basis for returned vectors, Default 'rpz'",
basis_in="{'rpz', 'xyz'}: Basis for input params vectors, Default 'xyz'",
)
def _x_s_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
if kwargs.get("basis_in", "xyz").lower() == "rpz":
center = rpz2xyz(params["center"])
normal = rpz2xyz_vec(params["normal"], phi=params["center"][1])

Check warning on line 228 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L227-L228

Added lines #L227 - L228 were not covered by tests
else:
center = params["center"]
normal = params["normal"]
r = transforms["r"].transform(params["r_n"], dz=0)
dr = transforms["r"].transform(params["r_n"], dz=1)
dX = dr * jnp.cos(data["s"]) - r * jnp.sin(data["s"])
Expand All @@ -223,8 +237,8 @@ def _x_s_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.array([dX, dY, dZ]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"])))
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
Expand All @@ -233,7 +247,7 @@ def _x_s_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
Y = r * jnp.sin(data["s"])
Z = jnp.zeros_like(X)
xyzcoords = jnp.array([X, Y, Z]).T
xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"]
xyzcoords = jnp.matmul(xyzcoords, A.T) + center
xyzcoords = (
jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
)
Expand All @@ -259,8 +273,15 @@ def _x_s_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
data=["s"],
parameterization="desc.geometry.curve.FourierPlanarCurve",
basis="{'rpz', 'xyz'}: Basis for returned vectors, Default 'rpz'",
basis_in="{'rpz', 'xyz'}: Basis for input params vectors, Default 'xyz'",
)
def _x_ss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
if kwargs.get("basis_in", "xyz").lower() == "rpz":
center = rpz2xyz(params["center"])
normal = rpz2xyz_vec(params["normal"], phi=params["center"][1])

Check warning on line 281 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L280-L281

Added lines #L280 - L281 were not covered by tests
else:
center = params["center"]
normal = params["normal"]
r = transforms["r"].transform(params["r_n"], dz=0)
dr = transforms["r"].transform(params["r_n"], dz=1)
d2r = transforms["r"].transform(params["r_n"], dz=2)
Expand All @@ -274,8 +295,8 @@ def _x_ss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.array([d2X, d2Y, d2Z]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"])))
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
Expand All @@ -284,7 +305,7 @@ def _x_ss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
Y = r * jnp.sin(data["s"])
Z = jnp.zeros_like(X)
xyzcoords = jnp.array([X, Y, Z]).T
xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"]
xyzcoords = jnp.matmul(xyzcoords, A.T) + center
xyzcoords = (
jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
)
Expand All @@ -310,8 +331,15 @@ def _x_ss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
data=["s"],
parameterization="desc.geometry.curve.FourierPlanarCurve",
basis="{'rpz', 'xyz'}: Basis for returned vectors, Default 'rpz'",
basis_in="{'rpz', 'xyz'}: Basis for input params vectors, Default 'xyz'",
)
def _x_sss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
if kwargs.get("basis_in", "xyz").lower() == "rpz":
center = rpz2xyz(params["center"])
normal = rpz2xyz_vec(params["normal"], phi=params["center"][1])

Check warning on line 339 in desc/compute/_curve.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_curve.py#L338-L339

Added lines #L338 - L339 were not covered by tests
else:
center = params["center"]
normal = params["normal"]
r = transforms["r"].transform(params["r_n"], dz=0)
dr = transforms["r"].transform(params["r_n"], dz=1)
d2r = transforms["r"].transform(params["r_n"], dz=2)
Expand All @@ -332,8 +360,8 @@ def _x_sss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
coords = jnp.array([d3X, d3Y, d3Z]).T
# rotate into place
Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis
axis = cross(Zaxis, params["normal"])
angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"])))
axis = cross(Zaxis, normal)
angle = jnp.arccos(dot(Zaxis, safenormalize(normal)))
A = rotation_matrix(axis=axis, angle=angle)
coords = jnp.matmul(coords, A.T)
coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T)
Expand All @@ -342,7 +370,7 @@ def _x_sss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs):
Y = r * jnp.sin(data["s"])
Z = jnp.zeros_like(X)
xyzcoords = jnp.array([X, Y, Z]).T
xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"]
xyzcoords = jnp.matmul(xyzcoords, A.T) + center
xyzcoords = (
jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"]
)
Expand Down
6 changes: 3 additions & 3 deletions desc/geometry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,17 @@ def compute(
return data

def translate(self, displacement=[0, 0, 0]):
"""Translate the curve by a rigid displacement in X, Y, Z."""
"""Translate the curve by a rigid displacement in X,Y,Z coordinates."""
self.shift = self.shift + jnp.asarray(displacement)

def rotate(self, axis=[0, 0, 1], angle=0):
"""Rotate the curve by a fixed angle about axis in X, Y, Z coordinates."""
"""Rotate the curve by a fixed angle about axis in X,Y,Z coordinates."""
R = rotation_matrix(axis=axis, angle=angle)
self.rotmat = (R @ self.rotmat.reshape(3, 3)).flatten()
self.shift = self.shift @ R.T

def flip(self, normal=[0, 0, 1]):
"""Flip the curve about the plane with specified normal."""
"""Flip the curve about the plane with specified normal in X,Y,Z coordinates."""
F = reflection_matrix(normal)
self.rotmat = (F @ self.rotmat.reshape(3, 3)).flatten()
self.shift = self.shift @ F.T
Expand Down
74 changes: 63 additions & 11 deletions desc/geometry/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,24 +556,21 @@ class FourierPlanarCurve(Curve):
Parameters
----------
center : array-like, shape(3,)
x,y,z coordinates of center of curve
Coordinates of center of curve, in system determined by basis.
normal : array-like, shape(3,)
x,y,z components of normal vector to planar surface
Components of normal vector to planar surface, in system determined by basis.
r_n : array-like
Fourier coefficients for radius from center as function of polar angle
modes : array-like
mode numbers associated with r_n
basis : {'xyz', 'rpz'}
Coordinate system for center and normal vectors. Default = 'xyz'.
name : str
name for this curve
Name for this curve.
"""

_io_attrs_ = Curve._io_attrs_ + [
"_r_n",
"_center",
"_normal",
"_r_basis",
]
_io_attrs_ = Curve._io_attrs_ + ["_r_n", "_center", "_normal", "_r_basis", "_basis"]

# Reference frame is centered at the origin with normal in the +Z direction.
# The curve is computed in this frame and then shifted/rotated to the correct frame.
Expand All @@ -583,6 +580,7 @@ def __init__(
normal=[0, 1, 0],
r_n=2,
modes=None,
basis="xyz",
name="",
):
super().__init__(name)
Expand All @@ -593,13 +591,15 @@ def __init__(
modes = np.asarray(modes)
assert issubclass(modes.dtype.type, np.integer)
assert r_n.size == modes.size, "r_n size and modes must be the same size"
assert basis.lower() in ["xyz", "rpz"]

N = np.max(abs(modes))
self._r_basis = FourierSeries(N, NFP=1, sym=False)
self._r_n = copy_coeffs(r_n, modes, self.r_basis.modes[:, 2])

self.normal = normal
self.center = center
self._basis = basis

@property
def r_basis(self):
Expand Down Expand Up @@ -632,7 +632,9 @@ def center(self, new):
self._center = np.asarray(new)
else:
raise ValueError(
"center should be a 3 element vector [cx, cy, cz], got {}".format(new)
"center should be a 3 element vector in "
+ self._basis
+ " coordinates, got {}".format(new)
)

@optimizable_parameter
Expand All @@ -647,7 +649,9 @@ def normal(self, new):
self._normal = np.asarray(new) / np.linalg.norm(new)
else:
raise ValueError(
"normal should be a 3 element vector [nx, ny, nz], got {}".format(new)
"normal should be a 3 element vector in "
+ self._basis
+ " coordinates, got {}".format(new)
)

@optimizable_parameter
Expand Down Expand Up @@ -685,6 +689,54 @@ def set_coeffs(self, n, r=None):
if rr is not None:
self.r_n = put(self.r_n, idx, rr)

def compute(
self,
names,
grid=None,
params=None,
transforms=None,
data=None,
override_grid=True,
**kwargs,
):
"""Compute the quantity given by name on grid.
Parameters
----------
names : str or array-like of str
Name(s) of the quantity(s) to compute.
grid : Grid or int, optional
Grid of coordinates to evaluate at. Defaults to a Linear grid.
If an integer, uses that many equally spaced points.
params : dict of ndarray
Parameters from the equilibrium. Defaults to attributes of self.
transforms : dict of Transform
Transforms for R, Z, lambda, etc. Default is to build from grid
data : dict of ndarray
Data computed so far, generally output from other compute functions
override_grid : bool
If True, override the user supplied grid if necessary and use a full
resolution grid to compute quantities and then downsample to user requested
grid. If False, uses only the user specified grid, which may lead to
inaccurate values for surface or volume averages.
Returns
-------
data : dict of ndarray
Computed quantity and intermediate variables.
"""
return super().compute(
names=names,
grid=grid,
params=params,
transforms=transforms,
data=data,
override_grid=override_grid,
basis_in=self._basis,
**kwargs,
)


class SplineXYZCurve(Curve):
"""Curve parameterized by spline knots in X,Y,Z.
Expand Down
24 changes: 24 additions & 0 deletions tests/test_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,30 @@ def test_coords(self):
np.testing.assert_allclose(y, -11)
np.testing.assert_allclose(z, 1)

@pytest.mark.unit
def test_basis(self):
"""Test xyz vs rpz basis."""
cxyz = FourierPlanarCurve(center=[1, 1, 0], normal=[-1, 1, 0], basis="xyz")
crpz = FourierPlanarCurve(
center=[np.sqrt(2), np.pi / 4, 0], normal=[0, 1, 0], basis="rpz"
)

x_xyz = cxyz.compute("x")["x"]
x_rpz = crpz.compute("x")["x"]
np.testing.assert_allclose(x_xyz, x_rpz)

xs_xyz = cxyz.compute("x_s")["x_s"]
xs_rpz = crpz.compute("x_s")["x_s"]
np.testing.assert_allclose(xs_xyz, xs_rpz, atol=2e-15)

xss_xyz = cxyz.compute("x_ss")["x_ss"]
xss_rpz = crpz.compute("x_ss")["x_ss"]
np.testing.assert_allclose(xss_xyz, xss_rpz, atol=2e-15)

xsss_xyz = cxyz.compute("x_sss")["x_sss"]
xsss_rpz = crpz.compute("x_sss")["x_sss"]
np.testing.assert_allclose(xsss_xyz, xsss_rpz, atol=2e-15)

@pytest.mark.unit
def test_misc(self):
"""Test getting/setting misc attributes of FourierPlanarCurve."""
Expand Down

0 comments on commit 5ba6004

Please sign in to comment.