Skip to content

Commit

Permalink
Merge branch 'master' into dp/laplace
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis authored Dec 11, 2024
2 parents 1710819 + e27015d commit bc95043
Show file tree
Hide file tree
Showing 16 changed files with 76 additions and 62 deletions.
9 changes: 9 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Please see the documentation for all configuration options:
# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file

version: 2
updates:
- package-ecosystem: "pip"
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
2 changes: 2 additions & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
pwd
lscpu
pip list
cd tests/benchmarks
python -m pytest benchmark_cpu_small.py -vv \
--benchmark-save='Latest_Commit' \
Expand All @@ -108,6 +109,7 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
pwd
lscpu
pip list
cd tests/benchmarks
python -m pytest benchmark_cpu_small.py -vv \
--benchmark-save='master' \
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/jax_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
run: |
pwd
lscpu
pip list
python -m pytest -m unit \
--durations=0 \
--mpl \
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/mpl_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,5 @@ jobs:
run: |
pwd
lscpu
pip list
python -m pytest tests/test_plotting.py --durations=0 --mpl --maxfail=1
1 change: 1 addition & 0 deletions .github/workflows/notebook_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
pwd
lscpu
pip list
export PYTHONPATH=$(pwd)
pytest -v --nbmake "./docs/notebooks" \
--nbmake-timeout=2000 \
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/regression_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
pwd
lscpu
pip list
python -m pytest -v -m regression\
--durations=0 \
--cov-report xml:cov.xml \
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ jobs:
source .venv-${{ matrix.combos.python_version }}/bin/activate
pwd
lscpu
pip list
python -m pytest -v -m unit \
--durations=0 \
--cov-report xml:cov.xml \
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/weekly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
run: |
pwd
lscpu
pip list
python -m pytest -v -m unit \
--durations=0 \
--splits 4 \
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ New Feature

- Adds a new profile class ``PowerProfile`` for raising profiles to a power.

Bug Fixes

- Small bug fix to use the correct normalization length ``a`` in the BallooningStability objective

v0.13.0
-------

