Skip to content

Commit

Permalink
Merge branch 'master' into yge/tr_direct
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Dec 4, 2024
2 parents ba92240 + 17d4939 commit e64f764
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 15 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

New Feature

- Adds a new profile class ``PowerProfile`` for raising profiles to a power.

v0.13.0
-------

Expand All @@ -24,7 +28,6 @@ New Features
- Adds tutorial notebook showcasing QFM surface capability.
- Adds ``rotate_zeta`` function to ``desc.compat`` to rotate an ``Equilibrium`` around Z axis.


Bug Fixes

- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version.
Expand Down
156 changes: 147 additions & 9 deletions desc/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,17 @@ def __sub__(self, x):
"""Subtract another profile from this one."""
return self.__add__(-x)

def __pow__(self, x):
"""Raise this profile to a power."""
if np.isscalar(x):
return PowerProfile(x, self)
else:
raise NotImplementedError()

def __rpow__(self, x):
"""Raise this profile to a power."""
return self.__pow__(x)


class ScaledProfile(_Profile):
"""Profile times a constant value.
Expand All @@ -252,10 +263,10 @@ class ScaledProfile(_Profile):
Parameters
----------
profile : Profile
Base profile to scale.
scale : float
Scale factor.
profile : Profile
Base profile to scale.
"""

Expand Down Expand Up @@ -335,6 +346,128 @@ def __repr__(self):
return s


class PowerProfile(_Profile):
"""Profile raised to a power.
f_1(x) = f(x)**a
Parameters
----------
power : float
Exponent of the new profile.
profile : Profile
Base profile to raise to a power.
"""

_io_attrs_ = _Profile._io_attrs_ + ["_profile", "_power"]

def __init__(self, power, profile, **kwargs):
assert isinstance(
profile, _Profile
), "profile in a PowerProfile must be a Profile or subclass, got {}.".format(
str(profile)
)
assert np.isscalar(power), "power must be a scalar."

self._profile = profile.copy()
self._power = power

self._check_params()

kwargs.setdefault("name", profile.name)
super().__init__(**kwargs)

def _check_params(self, params=None):
"""Check params and throw warnings or errors if necessary."""
params = self.params if params is None else params
power, params = self._parse_params(params)
warnif(
power < 0,
UserWarning,
"This profile may be undefined at some points because power < 0.",
)

@property
def params(self):
"""ndarray: Parameters for computation [power, profile.params]."""
return jnp.concatenate([jnp.atleast_1d(self._power), self._profile.params])

@params.setter
def params(self, x):
self._check_params(x)
self._power, self._profile.params = self._parse_params(x)

def _parse_params(self, x):
if x is None:
power = self._power
params = self._profile.params
elif isinstance(x, (tuple, list)) and len(x) == 2:
params = x[1]
power = x[0]
elif np.isscalar(x):
power = x
params = self._profile.params
elif len(x) == len(self._profile.params):
power = self._power
params = x
elif len(x) == len(self.params):
power = x[0]
params = x[1:]
else:
raise ValueError("Got wrong number of parameters for PowerProfile")
return power, params

def compute(self, grid, params=None, dr=0, dt=0, dz=0):
"""Compute values of profile at specified nodes.
Parameters
----------
grid : Grid
locations to compute values at.
params : array-like
Parameters to use. If not given, uses the
values given by the self.params attribute.
dr, dt, dz : int
derivative order in rho, theta, zeta.
Returns
-------
values : ndarray
values of the profile or its derivative at the points specified.
"""
if dt > 0 or dz > 0:
raise NotImplementedError(
"Poloidal and toroidal derivatives of PowerProfile have not been "
+ "implemented yet."
)
power, params = self._parse_params(params)
f0 = self._profile.compute(grid, params, 0, dt, dz)
if dr >= 1:
df1 = self._profile.compute(grid, params, 1, dt, dz) # df/dr
fn1 = self.compute(grid, (power - 1, params), 0, dt, dz) # f^(n-1)
if dr >= 2:
df2 = self._profile.compute(grid, params, 2, dt, dz) # d^2f/dr^2
fn2 = self.compute(grid, (power - 2, params), 0, dt, dz) # f^(n-2)
if dr == 0:
f = f0**power
elif dr == 1:
f = power * fn1 * df1
elif dr == 2:
f = power * ((power - 1) * fn2 * df1**2 + fn1 * df2)
else:
raise NotImplementedError("dr > 2 not implemented for PowerProfile!")
return f

def __repr__(self):
"""Get the string form of the object."""
s = super().__repr__()
s = s[:-1]
s += ", power={})".format(self._power)
return s


class SumProfile(_Profile):
"""Sum of two or more Profiles.
Expand Down Expand Up @@ -724,17 +857,24 @@ def __init__(self, params=None, name=""):
params = [0, 1, 1]
self._params = np.atleast_1d(params)

self._check_params()

def _check_params(self, params=None):
"""Check params and throw warnings or errors if necessary."""
params = self.params if params is None else params
errorif(
self._params.size != 3, ValueError, "params must be an array of size 3."
params.size != 3,
ValueError,
f"params must be an array of size 3, got {len(params)}.",
)
warnif(
self._params[1] < 1,
params[1] < 1,
UserWarning,
"Derivatives of this profile will be infinite at rho=0 "
+ "because params[1] < 1.",
)
warnif(
self._params[2] < 1,
params[2] < 1,
UserWarning,
"Derivatives of this profile will be infinite at rho=1 "
+ "because params[2] < 1.",
Expand All @@ -748,10 +888,8 @@ def params(self):
@params.setter
def params(self, new):
new = jnp.atleast_1d(jnp.asarray(new))
if new.size == 3:
self._params = jnp.asarray(new)
else:
raise ValueError(f"params should be an array of size 3, got {len(new)}.")
self._check_params(new)
self._params = new

def compute(self, grid, params=None, dr=0, dt=0, dz=0):
"""Compute values of profile at specified nodes.
Expand Down
10 changes: 6 additions & 4 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,16 @@ Profiles
:recursive:
:template: class.rst

desc.profiles.PowerSeriesProfile
desc.profiles.TwoPowerProfile
desc.profiles.SplineProfile
desc.profiles.FourierZernikeProfile
desc.profiles.HermiteSplineProfile
desc.profiles.MTanhProfile
desc.profiles.PowerProfile
desc.profiles.PowerSeriesProfile
desc.profiles.ProductProfile
desc.profiles.ScaledProfile
desc.profiles.SplineProfile
desc.profiles.SumProfile
desc.profiles.ProductProfile
desc.profiles.TwoPowerProfile

Transform
*********
Expand Down
5 changes: 4 additions & 1 deletion docs/api_equilibrium.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,12 @@ profiles together by addition, multiplication, or scaling.
desc.profiles.PowerSeriesProfile
desc.profiles.SplineProfile
desc.profiles.MTanhProfile
desc.profiles.TwoPowerProfile
desc.profiles.HermiteSplineProfile
desc.profiles.ScaledProfile
desc.profiles.SumProfile
desc.profiles.ProductProfile
desc.profiles.SumProfile
desc.profiles.PowerProfile


Utilities
Expand Down
46 changes: 46 additions & 0 deletions tests/test_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_repr(self):
assert "SumProfile" in str(pp + zp)
assert "ProductProfile" in str(pp * zp)
assert "ScaledProfile" in str(2 * zp)
assert "PowerProfile" in str(zp**2)

@pytest.mark.unit
def test_get_set(self):
Expand Down Expand Up @@ -347,6 +348,47 @@ def test_scaled_profiles(self):
np.testing.assert_allclose(pp.params, [1, -2, 1])
np.testing.assert_allclose(f(x), 8 * (pp(x)), atol=1e-3)

@pytest.mark.unit
def test_powered_profiles(self):
"""Test raising profiles to a power."""
pp = PowerSeriesProfile(
modes=np.array([0, 1, 2, 4]), params=np.array([1, 0, -2, 1]), sym="auto"
)

f = pp**3
x = np.linspace(0, 1, 50)
np.testing.assert_allclose(f(x), (pp(x)) ** 3, atol=1e-3)

params = f.params
assert params[0] == 3
assert all(params[1:] == pp.params)

f.params = 2
np.testing.assert_allclose(f(x), (pp(x)) ** 2, atol=1e-3)

f.params = 0.5
np.testing.assert_allclose(f(x), np.sqrt(pp(x)), atol=1e-3)

@pytest.mark.unit
def test_powered_profiles_derivative(self):
"""Test that powered profiles computes the derivative correctly."""
x = np.linspace(0, 1, 50)
p1 = PowerSeriesProfile(
modes=np.array([0, 1, 2, 4]), params=np.array([1, 3, -2, 4]), sym="auto"
)
p2 = p1 * p1
p3 = p1 * p2

f3 = p1**3
np.testing.assert_allclose(f3(x, dr=0), p3(x, dr=0))
np.testing.assert_allclose(f3(x, dr=1), p3(x, dr=1))
np.testing.assert_allclose(f3(x, dr=2), p3(x, dr=2))

f2 = f3 ** (2 / 3)
np.testing.assert_allclose(f2(x, dr=0), p2(x, dr=0))
np.testing.assert_allclose(f2(x, dr=1), p2(x, dr=1))
np.testing.assert_allclose(f2(x, dr=2), p2(x, dr=2))

@pytest.mark.unit
def test_profile_errors(self):
"""Test error checking when creating and working with profiles."""
Expand Down Expand Up @@ -383,6 +425,10 @@ def test_profile_errors(self):
tp.compute(grid, dr=3)
with pytest.raises(NotImplementedError):
mp.compute(grid, dr=3)
with pytest.raises(UserWarning):
tp.params = [1, 0.3, 0.7]
with pytest.raises(UserWarning):
a = sp**-1

@pytest.mark.unit
def test_default_profiles(self):
Expand Down

0 comments on commit e64f764

Please sign in to comment.