Expand Down
16 changes: 14 additions & 2 deletions desc/io/optimizable_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def _unjittable(x):
# strings and functions can't be args to jitted functions, and ints/bools are pretty
# much always flags or array sizes which also need to be a compile time constant
if isinstance(x, (list, tuple)):
return any([_unjittable(y) for y in x])
return all([_unjittable(y) or y is None for y in x])
if isinstance(x, dict):
return any([_unjittable(y) for y in x.values()])
return all([_unjittable(y) or y is None for y in x.values()])
if hasattr(x, "dtype") and np.ndim(x) == 0:
return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_)
return isinstance(
Expand All @@ -94,6 +94,12 @@ def _unjittable(x):

def _make_hashable(x):
# turn unhashable ndarray of ints into a hashable tuple
if isinstance(x, list):
return [_make_hashable(y) for y in x]
if isinstance(x, tuple):
return tuple([_make_hashable(y) for y in x])
if isinstance(x, dict):
return {key: _make_hashable(val) for key, val in x.items()}
if hasattr(x, "shape"):
return ("ndarray", x.shape, tuple(x.flatten()))
return x
Expand All @@ -103,6 +109,12 @@ def _unmake_hashable(x):
# turn tuple of ints and shape to ndarray
if isinstance(x, tuple) and x[0] == "ndarray":
return np.array(x[2]).reshape(x[1])
if isinstance(x, list):
return [_unmake_hashable(y) for y in x]
if isinstance(x, tuple):
return tuple([_unmake_hashable(y) for y in x])
if isinstance(x, dict):
return {key: _unmake_hashable(val) for key, val in x.items()}
return x


Expand Down
5 changes: 3 additions & 2 deletions desc/objectives/_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from desc.backend import jnp
from desc.compute import get_params, get_profiles, get_transforms
from desc.compute.utils import _compute as compute_fun
from desc.grid import LinearGrid
from desc.grid import LinearGrid, QuadratureGrid
from desc.utils import Timer, errorif, setdefault, warnif

from .normalization import compute_scaling_factors
Expand Down Expand Up @@ -491,7 +491,8 @@ def build(self, eq=None, use_jit=True, verbose=1):
iota_transforms = get_transforms(self._iota_keys, obj=eq, grid=iota_grid)

# Separate grid to calculate the right length scale for normalization
len_grid = LinearGrid(rho=1.0, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP)
len_grid = QuadratureGrid(L=eq.L, M=eq.M, N=eq.N, NFP=eq.NFP)

self._len_keys = ["a"]
len_profiles = get_profiles(self._len_keys, obj=eq, grid=len_grid)
len_transforms = get_transforms(self._len_keys, obj=eq, grid=len_grid)
Expand Down
18 changes: 0 additions & 18 deletions devtools/dev-requirements_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,10 @@ name: desc-env
channels:
- conda-forge
dependencies:
- colorama
- h5py >= 3.0.0, < 4.0
- matplotlib >= 3.5.0, < 4.0.0
- mpmath >= 1.0.0, < 2.0
- netcdf4 >= 1.5.4, < 2.0
- numpy >= 1.20.0
- psutil
- scipy >= 1.7.0
- termcolor
- pip
- pip:
# Conda only parses a single list of pip requirements.
# If two pip lists are given, all but the last list is skipped.
- jax >= 0.4.24, <= 0.4.35
- diffrax >= 0.4.1
- interpax >= 0.3.3
- nvgpu
- orthax
- plotly >= 5.16, < 6.0
- pylatexenc >= 2.0, < 3.0
- quadax >= 0.2.2
- scikit-image
# building the docs
- sphinx-github-style >= 1.0, < 2.0
# testing and benchmarking
Expand Down
6 changes: 2 additions & 4 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,9 @@ Option 2: Using conda to install packages (this will only install DESC + JAX wit

.. code-block:: sh
# only need to do one of these conda env create commands, not both
# option A: without developer requirements
conda env create --file requirements_conda.yml
# option B: with developer requirements (if you want to run tests)
conda env create --file devtools/dev-requirements_conda.yml
# optionally: install developer requirements (if you want to run tests)
conda install --file devtools/dev-requirements_conda.yml
# to add DESC to your Python path
conda activate desc-env
Expand Down
34 changes: 17 additions & 17 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
jax >= 0.4.24, <= 0.4.35
colorama
diffrax >= 0.4.1
h5py >= 3.0.0, < 4.0
interpax >= 0.3.3
matplotlib >= 3.5.0, < 4.0.0
mpmath >= 1.0.0, < 2.0
netcdf4 >= 1.5.4, < 2.0
numpy >= 1.20.0
nvgpu
orthax
plotly >= 5.16, < 6.0
psutil
pylatexenc >= 2.0, < 3.0
quadax >= 0.2.2
scikit-image
scipy >= 1.7.0
termcolor
colorama <= 0.4.6
diffrax >= 0.4.1, <= 0.6.0
h5py >= 3.0.0, <= 3.12.1
interpax >= 0.3.3, <= 0.3.4
matplotlib >= 3.5.0, <= 3.9.3
mpmath >= 1.0.0, <= 1.3.0
netcdf4 >= 1.5.4, <= 1.7.2
numpy >= 1.20.0, <= 2.1.3
nvgpu <= 0.10.0
orthax <= 0.2.1
plotly >= 5.16, <= 5.24.1
psutil <= 6.1.0
pylatexenc >= 2.0, <= 2.10
quadax >= 0.2.2, <= 0.2.4
scikit-image <= 0.24.0
scipy >= 1.7.0, <= 1.14.1
termcolor <= 2.5.0
34 changes: 17 additions & 17 deletions requirements_conda.yml
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
name: desc-env
dependencies:
# standard install
- colorama
- diffrax >= 0.4.1
- h5py >= 3.0.0, < 4.0
- matplotlib >= 3.5.0, < 4.0.0
- mpmath >= 1.0.0, < 2.0
- netcdf4 >= 1.5.4, < 2.0
- numpy >= 1.20.0
- psutil
- scipy >= 1.7.0
- termcolor
- colorama <= 0.4.6
- diffrax >= 0.4.1, <= 0.6.0
- h5py >= 3.0.0, <= 3.12.1
- matplotlib >= 3.5.0, <= 3.9.3
- mpmath >= 1.0.0, <= 1.3.0
- netcdf4 >= 1.5.4, <= 1.7.2
- numpy >= 1.20.0, <= 2.1.3
- psutil <= 6.1.0
- scipy >= 1.7.0, <= 1.14.1
- termcolor <= 2.5.0
- pip
- pip:
# Conda only parses a single list of pip requirements.
# If two pip lists are given, all but the last list is skipped.
- jax >= 0.4.24, <= 0.4.35
- interpax >= 0.3.3
- nvgpu
- orthax
- plotly >= 5.16, < 6.0
- pylatexenc >= 2.0, < 3.0
- quadax >= 0.2.2
- scikit-image
- interpax >= 0.3.3, <= 0.3.4
- nvgpu <= 0.10.0
- orthax <= 0.2.1
- plotly >= 5.16, <= 5.24.1
- pylatexenc >= 2.0, <= 2.10
- quadax >= 0.2.2, <= 0.2.4
- scikit-image <= 0.24.0
4 changes: 2 additions & 2 deletions tests/test_stability_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,8 @@ def test_ballooning_stability_eval():
# different numerics than "ideal ballooning lambda" so that we can verify them
# against one another
psi_b = data01["Psi"][-1] / (2 * jnp.pi)
# Calculating a_N accurately requires a QuadratureGrid
# which is automatically accounted for inside of eq.compute
a_N = data01["a"]
B_N = 2 * psi_b / a_N**2

Expand Down Expand Up @@ -755,8 +757,6 @@ def find_root_simple(x, y):
# Flux surfaces on which to evaluate ballooning stability
surfaces = [0.98, 0.985, 0.99, 0.995, 1.0]

grid = LinearGrid(rho=jnp.array(surfaces), NFP=eq.NFP)

Nalpha = 8 # Number of field lines

assert Nalpha == int(8), "Nalpha in the compute function hard-coded to 8!"
Expand Down

0 comments on commit bc95043

Please sign in to comment.