diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index ad00eeef5d..68c9442106 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -4,9 +4,6 @@ on: pull_request: branches: - master - paths-ignore: - - 'docs/**' - - 'devtools/**' workflow_dispatch: inputs: debug_enabled: @@ -22,6 +19,12 @@ concurrency: jobs: benchmark: runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + strategy: + matrix: + python-version: ['3.9'] + group: [1, 2] steps: # Enable tmate debugging of manually-triggered workflows if the input option was provided @@ -31,41 +34,91 @@ jobs: - uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - - name: Set up Python 3.9 + + - name: Filter changes + id: changes + uses: dorny/paths-filter@v3 + with: + filters: | + has_changes: + - 'desc/**' + - 'tests/benchmarks/**' + - 'requirements.txt' + - 'devtools/dev-requirements.txt' + - 'setup.cfg' + - '.github/workflows/benchmark.yml' + + - name: Check for relevant changes + id: check_changes + run: echo "has_changes=${{ steps.changes.outputs.has_changes }}" >> $GITHUB_ENV + + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.9 - - name: Install dependencies + python-version: ${{ matrix.python-version }} + + - name: Restore Python environment cache + if: env.has_changes == 'true' + id: restore-env + uses: actions/cache/restore@v4 + with: + path: .venv-${{ matrix.python-version }} + key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Set up virtual environment if not restored from cache + if: steps.restore-env.outputs.cache-hit != 'true' && env.has_changes == 'true' run: | + gh cache list + python -m venv .venv-${{ matrix.python-version }} + source .venv-${{ matrix.python-version }}/bin/activate python -m pip install --upgrade pip pip install -r devtools/dev-requirements.txt - - name: Benchmark with pytest-benchmark + + - name: Benchmark with pytest-benchmark (PR) + if: env.has_changes == 'true' run: | + source .venv-${{ matrix.python-version }}/bin/activate pwd lscpu cd tests/benchmarks python -m pytest benchmark_cpu_small.py -vv \ --benchmark-save='Latest_Commit' \ --durations=0 \ - --benchmark-save-data + --benchmark-save-data \ + --splits 2 \ + --group ${{ matrix.group }} \ + --splitting-algorithm least_duration + - name: Checkout current master + if: env.has_changes == 'true' uses: actions/checkout@v4 with: ref: master clean: false + - name: Checkout benchmarks from PR head + if: env.has_changes == 'true' run: git checkout ${{ github.event.pull_request.head.sha }} -- tests/benchmarks - - name: Benchmark with pytest-benchmark + + - name: Benchmark with pytest-benchmark (MASTER) + if: env.has_changes == 'true' run: | + source .venv-${{ matrix.python-version }}/bin/activate pwd lscpu cd tests/benchmarks python -m pytest benchmark_cpu_small.py -vv \ --benchmark-save='master' \ --durations=0 \ - --benchmark-save-data - - name: put benchmark results in same folder + --benchmark-save-data \ + --splits 2 \ + --group ${{ matrix.group }} \ + --splitting-algorithm least_duration + + - name: Put benchmark results in same folder + if: env.has_changes == 'true' run: | + source .venv-${{ matrix.python-version }}/bin/activate pwd cd tests/benchmarks find .benchmarks/ -type f -printf "%T@ %p\n" | sort -n | cut -d' ' -f 2- | tail -n 1 > temp1 @@ -75,22 +128,36 @@ jobs: mkdir compare_results cp $t1 compare_results cp $t2 compare_results + + - name: Download artifact + if: always() && env.has_changes == 'true' + uses: actions/download-artifact@v4 + with: + pattern: benchmark_artifact_* + path: tests/benchmarks + - name: Compare latest commit results to the master branch results + if: env.has_changes == 'true' run: | - pwd + source .venv-${{ matrix.python-version }}/bin/activate cd tests/benchmarks + pwd python compare_bench_results.py cat commit_msg.txt - - name: comment PR with the results + + - name: Comment PR with the results + if: env.has_changes == 'true' uses: thollander/actions-comment-pull-request@v2 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: filePath: tests/benchmarks/commit_msg.txt comment_tag: benchmark + - name: Upload benchmark data - if: always() + if: always() && env.has_changes == 'true' uses: actions/upload-artifact@v4 with: - name: benchmark_artifact + name: benchmark_artifact_${{ matrix.group }} path: tests/benchmarks/.benchmarks + include-hidden-files: true diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index 993e56ee4b..ae0be3dbb7 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -6,26 +6,48 @@ jobs: black_format: runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + strategy: + matrix: + python-version: ['3.10'] + steps: - uses: actions/checkout@v4 - - name: Set up Python 3.10 + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.10' - - name: Install dependencies + python-version: ${{ matrix.python-version }} + + - name: Restore Python environment cache + id: restore-env + uses: actions/cache/restore@v4 + with: + path: .venv-${{ matrix.python-version }} + key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Set up virtual environment if not restored from cache + if: steps.restore-env.outputs.cache-hit != 'true' run: | + gh cache list + python -m venv .venv-${{ matrix.python-version }} + source .venv-${{ matrix.python-version }}/bin/activate python -m pip install --upgrade pip pip install -r devtools/dev-requirements.txt + - name: Check files using the black formatter run: | + source .venv-${{ matrix.python-version }}/bin/activate black --version black --check desc/ tests/ || black_return_code=$? echo "BLACK_RETURN_CODE=$black_return_code" >> $GITHUB_ENV black desc/ tests/ + - name: Annotate diff changes using reviewdog uses: reviewdog/action-suggester@v1 with: tool_name: blackfmt + - name: Fail if not formatted run: | exit ${{ env.BLACK_RETURN_CODE }} diff --git a/.github/workflows/cache_dependencies.yml b/.github/workflows/cache_dependencies.yml new file mode 100644 index 0000000000..55da0c2e6d --- /dev/null +++ b/.github/workflows/cache_dependencies.yml @@ -0,0 +1,56 @@ +name: Cache dependencies +# This workflow is triggered every 2 days and updates the Python +# and pip dependencies cache +on: + schedule: + - cron: '30 8 */2 * *' # This triggers the workflow at 4:30 AM ET every 2 days + # cron syntax uses UTC time, so 4:30 AM ET is 8:30 AM UTC (for daylight time) + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Delete old cached file with same python version + run: | + echo "Current Cached files list" + gh cache list + echo "Deleting cached files with pattern: ${{ runner.os }}-venv-${{ matrix.python-version }}-" + for cache_key in $(gh cache list --json key -q ".[] | select(.key | startswith(\"${{ runner.os }}-venv-${{ matrix.python-version }}-\")) | .key"); do + echo "Deleting cache with key: $cache_key" + gh cache delete "$cache_key" + done + + - name: Set up virtual environment + run: | + python -m venv .venv-${{ matrix.python-version }} + source .venv-${{ matrix.python-version }}/bin/activate + python -m pip install --upgrade pip + pip install -r devtools/dev-requirements.txt + + - name: Cache Python environment + id: cache-env + uses: actions/cache@v4 + with: + path: .venv-${{ matrix.python-version }} + key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Verify virtual environment activation + run: | + source .venv-${{ matrix.python-version }}/bin/activate + python --version + pip --version + pip list diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index c3ea0f96e9..2eef77dcc0 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -5,17 +5,24 @@ on: [pull_request, workflow_dispatch] jobs: flake8_linting: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.10'] + name: Linting steps: - uses: actions/checkout@v4 - - name: Set up Python 3.10 + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: ${{ matrix.python-version }} + + # For some reason, loading venv makes this way slower - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r devtools/dev-requirements.txt + - name: flake8 Lint uses: reviewdog/action-flake8@v3 with: diff --git a/.github/workflows/nbtests.yml b/.github/workflows/nbtests.yml deleted file mode 100644 index 6d1fc6ca24..0000000000 --- a/.github/workflows/nbtests.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Notebook tests - -on: - push: - branches: - - master - - dev - pull_request: - branches: - - master - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - notebook_tests: - - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ['3.10'] - group: [1, 2] - - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r devtools/dev-requirements.txt - - name: Test notebooks with pytest and nbmake - run: | - pwd - lscpu - export PYTHONPATH=$(pwd) - pytest -v --nbmake "./docs/notebooks" \ - --nbmake-timeout=2000 \ - --ignore=./docs/notebooks/zernike_eval.ipynb \ - --splits 2 \ - --group ${{ matrix.group }} \ diff --git a/.github/workflows/notebook_tests.yml b/.github/workflows/notebook_tests.yml new file mode 100644 index 0000000000..88c6570c75 --- /dev/null +++ b/.github/workflows/notebook_tests.yml @@ -0,0 +1,83 @@ +name: Notebook tests + +on: + push: + branches: + - master + - dev + pull_request: + branches: + - master + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + notebook_tests: + + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + strategy: + matrix: + python-version: ['3.10'] + group: [1, 2, 3] + + steps: + - uses: actions/checkout@v4 + + - name: Filter changes + id: changes + uses: dorny/paths-filter@v3 + with: + filters: | + has_changes: + - 'desc/**' + - 'docs/notebooks/**' + - 'requirements.txt' + - 'devtools/dev-requirements.txt' + - 'setup.cfg' + - '.github/workflows/notebook_tests.yml' + + - name: Check for relevant changes + id: check_changes + run: echo "has_changes=${{ steps.changes.outputs.has_changes }}" >> $GITHUB_ENV + + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Restore Python environment cache + if: env.has_changes == 'true' + id: restore-env + uses: actions/cache/restore@v4 + with: + path: .venv-${{ matrix.python-version }} + key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Set up virtual environment if not restored from cache + if: steps.restore-env.outputs.cache-hit != 'true' && env.has_changes == 'true' + run: | + gh cache list + python -m venv .venv-${{ matrix.python-version }} + source .venv-${{ matrix.python-version }}/bin/activate + python -m pip install --upgrade pip + pip install -r devtools/dev-requirements.txt + + - name: Test notebooks with pytest and nbmake + if: env.has_changes == 'true' + run: | + source .venv-${{ matrix.python-version }}/bin/activate + pwd + lscpu + export PYTHONPATH=$(pwd) + pytest -v --nbmake "./docs/notebooks" \ + --nbmake-timeout=2000 \ + --ignore=./docs/notebooks/zernike_eval.ipynb \ + --splits 3 \ + --group ${{ matrix.group }} \ + --splitting-algorithm least_duration diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_tests.yml similarity index 57% rename from .github/workflows/regression_test.yml rename to .github/workflows/regression_tests.yml index d9ef1072d6..12ef17b1e0 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_tests.yml @@ -18,6 +18,8 @@ jobs: regression_tests: runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} strategy: matrix: python-version: ['3.10'] @@ -25,24 +27,61 @@ jobs: steps: - uses: actions/checkout@v4 + + - name: Filter changes + id: changes + uses: dorny/paths-filter@v3 + with: + filters: | + has_changes: + - 'desc/**' + - 'tests/**' + - 'requirements.txt' + - 'devtools/dev-requirements.txt' + - 'setup.cfg' + - '.github/workflows/regression_tests.yml' + + - name: Check for relevant changes + id: check_changes + run: echo "has_changes=${{ steps.changes.outputs.has_changes }}" >> $GITHUB_ENV + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + + - name: Restore Python environment cache + if: env.has_changes == 'true' + id: restore-env + uses: actions/cache/restore@v4 + with: + path: .venv-${{ matrix.python-version }} + key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Set up virtual environment if not restored from cache + if: steps.restore-env.outputs.cache-hit != 'true' && env.has_changes == 'true' run: | + gh cache list + python -m venv .venv-${{ matrix.python-version }} + source .venv-${{ matrix.python-version }}/bin/activate python -m pip install --upgrade pip pip install -r devtools/dev-requirements.txt pip install matplotlib==3.7.2 + - name: Set Swap Space + if: env.has_changes == 'true' uses: pierotofy/set-swap-space@master with: swap-size-gb: 10 + - name: Test with pytest + if: env.has_changes == 'true' run: | + source .venv-${{ matrix.python-version }}/bin/activate + pip install matplotlib==3.7.2 pwd lscpu - python -m pytest -v -m regression \ + python -m pytest -v -m regression\ --durations=0 \ --cov-report xml:cov.xml \ --cov-config=setup.cfg \ @@ -54,8 +93,9 @@ jobs: --group ${{ matrix.group }} \ --splitting-algorithm least_duration \ --db ./prof.db + - name: save coverage file and plot comparison results - if: always() + if: always() && env.has_changes == 'true' uses: actions/upload-artifact@v4 with: name: regression_test_artifact-${{ matrix.python-version }}-${{ matrix.group }} @@ -63,7 +103,9 @@ jobs: ./cov.xml ./mpl_results.html ./prof.db + - name: Upload coverage + if: env.has_changes == 'true' id : codecov uses: Wandalen/wretry.action@v1.3.0 with: diff --git a/.github/workflows/unittest.yml b/.github/workflows/unit_tests.yml similarity index 54% rename from .github/workflows/unittest.yml rename to .github/workflows/unit_tests.yml index 57e5881a05..fe58f21953 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unit_tests.yml @@ -18,30 +18,73 @@ jobs: unit_tests: runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} strategy: matrix: combos: [{group: 1, python_version: '3.9'}, {group: 2, python_version: '3.10'}, {group: 3, python_version: '3.11'}, - {group: 4, python_version: '3.12'}] + {group: 4, python_version: '3.12'}, + {group: 5, python_version: '3.12'}, + {group: 6, python_version: '3.12'}, + {group: 7, python_version: '3.12'}, + {group: 8, python_version: '3.12'}] steps: - uses: actions/checkout@v4 + + - name: Filter changes + id: changes + uses: dorny/paths-filter@v3 + with: + filters: | + has_changes: + - 'desc/**' + - 'tests/**' + - 'requirements.txt' + - 'devtools/dev-requirements.txt' + - 'setup.cfg' + - '.github/workflows/unit_tests.yml' + + - name: Check for relevant changes + id: check_changes + run: echo "has_changes=${{ steps.changes.outputs.has_changes }}" >> $GITHUB_ENV + - name: Set up Python ${{ matrix.combos.python_version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.combos.python_version }} - - name: Install dependencies + + - name: Restore Python environment cache + if: env.has_changes == 'true' + id: restore-env + uses: actions/cache/restore@v4 + with: + path: .venv-${{ matrix.combos.python_version }} + key: ${{ runner.os }}-venv-${{ matrix.combos.python_version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Set up virtual environment if not restored from cache + if: steps.restore-env.outputs.cache-hit != 'true' && env.has_changes == 'true' run: | + gh cache list + python -m venv .venv-${{ matrix.combos.python_version }} + source .venv-${{ matrix.combos.python_version }}/bin/activate python -m pip install --upgrade pip pip install -r devtools/dev-requirements.txt pip install matplotlib==3.7.2 + - name: Set Swap Space + if: env.has_changes == 'true' uses: pierotofy/set-swap-space@master with: swap-size-gb: 10 + - name: Test with pytest + if: env.has_changes == 'true' run: | + source .venv-${{ matrix.combos.python_version }}/bin/activate + pip install matplotlib==3.7.2 pwd lscpu python -m pytest -v -m unit \ @@ -52,12 +95,13 @@ jobs: --mpl \ --mpl-results-path=mpl_results.html \ --mpl-generate-summary=html \ - --splits 4 \ + --splits 8 \ --group ${{ matrix.combos.group }} \ --splitting-algorithm least_duration \ --db ./prof.db + - name: save coverage file and plot comparison results - if: always() + if: always() && env.has_changes == 'true' uses: actions/upload-artifact@v4 with: name: unit_test_artifact-${{ matrix.combos.python_version }}-${{ matrix.combos.group }} @@ -65,7 +109,9 @@ jobs: ./cov.xml ./mpl_results.html ./prof.db + - name: Upload coverage + if: env.has_changes == 'true' id : codecov uses: Wandalen/wretry.action@v1.3.0 with: diff --git a/.github/workflows/scheduled.yml b/.github/workflows/weekly_tests.yml similarity index 75% rename from .github/workflows/scheduled.yml rename to .github/workflows/weekly_tests.yml index a584db5cb8..2fb309bd8a 100644 --- a/.github/workflows/scheduled.yml +++ b/.github/workflows/weekly_tests.yml @@ -11,11 +11,10 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - combos: [{group: 1, python_version: '3.8'}, - {group: 2, python_version: '3.9'}, - {group: 3, python_version: '3.10'}, - {group: 4, python_version: '3.11'}, - {group: 5, python_version: '3.12'}] + combos: [{group: 1, python_version: '3.9'}, + {group: 2, python_version: '3.10'}, + {group: 3, python_version: '3.11'}, + {group: 4, python_version: '3.12'}] steps: - uses: actions/checkout@v4 @@ -37,6 +36,6 @@ jobs: lscpu python -m pytest -v -m unit \ --durations=0 \ - --splits 5 \ + --splits 4 \ --group ${{ matrix.combos.group }} \ --splitting-algorithm least_duration diff --git a/CHANGELOG.md b/CHANGELOG.md index aab80a4173..ccf156230a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,21 @@ Changelog New Features - Add ``use_signed_distance`` flag to ``PlasmaVesselDistance`` which will use a signed distance as the target, which is positive when the plasma is inside of the vessel surface and negative if the plasma is outside of the vessel surface, to allow optimizer to distinguish if the equilbrium surface exits the vessel surface and guard against it by targeting a positive signed distance. +- Add ``VectorPotentialField`` class to allow calculation of magnetic fields from a user-specified + vector potential function. +- Add ``compute_magnetic_vector_potential`` methods to most ``MagneticField`` objects to allow vector potential + computation. +- Add ability to save and load vector potential information from ``mgrid`` files. +- Changes ``ToroidalFlux`` objective to default using a 1D loop integral of the vector potential +to compute the toroidal flux when possible, as opposed to a 2D surface integral of the magnetic field dotted with ``n_zeta``. +- Allow specification of Nyquist spectrum maximum modenumbers when using ``VMECIO.save`` to save a DESC .h5 file as a VMEC-format wout file + +Bug Fixes + +- Fixes bugs that occur when saving asymmetric equilibria as wout files +- Fixes bug that occurs when using ``VMECIO.plot_vmec_comparison`` to compare to an asymmetric wout file + + v0.12.1 ------- diff --git a/README.rst b/README.rst index 18e4400c79..24e56055ad 100644 --- a/README.rst +++ b/README.rst @@ -111,12 +111,12 @@ Contribute :target: https://desc-docs.readthedocs.io/en/latest/?badge=latest :alt: Documentation -.. |UnitTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/unittest.yml/badge.svg - :target: https://github.com/PlasmaControl/DESC/actions/workflows/unittest.yml +.. |UnitTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/unit_tests.yml/badge.svg + :target: https://github.com/PlasmaControl/DESC/actions/workflows/unit_tests.yml :alt: UnitTests -.. |RegressionTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/regression_test.yml/badge.svg - :target: https://github.com/PlasmaControl/DESC/actions/workflows/regression_test.yml +.. |RegressionTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/regression_tests.yml/badge.svg + :target: https://github.com/PlasmaControl/DESC/actions/workflows/regression_tests.yml :alt: RegressionTests .. |Codecov| image:: https://codecov.io/gh/PlasmaControl/DESC/branch/master/graph/badge.svg?token=5LDR4B1O7Z diff --git a/codecov.yml b/codecov.yml index 8d3a272f14..e4c14a2bc1 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,7 +4,7 @@ comment: # this is a top-level key require_changes: false # if true: only post the comment if coverage changes require_base: true # [true :: must have a base report to post] require_head: true # [true :: must have a head report to post] - after_n_builds: 10 + after_n_builds: 14 coverage: status: patch: diff --git a/desc/backend.py b/desc/backend.py index c237ba1504..3b47cba4c5 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -71,7 +71,7 @@ imap = jax.lax.map from jax.experimental.ode import odeint from jax.lax import cond, fori_loop, scan, switch, while_loop - from jax.nn import softmax + from jax.nn import softmax as softargmax from jax.numpy import bincount, flatnonzero, repeat, take from jax.numpy.fft import irfft, rfft, rfft2 from jax.scipy.fft import dct, idct @@ -336,7 +336,7 @@ def root( This routine may be used on over or under-determined systems, in which case it will solve it in a least squares / least norm sense. """ - from desc.compute.utils import safenorm + from desc.utils import safenorm if fixup is None: fixup = lambda x, *args: x @@ -422,7 +422,8 @@ def tangent_solve(g, y): qr, solve_triangular, ) - from scipy.special import gammaln, logsumexp, softmax # noqa: F401 + from scipy.special import gammaln, logsumexp # noqa: F401 + from scipy.special import softmax as softargmax # noqa: F401 trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz diff --git a/desc/coils.py b/desc/coils.py index f184365918..9ffc5015c7 100644 --- a/desc/coils.py +++ b/desc/coils.py @@ -19,7 +19,6 @@ from desc.compute import get_params, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec from desc.compute.geom_utils import reflection_matrix from desc.compute.utils import _compute as compute_fun -from desc.compute.utils import safenorm from desc.geometry import ( FourierPlanarCurve, FourierRZCurve, @@ -29,7 +28,7 @@ from desc.grid import LinearGrid from desc.magnetic_fields import _MagneticField from desc.optimizable import Optimizable, OptimizableCollection, optimizable_parameter -from desc.utils import equals, errorif, flatten_list, warnif +from desc.utils import equals, errorif, flatten_list, safenorm, warnif @jit @@ -82,6 +81,53 @@ def biot_savart_hh(eval_pts, coil_pts_start, coil_pts_end, current): return B +@jit +def biot_savart_vector_potential_hh(eval_pts, coil_pts_start, coil_pts_end, current): + """Biot-Savart law for vector potential for filamentary coils following [1]. + + The coil is approximated by a series of straight line segments + and an analytic expression is used to evaluate the vector potential from each + segment. This expression assumes the Coulomb gauge. + + Parameters + ---------- + eval_pts : array-like shape(n,3) + Evaluation points in cartesian coordinates + coil_pts_start, coil_pts_end : array-like shape(m,3) + Points in cartesian space defining the start and end of each segment. + Should be a closed curve, such that coil_pts_start[0] == coil_pts_end[-1] + though this is not checked. + current : float + Current through the coil (in Amps). + + Returns + ------- + A : ndarray, shape(n,3) + Magnetic vector potential in cartesian components at specified points + + [1] Hanson & Hirshman, "Compact expressions for the Biot-Savart + fields of a filamentary segment" (2002) + """ + d_vec = coil_pts_end - coil_pts_start + L = jnp.linalg.norm(d_vec, axis=-1) + d_vec_over_L = ((1 / L) * d_vec.T).T + + Ri_vec = eval_pts[jnp.newaxis, :] - coil_pts_start[:, jnp.newaxis, :] + Ri = jnp.linalg.norm(Ri_vec, axis=-1) + Rf = jnp.linalg.norm( + eval_pts[jnp.newaxis, :] - coil_pts_end[:, jnp.newaxis, :], axis=-1 + ) + Ri_p_Rf = Ri + Rf + + eps = L[:, jnp.newaxis] / (Ri_p_Rf) + + A_mag = 1.0e-7 * current * jnp.log((1 + eps) / (1 - eps)) # 1.0e-7 == mu_0/(4 pi) + + # Now just need to multiply by e^ = d_vec/L = (x_f - x_i)/L + A = jnp.sum(A_mag[:, :, jnp.newaxis] * d_vec_over_L[:, jnp.newaxis, :], axis=0) + return A + + @jit def biot_savart_quad(eval_pts, coil_pts, tangents, current): """Biot-Savart law for filamentary coil using numerical quadrature. @@ -123,6 +169,42 @@ def biot_savart_quad(eval_pts, coil_pts, tangents, current): return B +@jit +def biot_savart_vector_potential_quad(eval_pts, coil_pts, tangents, current): + """Biot-Savart law (for A) for filamentary coil using numerical quadrature. + + This expression assumes the Coulomb gauge. + + Parameters + ---------- + eval_pts : array-like shape(n,3) + Evaluation points in cartesian coordinates + coil_pts : array-like shape(m,3) + Points in cartesian space defining coil + tangents : array-like, shape(m,3) + Tangent vectors to the coil at coil_pts. If the curve is given + by x(s) with curve parameter s, coil_pts = x, tangents = dx/ds*ds where + ds is the spacing between points. + current : float + Current through the coil (in Amps). + + Returns + ------- + A : ndarray, shape(n,3) + Magnetic vector potential in cartesian components at specified points. + """ + dl = tangents + R_vec = eval_pts[jnp.newaxis, :] - coil_pts[:, jnp.newaxis, :] + R_mag = jnp.linalg.norm(R_vec, axis=-1) + + vec = dl[:, jnp.newaxis, :] + denom = R_mag + + # 1e-7 == mu_0/(4 pi) + A = jnp.sum(1.0e-7 * current * vec / denom[:, :, None], axis=0) + return A + + class _Coil(_MagneticField, Optimizable, ABC): """Base class representing a magnetic field coil. @@ -187,10 +269,16 @@ def _compute_position(self, params=None, grid=None, **kwargs): x = x.at[:, :, 1].set(jnp.mod(x[:, :, 1], 2 * jnp.pi)) return x - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. The coil current may be overridden by including `current` in the `params` dictionary. @@ -208,6 +296,9 @@ def compute_magnetic_field( points. Should NOT include endpoint at 2pi. transforms : dict of Transform or array-like Transforms for R, Z, lambda, etc. Default is to build from grid. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns @@ -223,6 +314,14 @@ def compute_magnetic_field( may not be zero if not fully converged. """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) + op = {"B": biot_savart_quad, "A": biot_savart_vector_potential_quad}[ + compute_A_or_B + ] assert basis.lower() in ["rpz", "xyz"] coords = jnp.atleast_2d(jnp.asarray(coords)) if basis.lower() == "rpz": @@ -256,13 +355,87 @@ def compute_magnetic_field( data["x_s"] = rpz2xyz_vec(data["x_s"], phi=data["x"][:, 1]) data["x"] = rpz2xyz(data["x"]) - B = biot_savart_quad( - coords, data["x"], data["x_s"] * data["ds"][:, None], current - ) + AB = op(coords, data["x"], data["x_s"] * data["ds"][:, None], current) if basis.lower() == "rpz": - B = xyz2rpz_vec(B, phi=phi) - return B + AB = xyz2rpz_vec(AB, phi=phi) + return AB + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + The coil current may be overridden by including `current` + in the `params` dictionary. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict, optional + Parameters to pass to Curve. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None, optional + Grid used to discretize coil. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + + Returns + ------- + field : ndarray, shape(n,3) + magnetic field at specified points, in either rpz or xyz coordinates + + Notes + ----- + Uses direct quadrature of the Biot-Savart integral for filamentary coils with + tangents provided by the underlying curve class. Convergence should be + exponential in the number of points used to discretize the curve, though curl(B) + may not be zero if not fully converged. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + The coil current may be overridden by including `current` + in the `params` dictionary. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict, optional + Parameters to pass to Curve. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None, optional + Grid used to discretize coil. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + Returns + ------- + vector_potential : ndarray, shape(n,3) + Magnetic vector potential at specified points, in either rpz or + xyz coordinates. + + Notes + ----- + Uses direct quadrature of the Biot-Savart integral for filamentary coils with + tangents provided by the underlying curve class. Convergence should be + exponential in the number of points used to discretize the curve, though curl(B) + may not be zero if not fully converged. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") def __repr__(self): """Get the string form of the object.""" @@ -783,10 +956,16 @@ def __init__( ): super().__init__(current, X, Y, Z, knots, method, name) - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. The coil current may be overridden by including `current` in the `params` dictionary. @@ -804,6 +983,9 @@ def compute_magnetic_field( points. Should NOT include endpoint at 2pi. transforms : dict of Transform or array-like Transforms for R, Z, lambda, etc. Default is to build from grid. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- @@ -817,6 +999,12 @@ def compute_magnetic_field( is approximately quadratic in the number of coil points. """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) + op = {"B": biot_savart_hh, "A": biot_savart_vector_potential_hh}[compute_A_or_B] assert basis.lower() in ["rpz", "xyz"] coords = jnp.atleast_2d(jnp.asarray(coords)) if basis == "rpz": @@ -826,7 +1014,9 @@ def compute_magnetic_field( else: current = params.pop("current", self.current) - data = self.compute(["x"], grid=source_grid, params=params, basis="xyz") + data = self.compute( + ["x"], grid=source_grid, params=params, basis="xyz", transforms=transforms + ) # need to make sure the curve is closed. If it's already closed, this doesn't # do anything (effectively just adds a segment of zero length which has no # effect on the overall result) @@ -837,11 +1027,85 @@ def compute_magnetic_field( # coils curvature which is a 2nd derivative of the position, and doing that # with only possibly c1 cubic splines is inaccurate, so we don't do it # (for now, maybe in the future?) - B = biot_savart_hh(coords, coil_pts_start, coil_pts_end, current) + AB = op(coords, coil_pts_start, coil_pts_end, current) if basis == "rpz": - B = xyz2rpz_vec(B, x=coords[:, 0], y=coords[:, 1]) - return B + AB = xyz2rpz_vec(AB, x=coords[:, 0], y=coords[:, 1]) + return AB + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + The coil current may be overridden by including `current` + in the `params` dictionary. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict, optional + Parameters to pass to Curve. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None, optional + Grid used to discretize coil. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + Returns + ------- + field : ndarray, shape(n,3) + magnetic field at specified points, in either rpz or xyz coordinates + + Notes + ----- + Discretizes the coil into straight segments between grid points, and uses the + Hanson-Hirshman expression for exact field from a straight segment. Convergence + is approximately quadratic in the number of coil points. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + The coil current may be overridden by including `current` + in the `params` dictionary. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate magnetic vector potential at in [R,phi,Z] + or [X,Y,Z] coordinates. + params : dict, optional + Parameters to pass to Curve. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None, optional + Grid used to discretize coil. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + Returns + ------- + A : ndarray, shape(n,3) + Magnetic vector potential at specified points, in either + rpz or xyz coordinates + + Notes + ----- + Discretizes the coil into straight segments between grid points, and uses the + Hanson-Hirshman expression for exact vector potential from a straight segment. + Convergence is approximately quadratic in the number of coil points. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") @classmethod def from_values( @@ -1153,8 +1417,14 @@ def _compute_position(self, params=None, grid=None, **kwargs): x = rpz return x - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): """Compute magnetic field at a set of points. @@ -1171,13 +1441,22 @@ def compute_magnetic_field( points. Should NOT include endpoint at 2pi. transforms : dict of Transform or array-like Transforms for R, Z, lambda, etc. Default is to build from grid. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- field : ndarray, shape(n,3) - Magnetic field at specified nodes, in [R,phi,Z] or [X,Y,Z] coordinates. + Magnetic field or vector potential at specified nodes, in [R,phi,Z] + or [X,Y,Z] coordinates. """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) assert basis.lower() in ["rpz", "xyz"] coords = jnp.atleast_2d(jnp.asarray(coords)) if params is None: @@ -1207,31 +1486,89 @@ def compute_magnetic_field( # field period rotation is easiest in [R,phi,Z] coordinates coords_rpz = xyz2rpz(coords_xyz) + op = { + "B": self[0].compute_magnetic_field, + "A": self[0].compute_magnetic_vector_potential, + }[compute_A_or_B] # sum the magnetic fields from each field period - def nfp_loop(k, B): + def nfp_loop(k, AB): coords_nfp = coords_rpz + jnp.array([0, 2 * jnp.pi * k / self.NFP, 0]) - def body(B, x): - B += self[0].compute_magnetic_field( - coords_nfp, params=x, basis="rpz", source_grid=source_grid - ) - return B, None + def body(AB, x): + AB += op(coords_nfp, params=x, basis="rpz", source_grid=source_grid) + return AB, None - B += scan(body, jnp.zeros(coords_nfp.shape), tree_stack(params))[0] - return B + AB += scan(body, jnp.zeros(coords_nfp.shape), tree_stack(params))[0] + return AB - B = fori_loop(0, self.NFP, nfp_loop, jnp.zeros_like(coords_rpz)) + AB = fori_loop(0, self.NFP, nfp_loop, jnp.zeros_like(coords_rpz)) - # sum the magnetic fields from both halves of the symmetric field period + # sum the magnetic field/potential from both halves of + # the symmetric field period if self.sym: - B = B[: coords.shape[0], :] + B[coords.shape[0] :, :] * jnp.array( + AB = AB[: coords.shape[0], :] + AB[coords.shape[0] :, :] * jnp.array( [-1, 1, 1] ) if basis.lower() == "xyz": - B = rpz2xyz_vec(B, x=coords[:, 0], y=coords[:, 1]) - return B + AB = rpz2xyz_vec(AB, x=coords[:, 0], y=coords[:, 1]) + return AB + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Parameters to pass to coils, either the same for all coils or one for each. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None, optional + Grid used to discretize coils. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + Returns + ------- + field : ndarray, shape(n,3) + Magnetic field at specified nodes, in [R,phi,Z] or [X,Y,Z] coordinates. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Parameters to pass to coils, either the same for all coils or one for each. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None, optional + Grid used to discretize coils. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + Returns + ------- + vector_potential : ndarray, shape(n,3) + magnetic vector potential at specified points, in either rpz + or xyz coordinates + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") @classmethod def linspaced_angular( @@ -2002,6 +2339,65 @@ def _compute_position(self, params=None, grid=None, **kwargs): ) return x + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", + ): + """Compute magnetic field or vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Parameters to pass to coils, either the same for all coils or one for each. + If array-like, should be 1 value per coil. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize coils. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + If array-like, should be 1 value per coil. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" + + Returns + ------- + field : ndarray, shape(n,3) + magnetic field or vector potential at specified points, in either rpz + or xyz coordinates + + """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) + params = self._make_arraylike(params) + source_grid = self._make_arraylike(source_grid) + transforms = self._make_arraylike(transforms) + + AB = 0 + if compute_A_or_B == "B": + for coil, par, grd, tr in zip(self.coils, params, source_grid, transforms): + AB += coil.compute_magnetic_field( + coords, par, basis, grd, transforms=tr + ) + elif compute_A_or_B == "A": + for coil, par, grd, tr in zip(self.coils, params, source_grid, transforms): + AB += coil.compute_magnetic_vector_potential( + coords, par, basis, grd, transforms=tr + ) + return AB + def compute_magnetic_field( self, coords, params=None, basis="rpz", source_grid=None, transforms=None ): @@ -2029,15 +2425,37 @@ def compute_magnetic_field( magnetic field at specified points, in either rpz or xyz coordinates """ - params = self._make_arraylike(params) - source_grid = self._make_arraylike(source_grid) - transforms = self._make_arraylike(transforms) + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") - B = 0 - for coil, par, grd, tr in zip(self.coils, params, source_grid, transforms): - B += coil.compute_magnetic_field(coords, par, basis, grd, transforms=tr) + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Parameters to pass to coils, either the same for all coils or one for each. + If array-like, should be 1 value per coil. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize coils. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + If array-like, should be 1 value per coil. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. - return B + Returns + ------- + vector_potential : ndarray, shape(n,3) + magnetic vector potential at specified points, in either rpz + or xyz coordinates + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") def to_FourierPlanar( self, N=10, grid=None, basis="xyz", name="", check_intersection=False diff --git a/desc/compute/_basis_vectors.py b/desc/compute/_basis_vectors.py index 8fca1346d2..72803a1d7a 100644 --- a/desc/compute/_basis_vectors.py +++ b/desc/compute/_basis_vectors.py @@ -11,8 +11,8 @@ from desc.backend import jnp +from ..utils import cross, dot, safediv from .data_index import register_compute_fun -from .utils import cross, dot, safediv @register_compute_fun( diff --git a/desc/compute/_core.py b/desc/compute/_core.py index 7bc6e8acb3..c61b6974e6 100644 --- a/desc/compute/_core.py +++ b/desc/compute/_core.py @@ -2764,6 +2764,25 @@ def _phi_rr(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="phi_rrz", + label="\\partial_{\\rho \\rho \\zeta} \\phi", + units="rad", + units_long="radians", + description="Toroidal angle in lab frame, second derivative wrt radial coordinate " + "and first wrt DESC toroidal coordinate", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["omega_rrz"], +) +def _phi_rrz(params, transforms, profiles, data, **kwargs): + data["phi_rrz"] = data["omega_rrz"] + return data + + @register_compute_fun( name="phi_rt", label="\\partial_{\\rho \\theta} \\phi", @@ -2783,6 +2802,25 @@ def _phi_rt(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="phi_rtz", + label="\\partial_{\\rho \\theta \\zeta} \\phi", + units="rad", + units_long="radians", + description="Toroidal angle in lab frame, third derivative wrt radial, " + "poloidal, and toroidal coordinates", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["omega_rtz"], +) +def _phi_rtz(params, transforms, profiles, data, **kwargs): + data["phi_rtz"] = data["omega_rtz"] + return data + + @register_compute_fun( name="phi_rz", label="\\partial_{\\rho \\zeta} \\phi", @@ -2802,6 +2840,25 @@ def _phi_rz(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="phi_rzz", + label="\\partial_{\\rho \\zeta \\zeta} \\phi", + units="rad", + units_long="radians", + description="Toroidal angle in lab frame, first derivative wrt radial and " + "second derivative wrt DESC toroidal coordinate", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["omega_rzz"], +) +def _phi_rzz(params, transforms, profiles, data, **kwargs): + data["phi_rzz"] = data["omega_rzz"] + return data + + @register_compute_fun( name="phi_t", label="\\partial_{\\theta} \\phi", @@ -2843,12 +2900,31 @@ def _phi_tt(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="phi_ttz", + label="\\partial_{\\theta \\theta \\zeta} \\phi", + units="rad", + units_long="radians", + description="Toroidal angle in lab frame, second derivative wrt poloidal " + "coordinate and first derivative wrt toroidal coordinate", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["omega_ttz"], +) +def _phi_ttz(params, transforms, profiles, data, **kwargs): + data["phi_ttz"] = data["omega_ttz"] + return data + + @register_compute_fun( name="phi_tz", label="\\partial_{\\theta \\zeta} \\phi", units="rad", units_long="radians", - description="Toroidal angle in lab frame, second derivative wrt poloidal and " + description="Toroidal angle in lab frame, derivative wrt poloidal and " "toroidal coordinate", dim=1, params=[], @@ -2862,6 +2938,25 @@ def _phi_tz(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="phi_tzz", + label="\\partial_{\\theta \\zeta \\zeta} \\phi", + units="rad", + units_long="radians", + description="Toroidal angle in lab frame, derivative wrt poloidal coordinate and " + "second derivative wrt toroidal coordinate", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["omega_tzz"], +) +def _phi_tzz(params, transforms, profiles, data, **kwargs): + data["phi_tzz"] = data["omega_tzz"] + return data + + @register_compute_fun( name="phi_z", label="\\partial_{\\zeta} \\phi", @@ -2903,6 +2998,25 @@ def _phi_zz(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="phi_zzz", + label="\\partial_{\\zeta \\zeta \\zeta} \\phi", + units="rad", + units_long="radians", + description="Toroidal angle in lab frame, third derivative wrt toroidal " + "coordinate", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["omega_zzz"], +) +def _phi_zzz(params, transforms, profiles, data, **kwargs): + data["phi_zzz"] = data["omega_zzz"] + return data + + @register_compute_fun( name="rho", label="\\rho", @@ -2986,6 +3100,83 @@ def _theta_PEST_r(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="theta_PEST_rt", + label="\\partial_{\\rho \\theta} \\vartheta", + units="rad", + units_long="radians", + description="PEST straight field line poloidal angular coordinate, derivative wrt " + "radial and DESC poloidal coordinate", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["lambda_rt"], +) +def _theta_PEST_rt(params, transforms, profiles, data, **kwargs): + data["theta_PEST_rt"] = data["lambda_rt"] + return data + + +@register_compute_fun( + name="theta_PEST_rrt", + label="\\partial_{\\rho \\rho \\theta} \\vartheta", + units="rad", + units_long="radians", + description="PEST straight field line poloidal angular coordinate, second " + "derivative wrt radial coordinate and first derivative wrt DESC poloidal " + "coordinate", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["lambda_rrt"], +) +def _theta_PEST_rrt(params, transforms, profiles, data, **kwargs): + data["theta_PEST_rrt"] = data["lambda_rrt"] + return data + + +@register_compute_fun( + name="theta_PEST_rtz", + label="\\partial_{\\rho \\theta \\zeta} \\vartheta", + units="rad", + units_long="radians", + description="PEST straight field line poloidal angular coordinate, derivative wrt " + "radial and DESC poloidal and toroidal coordinates", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["lambda_rtz"], +) +def _theta_PEST_rtz(params, transforms, profiles, data, **kwargs): + data["theta_PEST_rtz"] = data["lambda_rtz"] + return data + + +@register_compute_fun( + name="theta_PEST_rtt", + label="\\partial_{\\rho \\theta \\theta} \\vartheta", + units="rad", + units_long="radians", + description="PEST straight field line poloidal angular coordinate, derivative wrt " + "radial coordinate once and DESC poloidal coordinate twice", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["lambda_rtt"], +) +def _theta_PEST_rtt(params, transforms, profiles, data, **kwargs): + data["theta_PEST_rtt"] = data["lambda_rtt"] + return data + + @register_compute_fun( name="theta_PEST_t", label="\\partial_{\\theta} \\vartheta", @@ -3024,6 +3215,25 @@ def _theta_PEST_tt(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="theta_PEST_ttt", + label="\\partial_{\\theta \\theta \\theta} \\vartheta", + units="rad", + units_long="radians", + description="PEST straight field line poloidal angular coordinate, third " + "derivative wrt poloidal coordinate", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["lambda_ttt"], +) +def _theta_PEST_ttt(params, transforms, profiles, data, **kwargs): + data["theta_PEST_ttt"] = data["lambda_ttt"] + return data + + @register_compute_fun( name="theta_PEST_tz", label="\\partial_{\\theta \\zeta} \\vartheta", @@ -3043,6 +3253,25 @@ def _theta_PEST_tz(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="theta_PEST_tzz", + label="\\partial_{\\theta \\zeta \\zeta} \\vartheta", + units="rad", + units_long="radians", + description="PEST straight field line poloidal angular coordinate, derivative wrt " + "poloidal coordinate once and toroidal coordinate twice", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["lambda_tzz"], +) +def _theta_PEST_tzz(params, transforms, profiles, data, **kwargs): + data["theta_PEST_tzz"] = data["lambda_tzz"] + return data + + @register_compute_fun( name="theta_PEST_z", label="\\partial_{\\zeta} \\vartheta", @@ -3081,6 +3310,25 @@ def _theta_PEST_zz(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="theta_PEST_ttz", + label="\\partial_{\\theta \\theta \\zeta} \\vartheta", + units="rad", + units_long="radians", + description="PEST straight field line poloidal angular coordinate, second " + "derivative wrt poloidal coordinate and derivative wrt toroidal coordinate", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["lambda_ttz"], +) +def _theta_PEST_ttz(params, transforms, profiles, data, **kwargs): + data["theta_PEST_ttz"] = data["lambda_ttz"] + return data + + @register_compute_fun( name="zeta", label="\\zeta", diff --git a/desc/compute/_curve.py b/desc/compute/_curve.py index 2e96e7a767..f8e4bbeb8c 100644 --- a/desc/compute/_curve.py +++ b/desc/compute/_curve.py @@ -2,9 +2,9 @@ from desc.backend import jnp, sign +from ..utils import cross, dot, safenormalize from .data_index import register_compute_fun from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec -from .utils import cross, dot, safenormalize @register_compute_fun( diff --git a/desc/compute/_equil.py b/desc/compute/_equil.py index 123b7d39fe..de0fad7797 100644 --- a/desc/compute/_equil.py +++ b/desc/compute/_equil.py @@ -15,8 +15,8 @@ from desc.backend import jnp from ..integrals.surface_integral import surface_averages +from ..utils import cross, dot, safediv, safenorm from .data_index import register_compute_fun -from .utils import cross, dot, safediv, safenorm @register_compute_fun( @@ -625,7 +625,7 @@ def _e_sup_helical_times_sqrt_g_mag(params, transforms, profiles, data, **kwargs @register_compute_fun( name="F_anisotropic", - label="F_{anisotropic}", + label="F_{\\mathrm{anisotropic}}", units="N \\cdot m^{-3}", units_long="Newtons / cubic meter", description="Anisotropic force balance error", diff --git a/desc/compute/_field.py b/desc/compute/_field.py index 8af2e8368c..97cb44515f 100644 --- a/desc/compute/_field.py +++ b/desc/compute/_field.py @@ -19,8 +19,8 @@ surface_max, surface_min, ) +from ..utils import cross, dot, safediv, safenorm from .data_index import register_compute_fun -from .utils import cross, dot, safediv, safenorm @register_compute_fun( @@ -74,11 +74,12 @@ def _B_sup_rho(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["psi_r/sqrt(g)", "iota", "lambda_z", "omega_z"], + data=["psi_r/sqrt(g)", "iota", "phi_z", "lambda_z"], ) def _B_sup_theta(params, transforms, profiles, data, **kwargs): + # Assumes θ = ϑ − λ. data["B^theta"] = data["psi_r/sqrt(g)"] * ( - data["iota"] * data["omega_z"] + data["iota"] - data["lambda_z"] + data["iota"] * data["phi_z"] - data["lambda_z"] ) return data @@ -94,11 +95,12 @@ def _B_sup_theta(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["psi_r/sqrt(g)", "iota", "lambda_t", "omega_t"], + data=["psi_r/sqrt(g)", "iota", "theta_PEST_t", "omega_t"], ) def _B_sup_zeta(params, transforms, profiles, data, **kwargs): + # Assumes ζ = ϕ − ω. data["B^zeta"] = data["psi_r/sqrt(g)"] * ( - -data["iota"] * data["omega_t"] + data["lambda_t"] + 1 + -data["iota"] * data["omega_t"] + data["theta_PEST_t"] ) return data @@ -224,21 +226,18 @@ def _psi_r_over_sqrtg_r(params, transforms, profiles, data, **kwargs): "(psi_r/sqrt(g))_r", "iota", "iota_r", + "phi_rz", + "phi_z", "lambda_rz", "lambda_z", - "omega_rz", - "omega_z", ], ) def _B_sup_theta_r(params, transforms, profiles, data, **kwargs): data["B^theta_r"] = data["psi_r/sqrt(g)"] * ( - data["iota"] * data["omega_rz"] - + data["iota_r"] * data["omega_z"] - + data["iota_r"] + data["iota"] * data["phi_rz"] + + data["iota_r"] * data["phi_z"] - data["lambda_rz"] - ) + data["(psi_r/sqrt(g))_r"] * ( - data["iota"] * data["omega_z"] + data["iota"] - data["lambda_z"] - ) + ) + data["(psi_r/sqrt(g))_r"] * (data["iota"] * data["phi_z"] - data["lambda_z"]) return data @@ -261,8 +260,8 @@ def _B_sup_theta_r(params, transforms, profiles, data, **kwargs): "(psi_r/sqrt(g))_r", "iota", "iota_r", - "lambda_rt", - "lambda_t", + "theta_PEST_rt", + "theta_PEST_t", "omega_rt", "omega_t", ], @@ -271,9 +270,9 @@ def _B_sup_zeta_r(params, transforms, profiles, data, **kwargs): data["B^zeta_r"] = data["psi_r/sqrt(g)"] * ( -data["iota"] * data["omega_rt"] - data["iota_r"] * data["omega_t"] - + data["lambda_rt"] + + data["theta_PEST_rt"] ) + data["(psi_r/sqrt(g))_r"] * ( - -data["iota"] * data["omega_t"] + data["lambda_t"] + 1 + -data["iota"] * data["omega_t"] + data["theta_PEST_t"] ) return data @@ -352,16 +351,14 @@ def _psi_r_over_sqrtg_t(params, transforms, profiles, data, **kwargs): "iota", "lambda_tz", "lambda_z", - "omega_tz", - "omega_z", + "phi_tz", + "phi_z", ], ) def _B_sup_theta_t(params, transforms, profiles, data, **kwargs): data["B^theta_t"] = data["psi_r/sqrt(g)"] * ( - data["iota"] * data["omega_tz"] - data["lambda_tz"] - ) + data["(psi_r/sqrt(g))_t"] * ( - data["iota"] * data["omega_z"] + data["iota"] - data["lambda_z"] - ) + data["iota"] * data["phi_tz"] - data["lambda_tz"] + ) + data["(psi_r/sqrt(g))_t"] * (data["iota"] * data["phi_z"] - data["lambda_z"]) return data @@ -383,17 +380,17 @@ def _B_sup_theta_t(params, transforms, profiles, data, **kwargs): "psi_r/sqrt(g)", "(psi_r/sqrt(g))_t", "iota", - "lambda_t", - "lambda_tt", + "theta_PEST_t", + "theta_PEST_tt", "omega_t", "omega_tt", ], ) def _B_sup_zeta_t(params, transforms, profiles, data, **kwargs): data["B^zeta_t"] = data["psi_r/sqrt(g)"] * ( - -data["iota"] * data["omega_tt"] + data["lambda_tt"] + -data["iota"] * data["omega_tt"] + data["theta_PEST_tt"] ) + data["(psi_r/sqrt(g))_t"] * ( - -data["iota"] * data["omega_t"] + data["lambda_t"] + 1 + -data["iota"] * data["omega_t"] + data["theta_PEST_t"] ) return data @@ -472,16 +469,14 @@ def _psi_r_over_sqrtg_z(params, transforms, profiles, data, **kwargs): "iota", "lambda_z", "lambda_zz", - "omega_z", - "omega_zz", + "phi_z", + "phi_zz", ], ) def _B_sup_theta_z(params, transforms, profiles, data, **kwargs): data["B^theta_z"] = data["psi_r/sqrt(g)"] * ( - data["iota"] * data["omega_zz"] - data["lambda_zz"] - ) + data["(psi_r/sqrt(g))_z"] * ( - data["iota"] * data["omega_z"] + data["iota"] - data["lambda_z"] - ) + data["iota"] * data["phi_zz"] - data["lambda_zz"] + ) + data["(psi_r/sqrt(g))_z"] * (data["iota"] * data["phi_z"] - data["lambda_z"]) return data @@ -503,17 +498,17 @@ def _B_sup_theta_z(params, transforms, profiles, data, **kwargs): "psi_r/sqrt(g)", "(psi_r/sqrt(g))_z", "iota", - "lambda_t", - "lambda_tz", + "theta_PEST_t", + "theta_PEST_tz", "omega_t", "omega_tz", ], ) def _B_sup_zeta_z(params, transforms, profiles, data, **kwargs): data["B^zeta_z"] = data["psi_r/sqrt(g)"] * ( - -data["iota"] * data["omega_tz"] + data["lambda_tz"] + -data["iota"] * data["omega_tz"] + data["theta_PEST_tz"] ) + data["(psi_r/sqrt(g))_z"] * ( - -data["iota"] * data["omega_t"] + data["lambda_t"] + 1 + -data["iota"] * data["omega_t"] + data["theta_PEST_t"] ) return data @@ -650,31 +645,28 @@ def _psi_r_over_sqrtg_rr(params, transforms, profiles, data, **kwargs): "lambda_rrz", "lambda_rz", "lambda_z", - "omega_rrz", - "omega_rz", - "omega_z", + "phi_rrz", + "phi_rz", + "phi_z", ], ) def _B_sup_theta_rr(params, transforms, profiles, data, **kwargs): data["B^theta_rr"] = ( data["psi_r/sqrt(g)"] * ( - data["iota"] * data["omega_rrz"] - + 2 * data["iota_r"] * data["omega_rz"] - + data["iota_rr"] * data["omega_z"] - + data["iota_rr"] + data["iota"] * data["phi_rrz"] + + 2 * data["iota_r"] * data["phi_rz"] + + data["iota_rr"] * data["phi_z"] - data["lambda_rrz"] ) + 2 * data["(psi_r/sqrt(g))_r"] * ( - data["iota"] * data["omega_rz"] - + data["iota_r"] * data["omega_z"] - + data["iota_r"] + data["iota"] * data["phi_rz"] + + data["iota_r"] * data["phi_z"] - data["lambda_rz"] ) - + data["(psi_r/sqrt(g))_rr"] - * (data["iota"] * data["omega_z"] + data["iota"] - data["lambda_z"]) + + data["(psi_r/sqrt(g))_rr"] * (data["iota"] * data["phi_z"] - data["lambda_z"]) ) return data @@ -700,9 +692,9 @@ def _B_sup_theta_rr(params, transforms, profiles, data, **kwargs): "iota", "iota_r", "iota_rr", - "lambda_rrt", - "lambda_rt", - "lambda_t", + "theta_PEST_rrt", + "theta_PEST_rt", + "theta_PEST_t", "omega_rrt", "omega_rt", "omega_t", @@ -715,17 +707,17 @@ def _B_sup_zeta_rr(params, transforms, profiles, data, **kwargs): data["iota"] * data["omega_rrt"] + 2 * data["iota_r"] * data["omega_rt"] + data["iota_rr"] * data["omega_t"] - - data["lambda_rrt"] + - data["theta_PEST_rrt"] ) - 2 * data["(psi_r/sqrt(g))_r"] * ( data["iota"] * data["omega_rt"] + data["iota_r"] * data["omega_t"] - - data["lambda_rt"] + - data["theta_PEST_rt"] ) + data["(psi_r/sqrt(g))_rr"] - * (-data["iota"] * data["omega_t"] + data["lambda_t"] + 1) + * (-data["iota"] * data["omega_t"] + data["theta_PEST_t"]) ) return data @@ -820,19 +812,18 @@ def _psi_r_over_sqrtg_tt(params, transforms, profiles, data, **kwargs): "lambda_ttz", "lambda_tz", "lambda_z", - "omega_ttz", - "omega_tz", - "omega_z", + "phi_ttz", + "phi_tz", + "phi_z", ], ) def _B_sup_theta_tt(params, transforms, profiles, data, **kwargs): data["B^theta_tt"] = ( - data["psi_r/sqrt(g)"] * (data["iota"] * data["omega_ttz"] - data["lambda_ttz"]) + data["psi_r/sqrt(g)"] * (data["iota"] * data["phi_ttz"] - data["lambda_ttz"]) + 2 * data["(psi_r/sqrt(g))_t"] - * (data["iota"] * data["omega_tz"] - data["lambda_tz"]) - + data["(psi_r/sqrt(g))_tt"] - * (data["iota"] * data["omega_z"] + data["iota"] - data["lambda_z"]) + * (data["iota"] * data["phi_tz"] - data["lambda_tz"]) + + data["(psi_r/sqrt(g))_tt"] * (data["iota"] * data["phi_z"] - data["lambda_z"]) ) return data @@ -856,9 +847,9 @@ def _B_sup_theta_tt(params, transforms, profiles, data, **kwargs): "(psi_r/sqrt(g))_t", "(psi_r/sqrt(g))_tt", "iota", - "lambda_t", - "lambda_tt", - "lambda_ttt", + "theta_PEST_t", + "theta_PEST_tt", + "theta_PEST_ttt", "omega_t", "omega_tt", "omega_ttt", @@ -866,12 +857,13 @@ def _B_sup_theta_tt(params, transforms, profiles, data, **kwargs): ) def _B_sup_zeta_tt(params, transforms, profiles, data, **kwargs): data["B^zeta_tt"] = ( - -data["psi_r/sqrt(g)"] * (data["iota"] * data["omega_ttt"] - data["lambda_ttt"]) + -data["psi_r/sqrt(g)"] + * (data["iota"] * data["omega_ttt"] - data["theta_PEST_ttt"]) - 2 * data["(psi_r/sqrt(g))_t"] - * (data["iota"] * data["omega_tt"] - data["lambda_tt"]) + * (data["iota"] * data["omega_tt"] - data["theta_PEST_tt"]) + data["(psi_r/sqrt(g))_tt"] - * (-data["iota"] * data["omega_t"] + data["lambda_t"] + 1) + * (-data["iota"] * data["omega_t"] + data["theta_PEST_t"]) ) return data @@ -966,19 +958,18 @@ def _psi_r_over_sqrtg_zz(params, transforms, profiles, data, **kwargs): "lambda_z", "lambda_zz", "lambda_zzz", - "omega_z", - "omega_zz", - "omega_zzz", + "phi_z", + "phi_zz", + "phi_zzz", ], ) def _B_sup_theta_zz(params, transforms, profiles, data, **kwargs): data["B^theta_zz"] = ( - data["psi_r/sqrt(g)"] * (data["iota"] * data["omega_zzz"] - data["lambda_zzz"]) + data["psi_r/sqrt(g)"] * (data["iota"] * data["phi_zzz"] - data["lambda_zzz"]) + 2 * data["(psi_r/sqrt(g))_z"] - * (data["iota"] * data["omega_zz"] - data["lambda_zz"]) - + data["(psi_r/sqrt(g))_zz"] - * (data["iota"] * data["omega_z"] + data["iota"] - data["lambda_z"]) + * (data["iota"] * data["phi_zz"] - data["lambda_zz"]) + + data["(psi_r/sqrt(g))_zz"] * (data["iota"] * data["phi_z"] - data["lambda_z"]) ) return data @@ -1002,9 +993,9 @@ def _B_sup_theta_zz(params, transforms, profiles, data, **kwargs): "(psi_r/sqrt(g))_z", "(psi_r/sqrt(g))_zz", "iota", - "lambda_t", - "lambda_tz", - "lambda_tzz", + "theta_PEST_t", + "theta_PEST_tz", + "theta_PEST_tzz", "omega_t", "omega_tz", "omega_tzz", @@ -1012,12 +1003,13 @@ def _B_sup_theta_zz(params, transforms, profiles, data, **kwargs): ) def _B_sup_zeta_zz(params, transforms, profiles, data, **kwargs): data["B^zeta_zz"] = ( - -data["psi_r/sqrt(g)"] * (data["iota"] * data["omega_tzz"] - data["lambda_tzz"]) + -data["psi_r/sqrt(g)"] + * (data["iota"] * data["omega_tzz"] - data["theta_PEST_tzz"]) - 2 * data["(psi_r/sqrt(g))_z"] - * (data["iota"] * data["omega_tz"] - data["lambda_tz"]) + * (data["iota"] * data["omega_tz"] - data["theta_PEST_tz"]) + data["(psi_r/sqrt(g))_zz"] - * (-data["iota"] * data["omega_t"] + data["lambda_t"] + 1) + * (-data["iota"] * data["omega_t"] + data["theta_PEST_t"]) ) return data @@ -1120,31 +1112,29 @@ def _psi_r_over_sqrtg_rt(params, transforms, profiles, data, **kwargs): "lambda_rz", "lambda_tz", "lambda_z", - "omega_rtz", - "omega_rz", - "omega_tz", - "omega_z", + "phi_rtz", + "phi_rz", + "phi_tz", + "phi_z", ], ) def _B_sup_theta_rt(params, transforms, profiles, data, **kwargs): data["B^theta_rt"] = ( data["psi_r/sqrt(g)"] * ( - data["iota"] * data["omega_rtz"] - + data["iota_r"] * data["omega_tz"] + data["iota"] * data["phi_rtz"] + + data["iota_r"] * data["phi_tz"] - data["lambda_rtz"] ) + data["(psi_r/sqrt(g))_r"] - * (data["iota"] * data["omega_tz"] - data["lambda_tz"]) + * (data["iota"] * data["phi_tz"] - data["lambda_tz"]) + data["(psi_r/sqrt(g))_t"] * ( - data["iota"] * data["omega_rz"] - + data["iota_r"] * data["omega_z"] - + data["iota_r"] + data["iota"] * data["phi_rz"] + + data["iota_r"] * data["phi_z"] - data["lambda_rz"] ) - + data["(psi_r/sqrt(g))_rt"] - * (data["iota"] * data["omega_z"] + data["iota"] - data["lambda_z"]) + + data["(psi_r/sqrt(g))_rt"] * (data["iota"] * data["phi_z"] - data["lambda_z"]) ) return data @@ -1170,10 +1160,10 @@ def _B_sup_theta_rt(params, transforms, profiles, data, **kwargs): "(psi_r/sqrt(g))_t", "iota", "iota_r", - "lambda_rt", - "lambda_rtt", - "lambda_t", - "lambda_tt", + "theta_PEST_rt", + "theta_PEST_rtt", + "theta_PEST_t", + "theta_PEST_tt", "omega_rt", "omega_rtt", "omega_t", @@ -1186,18 +1176,18 @@ def _B_sup_zeta_rt(params, transforms, profiles, data, **kwargs): * ( data["iota"] * data["omega_rtt"] + data["iota_r"] * data["omega_tt"] - - data["lambda_rtt"] + - data["theta_PEST_rtt"] ) - data["(psi_r/sqrt(g))_r"] - * (data["iota"] * data["omega_tt"] - data["lambda_tt"]) + * (data["iota"] * data["omega_tt"] - data["theta_PEST_tt"]) - data["(psi_r/sqrt(g))_t"] * ( data["iota"] * data["omega_rt"] + data["iota_r"] * data["omega_t"] - - data["lambda_rt"] + - data["theta_PEST_rt"] ) + data["(psi_r/sqrt(g))_rt"] - * (-data["iota"] * data["omega_t"] + data["lambda_t"] + 1) + * (-data["iota"] * data["omega_t"] + data["theta_PEST_t"]) ) return data @@ -1307,21 +1297,20 @@ def _psi_r_over_sqrtg_tz(params, transforms, profiles, data, **kwargs): "lambda_tzz", "lambda_z", "lambda_zz", - "omega_tz", - "omega_tzz", - "omega_z", - "omega_zz", + "phi_tz", + "phi_tzz", + "phi_z", + "phi_zz", ], ) def _B_sup_theta_tz(params, transforms, profiles, data, **kwargs): data["B^theta_tz"] = ( - data["psi_r/sqrt(g)"] * (data["iota"] * data["omega_tzz"] - data["lambda_tzz"]) + data["psi_r/sqrt(g)"] * (data["iota"] * data["phi_tzz"] - data["lambda_tzz"]) + data["(psi_r/sqrt(g))_t"] - * (data["iota"] * data["omega_zz"] - data["lambda_zz"]) + * (data["iota"] * data["phi_zz"] - data["lambda_zz"]) + data["(psi_r/sqrt(g))_z"] - * (data["iota"] * data["omega_tz"] - data["lambda_tz"]) - + data["(psi_r/sqrt(g))_tz"] - * (data["iota"] * data["omega_z"] + data["iota"] - data["lambda_z"]) + * (data["iota"] * data["phi_tz"] - data["lambda_tz"]) + + data["(psi_r/sqrt(g))_tz"] * (data["iota"] * data["phi_z"] - data["lambda_z"]) ) return data @@ -1346,10 +1335,10 @@ def _B_sup_theta_tz(params, transforms, profiles, data, **kwargs): "(psi_r/sqrt(g))_tz", "(psi_r/sqrt(g))_z", "iota", - "lambda_t", - "lambda_tt", - "lambda_ttz", - "lambda_tz", + "theta_PEST_t", + "theta_PEST_tt", + "theta_PEST_ttz", + "theta_PEST_tz", "omega_t", "omega_tt", "omega_ttz", @@ -1358,13 +1347,14 @@ def _B_sup_theta_tz(params, transforms, profiles, data, **kwargs): ) def _B_sup_zeta_tz(params, transforms, profiles, data, **kwargs): data["B^zeta_tz"] = ( - -data["psi_r/sqrt(g)"] * (data["iota"] * data["omega_ttz"] - data["lambda_ttz"]) + -data["psi_r/sqrt(g)"] + * (data["iota"] * data["omega_ttz"] - data["theta_PEST_ttz"]) - data["(psi_r/sqrt(g))_t"] - * (data["iota"] * data["omega_tz"] - data["lambda_tz"]) + * (data["iota"] * data["omega_tz"] - data["theta_PEST_tz"]) - data["(psi_r/sqrt(g))_z"] - * (data["iota"] * data["omega_tt"] - data["lambda_tt"]) + * (data["iota"] * data["omega_tt"] - data["theta_PEST_tt"]) + data["(psi_r/sqrt(g))_tz"] - * (-data["iota"] * data["omega_t"] + data["lambda_t"] + 1) + * (-data["iota"] * data["omega_t"] + data["theta_PEST_t"]) ) return data @@ -1473,31 +1463,29 @@ def _psi_r_over_sqrtg_rz(params, transforms, profiles, data, **kwargs): "lambda_rzz", "lambda_z", "lambda_zz", - "omega_rz", - "omega_rzz", - "omega_z", - "omega_zz", + "phi_rz", + "phi_rzz", + "phi_z", + "phi_zz", ], ) def _B_sup_theta_rz(params, transforms, profiles, data, **kwargs): data["B^theta_rz"] = ( data["psi_r/sqrt(g)"] * ( - data["iota"] * data["omega_rzz"] - + data["iota_r"] * data["omega_zz"] + data["iota"] * data["phi_rzz"] + + data["iota_r"] * data["phi_zz"] - data["lambda_rzz"] ) + data["(psi_r/sqrt(g))_r"] - * (data["iota"] * data["omega_zz"] - data["lambda_zz"]) + * (data["iota"] * data["phi_zz"] - data["lambda_zz"]) + data["(psi_r/sqrt(g))_z"] * ( - data["iota"] * data["omega_rz"] - + data["iota_r"] * data["omega_z"] - + data["iota_r"] + data["iota"] * data["phi_rz"] + + data["iota_r"] * data["phi_z"] - data["lambda_rz"] ) - + data["(psi_r/sqrt(g))_rz"] - * (data["iota"] * data["omega_z"] + data["iota"] - data["lambda_z"]) + + data["(psi_r/sqrt(g))_rz"] * (data["iota"] * data["phi_z"] - data["lambda_z"]) ) return data @@ -1523,10 +1511,10 @@ def _B_sup_theta_rz(params, transforms, profiles, data, **kwargs): "(psi_r/sqrt(g))_z", "iota", "iota_r", - "lambda_rt", - "lambda_rtz", - "lambda_t", - "lambda_tz", + "theta_PEST_rt", + "theta_PEST_rtz", + "theta_PEST_t", + "theta_PEST_tz", "omega_rt", "omega_rtz", "omega_t", @@ -1539,18 +1527,18 @@ def _B_sup_zeta_rz(params, transforms, profiles, data, **kwargs): * ( data["iota"] * data["omega_rtz"] + data["iota_r"] * data["omega_tz"] - - data["lambda_rtz"] + - data["theta_PEST_rtz"] ) - data["(psi_r/sqrt(g))_r"] - * (data["iota"] * data["omega_tz"] - data["lambda_tz"]) + * (data["iota"] * data["omega_tz"] - data["theta_PEST_tz"]) - data["(psi_r/sqrt(g))_z"] * ( data["iota"] * data["omega_rt"] + data["iota_r"] * data["omega_t"] - - data["lambda_rt"] + - data["theta_PEST_rt"] ) + data["(psi_r/sqrt(g))_rz"] - * (-data["iota"] * data["omega_t"] + data["lambda_t"] + 1) + * (-data["iota"] * data["omega_t"] + data["theta_PEST_t"]) ) return data @@ -1654,6 +1642,25 @@ def _B_sub_zeta(params, transforms, profiles, data, **kwargs): return data +@register_compute_fun( + name="B_phi|r,t", + label="B_{\\phi} = B \\cdot \\mathbf{e}_{\\phi} |_{\\rho, \\theta}", + units="T \\cdot m", + units_long="Tesla * meters", + description="Covariant toroidal component of magnetic field in (ρ,θ,ϕ) " + "coordinates.", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["B", "e_phi|r,t"], +) +def _B_sub_phi_rt(params, transforms, profiles, data, **kwargs): + data["B_phi|r,t"] = dot(data["B"], data["e_phi|r,t"]) + return data + + @register_compute_fun( name="B_rho_r", label="\\partial_{\\rho} B_{\\rho}", @@ -2262,7 +2269,7 @@ def _B_sub_zeta_rz(params, transforms, profiles, data, **kwargs): @register_compute_fun( name="<|B|>_axis", - label="\\lange |\\mathbf{B}| \\rangle_{axis}", + label="\\langle |\\mathbf{B}| \\rangle_{axis}", units="T", units_long="Tesla", description="Average magnitude of magnetic field on the magnetic axis", diff --git a/desc/compute/_geometry.py b/desc/compute/_geometry.py index 139f91f537..662413501b 100644 --- a/desc/compute/_geometry.py +++ b/desc/compute/_geometry.py @@ -12,8 +12,8 @@ from desc.backend import jnp from ..integrals.surface_integral import line_integrals, surface_integrals +from ..utils import cross, dot, safenorm from .data_index import register_compute_fun -from .utils import cross, dot, safenorm @register_compute_fun( diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index ceb6703386..ed4ea48145 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -14,8 +14,8 @@ from desc.backend import jnp from ..integrals.surface_integral import surface_averages +from ..utils import cross, dot, safediv, safenorm from .data_index import register_compute_fun -from .utils import cross, dot, safediv, safenorm @register_compute_fun( diff --git a/desc/compute/_omnigenity.py b/desc/compute/_omnigenity.py index fa43976898..db37766882 100644 --- a/desc/compute/_omnigenity.py +++ b/desc/compute/_omnigenity.py @@ -13,51 +13,69 @@ from desc.backend import jnp, sign, vmap +from ..utils import cross, dot, safediv from .data_index import register_compute_fun -from .utils import cross, dot, safediv @register_compute_fun( name="B_theta_mn", label="B_{\\theta, m, n}", - units="T \\cdot m}", + units="T \\cdot m", units_long="Tesla * meters", description="Fourier coefficients for covariant poloidal component of " - + "magnetic field", + "magnetic field.", dim=1, params=[], - transforms={"B": [[0, 0, 0]]}, + transforms={"B": [[0, 0, 0]], "grid": []}, profiles=[], coordinates="rtz", data=["B_theta"], + resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", - resolution_requirement="tz", ) def _B_theta_mn(params, transforms, profiles, data, **kwargs): - data["B_theta_mn"] = transforms["B"].fit(data["B_theta"]) + B_theta = transforms["grid"].meshgrid_reshape(data["B_theta"], "rtz") + + def fitfun(x): + return transforms["B"].fit(x.flatten(order="F")) + + B_theta_mn = vmap(fitfun)(B_theta) + # modes stored as shape(rho, mn) flattened + data["B_theta_mn"] = B_theta_mn.flatten() return data +# TODO: do math to change definition of nu so that we can just use B_zeta_mn here @register_compute_fun( - name="B_zeta_mn", - label="B_{\\zeta, m, n}", - units="T \\cdot m}", + name="B_phi_mn", + label="B_{\\phi, m, n}", + units="T \\cdot m", units_long="Tesla * meters", description="Fourier coefficients for covariant toroidal component of " - + "magnetic field", + "magnetic field in (ρ,θ,ϕ) coordinates.", dim=1, params=[], transforms={"B": [[0, 0, 0]]}, profiles=[], coordinates="rtz", - data=["B_zeta"], + data=["B_phi|r,t"], + resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, + aliases="B_zeta_mn", # TODO: remove when phi != zeta M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", - resolution_requirement="tz", ) -def _B_zeta_mn(params, transforms, profiles, data, **kwargs): - data["B_zeta_mn"] = transforms["B"].fit(data["B_zeta"]) +def _B_phi_mn(params, transforms, profiles, data, **kwargs): + B_phi = transforms["grid"].meshgrid_reshape(data["B_phi|r,t"], "rtz") + + def fitfun(x): + return transforms["B"].fit(x.flatten(order="F")) + + B_zeta_mn = vmap(fitfun)(B_phi) + # modes stored as shape(rho, mn) flattened + data["B_phi_mn"] = B_zeta_mn.flatten() return data @@ -70,15 +88,16 @@ def _B_zeta_mn(params, transforms, profiles, data, **kwargs): + "Boozer Coordinates'", dim=1, params=[], - transforms={"w": [[0, 0, 0]], "B": [[0, 0, 0]]}, + transforms={"w": [[0, 0, 0]], "B": [[0, 0, 0]], "grid": []}, profiles=[], coordinates="rtz", - data=["B_theta_mn", "B_zeta_mn"], + data=["B_theta_mn", "B_phi_mn"], + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _w_mn(params, transforms, profiles, data, **kwargs): - w_mn = jnp.zeros((transforms["w"].basis.num_modes,)) + w_mn = jnp.zeros((transforms["grid"].num_rho, transforms["w"].basis.num_modes)) Bm = transforms["B"].basis.modes[:, 1] Bn = transforms["B"].basis.modes[:, 2] wm = transforms["w"].basis.modes[:, 1] @@ -87,15 +106,19 @@ def _w_mn(params, transforms, profiles, data, **kwargs): mask_t = (Bm[:, None] == -wm) & (Bn[:, None] == wn) & (wm != 0) mask_z = (Bm[:, None] == wm) & (Bn[:, None] == -wn) & (wm == 0) & (wn != 0) - num_t = (mask_t @ sign(wn)) * data["B_theta_mn"] + num_t = (mask_t @ sign(wn)) * data["B_theta_mn"].reshape( + (transforms["grid"].num_rho, -1) + ) den_t = mask_t @ jnp.abs(wm) - num_z = (mask_z @ sign(wm)) * data["B_zeta_mn"] + num_z = (mask_z @ sign(wm)) * data["B_phi_mn"].reshape( + (transforms["grid"].num_rho, -1) + ) den_z = mask_z @ jnp.abs(NFP * wn) - w_mn = jnp.where(mask_t.any(axis=0), mask_t.T @ safediv(num_t, den_t), w_mn) - w_mn = jnp.where(mask_z.any(axis=0), mask_z.T @ safediv(num_z, den_z), w_mn) + w_mn = jnp.where(mask_t.any(axis=0), (mask_t.T @ safediv(num_t, den_t).T).T, w_mn) + w_mn = jnp.where(mask_z.any(axis=0), (mask_z.T @ safediv(num_z, den_z).T).T, w_mn) - data["w_Boozer_mn"] = w_mn + data["w_Boozer_mn"] = w_mn.flatten() return data @@ -108,16 +131,22 @@ def _w_mn(params, transforms, profiles, data, **kwargs): + "'Transformation from VMEC to Boozer Coordinates'", dim=1, params=[], - transforms={"w": [[0, 0, 0]]}, + transforms={"w": [[0, 0, 0]], "grid": []}, profiles=[], coordinates="rtz", data=["w_Boozer_mn"], resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _w(params, transforms, profiles, data, **kwargs): - data["w_Boozer"] = transforms["w"].transform(data["w_Boozer_mn"]) + grid = transforms["grid"] + w_mn = data["w_Boozer_mn"].reshape((grid.num_rho, -1)) + w = vmap(transforms["w"].transform)(w_mn) # shape(rho, theta*zeta) + w = w.reshape((grid.num_rho, grid.num_theta, grid.num_zeta), order="F") + w = jnp.moveaxis(w, 0, 1) + data["w_Boozer"] = w.flatten(order="F") return data @@ -130,16 +159,24 @@ def _w(params, transforms, profiles, data, **kwargs): + "'Transformation from VMEC to Boozer Coordinates', poloidal derivative", dim=1, params=[], - transforms={"w": [[0, 1, 0]]}, + transforms={"w": [[0, 1, 0]], "grid": []}, profiles=[], coordinates="rtz", data=["w_Boozer_mn"], resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _w_t(params, transforms, profiles, data, **kwargs): - data["w_Boozer_t"] = transforms["w"].transform(data["w_Boozer_mn"], dt=1) + grid = transforms["grid"] + w_mn = data["w_Boozer_mn"].reshape((grid.num_rho, -1)) + # need to close over dt which can't be vmapped + fun = lambda x: transforms["w"].transform(x, dt=1) + w_t = vmap(fun)(w_mn) # shape(rho, theta*zeta) + w_t = w_t.reshape((grid.num_rho, grid.num_theta, grid.num_zeta), order="F") + w_t = jnp.moveaxis(w_t, 0, 1) + data["w_Boozer_t"] = w_t.flatten(order="F") return data @@ -152,16 +189,24 @@ def _w_t(params, transforms, profiles, data, **kwargs): + "'Transformation from VMEC to Boozer Coordinates', toroidal derivative", dim=1, params=[], - transforms={"w": [[0, 0, 1]]}, + transforms={"w": [[0, 0, 1]], "grid": []}, profiles=[], coordinates="rtz", data=["w_Boozer_mn"], resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _w_z(params, transforms, profiles, data, **kwargs): - data["w_Boozer_z"] = transforms["w"].transform(data["w_Boozer_mn"], dz=1) + grid = transforms["grid"] + w_mn = data["w_Boozer_mn"].reshape((grid.num_rho, -1)) + # need to close over dz which can't be vmapped + fun = lambda x: transforms["w"].transform(x, dz=1) + w_z = vmap(fun)(w_mn) # shape(rho, theta*zeta) + w_z = w_z.reshape((grid.num_rho, grid.num_theta, grid.num_zeta), order="F") + w_z = jnp.moveaxis(w_z, 0, 1) + data["w_Boozer_z"] = w_z.flatten(order="F") return data @@ -233,10 +278,10 @@ def _nu_z(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["theta", "lambda", "iota", "nu"], + data=["theta_PEST", "iota", "nu"], ) def _theta_B(params, transforms, profiles, data, **kwargs): - data["theta_B"] = data["theta"] + data["lambda"] + data["iota"] * data["nu"] + data["theta_B"] = data["theta_PEST"] + data["iota"] * data["nu"] return data @@ -251,10 +296,10 @@ def _theta_B(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["zeta", "nu"], + data=["phi", "nu"], ) def _zeta_B(params, transforms, profiles, data, **kwargs): - data["zeta_B"] = data["zeta"] + data["nu"] + data["zeta_B"] = data["phi"] + data["nu"] return data @@ -263,18 +308,20 @@ def _zeta_B(params, transforms, profiles, data, **kwargs): label="\\sqrt{g}_{B}", units="~", units_long="None", - description="Jacobian determinant of Boozer coordinates", + description="Jacobian determinant from Boozer to DESC coordinates", dim=1, params=[], transforms={}, profiles=[], coordinates="rtz", - data=["lambda_t", "lambda_z", "nu_t", "nu_z", "iota"], + data=["theta_PEST_t", "theta_PEST_z", "phi_t", "phi_z", "nu_t", "nu_z", "iota"], ) def _sqrtg_B(params, transforms, profiles, data, **kwargs): - data["sqrt(g)_B"] = (1 + data["lambda_t"]) * (1 + data["nu_z"]) + ( - data["iota"] - data["lambda_z"] - ) * data["nu_t"] + data["sqrt(g)_B"] = ( + data["theta_PEST_t"] * (data["phi_z"] + data["nu_z"]) + - data["theta_PEST_z"] * (data["phi_t"] + data["nu_t"]) + + data["iota"] * (data["nu_t"] * data["phi_z"] - data["nu_z"] * data["phi_t"]) + ) return data @@ -286,21 +333,38 @@ def _sqrtg_B(params, transforms, profiles, data, **kwargs): description="Boozer harmonics of magnetic field", dim=1, params=[], - transforms={"B": [[0, 0, 0]]}, + transforms={"B": [[0, 0, 0]], "grid": []}, profiles=[], coordinates="rtz", data=["sqrt(g)_B", "|B|", "rho", "theta_B", "zeta_B"], + resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _B_mn(params, transforms, profiles, data, **kwargs): - nodes = jnp.array([data["rho"], data["theta_B"], data["zeta_B"]]).T norm = 2 ** (3 - jnp.sum((transforms["B"].basis.modes == 0), axis=1)) - data["|B|_mn"] = ( - norm # 1 if m=n=0, 2 if m=0 or n=0, 4 if m!=0 and n!=0 - * (transforms["B"].basis.evaluate(nodes).T @ (data["sqrt(g)_B"] * data["|B|"])) - / transforms["B"].grid.num_nodes + grid = transforms["grid"] + + def fun(rho, theta_B, zeta_B, sqrtg_B, B): + # this fits Boozer modes on a single surface + nodes = jnp.array([rho, theta_B, zeta_B]).T + B_mn = ( + norm # 1 if m=n=0, 2 if m=0 or n=0, 4 if m!=0 and n!=0 + * (transforms["B"].basis.evaluate(nodes).T @ (sqrtg_B * B)) + / transforms["B"].grid.num_nodes + ) + return B_mn + + def reshape(x): + return grid.meshgrid_reshape(x, "rtz").reshape((grid.num_rho, -1)) + + rho, theta_B, zeta_B, sqrtg_B, B = map( + reshape, + (data["rho"], data["theta_B"], data["zeta_B"], data["sqrt(g)_B"], data["|B|"]), ) + B_mn = vmap(fun)(rho, theta_B, zeta_B, sqrtg_B, B) + data["|B|_mn"] = B_mn.flatten() return data diff --git a/desc/compute/_profiles.py b/desc/compute/_profiles.py index 40b875a9d7..65bca54b59 100644 --- a/desc/compute/_profiles.py +++ b/desc/compute/_profiles.py @@ -14,8 +14,8 @@ from desc.backend import cond, jnp from ..integrals.surface_integral import surface_averages, surface_integrals +from ..utils import cumtrapz, dot, safediv from .data_index import register_compute_fun -from .utils import cumtrapz, dot, safediv @register_compute_fun( @@ -764,6 +764,7 @@ def _iota(params, transforms, profiles, data, **kwargs): data["iota"] = profiles["iota"].compute(transforms["grid"], params["i_l"], dr=0) elif profiles["current"] is not None: # See the document attached to GitHub pull request #556 for the math. + # Assumes ζ = ϕ − ω and θ = ϑ − λ. data["iota"] = transforms["grid"].replace_at_axis( safediv(data["iota_num"], data["iota_den"]), lambda: safediv(data["iota_num_r"], data["iota_den_r"]), @@ -792,6 +793,7 @@ def _iota_r(params, transforms, profiles, data, **kwargs): ) elif profiles["current"] is not None: # See the document attached to GitHub pull request #556 for the math. + # Assumes ζ = ϕ − ω and θ = ϑ − λ. data["iota_r"] = transforms["grid"].replace_at_axis( safediv( data["iota_num_r"] * data["iota_den"] @@ -835,6 +837,7 @@ def _iota_rr(params, transforms, profiles, data, **kwargs): ) elif profiles["current"] is not None: # See the document attached to GitHub pull request #556 for the math. + # Assumes ζ = ϕ − ω and θ = ϑ − λ. data["iota_rr"] = transforms["grid"].replace_at_axis( safediv( data["iota_num_rr"] * data["iota_den"] ** 2 @@ -922,7 +925,7 @@ def _iota_num_current(params, transforms, profiles, data, **kwargs): iota = profiles["iota"].compute(transforms["grid"], params["i_l"], dr=0) data["iota_num current"] = iota * data["iota_den"] - data["iota_num vacuum"] elif profiles["current"] is not None: - # 4π^2 I = 4π^2 (mu_0 current / 2π) = 2π mu_0 current + # 4π² I = 4π² (μ₀ current / 2π) = 2π μ₀ current current = profiles["current"].compute(transforms["grid"], params["c_l"], dr=0) current_r = profiles["current"].compute(transforms["grid"], params["c_l"], dr=1) data["iota_num current"] = ( @@ -954,6 +957,7 @@ def _iota_num_current(params, transforms, profiles, data, **kwargs): ) def _iota_num_vacuum(params, transforms, profiles, data, **kwargs): """Vacuum contribution to the numerator of rotational transform formula.""" + # Assumes ζ = ϕ − ω and θ = ϑ − λ. iota_num_vacuum = transforms["grid"].replace_at_axis( safediv( data["lambda_z"] * data["g_tt"] - (1 + data["lambda_t"]) * data["g_tz"], @@ -1035,6 +1039,7 @@ def _iota_num_r_current(params, transforms, profiles, data, **kwargs): resolution_requirement="tz", ) def _iota_num_r_vacuum(params, transforms, profiles, data, **kwargs): + # Assumes ζ = ϕ − ω and θ = ϑ − λ. iota_num_vacuum = safediv( data["lambda_z"] * data["g_tt"] - (1 + data["lambda_t"]) * data["g_tz"], data["sqrt(g)"], @@ -1154,11 +1159,11 @@ def _iota_num_rr(params, transforms, profiles, data, **kwargs): Computes d2(𝛼+𝛽)/d𝜌2 as defined in the document attached to the description of GitHub pull request #556. 𝛼 supplements the rotational transform with an additional term to account for the enclosed net toroidal current. + Assumes ζ = ϕ − ω and θ = ϑ − λ. """ if profiles["iota"] is not None: data["iota_num_rr"] = jnp.nan * data["0"] elif profiles["current"] is not None: - # 4π^2 I = 4π^2 (mu_0 current / 2π) = 2π mu_0 current current = profiles["current"].compute(transforms["grid"], params["c_l"], dr=0) current_r = profiles["current"].compute(transforms["grid"], params["c_l"], dr=1) current_rr = profiles["current"].compute( @@ -1167,6 +1172,7 @@ def _iota_num_rr(params, transforms, profiles, data, **kwargs): current_rrr = profiles["current"].compute( transforms["grid"], params["c_l"], dr=3 ) + # 4π² I = 4π² (μ₀ current / 2π) = 2π μ₀ current alpha_rr = ( jnp.pi * mu_0 @@ -1283,6 +1289,7 @@ def _iota_num_rrr(params, transforms, profiles, data, **kwargs): Computes d3(𝛼+𝛽)/d𝜌3 as defined in the document attached to the description of GitHub pull request #556. 𝛼 supplements the rotational transform with an additional term to account for the enclosed net toroidal current. + Assumes ζ = ϕ − ω and θ = ϑ − λ. """ if profiles["iota"] is not None: data["iota_num_rrr"] = jnp.nan * data["0"] @@ -1298,7 +1305,7 @@ def _iota_num_rrr(params, transforms, profiles, data, **kwargs): current_rrrr = profiles["current"].compute( transforms["grid"], params["c_l"], dr=4 ) - # 4π^2 I = 4π^2 (mu_0 current / 2π) = 2π mu_0 current + # 4π² I = 4π² (μ₀ current / 2π) = 2π μ₀ current alpha_rrr = ( jnp.pi * mu_0 @@ -1402,14 +1409,14 @@ def _iota_den(params, transforms, profiles, data, **kwargs): """Denominator of rotational transform formula. Computes 𝛾 as defined in the document attached to the description - of GitHub pull request #556. + of GitHub pull request #556. Assumes ζ = ϕ − ω and θ = ϑ − λ. """ gamma = safediv( (1 + data["omega_z"]) * data["g_tt"] - data["omega_t"] * data["g_tz"], data["sqrt(g)"], ) # Assumes toroidal stream function behaves such that the magnetic axis limit - # of gamma is zero (as it would if omega = 0 identically). + # of γ is zero (as it would if ω = 0 identically). gamma = transforms["grid"].replace_at_axis( surface_integrals(transforms["grid"], gamma), 0 ) @@ -1447,7 +1454,7 @@ def _iota_den_r(params, transforms, profiles, data, **kwargs): """Denominator of rotational transform formula, first radial derivative. Computes d𝛾/d𝜌 as defined in the document attached to the description - of GitHub pull request #556. + of GitHub pull request #556. Assumes ζ = ϕ − ω and θ = ϑ − λ. """ gamma = safediv( (1 + data["omega_z"]) * data["g_tt"] - data["omega_t"] * data["g_tz"], @@ -1514,7 +1521,7 @@ def _iota_den_rr(params, transforms, profiles, data, **kwargs): """Denominator of rotational transform formula, second radial derivative. Computes d2𝛾/d𝜌2 as defined in the document attached to the description - of GitHub pull request #556. + of GitHub pull request #556. Assumes ζ = ϕ − ω and θ = ϑ − λ. """ gamma = safediv( (1 + data["omega_z"]) * data["g_tt"] - data["omega_t"] * data["g_tz"], @@ -1609,7 +1616,7 @@ def _iota_den_rrr(params, transforms, profiles, data, **kwargs): """Denominator of rotational transform formula, third radial derivative. Computes d3𝛾/d𝜌3 as defined in the document attached to the description - of GitHub pull request #556. + of GitHub pull request #556. Assumes ζ = ϕ − ω and θ = ϑ − λ. """ gamma = safediv( (1 + data["omega_z"]) * data["g_tt"] - data["omega_t"] * data["g_tz"], @@ -1675,9 +1682,12 @@ def _iota_den_rrr(params, transforms, profiles, data, **kwargs): axis_limit_data=["iota_rr", "psi_rr"], ) def _iota_psi(params, transforms, profiles, data, **kwargs): - # Existence of limit at magnetic axis requires ∂ᵨ iota = 0 at axis. - # Assume iota may be expanded as an even power series of ρ so that this - # condition is satisfied. + """∂ι/∂ψ. + + Existence of limit at magnetic axis requires ∂ι/∂ρ = 0 at axis. + Assume ι may be expanded as an even power series of ρ so that this + condition is satisfied. + """ data["iota_psi"] = transforms["grid"].replace_at_axis( safediv(data["iota_r"], data["psi_r"]), lambda: safediv(data["iota_rr"], data["psi_rr"]), diff --git a/desc/compute/_stability.py b/desc/compute/_stability.py index 3b820f83b0..1757fee0ba 100644 --- a/desc/compute/_stability.py +++ b/desc/compute/_stability.py @@ -14,8 +14,8 @@ from desc.backend import jnp from ..integrals.surface_integral import surface_integrals_map +from ..utils import dot from .data_index import register_compute_fun -from .utils import dot @register_compute_fun( diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 26341ec587..f8f30fa36d 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -63,6 +63,7 @@ def register_compute_fun( # noqa: C901 aliases=None, parameterization="desc.equilibrium.equilibrium.Equilibrium", resolution_requirement="", + grid_requirement=None, source_grid_requirement=None, **kwargs, ): @@ -110,6 +111,11 @@ def register_compute_fun( # noqa: C901 If the computation simply performs pointwise operations, instead of a reduction (such as integration) over a coordinate, then an empty string may be used to indicate no requirements. + grid_requirement : dict + Attributes of the grid that the compute function requires. + Also assumes dependencies were computed on such a grid. + As an example, quantities that require tensor product grids over 2 or more + coordinates may specify ``grid_requirement={"is_meshgrid": True}``. source_grid_requirement : dict Attributes of the source grid that the compute function requires. Also assumes dependencies were computed on such a grid. @@ -130,6 +136,8 @@ def register_compute_fun( # noqa: C901 aliases = [] if source_grid_requirement is None: source_grid_requirement = {} + if grid_requirement is None: + grid_requirement = {} if not isinstance(parameterization, (tuple, list)): parameterization = [parameterization] if not isinstance(aliases, (tuple, list)): @@ -168,6 +176,7 @@ def _decorator(func): "dependencies": deps, "aliases": aliases, "resolution_requirement": resolution_requirement, + "grid_requirement": grid_requirement, "source_grid_requirement": source_grid_requirement, } for p in parameterization: diff --git a/desc/compute/geom_utils.py b/desc/compute/geom_utils.py index fc5e1dab83..eeda658b61 100644 --- a/desc/compute/geom_utils.py +++ b/desc/compute/geom_utils.py @@ -4,7 +4,7 @@ from desc.backend import jnp -from .utils import safenorm, safenormalize +from ..utils import safenorm, safenormalize def reflection_matrix(normal): diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 0c6e2f7de3..b5bbe8cbbc 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -33,7 +33,9 @@ def _parse_parameterization(p): return module + "." + klass.__qualname__ -def compute(parameterization, names, params, transforms, profiles, data=None, **kwargs): +def compute( # noqa: C901 + parameterization, names, params, transforms, profiles, data=None, **kwargs +): """Compute the quantity given by name on grid. Parameters @@ -88,6 +90,15 @@ def compute(parameterization, names, params, transforms, profiles, data=None, ** if "grid" in transforms: def check_fun(name): + reqs = data_index[p][name]["grid_requirement"] + for req in reqs: + errorif( + not hasattr(transforms["grid"], req) + or reqs[req] != getattr(transforms["grid"], req), + AttributeError, + f"Expected grid with '{req}:{reqs[req]}' to compute {name}.", + ) + reqs = data_index[p][name]["source_grid_requirement"] errorif( reqs and not hasattr(transforms["grid"], "source_grid"), @@ -517,6 +528,7 @@ def get_transforms( """ from desc.basis import DoubleFourierSeries + from desc.grid import LinearGrid from desc.transform import Transform method = "jitable" if jitable or kwargs.get("method") == "jitable" else "auto" @@ -556,8 +568,15 @@ def get_transforms( ) transforms[c] = c_transform elif c == "B": # used for Boozer transform + # assume grid is a meshgrid but only care about a single surface + if grid.num_rho > 1: + theta = grid.nodes[grid.unique_theta_idx, 1] + zeta = grid.nodes[grid.unique_zeta_idx, 2] + grid_B = LinearGrid(theta=theta, zeta=zeta, NFP=grid.NFP, sym=grid.sym) + else: + grid_B = grid transforms["B"] = Transform( - grid, + grid_B, DoubleFourierSeries( M=kwargs.get("M_booz", 2 * obj.M), N=kwargs.get("N_booz", 2 * obj.N), @@ -570,8 +589,15 @@ def get_transforms( method=method, ) elif c == "w": # used for Boozer transform + # assume grid is a meshgrid but only care about a single surface + if grid.num_rho > 1: + theta = grid.nodes[grid.unique_theta_idx, 1] + zeta = grid.nodes[grid.unique_zeta_idx, 2] + grid_w = LinearGrid(theta=theta, zeta=zeta, NFP=grid.NFP, sym=grid.sym) + else: + grid_w = grid transforms["w"] = Transform( - grid, + grid_w, DoubleFourierSeries( M=kwargs.get("M_booz", 2 * obj.M), N=kwargs.get("N_booz", 2 * obj.N), @@ -685,187 +711,3 @@ def _has_transforms(qty, transforms, parameterization): [d in transforms[key].derivatives.tolist() for d in derivs[key]] ).all() return all(flags.values()) - - -def dot(a, b, axis=-1): - """Batched vector dot product. - - Parameters - ---------- - a : array-like - First array of vectors. - b : array-like - Second array of vectors. - axis : int - Axis along which vectors are stored. - - Returns - ------- - y : array-like - y = sum(a*b, axis=axis) - - """ - return jnp.sum(a * b, axis=axis, keepdims=False) - - -def cross(a, b, axis=-1): - """Batched vector cross product. - - Parameters - ---------- - a : array-like - First array of vectors. - b : array-like - Second array of vectors. - axis : int - Axis along which vectors are stored. - - Returns - ------- - y : array-like - y = a x b - - """ - return jnp.cross(a, b, axis=axis) - - -def safenorm(x, ord=None, axis=None, fill=0, threshold=0): - """Like jnp.linalg.norm, but without nan gradient at x=0. - - Parameters - ---------- - x : ndarray - Vector or array to norm. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of norm. - axis : {None, int, 2-tuple of ints}, optional - Axis to take norm along. - fill : float, ndarray, optional - Value to return where x is zero. - threshold : float >= 0 - How small is x allowed to be. - - """ - is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) - y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero - n = jnp.linalg.norm(y, ord=ord, axis=axis) - n = jnp.where(is_zero.squeeze(), fill, n) # replace norm with zero if is_zero - return n - - -def safenormalize(x, ord=None, axis=None, fill=0, threshold=0): - """Normalize a vector to unit length, but without nan gradient at x=0. - - Parameters - ---------- - x : ndarray - Vector or array to norm. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of norm. - axis : {None, int, 2-tuple of ints}, optional - Axis to take norm along. - fill : float, ndarray, optional - Value to return where x is zero. - threshold : float >= 0 - How small is x allowed to be. - - """ - is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) - y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero - n = safenorm(x, ord, axis, fill, threshold) * jnp.ones_like(x) - # return unit vector with equal components if norm <= threshold - return jnp.where(n <= threshold, jnp.ones_like(y) / jnp.sqrt(y.size), y / n) - - -def safediv(a, b, fill=0, threshold=0): - """Divide a/b with guards for division by zero. - - Parameters - ---------- - a, b : ndarray - Numerator and denominator. - fill : float, ndarray, optional - Value to return where b is zero. - threshold : float >= 0 - How small is b allowed to be. - """ - mask = jnp.abs(b) <= threshold - num = jnp.where(mask, fill, a) - den = jnp.where(mask, 1, b) - return num / den - - -def cumtrapz(y, x=None, dx=1.0, axis=-1, initial=None): - """Cumulatively integrate y(x) using the composite trapezoidal rule. - - Taken from SciPy, but changed NumPy references to JAX.NumPy: - https://github.com/scipy/scipy/blob/v1.10.1/scipy/integrate/_quadrature.py - - Parameters - ---------- - y : array_like - Values to integrate. - x : array_like, optional - The coordinate to integrate along. If None (default), use spacing `dx` - between consecutive elements in `y`. - dx : float, optional - Spacing between elements of `y`. Only used if `x` is None. - axis : int, optional - Specifies the axis to cumulate. Default is -1 (last axis). - initial : scalar, optional - If given, insert this value at the beginning of the returned result. - Typically, this value should be 0. Default is None, which means no - value at ``x[0]`` is returned and `res` has one element less than `y` - along the axis of integration. - - Returns - ------- - res : ndarray - The result of cumulative integration of `y` along `axis`. - If `initial` is None, the shape is such that the axis of integration - has one less value than `y`. If `initial` is given, the shape is equal - to that of `y`. - - """ - y = jnp.asarray(y) - if x is None: - d = dx - else: - x = jnp.asarray(x) - if x.ndim == 1: - d = jnp.diff(x) - # reshape to correct shape - shape = [1] * y.ndim - shape[axis] = -1 - d = d.reshape(shape) - elif len(x.shape) != len(y.shape): - raise ValueError("If given, shape of x must be 1-D or the " "same as y.") - else: - d = jnp.diff(x, axis=axis) - - if d.shape[axis] != y.shape[axis] - 1: - raise ValueError( - "If given, length of x along axis must be the " "same as y." - ) - - def tupleset(t, i, value): - l = list(t) - l[i] = value - return tuple(l) - - nd = len(y.shape) - slice1 = tupleset((slice(None),) * nd, axis, slice(1, None)) - slice2 = tupleset((slice(None),) * nd, axis, slice(None, -1)) - res = jnp.cumsum(d * (y[slice1] + y[slice2]) / 2.0, axis=axis) - - if initial is not None: - if not jnp.isscalar(initial): - raise ValueError("`initial` parameter should be a scalar.") - - shape = list(res.shape) - shape[axis] = 1 - res = jnp.concatenate( - [jnp.full(shape, initial, dtype=res.dtype), res], axis=axis - ) - - return res diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 9d722ce52d..d21c1cd73b 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -598,8 +598,8 @@ def to_sfl( M_grid = M_grid or int(2 * M) N_grid = N_grid or int(2 * N) - grid = ConcentricGrid(L_grid, M_grid, N_grid, node_pattern="ocs") - bdry_grid = LinearGrid(M=M, N=N, rho=1.0) + grid = ConcentricGrid(L_grid, M_grid, N_grid, node_pattern="ocs", NFP=eq.NFP) + bdry_grid = LinearGrid(M=M, N=N, rho=1.0, NFP=eq.NFP) toroidal_coords = eq.compute(["R", "Z", "lambda"], grid=grid) theta = grid.nodes[:, 1] @@ -685,12 +685,12 @@ def get_rtz_grid( rvp : rho, theta_PEST, phi rtz : rho, theta, zeta period : tuple of float - Assumed periodicity for each quantity in inbasis. + Assumed periodicity for functions of the given coordinates. Use ``np.inf`` to denote no periodicity. jitable : bool, optional If false the returned grid has additional attributes. Required to be false to retain nodes at magnetic axis. - kwargs : dict + kwargs Additional parameters to supply to the coordinate mapping function. See ``desc.equilibrium.coords.map_coordinates``. diff --git a/desc/geometry/curve.py b/desc/geometry/curve.py index e2c583ea6f..2e141710fb 100644 --- a/desc/geometry/curve.py +++ b/desc/geometry/curve.py @@ -203,13 +203,16 @@ def Z_n(self, new): ) @classmethod - def from_input_file(cls, path): + def from_input_file(cls, path, **kwargs): """Create a axis curve from Fourier coefficients in a DESC or VMEC input file. Parameters ---------- path : Path-like or str Path to DESC or VMEC input file. + **kwargs : dict, optional + keyword arguments to pass to the constructor of the + FourierRZCurve being created. Returns ------- @@ -227,6 +230,7 @@ def from_input_file(cls, path): inputs["axis"][:, 0].astype(int), inputs["NFP"], inputs["sym"], + **kwargs, ) return curve diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 79f1b871a9..2f74200aaa 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -298,13 +298,16 @@ def set_coeffs(self, m, n=0, R=None, Z=None): self.Z_lmn = put(self.Z_lmn, idxZ, ZZ) @classmethod - def from_input_file(cls, path): + def from_input_file(cls, path, **kwargs): """Create a surface from Fourier coefficients in a DESC or VMEC input file. Parameters ---------- path : Path-like or str Path to DESC or VMEC input file. + **kwargs : dict, optional + keyword arguments to pass to the constructor of the + FourierRZToroidalSurface being created. Returns ------- @@ -328,6 +331,7 @@ def from_input_file(cls, path): inputs["surface"][:, 1:3].astype(int), inputs["NFP"], inputs["sym"], + **kwargs, ) return surf diff --git a/desc/grid.py b/desc/grid.py index 2eb22a6c5c..6a8ab78fe3 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -638,7 +638,8 @@ def meshgrid_reshape(self, x, order): vec = True shape += (-1,) x = x.reshape(shape, order="F") - x = jnp.swapaxes(x, 1, 0) # now shape rtz/raz etc + # swap to change shape from trz/arz to rtz/raz etc. + x = jnp.swapaxes(x, 1, 0) newax = tuple(self.coordinates.index(c) for c in order) if vec: newax += (3,) diff --git a/desc/integrals/basis.py b/desc/integrals/basis.py index 6a40b0bc8f..b9ec3c78c4 100644 --- a/desc/integrals/basis.py +++ b/desc/integrals/basis.py @@ -47,16 +47,12 @@ def _subtract(c, k): @partial(jnp.vectorize, signature="(m),(m)->(m)") -def _in_epigraph_and(is_intersect, df_dy_sign): +def _in_epigraph_and(is_intersect, df_dy_sign, /): """Set and epigraph of function f with the given set of points. Used to return only intersects where the straight line path between adjacent intersects resides in the epigraph of a continuous map ``f``. - Warnings - -------- - Does not support keyword arguments. - Parameters ---------- is_intersect : jnp.ndarray @@ -71,6 +67,11 @@ def _in_epigraph_and(is_intersect, df_dy_sign): Boolean array indicating whether element is an intersect and satisfies the stated condition. + Examples + -------- + See ``desc/integrals/bounce_utils.py::bounce_points``. + This is used there to ensure the domains of integration are magnetic wells. + """ # The pairs ``y1`` and ``y2`` are boundaries of an integral only if ``y1 <= y2``. # For the integrals to be over wells, it is required that the first intersect @@ -80,7 +81,7 @@ def _in_epigraph_and(is_intersect, df_dy_sign): # must be at the first pair. To correct the inversion, it suffices to disqualify the # first intersect as a right boundary, except under an edge case of a series of # inflection points. - idx = flatnonzero(is_intersect, size=2, fill_value=-1) # idx of first 2 intersects + idx = flatnonzero(is_intersect, size=2, fill_value=-1) edge_case = ( (df_dy_sign[idx[0]] == 0) & (df_dy_sign[idx[1]] < 0) @@ -700,9 +701,9 @@ def _plot_intersect(ax, legend, z1, z2, k, k_transparency, klabel): for i in range(k.size): _z1, _z2 = z1[i], z2[i] if _z1.size == _z2.size: - mask = (z1 - z2) != 0.0 - _z1 = z1[mask] - _z2 = z2[mask] + mask = (_z1 - _z2) != 0.0 + _z1 = _z1[mask] + _z2 = _z2[mask] _add2legend( legend, ax.scatter( diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py index bb2240a970..1826791a9d 100644 --- a/desc/integrals/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -1,15 +1,17 @@ """Methods for computing bounce integrals (singular or otherwise).""" -from interpax import CubicHermiteSpline +from interpax import CubicHermiteSpline, PPoly from orthax.legendre import leggauss from desc.backend import jnp, rfft2 from desc.integrals.basis import FourierChebyshevBasis from desc.integrals.bounce_utils import ( + _bounce_quadrature, _check_bounce_points, + _set_default_plot_kwargs, bounce_points, - bounce_quadrature, get_alpha, + get_pitch_inv, interp_to_argmin, plot_ppoly, ) @@ -20,6 +22,7 @@ get_quadrature, grad_automorphism_sin, ) +from desc.io import IOAble from desc.utils import errorif, flatten_matrix, setdefault, warnif @@ -569,15 +572,15 @@ def _integrate(self, z1, z2, pitch, integrand, f): return result -class Bounce1D: +class Bounce1D(IOAble): """Computes bounce integrals using one-dimensional local spline methods. - The bounce integral is defined as ∫ f(ℓ) dℓ, where + The bounce integral is defined as ∫ f(λ, ℓ) dℓ, where dℓ parameterizes the distance along the field line in meters, - f(ℓ) is the quantity to integrate along the field line, - and the boundaries of the integral are bounce points ζ₁, ζ₂ s.t. λ|B|(ζᵢ) = 1, - where λ is a constant proportional to the magnetic moment over energy - and |B| is the norm of the magnetic field. + f(λ, ℓ) is the quantity to integrate along the field line, + and the boundaries of the integral are bounce points ℓ₁, ℓ₂ s.t. λ|B|(ℓᵢ) = 1, + where λ is a constant defining the integral proportional to the magnetic moment + over energy and |B| is the norm of the magnetic field. For a particle with fixed λ, bounce points are defined to be the location on the field line such that the particle's velocity parallel to the magnetic field is zero. @@ -611,9 +614,11 @@ class Bounce1D: cannot support reconstruction of the function near the origin. As the functions of interest do not vanish at infinity, pseudo-spectral techniques are not used. Instead, function approximation is done with local splines. - This is useful if one can efficiently obtain data along field lines. + This is useful if one can efficiently obtain data along field lines and + most efficient if the number of toroidal transits to follow a field line is + not too large. - After obtaining the bounce points, the supplied quadrature is performed. + After computing the bounce points, the supplied quadrature is performed. By default, this is a Gauss quadrature after removing the singularity. Local splines interpolate functions in the integrand to the quadrature nodes. @@ -621,31 +626,27 @@ class Bounce1D: -------- Bounce2D : Uses two-dimensional pseudo-spectral techniques for the same task. - Warnings - -------- - The supplied data must be from a Clebsch coordinate (ρ, α, ζ) tensor-product grid. - The ζ coordinates (the unique values prior to taking the tensor-product) must be - strictly increasing and preferably uniformly spaced. These are used as knots to - construct splines; a reference knot density is 100 knots per toroidal transit. - Examples -------- - See ``tests/test_integrals.py::TestBounce1D::test_integrate_checks``. + See ``tests/test_integrals.py::TestBounce1D::test_bounce1d_checks``. Attributes ---------- - _B : jnp.ndarray - TODO: Make this (4, M, L, N-1) now that tensor product in rho and alpha - required as well after GitHub PR #1214. - Shape (4, L * M, N - 1). + required_names : list + Names in ``data_index`` required to compute bounce integrals. + B : jnp.ndarray + Shape (M, L, N - 1, B.shape[-1]). Polynomial coefficients of the spline of |B| in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. + Last axis enumerates the coefficients of power series. For a polynomial + given by ∑ᵢⁿ cᵢ xⁱ, coefficient cᵢ is stored at ``B[...,n-i]``. + Third axis enumerates the polynomials that compose a particular spline. + Second axis enumerates flux surfaces. + First axis enumerates field lines of a particular flux surface. """ - plot_ppoly = staticmethod(plot_ppoly) + required_names = ["B^zeta", "B^zeta_z|r,a", "|B|", "|B|_z|r,a"] + get_pitch_inv = staticmethod(get_pitch_inv) def __init__( self, @@ -655,6 +656,8 @@ def __init__( automorphism=(automorphism_sin, grad_automorphism_sin), Bref=1.0, Lref=1.0, + *, + is_reshaped=False, check=False, **kwargs, ): @@ -664,11 +667,14 @@ def __init__( ---------- grid : Grid Clebsch coordinate (ρ, α, ζ) tensor-product grid. - Note that below shape notation defines + The ζ coordinates (the unique values prior to taking the tensor-product) + must be strictly increasing and preferably uniformly spaced. These are used + as knots to construct splines. A reference knot density is 100 knots per + toroidal transit. Note that below shape notation defines L = ``grid.num_rho``, M = ``grid.num_alpha``, and N = ``grid.num_zeta``. data : dict[str, jnp.ndarray] Data evaluated on ``grid``. - Must include names in ``Bounce1D.required_names()``. + Must include names in ``Bounce1D.required_names``. quad : (jnp.ndarray, jnp.ndarray) Quadrature points xₖ and weights wₖ for the approximate evaluation of an integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). Default is 32 points. @@ -681,6 +687,13 @@ def __init__( Optional. Reference magnetic field strength for normalization. Lref : float Optional. Reference length scale for normalization. + is_reshaped : bool + Whether the arrays in ``data`` are already reshaped to the expected form of + shape (..., N) or (..., L, N) or (M, L, N). This option can be used to + iteratively compute bounce integrals one field line or one flux surface + at a time, respectively, potentially reducing memory usage. To do so, + set to true and provide only those axes of the reshaped data. + Default is false. check : bool Flag for debugging. Must be false for JAX transformations. @@ -703,15 +716,16 @@ def __init__( "|B|": data["|B|"] / Bref, "|B|_z|r,a": data["|B|_z|r,a"] / Bref, # This is already the correct sign. } - self._data = { - key: grid.meshgrid_reshape(val, "raz").reshape(-1, grid.num_zeta) - for key, val in data.items() - } + self._data = ( + data + if is_reshaped + else dict(zip(data.keys(), Bounce1D.reshape_data(grid, *data.values()))) + ) self._x, self._w = get_quadrature(quad, automorphism) # Compute local splines. self._zeta = grid.compress(grid.nodes[:, 2], surface_label="zeta") - self._B = jnp.moveaxis( + self.B = jnp.moveaxis( CubicHermiteSpline( x=self._zeta, y=self._data["|B|"], @@ -719,56 +733,52 @@ def __init__( axis=-1, check=check, ).c, - source=1, - destination=-1, + source=(0, 1), + destination=(-1, -2), ) - self._dB_dz = polyder_vec(self._B) - degree = 3 - assert self._B.shape[0] == degree + 1 - assert self._dB_dz.shape[0] == degree - assert self._B.shape[-1] == self._dB_dz.shape[-1] == grid.num_zeta - 1 + self._dB_dz = polyder_vec(self.B) - @staticmethod - def required_names(): - """Return names in ``data_index`` required to compute bounce integrals.""" - return ["B^zeta", "B^zeta_z|r,a", "|B|", "|B|_z|r,a"] + # Add axis here instead of in ``_bounce_quadrature``. + for name in self._data: + self._data[name] = self._data[name][..., jnp.newaxis, :] @staticmethod - def reshape_data(grid, *data): - """Reshape ``data`` arrays for acceptable input to ``integrate``. + def reshape_data(grid, *arys): + """Reshape arrays for acceptable input to ``integrate``. Parameters ---------- grid : Grid Clebsch coordinate (ρ, α, ζ) tensor-product grid. - data : jnp.ndarray + arys : jnp.ndarray Data evaluated on grid. Returns ------- - f : list[jnp.ndarray] - List of reshaped data which may be given to ``integrate``. + f : jnp.ndarray + Shape (M, L, N). + Reshaped data which may be given to ``integrate``. """ - f = [grid.meshgrid_reshape(d, "raz").reshape(-1, grid.num_zeta) for d in data] - return f + f = [grid.meshgrid_reshape(d, "arz") for d in arys] + return f if len(f) > 1 else f[0] - def bounce_points(self, pitch, num_well=None): + def points(self, pitch_inv, *, num_well=None): """Compute bounce points. Parameters ---------- - pitch : jnp.ndarray - Shape must broadcast with (P, L * M). - λ values to evaluate the bounce integral at each field line. λ(ρ,α) is - specified by ``pitch[...,ρ]`` where in the latter the labels (ρ,α) are - interpreted as the index into the last axis that corresponds to that field - line. If two-dimensional, the first axis is the batch axis. + pitch_inv : jnp.ndarray + Shape (M, L, P). + 1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is + specified by ``pitch_inv[α,ρ]`` where in the latter the labels + are interpreted as the indices that correspond to that field line. num_well : int or None Specify to return the first ``num_well`` pairs of bounce points for each pitch along each field line. This is useful if ``num_well`` tightly - bounds the actual number. As a reference, there are typically at most 5 - wells per toroidal transit for a given pitch. + bounds the actual number. As a reference, there are typically 20 wells + per toroidal transit for a given pitch. You can check this by plotting + the field lines with the ``check_points`` method. If not specified, then all bounce points are returned. If there were fewer wells detected along a field line than the size of the last axis of the @@ -777,103 +787,102 @@ def bounce_points(self, pitch, num_well=None): Returns ------- z1, z2 : (jnp.ndarray, jnp.ndarray) - Shape (P, L * M, num_well). - ζ coordinates of bounce points. The points are grouped and ordered such + Shape (M, L, P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of |B|. - If there were less than ``num_wells`` wells detected along a field line, - then the last axis, which enumerates bounce points for a particular field + If there were less than ``num_well`` wells detected along a field line, + then the last axis, which enumerates bounce points for a particular field line and pitch, is padded with zero. """ - return bounce_points( - pitch=pitch, - knots=self._zeta, - B=self._B, - dB_dz=self._dB_dz, - num_well=num_well, - ) + return bounce_points(pitch_inv, self._zeta, self.B, self._dB_dz, num_well) - def check_bounce_points(self, z1, z2, pitch, plot=True, **kwargs): + def check_points(self, z1, z2, pitch_inv, *, plot=True, **kwargs): """Check that bounce points are computed correctly. Parameters ---------- z1, z2 : (jnp.ndarray, jnp.ndarray) - Shape (P, L * M, num_well). - ζ coordinates of bounce points. The points are grouped and ordered such + Shape (M, L, P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of |B|. - pitch : jnp.ndarray - Shape must broadcast with (P, L * M). - λ values to evaluate the bounce integral at each field line. λ(ρ,α) is - specified by ``pitch[...,(ρ,α)]`` where in the latter the labels (ρ,α) are - interpreted as the index into the last axis that corresponds to that field - line. If two-dimensional, the first axis is the batch axis. + pitch_inv : jnp.ndarray + Shape (M, L, P). + 1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is + specified by ``pitch_inv[α,ρ]`` where in the latter the labels + are interpreted as the indices that correspond to that field line. plot : bool - Whether to plot stuff. - kwargs : dict - Keyword arguments into ``self.plot_ppoly``. + Whether to plot the field lines and bounce points of the given pitch angles. + kwargs + Keyword arguments into ``desc/integrals/bounce_utils.py::plot_ppoly``. + + Returns + ------- + plots : list + Matplotlib (fig, ax) tuples for the 1D plot of each field line. """ - _check_bounce_points( + return _check_bounce_points( z1=z1, z2=z2, - pitch=jnp.atleast_2d(pitch), + pitch_inv=pitch_inv, knots=self._zeta, - B=self._B, + B=self.B, plot=plot, **kwargs, ) def integrate( self, - pitch, integrand, + pitch_inv, f=None, weight=None, + *, num_well=None, method="cubic", batch=True, check=False, + plot=False, ): - """Bounce integrate ∫ f(ℓ) dℓ. + """Bounce integrate ∫ f(λ, ℓ) dℓ. - Computes the bounce integral ∫ f(ℓ) dℓ for every specified field line - for every λ value in ``pitch``. + Computes the bounce integral ∫ f(λ, ℓ) dℓ for every field line and pitch. Parameters ---------- - pitch : jnp.ndarray - Shape must broadcast with (P, L * M). - λ values to evaluate the bounce integral at each field line. λ(ρ,α) is - specified by ``pitch[...,(ρ,α)]`` where in the latter the labels (ρ,α) are - interpreted as the index into the last axis that corresponds to that field - line. If two-dimensional, the first axis is the batch axis. integrand : callable The composition operator on the set of functions in ``f`` that maps the - functions in ``f`` to the integrand f(ℓ) in ∫ f(ℓ) dℓ. It should accept the - arrays in ``f`` as arguments as well as the additional keyword arguments: - ``B`` and ``pitch``. A quadrature will be performed to approximate the - bounce integral of ``integrand(*f,B=B,pitch=pitch)``. - f : list[jnp.ndarray] - Shape (L * M, N). + functions in ``f`` to the integrand f(λ, ℓ) in ∫ f(λ, ℓ) dℓ. It should + accept the arrays in ``f`` as arguments as well as the additional keyword + arguments: ``B`` and ``pitch``. A quadrature will be performed to + approximate the bounce integral of ``integrand(*f,B=B,pitch=pitch)``. + pitch_inv : jnp.ndarray + Shape (M, L, P). + 1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by + ``pitch_inv[α,ρ]`` where in the latter the labels are interpreted + as the indices that correspond to that field line. + f : list[jnp.ndarray] or jnp.ndarray + Shape (M, L, N). Real scalar-valued functions evaluated on the ``grid`` supplied to construct this object. These functions should be arguments to the callable ``integrand``. Use the method ``self.reshape_data`` to reshape the data into the expected shape. weight : jnp.ndarray - Shape must broadcast with (L * M, N). + Shape (M, L, N). If supplied, the bounce integral labeled by well j is weighted such that - the returned value is w(j) ∫ f(ℓ) dℓ, where w(j) is ``weight`` - interpolated to the deepest point in the magnetic well. Use the method + the returned value is w(j) ∫ f(λ, ℓ) dℓ, where w(j) is ``weight`` + interpolated to the deepest point in that magnetic well. Use the method ``self.reshape_data`` to reshape the data into the expected shape. num_well : int or None Specify to return the first ``num_well`` pairs of bounce points for each pitch along each field line. This is useful if ``num_well`` tightly - bounds the actual number. As a reference, there are typically at most 5 - wells per toroidal transit for a given pitch. + bounds the actual number. As a reference, there are typically 20 wells + per toroidal transit for a given pitch. You can check this by plotting + the field lines with the ``check_points`` method. If not specified, then all bounce points are returned. If there were fewer wells detected along a field line than the size of the last axis of the @@ -886,40 +895,83 @@ def integrate( Whether to perform computation in a batched manner. Default is true. check : bool Flag for debugging. Must be false for JAX transformations. + plot : bool + Whether to plot the quantities in the integrand interpolated to the + quadrature points of each integral. Ignored if ``check`` is false. Returns ------- result : jnp.ndarray - Shape (P, L*M, num_well). - First axis enumerates pitch values. Second axis enumerates the field lines. - Last axis enumerates the bounce integrals. + Shape (M, L, P, num_well). + Last axis enumerates the bounce integrals for a given field line, + flux surface, and pitch value. """ - pitch = jnp.atleast_2d(pitch) - z1, z2 = self.bounce_points(pitch, num_well) - result = bounce_quadrature( + z1, z2 = self.points(pitch_inv, num_well=num_well) + result = _bounce_quadrature( x=self._x, w=self._w, z1=z1, z2=z2, - pitch=pitch, integrand=integrand, + pitch_inv=pitch_inv, f=setdefault(f, []), data=self._data, knots=self._zeta, method=method, batch=batch, check=check, + plot=plot, ) if weight is not None: result *= interp_to_argmin( - h=weight, - z1=z1, - z2=z2, - knots=self._zeta, - g=self._B, - dg_dz=self._dB_dz, - method=method, + weight, + z1, + z2, + self._zeta, + self.B, + self._dB_dz, + method, ) - assert result.shape[-1] == setdefault(num_well, (self._zeta.size - 1) * 3) + assert result.shape == z1.shape return result + + def plot(self, m, l, pitch_inv=None, /, **kwargs): + """Plot the field line and bounce points of the given pitch angles. + + Parameters + ---------- + m, l : int, int + Indices into the nodes of the grid supplied to make this object. + ``alpha,rho=grid.meshgrid_reshape(grid.nodes[:,:2],"arz")[m,l,0]``. + pitch_inv : jnp.ndarray + Shape (P, ). + Optional, 1/λ values whose corresponding bounce points on the field line + specified by Clebsch coordinate α(m), ρ(l) will be plotted. + kwargs + Keyword arguments into ``desc/integrals/bounce_utils.py::plot_ppoly``. + + Returns + ------- + fig, ax + Matplotlib (fig, ax) tuple. + + """ + B, dB_dz = self.B, self._dB_dz + if B.ndim == 4: + B = B[m] + dB_dz = dB_dz[m] + if B.ndim == 3: + B = B[l] + dB_dz = dB_dz[l] + if pitch_inv is not None: + errorif( + pitch_inv.ndim > 1, + msg=f"Got pitch_inv.ndim={pitch_inv.ndim}, but expected 1.", + ) + z1, z2 = bounce_points(pitch_inv, self._zeta, B, dB_dz) + kwargs["z1"] = z1 + kwargs["z2"] = z2 + kwargs["k"] = pitch_inv + fig, ax = plot_ppoly(PPoly(B.T, self._zeta), **_set_default_plot_kwargs(kwargs)) + return fig, ax diff --git a/desc/integrals/bounce_utils.py b/desc/integrals/bounce_utils.py index 34dab1b542..e5a1d3ebbd 100644 --- a/desc/integrals/bounce_utils.py +++ b/desc/integrals/bounce_utils.py @@ -1,14 +1,15 @@ """Utilities and functional programming interface for bounce integrals.""" +import numpy as np from interpax import PPoly from matplotlib import pyplot as plt -from desc.backend import imap, jnp, softmax +from desc.backend import imap, jnp, softargmax from desc.integrals.basis import _add2legend, _in_epigraph_and, _plot_intersect from desc.integrals.interp_utils import ( interp1d_Hermite_vec, interp1d_vec, - poly_root, + polyroot_vec, polyval_vec, ) from desc.integrals.quad_utils import ( @@ -16,11 +17,18 @@ composite_linspace, grad_bijection_from_disc, ) -from desc.utils import atleast_3d_mid, errorif, setdefault, take_mask +from desc.utils import ( + atleast_nd, + errorif, + flatten_matrix, + is_broadcastable, + setdefault, + take_mask, +) -def get_pitch(min_B, max_B, num, relative_shift=1e-6): - """Return uniformly spaced values between ``1/max_B`` and ``1/min_B``. +def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6): + """Return 1/λ values for quadrature between ``min_B`` and ``max_B``. Parameters ---------- @@ -36,17 +44,19 @@ def get_pitch(min_B, max_B, num, relative_shift=1e-6): Returns ------- - pitch : jnp.ndarray - Shape (num + 2, *min_B.shape). + pitch_inv : jnp.ndarray + Shape (*min_B.shape, num + 2). + 1/λ values. """ # Floating point error impedes consistent detection of bounce points riding # extrema. Shift values slightly to resolve this issue. min_B = (1 + relative_shift) * min_B max_B = (1 - relative_shift) * max_B - pitch = composite_linspace(1 / jnp.stack([max_B, min_B]), num) - assert pitch.shape == (num + 2, *min_B.shape) - return pitch + # Samples should be uniformly spaced in |B| and not λ (GitHub issue #1228). + pitch_inv = jnp.moveaxis(composite_linspace(jnp.stack([min_B, max_B]), num), 0, -1) + assert pitch_inv.shape == (*min_B.shape, num + 2) + return pitch_inv # TODO: Generalize this beyond ζ = ϕ or just map to Clebsch with ϕ. @@ -77,94 +87,84 @@ def get_alpha(alpha_0, iota, num_transit, period): return alpha -def _check_spline_shape(knots, g, dg_dz, pitch=None): - """Ensure inputs have compatible shape, and return them with full dimension. +def _check_spline_shape(knots, g, dg_dz, pitch_inv=None): + """Ensure inputs have compatible shape. Parameters ---------- knots : jnp.ndarray - Shape (knots.size, ). + Shape (N, ). ζ coordinates of spline knots. Must be strictly increasing. g : jnp.ndarray - Shape (g.shape[0], S, knots.size - 1). + Shape (..., N - 1, g.shape[-1]). Polynomial coefficients of the spline of g in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. dg_dz : jnp.ndarray - Shape (g.shape[0] - 1, *g.shape[1:]). + Shape (..., N - 1, g.shape[-1] - 1). Polynomial coefficients of the spline of ∂g/∂ζ in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. - pitch : jnp.ndarray - Shape must broadcast with (P, S). - λ values to evaluate the bounce integral at each field line. λ(ρ,α) is - specified by ``pitch[...,(ρ,α)]`` where in the latter the labels (ρ,α) are - interpreted as the index into the last axis that corresponds to that field - line. If two-dimensional, the first axis is the batch axis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + pitch_inv : jnp.ndarray + Shape (..., P). + 1/λ values. 1/λ(α,ρ) is specified by ``pitch_inv[α,ρ]`` where in + the latter the labels are interpreted as the indices that correspond + to that field line. """ errorif(knots.ndim != 1, msg=f"knots should be 1d; got shape {knots.shape}.") errorif( - g.shape[-1] != (knots.size - 1), + g.shape[-2] != (knots.size - 1), msg=( - "Last axis does not enumerate polynomials of spline. " + "Second to last axis does not enumerate polynomials of spline. " f"Spline shape {g.shape}. Knots shape {knots.shape}." ), ) errorif( - g.ndim > 3 - or dg_dz.ndim > 3 - or (g.shape[0] - 1) != dg_dz.shape[0] - or g.shape[1:] != dg_dz.shape[1:], + not (g.ndim == dg_dz.ndim < 5) + or g.shape != (*dg_dz.shape[:-1], dg_dz.shape[-1] + 1), msg=f"Invalid shape {g.shape} for spline and derivative {dg_dz.shape}.", ) - # Add axis which enumerates field lines if necessary. - g, dg_dz = atleast_3d_mid(g, dg_dz) - if pitch is not None: - pitch = jnp.atleast_2d(pitch) + g, dg_dz = jnp.atleast_2d(g, dg_dz) + if pitch_inv is not None: + pitch_inv = jnp.atleast_1d(pitch_inv) errorif( - pitch.ndim != 2 - or not (pitch.shape[-1] == 1 or pitch.shape[-1] == g.shape[1]), - msg=f"Invalid shape {pitch.shape} for pitch angles.", + pitch_inv.ndim > 3 + or not is_broadcastable(pitch_inv.shape[:-1], g.shape[:-2]), + msg=f"Invalid shape {pitch_inv.shape} for pitch angles.", ) - return g, dg_dz, pitch + return g, dg_dz, pitch_inv def bounce_points( - pitch, knots, B, dB_dz, num_well=None, check=False, plot=True, **kwargs + pitch_inv, knots, B, dB_dz, num_well=None, check=False, plot=True, **kwargs ): """Compute the bounce points given spline of |B| and pitch λ. Parameters ---------- - pitch : jnp.ndarray - Shape must broadcast with (P, S). - λ values to evaluate the bounce integral at each field line. λ(ρ,α) is - specified by ``pitch[...,(ρ,α)]`` where in the latter the labels (ρ,α) are - interpreted as the index into the last axis that corresponds to that field - line. If two-dimensional, the first axis is the batch axis. + pitch_inv : jnp.ndarray + Shape (..., P). + 1/λ values to compute the bounce points. knots : jnp.ndarray - Shape (knots.size, ). + Shape (N, ). ζ coordinates of spline knots. Must be strictly increasing. B : jnp.ndarray - Shape (B.shape[0], S, knots.size - 1). + Shape (..., N - 1, B.shape[-1]). Polynomial coefficients of the spline of |B| in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. dB_dz : jnp.ndarray - Shape (B.shape[0] - 1, *B.shape[1:]). + Shape (..., N - 1, B.shape[-1] - 1). Polynomial coefficients of the spline of (∂|B|/∂ζ)|(ρ,α) in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. num_well : int or None Specify to return the first ``num_well`` pairs of bounce points for each pitch along each field line. This is useful if ``num_well`` tightly - bounds the actual number. As a reference, there are typically at most 5 - wells per toroidal transit for a given pitch. + bounds the actual number. As a reference, there are typically 20 wells + per toroidal transit for a given pitch. You can check this by plotting + the field lines with the ``_check_bounce_points`` method. If not specified, then all bounce points are returned. If there were fewer wells detected along a field line than the size of the last axis of the @@ -173,50 +173,52 @@ def bounce_points( Flag for debugging. Must be false for JAX transformations. plot : bool Whether to plot some things if check is true. Default is true. - kwargs : dict + kwargs Keyword arguments into ``plot_ppoly``. Returns ------- z1, z2 : (jnp.ndarray, jnp.ndarray) - Shape (P, S, num_well). - ζ coordinates of bounce points. The points are grouped and ordered such + Shape (..., P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of |B|. - If there were less than ``num_wells`` wells detected along a field line, + If there were less than ``num_well`` wells detected along a field line, then the last axis, which enumerates bounce points for a particular field line and pitch, is padded with zero. """ - B, dB_dz, pitch = _check_spline_shape(knots, B, dB_dz, pitch) - P, S, degree = pitch.shape[0], B.shape[1], B.shape[0] - 1 - # Intersection points in local power basis. - intersect = poly_root( - c=B, - k=(1 / pitch)[..., jnp.newaxis], + B, dB_dz, pitch_inv = _check_spline_shape(knots, B, dB_dz, pitch_inv) + intersect = polyroot_vec( + c=B[..., jnp.newaxis, :, :], # Add P axis + k=pitch_inv[..., jnp.newaxis], # Add N axis a_min=jnp.array([0.0]), a_max=jnp.diff(knots), sort=True, sentinel=-1.0, distinct=True, ) - assert intersect.shape == (P, S, knots.size - 1, degree) + assert intersect.shape[-3:] == ( + pitch_inv.shape[-1], + knots.size - 1, + B.shape[-1] - 1, + ) # Reshape so that last axis enumerates intersects of a pitch along a field line. - dB_dz_sign = jnp.sign( - polyval_vec(x=intersect, c=dB_dz[..., jnp.newaxis]).reshape(P, S, -1) + dB_sign = flatten_matrix( + jnp.sign(polyval_vec(x=intersect, c=dB_dz[..., jnp.newaxis, :, jnp.newaxis, :])) ) # Only consider intersect if it is within knots that bound that polynomial. - is_intersect = intersect.reshape(P, S, -1) >= 0 + is_intersect = flatten_matrix(intersect) >= 0 # Following discussion on page 3 and 5 of https://doi.org/10.1063/1.873749, # we ignore the bounce points of particles only assigned to a class that are # trapped outside this snapshot of the field line. - is_z1 = (dB_dz_sign <= 0) & is_intersect - is_z2 = (dB_dz_sign >= 0) & _in_epigraph_and(is_intersect, dB_dz_sign) + is_z1 = (dB_sign <= 0) & is_intersect + is_z2 = (dB_sign >= 0) & _in_epigraph_and(is_intersect, dB_sign) # Transform out of local power basis expansion. - intersect = (intersect + knots[:-1, jnp.newaxis]).reshape(P, S, -1) + intersect = flatten_matrix(intersect + knots[:-1, jnp.newaxis]) # New versions of JAX only like static sentinels. sentinel = -10000000.0 # instead of knots[0] - 1 z1 = take_mask(intersect, is_z1, size=num_well, fill_value=sentinel) @@ -228,14 +230,12 @@ def bounce_points( z2 = jnp.where(mask, z2, 0.0) if check: - _check_bounce_points(z1, z2, pitch, knots, B, plot, **kwargs) + _check_bounce_points(z1, z2, pitch_inv, knots, B, plot, **kwargs) return z1, z2 -def _check_bounce_points(z1, z2, pitch, knots, B, plot=True, **kwargs): - """Check that bounce points are computed correctly.""" - eps = kwargs.pop("eps", jnp.finfo(jnp.array(1.0).dtype).eps * 10) +def _set_default_plot_kwargs(kwargs): kwargs.setdefault( "title", r"Intersects $\zeta$ in epigraph($\vert B \vert$) s.t. " @@ -244,6 +244,18 @@ def _check_bounce_points(z1, z2, pitch, knots, B, plot=True, **kwargs): kwargs.setdefault("klabel", r"$1/\lambda$") kwargs.setdefault("hlabel", r"$\zeta$") kwargs.setdefault("vlabel", r"$\vert B \vert$") + return kwargs + + +def _check_bounce_points(z1, z2, pitch_inv, knots, B, plot=True, **kwargs): + """Check that bounce points are computed correctly.""" + z1 = atleast_nd(4, z1) + z2 = atleast_nd(4, z2) + pitch_inv = atleast_nd(3, pitch_inv) + B = atleast_nd(4, B) + + kwargs = _set_default_plot_kwargs(kwargs) + plots = [] assert z1.shape == z2.shape mask = (z1 - z2) != 0.0 @@ -253,51 +265,56 @@ def _check_bounce_points(z1, z2, pitch, knots, B, plot=True, **kwargs): err_1 = jnp.any(z1 > z2, axis=-1) err_2 = jnp.any(z1[..., 1:] < z2[..., :-1], axis=-1) - P, S, _ = z1.shape - for s in range(S): - Bs = PPoly(B[:, s], knots) - for p in range(P): - Bs_midpoint = Bs((z1[p, s] + z2[p, s]) / 2) - err_3 = jnp.any(Bs_midpoint > 1 / pitch[p, s] + eps) - if not (err_1[p, s] or err_2[p, s] or err_3): + eps = kwargs.pop("eps", jnp.finfo(jnp.array(1.0).dtype).eps * 10) + for ml in np.ndindex(B.shape[:-2]): + ppoly = PPoly(B[ml].T, knots) + for p in range(pitch_inv.shape[-1]): + idx = (*ml, p) + B_midpoint = ppoly((z1[idx] + z2[idx]) / 2) + err_3 = jnp.any(B_midpoint > pitch_inv[idx] + eps) + if not (err_1[idx] or err_2[idx] or err_3): continue - _z1 = z1[p, s][mask[p, s]] - _z2 = z2[p, s][mask[p, s]] + _z1 = z1[idx][mask[idx]] + _z2 = z2[idx][mask[idx]] if plot: plot_ppoly( - ppoly=Bs, + ppoly=ppoly, z1=_z1, z2=_z2, - k=1 / pitch[p, s], + k=pitch_inv[idx], + title=kwargs.pop("title") + f", (m,l,p)={idx}", **kwargs, ) print(" z1 | z2") print(jnp.column_stack([_z1, _z2])) - assert not err_1[p, s], "Intersects have an inversion.\n" - assert not err_2[p, s], "Detected discontinuity.\n" + assert not err_1[idx], "Intersects have an inversion.\n" + assert not err_2[idx], "Detected discontinuity.\n" assert not err_3, ( - f"Detected |B| = {Bs_midpoint[mask[p, s]]} > {1 / pitch[p, s] + eps} " + f"Detected |B| = {B_midpoint[mask[idx]]} > {pitch_inv[idx] + eps} " "= 1/λ in well, implying the straight line path between " "bounce points is in hypograph(|B|). Use more knots.\n" ) if plot: - plot_ppoly( - ppoly=Bs, - z1=z1[:, s], - z2=z2[:, s], - k=1 / pitch[:, s], - **kwargs, + plots.append( + plot_ppoly( + ppoly=ppoly, + z1=z1[ml], + z2=z2[ml], + k=pitch_inv[ml], + **kwargs, + ) ) + return plots -def bounce_quadrature( +def _bounce_quadrature( x, w, z1, z2, - pitch, integrand, + pitch_inv, f, data, knots, @@ -306,7 +323,7 @@ def bounce_quadrature( check=False, plot=False, ): - """Bounce integrate ∫ f(ℓ) dℓ. + """Bounce integrate ∫ f(λ, ℓ) dℓ. Parameters ---------- @@ -317,31 +334,29 @@ def bounce_quadrature( Shape (w.size, ). Quadrature weights. z1, z2 : jnp.ndarray - Shape (P, S, num_well). - ζ coordinates of bounce points. The points are grouped and ordered such + Shape (..., P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of |B|. - pitch : jnp.ndarray - Shape must broadcast with (P, S). - λ values to evaluate the bounce integral at each field line. λ(ρ,α) is - specified by ``pitch[...,(ρ,α)]`` where in the latter the labels (ρ,α) are - interpreted as the index into the last axis that corresponds to that field - line. If two-dimensional, the first axis is the batch axis. integrand : callable The composition operator on the set of functions in ``f`` that maps the - functions in ``f`` to the integrand f(ℓ) in ∫ f(ℓ) dℓ. It should accept the - arrays in ``f`` as arguments as well as the additional keyword arguments: - ``B`` and ``pitch``. A quadrature will be performed to approximate the - bounce integral of ``integrand(*f,B=B,pitch=pitch)``. + functions in ``f`` to the integrand f(λ, ℓ) in ∫ f(λ, ℓ) dℓ. It should + accept the arrays in ``f`` as arguments as well as the additional keyword + arguments: ``B`` and ``pitch``. A quadrature will be performed to + approximate the bounce integral of ``integrand(*f,B=B,pitch=pitch)``. + pitch_inv : jnp.ndarray + Shape (..., P). + 1/λ values to compute the bounce integrals. f : list[jnp.ndarray] - Shape (S, knots.size). + Shape (..., N). Real scalar-valued functions evaluated on the ``knots``. These functions should be arguments to the callable ``integrand``. data : dict[str, jnp.ndarray] - Data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``. - Must include names in ``Bounce1D.required_names()``. + Shape (..., 1, N). + Required data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``. + Must include names in ``Bounce1D.required_names``. knots : jnp.ndarray - Shape (knots.size, ). + Shape (N, ). Unique ζ coordinates where the arrays in ``data`` and ``f`` were evaluated. method : str Method of interpolation. @@ -353,30 +368,29 @@ def bounce_quadrature( Flag for debugging. Must be false for JAX transformations. Ignored if ``batch`` is false. plot : bool - Whether to plot stuff if ``check`` is true. Default is false. - Only developers doing debugging want to see these plots. + Whether to plot the quantities in the integrand interpolated to the + quadrature points of each integral. Ignored if ``check`` is false. Returns ------- result : jnp.ndarray - Shape (P, S, num_well). - Quadrature for every pitch. - First axis enumerates pitch values. Second axis enumerates the field lines. - Last axis enumerates the bounce integrals. + Shape (..., P, num_well). + Last axis enumerates the bounce integrals for a field line, + flux surface, and pitch. """ - errorif(z1.ndim != 3 or z1.shape != z2.shape) errorif(x.ndim != 1 or x.shape != w.shape) - pitch = jnp.atleast_2d(pitch) + errorif(z1.ndim < 2 or z1.shape != z2.shape) + pitch_inv = jnp.atleast_1d(pitch_inv) if not isinstance(f, (list, tuple)): - f = [f] + f = [f] if isinstance(f, (jnp.ndarray, np.ndarray)) else list(f) # Integrate and complete the change of variable. if batch: result = _interpolate_and_integrate( w=w, Q=bijection_from_disc(x, z1[..., jnp.newaxis], z2[..., jnp.newaxis]), - pitch=pitch, + pitch_inv=pitch_inv, integrand=integrand, f=f, data=data, @@ -386,16 +400,14 @@ def bounce_quadrature( plot=plot, ) else: - f = list(f) - # TODO: Use batched vmap. - def loop(z): + def loop(z): # over num well axis z1, z2 = z # Need to return tuple because input was tuple; artifact of JAX map. return None, _interpolate_and_integrate( w=w, Q=bijection_from_disc(x, z1[..., jnp.newaxis], z2[..., jnp.newaxis]), - pitch=pitch, + pitch_inv=pitch_inv, integrand=integrand, f=f, data=data, @@ -403,6 +415,7 @@ def loop(z): method=method, check=False, plot=False, + batch=True, ) result = jnp.moveaxis( @@ -411,22 +424,21 @@ def loop(z): destination=-1, ) - result = result * grad_bijection_from_disc(z1, z2) - assert result.shape == (pitch.shape[0], data["|B|"].shape[0], z1.shape[-1]) - return result + return result * grad_bijection_from_disc(z1, z2) def _interpolate_and_integrate( w, Q, - pitch, + pitch_inv, integrand, f, data, knots, method, - check=False, - plot=False, + check, + plot, + batch=False, ): """Interpolate given functions to points ``Q`` and perform quadrature. @@ -436,68 +448,60 @@ def _interpolate_and_integrate( Shape (w.size, ). Quadrature weights. Q : jnp.ndarray - Shape (P, S, Q.shape[2], w.size). + Shape (..., P, Q.shape[-2], w.size). Quadrature points in ζ coordinates. - data : dict[str, jnp.ndarray] - Data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``. - Must include names in ``Bounce1D.required_names()``. Returns ------- result : jnp.ndarray Shape Q.shape[:-1]. - Quadrature for every pitch. + Quadrature result. """ - assert pitch.ndim == 2 - assert w.ndim == knots.ndim == 1 - assert 3 <= Q.ndim <= 4 and Q.shape[:2] == (pitch.shape[0], data["|B|"].shape[0]) - assert Q.shape[-1] == w.size - assert knots.size == data["|B|"].shape[-1] - assert ( - data["B^zeta"].shape - == data["B^zeta_z|r,a"].shape - == data["|B|"].shape - == data["|B|_z|r,a"].shape - ) + assert w.ndim == 1 and Q.shape[-1] == w.size + assert Q.shape[-3 + batch] == pitch_inv.shape[-1] + assert data["|B|"].shape[-1] == knots.size - pitch = jnp.expand_dims(pitch, axis=(2, 3) if (Q.ndim == 4) else 2) shape = Q.shape - Q = Q.reshape(Q.shape[0], Q.shape[1], -1) + if not batch: + Q = flatten_matrix(Q) b_sup_z = interp1d_Hermite_vec( Q, knots, data["B^zeta"] / data["|B|"], data["B^zeta_z|r,a"] / data["|B|"] - data["B^zeta"] * data["|B|_z|r,a"] / data["|B|"] ** 2, - ).reshape(shape) - B = interp1d_Hermite_vec(Q, knots, data["|B|"], data["|B|_z|r,a"]).reshape(shape) + ) + B = interp1d_Hermite_vec(Q, knots, data["|B|"], data["|B|_z|r,a"]) # Spline each function separately so that operations in the integrand # that do not preserve smoothness can be captured. - f = [interp1d_vec(Q, knots, f_i, method=method).reshape(shape) for f_i in f] - result = jnp.dot(integrand(*f, B=B, pitch=pitch) / b_sup_z, w) - + f = [interp1d_vec(Q, knots, f_i[..., jnp.newaxis, :], method=method) for f_i in f] + result = ( + (integrand(*f, B=B, pitch=1 / pitch_inv[..., jnp.newaxis]) / b_sup_z) + .reshape(shape) + .dot(w) + ) if check: - _check_interp(Q.reshape(shape), f, b_sup_z, B, data["|B|_z|r,a"], result, plot) + _check_interp(shape, Q, f, b_sup_z, B, result, plot) return result -def _check_interp(Q, f, b_sup_z, B, B_z_ra, result, plot): - """Check for floating point errors. +def _check_interp(shape, Q, f, b_sup_z, B, result, plot): + """Check for interpolation failures and floating point issues. Parameters ---------- + shape : tuple + (..., P, Q.shape[-2], w.size). Q : jnp.ndarray Quadrature points in ζ coordinates. - f : list of jnp.ndarray + f : list[jnp.ndarray] Arguments to the integrand, interpolated to Q. b_sup_z : jnp.ndarray Contravariant toroidal component of magnetic field, interpolated to Q. B : jnp.ndarray Norm of magnetic field, interpolated to Q. - B_z_ra : jnp.ndarray - Norm of magnetic field derivative, (∂|B|/∂ζ)|(ρ,α). result : jnp.ndarray Output of ``_interpolate_and_integrate``. plot : bool @@ -505,106 +509,95 @@ def _check_interp(Q, f, b_sup_z, B, B_z_ra, result, plot): """ assert jnp.isfinite(Q).all(), "NaN interpolation point." + assert not ( + jnp.isclose(B, 0).any() or jnp.isclose(b_sup_z, 0).any() + ), "|B| has vanished, violating the hairy ball theorem." + # Integrals that we should be computing. - marked = jnp.any(Q != 0.0, axis=-1) - goal = jnp.sum(marked) + marked = jnp.any(Q.reshape(shape) != 0.0, axis=-1) + goal = marked.sum() - msg = "Interpolation failed." - assert jnp.isfinite(B_z_ra).all(), msg - assert goal == jnp.sum(marked & jnp.isfinite(jnp.sum(b_sup_z, axis=-1))), msg - assert goal == jnp.sum(marked & jnp.isfinite(jnp.sum(B, axis=-1))), msg + assert goal == (marked & jnp.isfinite(b_sup_z).reshape(shape).all(axis=-1)).sum() + assert goal == (marked & jnp.isfinite(B).reshape(shape).all(axis=-1)).sum() for f_i in f: - assert goal == jnp.sum(marked & jnp.isfinite(jnp.sum(f_i, axis=-1))), msg - - msg = "|B| has vanished, violating the hairy ball theorem." - assert not jnp.isclose(B, 0).any(), msg - assert not jnp.isclose(b_sup_z, 0).any(), msg + assert goal == (marked & jnp.isfinite(f_i).reshape(shape).all(axis=-1)).sum() # Number of those integrals that were computed. - actual = jnp.sum(marked & jnp.isfinite(result)) + actual = (marked & jnp.isfinite(result)).sum() assert goal == actual, ( f"Lost {goal - actual} integrals from NaN generation in the integrand. This " - "can be caused by floating point error or a poor choice of quadrature nodes." + "is caused by floating point error, usually due to a poor quadrature choice." ) if plot: - _plot_check_interp(Q, B, name=r"$\vert B \vert$") - _plot_check_interp(Q, b_sup_z, name=r"$ (B / \vert B \vert) \cdot e^{\zeta}$") + Q = Q.reshape(shape) + _plot_check_interp(Q, B.reshape(shape), name=r"$\vert B \vert$") + _plot_check_interp( + Q, b_sup_z.reshape(shape), name=r"$(B / \vert B \vert) \cdot e^{\zeta}$" + ) def _plot_check_interp(Q, V, name=""): - """Plot V[λ, (ρ, α), (ζ₁, ζ₂)](Q).""" - for p in range(Q.shape[0]): - for s in range(Q.shape[1]): - marked = jnp.nonzero(jnp.any(Q != 0.0, axis=-1))[0] - if marked.size == 0: - continue - fig, ax = plt.subplots() - ax.set_xlabel(r"$\zeta$") - ax.set_ylabel(name) - ax.set_title( - f"Interpolation of {name} to quadrature points. Index {p},{s}." - ) - for i in marked: - ax.plot(Q[p, s, i], V[p, s, i], marker="o") - fig.text( - 0.01, - 0.01, - f"Each color specifies the set of points and values (ζ, {name}(ζ)) " - "used to evaluate an integral.", - ) - plt.tight_layout() - plt.show() + """Plot V[..., λ, (ζ₁, ζ₂)](Q).""" + for idx in np.ndindex(Q.shape[:3]): + marked = jnp.nonzero(jnp.any(Q[idx] != 0.0, axis=-1))[0] + if marked.size == 0: + continue + fig, ax = plt.subplots() + ax.set_xlabel(r"$\zeta$") + ax.set_ylabel(name) + ax.set_title(f"Interpolation of {name} to quadrature points, (m,l,p)={idx}") + for i in marked: + ax.plot(Q[(*idx, i)], V[(*idx, i)], marker="o") + fig.text(0.01, 0.01, "Each color specifies a particular integral.") + plt.tight_layout() + plt.show() def _get_extrema(knots, g, dg_dz, sentinel=jnp.nan): - """Return ext (ζ*, g(ζ*)). + """Return extrema (z*, g(z*)). Parameters ---------- knots : jnp.ndarray - Shape (knots.size, ). + Shape (N, ). ζ coordinates of spline knots. Must be strictly increasing. g : jnp.ndarray - Shape (g.shape[0], S, knots.size - 1). + Shape (..., N - 1, g.shape[-1]). Polynomial coefficients of the spline of g in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. dg_dz : jnp.ndarray - Shape (g.shape[0] - 1, *g.shape[1:]). - Polynomial coefficients of the spline of ∂g/∂ζ in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂z in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. sentinel : float Value with which to pad array to return fixed shape. Returns ------- ext, g_ext : jnp.ndarray - Shape (S, (knots.size - 1) * (degree - 1)). - First array enumerates ζ*. Second array enumerates g(ζ*) + Shape (..., (N - 1) * (g.shape[-1] - 2)). + First array enumerates z*. Second array enumerates g(z*) Sorting order of extrema is arbitrary. """ g, dg_dz, _ = _check_spline_shape(knots, g, dg_dz) - S, degree = g.shape[1], g.shape[0] - 1 - ext = poly_root( + ext = polyroot_vec( c=dg_dz, a_min=jnp.array([0.0]), a_max=jnp.diff(knots), sentinel=sentinel ) - assert ext.shape == (S, knots.size - 1, degree - 1) - g_ext = polyval_vec(x=ext, c=g[..., jnp.newaxis]).reshape(S, -1) + g_ext = flatten_matrix(polyval_vec(x=ext, c=g[..., jnp.newaxis, :])) # Transform out of local power basis expansion. - ext = (ext + knots[:-1, jnp.newaxis]).reshape(S, -1) + ext = flatten_matrix(ext + knots[:-1, jnp.newaxis]) + assert ext.shape == g_ext.shape and ext.shape[-1] == g.shape[-2] * (g.shape[-1] - 2) return ext, g_ext def _where_for_argmin(z1, z2, ext, g_ext, upper_sentinel): - assert z1.shape[1] == z2.shape[1] == ext.shape[0] == g_ext.shape[0] return jnp.where( - (z1[..., jnp.newaxis] < ext[:, jnp.newaxis]) - & (ext[:, jnp.newaxis] < z2[..., jnp.newaxis]), - g_ext[:, jnp.newaxis], + (z1[..., jnp.newaxis] < ext[..., jnp.newaxis, jnp.newaxis, :]) + & (ext[..., jnp.newaxis, jnp.newaxis, :] < z2[..., jnp.newaxis]), + g_ext[..., jnp.newaxis, jnp.newaxis, :], upper_sentinel, ) @@ -619,28 +612,24 @@ def interp_to_argmin( Parameters ---------- h : jnp.ndarray - Shape must broadcast with (S, knots.size). + Shape (..., N). Values evaluated on ``knots`` to interpolate. z1, z2 : jnp.ndarray - Shape (P, S, num_well). - ζ coordinates of bounce points. The points are grouped and ordered such - that the straight line path between ``z1`` and ``z2`` resides in the - epigraph of g. + Shape (..., P, W). + Boundaries to detect argmin between. knots : jnp.ndarray - Shape (knots.size, ). - ζ coordinates of spline knots. Must be strictly increasing. + Shape (N, ). + z coordinates of spline knots. Must be strictly increasing. g : jnp.ndarray - Shape (g.shape[0], S, knots.size - 1). + Shape (..., N - 1, g.shape[-1]). Polynomial coefficients of the spline of g in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. dg_dz : jnp.ndarray - Shape (g.shape[0] - 1, *g.shape[1:]). - Polynomial coefficients of the spline of ∂g/∂ζ in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂z in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. method : str Method of interpolation. See https://interpax.readthedocs.io/en/latest/_api/interpax.interp1d.html. @@ -663,18 +652,21 @@ def interp_to_argmin( Returns ------- h : jnp.ndarray - Shape (P, S, num_well). - mean_A h(ζ) + Shape (..., P, W). """ - ext, g = _get_extrema(knots, g, dg_dz, sentinel=0) - # JAX softmax(x) does the proper shift to compute softmax(x - max(x)), but it's - # still not a good idea to compute over a large length scale, so we warn in - # docstring to choose upper sentinel properly. - argmin = softmax(beta * _where_for_argmin(z1, z2, ext, g, upper_sentinel), axis=-1) + assert z1.ndim == z2.ndim >= 2 and z1.shape == z2.shape + ext, g_ext = _get_extrema(knots, g, dg_dz, sentinel=0) + # Our softargmax(x) does the proper shift to compute softargmax(x - max(x)), + # but it's still not a good idea to compute over a large length scale, so we + # warn in docstring to choose upper sentinel properly. + argmin = softargmax( + beta * _where_for_argmin(z1, z2, ext, g_ext, upper_sentinel), + axis=-1, + ) h = jnp.linalg.vecdot( argmin, - interp1d_vec(ext, knots, jnp.atleast_2d(h), method=method)[:, jnp.newaxis], + interp1d_vec(ext, knots, h, method=method)[..., jnp.newaxis, jnp.newaxis, :], ) assert h.shape == z1.shape return h @@ -694,28 +686,24 @@ def interp_to_argmin_hard(h, z1, z2, knots, g, dg_dz, method="cubic"): Parameters ---------- h : jnp.ndarray - Shape must broadcast with (S, knots.size). + Shape (..., N). Values evaluated on ``knots`` to interpolate. z1, z2 : jnp.ndarray - Shape (P, S, num_well). - ζ coordinates of bounce points. The points are grouped and ordered such - that the straight line path between ``z1`` and ``z2`` resides in the - epigraph of g. + Shape (..., P, W). + Boundaries to detect argmin between. knots : jnp.ndarray - Shape (knots.size, ). - ζ coordinates of spline knots. Must be strictly increasing. + Shape (N, ). + z coordinates of spline knots. Must be strictly increasing. g : jnp.ndarray - Shape (g.shape[0], S, knots.size - 1). + Shape (..., N - 1, g.shape[-1]). Polynomial coefficients of the spline of g in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. dg_dz : jnp.ndarray - Shape (g.shape[0] - 1, *g.shape[1:]). - Polynomial coefficients of the spline of ∂g/∂ζ in local power basis. - First axis enumerates the coefficients of power series. Second axis - enumerates the splines. Last axis enumerates the polynomials that - compose a particular spline. + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂z in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. method : str Method of interpolation. See https://interpax.readthedocs.io/en/latest/_api/interpax.interp1d.html. @@ -724,19 +712,26 @@ def interp_to_argmin_hard(h, z1, z2, knots, g, dg_dz, method="cubic"): Returns ------- h : jnp.ndarray - Shape (P, S, num_well). - h(A) + Shape (..., P, W). """ - ext, g = _get_extrema(knots, g, dg_dz, sentinel=0) + assert z1.ndim == z2.ndim >= 2 and z1.shape == z2.shape + ext, g_ext = _get_extrema(knots, g, dg_dz, sentinel=0) # We can use the non-differentiable max because we actually want the gradients # to accumulate through only the minimum since we are differentiating how our # physics objective changes wrt equilibrium perturbations not wrt which of the # extrema get interpolated to. - argmin = jnp.argmin(_where_for_argmin(z1, z2, ext, g, jnp.max(g) + 1), axis=-1) - A = jnp.take_along_axis(ext[jnp.newaxis], argmin, axis=-1) - h = interp1d_vec(A, knots, jnp.atleast_2d(h), method=method) - assert h.shape == z1.shape + argmin = jnp.argmin( + _where_for_argmin(z1, z2, ext, g_ext, jnp.max(g_ext) + 1), + axis=-1, + ) + h = interp1d_vec( + jnp.take_along_axis(ext[jnp.newaxis], argmin, axis=-1), + knots, + h[..., jnp.newaxis, :], + method=method, + ) + assert h.shape == z1.shape, h.shape return h @@ -755,7 +750,8 @@ def plot_ppoly( start=None, stop=None, include_knots=False, - knot_transparency=0.1, + knot_transparency=0.2, + include_legend=True, ): """Plot the piecewise polynomial ``ppoly``. @@ -794,10 +790,13 @@ def plot_ppoly( Whether to plot vertical lines at the knots. knot_transparency : float Transparency of knot lines. + include_legend : bool + Whether to include the legend in the plot. Default is true. Returns ------- - fig, ax : matplotlib figure and axes + fig, ax + Matplotlib (fig, ax) tuple. """ fig, ax = plt.subplots() @@ -828,7 +827,8 @@ def plot_ppoly( ) ax.set_xlabel(hlabel) ax.set_ylabel(vlabel) - ax.legend(legend.values(), legend.keys(), loc="lower right") + if include_legend: + ax.legend(legend.values(), legend.keys(), loc="lower right") ax.set_title(title) plt.tight_layout() if show: diff --git a/desc/integrals/interp_utils.py b/desc/integrals/interp_utils.py index c0fd0818c7..3dbc5b14a0 100644 --- a/desc/integrals/interp_utils.py +++ b/desc/integrals/interp_utils.py @@ -1,16 +1,21 @@ -"""Fast interpolation utilities.""" +"""Fast interpolation utilities. + +Notes +----- +These polynomial utilities are chosen for performance on gpu among +methods that have the best (asymptotic) algorithmic complexity. +For example, we prefer to not use Horner's method. +""" from functools import partial import numpy as np from interpax import interp1d from orthax.chebyshev import chebroots, chebvander -from orthax.polynomial import polyvander from desc.backend import dct, jnp, rfft, rfft2, take -from desc.compute.utils import safediv from desc.integrals.quad_utils import bijection_from_disc -from desc.utils import Index, errorif +from desc.utils import Index, errorif, safediv # TODO: Boyd's method 𝒪(N²) instead of Chebyshev companion matrix 𝒪(N³). # John P. Boyd, Computing real roots of a polynomial in Chebyshev series @@ -375,41 +380,50 @@ def idct_non_uniform(xq, a, n, axis=-1): return fq +# Warning: method must be specified as keyword argument. +interp1d_vec = jnp.vectorize( + interp1d, signature="(m),(n),(n)->(m)", excluded={"method"} +) + + +@partial(jnp.vectorize, signature="(m),(n),(n),(n)->(m)") +def interp1d_Hermite_vec(xq, x, f, fx, /): + """Vectorized cubic Hermite spline.""" + return interp1d(xq, x, f, method="cubic", fx=fx) + + def polyder_vec(c): """Coefficients for the derivatives of the given set of polynomials. Parameters ---------- c : jnp.ndarray - First axis should store coefficients of a polynomial. For a polynomial given by - ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[0]-1``, coefficient cᵢ should be stored at - ``c[n-i]``. + Last axis should store coefficients of a polynomial. For a polynomial given by + ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at + ``c[...,n-i]``. Returns ------- poly : jnp.ndarray Coefficients of polynomial derivative, ignoring the arbitrary constant. That is, - ``poly[i]`` stores the coefficient of the monomial xⁿ⁻ⁱ⁻¹, where n is - ``c.shape[0]-1``. + ``poly[...,i]`` stores the coefficient of the monomial xⁿ⁻ⁱ⁻¹, where n is + ``c.shape[-1]-1``. """ - poly = (c[:-1].T * jnp.arange(c.shape[0] - 1, 0, -1)).T - return poly + return c[..., :-1] * jnp.arange(c.shape[-1] - 1, 0, -1) -def polyval_vec(x, c): +def polyval_vec(*, x, c): """Evaluate the set of polynomials ``c`` at the points ``x``. - Note this function is not the same as ``np.polynomial.polynomial.polyval(x,c)``. - Parameters ---------- x : jnp.ndarray - Real coordinates at which to evaluate the set of polynomials. + Coordinates at which to evaluate the set of polynomials. c : jnp.ndarray - First axis should store coefficients of a polynomial. For a polynomial given by - ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[0]-1``, coefficient cᵢ should be stored at - ``c[n-i]``. + Last axis should store coefficients of a polynomial. For a polynomial given by + ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at + ``c[...,n-i]``. Returns ------- @@ -420,53 +434,60 @@ def polyval_vec(x, c): -------- .. code-block:: python - val = polyval_vec(x, c) - if val.ndim != max(x.ndim, c.ndim - 1): - raise ValueError(f"Incompatible shapes {x.shape} and {c.shape}.") - for index in np.ndindex(c.shape[1:]): - idx = (..., *index) - np.testing.assert_allclose( - actual=val[idx], - desired=np.poly1d(c[idx])(x[idx]), - err_msg=f"Failed with shapes {x.shape} and {c.shape}.", - ) + np.testing.assert_allclose( + polyval_vec(x=x, c=c), + np.sum(polyvander(x, c.shape[-1] - 1) * c[..., ::-1], axis=-1), + ) """ # Better than Horner's method as we expect to evaluate low order polynomials. # No need to use fast multipoint evaluation techniques for the same reason. - val = jnp.linalg.vecdot( - polyvander(x, c.shape[0] - 1), jnp.moveaxis(jnp.flipud(c), 0, -1) + return jnp.sum( + c * x[..., jnp.newaxis] ** jnp.arange(c.shape[-1] - 1, -1, -1), + axis=-1, ) - return val -# Warning: method must be specified as keyword argument. -interp1d_vec = jnp.vectorize( - interp1d, signature="(m),(n),(n)->(m)", excluded={"method"} -) +# TODO: Eventually do a PR to move this stuff into interpax. -@partial(jnp.vectorize, signature="(m),(n),(n),(n)->(m)") -def interp1d_Hermite_vec(xq, x, f, fx): - """Vectorized cubic Hermite spline. Does not support keyword arguments.""" - return interp1d(xq, x, f, method="cubic", fx=fx) +def _subtract_last(c, k): + """Subtract ``k`` from last index of last axis of ``c``. + Semantically same as ``return c.copy().at[...,-1].add(-k)``, + but allows dimension to increase. + """ + c_1 = c[..., -1] - k + c = jnp.concatenate( + [ + jnp.broadcast_to(c[..., :-1], (*c_1.shape, c.shape[-1] - 1)), + c_1[..., jnp.newaxis], + ], + axis=-1, + ) + return c -# TODO: Eventually do a PR to move this stuff into interpax. + +def _filter_distinct(r, sentinel, eps): + """Set all but one of matching adjacent elements in ``r`` to ``sentinel``.""" + # eps needs to be low enough that close distinct roots do not get removed. + # Otherwise, algorithms relying on continuity will fail. + mask = jnp.isclose(jnp.diff(r, axis=-1, prepend=sentinel), 0, atol=eps) + r = jnp.where(mask, sentinel, r) + return r _roots = jnp.vectorize(partial(jnp.roots, strip_zeros=False), signature="(m)->(n)") -def poly_root( +def polyroot_vec( c, k=0, a_min=None, a_max=None, sort=False, sentinel=jnp.nan, - # About 2e-12 for 64 bit jax. - eps=min(jnp.finfo(jnp.array(1.0).dtype).eps * 1e4, 1e-8), + eps=max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12), distinct=False, ): """Roots of polynomial with given coefficients. @@ -474,26 +495,26 @@ def poly_root( Parameters ---------- c : jnp.ndarray - First axis should store coefficients of a polynomial. For a polynomial given by - ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[0]-1``, coefficient cᵢ should be stored at - ``c[n-i]``. + Last axis should store coefficients of a polynomial. For a polynomial given by + ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at + ``c[...,n-i]``. k : jnp.ndarray - Specify to find solutions to ∑ᵢⁿ cᵢ xⁱ = ``k``. Should broadcast with arrays of - shape ``c.shape[1:]``. + Shape (..., *c.shape[:-1]). + Specify to find solutions to ∑ᵢⁿ cᵢ xⁱ = ``k``. a_min : jnp.ndarray + Shape (..., *c.shape[:-1]). Minimum ``a_min`` and maximum ``a_max`` value to return roots between. - If specified only real roots are returned. If None, returns all complex roots. - Should broadcast with arrays of shape ``c.shape[1:]``. + If specified only real roots are returned, otherwise returns all complex roots. a_max : jnp.ndarray + Shape (..., *c.shape[:-1]). Minimum ``a_min`` and maximum ``a_max`` value to return roots between. - If specified only real roots are returned. If None, returns all complex roots. - Should broadcast with arrays of shape ``c.shape[1:]``. + If specified only real roots are returned, otherwise returns all complex roots. sort : bool Whether to sort the roots. sentinel : float Value with which to pad array in place of filtered elements. Anything less than ``a_min`` or greater than ``a_max`` plus some floating point - error buffer will work just like nan while avoiding nan gradient. + error buffer will work just like nan while avoiding ``nan`` gradient. eps : float Absolute tolerance with which to consider value as zero. distinct : bool @@ -503,30 +524,29 @@ def poly_root( Returns ------- r : jnp.ndarray - Shape (..., c.shape[1:], c.shape[0] - 1). + Shape (..., *c.shape[:-1], c.shape[-1] - 1). The roots of the polynomial, iterated over the last axis. """ get_only_real_roots = not (a_min is None and a_max is None) - + num_coef = c.shape[-1] + c = _subtract_last(c, k) func = {2: _root_linear, 3: _root_quadratic, 4: _root_cubic} + if ( - c.shape[0] in func + num_coef in func and get_only_real_roots and not (jnp.iscomplexobj(c) or jnp.iscomplexobj(k)) ): # Compute from analytic formula to avoid the issue of complex roots with small # imaginary parts and to avoid nan in gradient. - r = func[c.shape[0]](*c[:-1], c[-1] - k, sentinel, eps, distinct) + r = func[num_coef](C=c, sentinel=sentinel, eps=eps, distinct=distinct) # We already filtered distinct roots for quadratics. - distinct = distinct and c.shape[0] > 3 + distinct = distinct and num_coef > 3 else: # Compute from eigenvalues of polynomial companion matrix. - c_n = c[-1] - k - c = [jnp.broadcast_to(c_i, c_n.shape) for c_i in c[:-1]] - c.append(c_n) - c = jnp.stack(c, axis=-1) r = _roots(c) + if get_only_real_roots: a_min = -jnp.inf if a_min is None else a_min[..., jnp.newaxis] a_max = +jnp.inf if a_max is None else a_max[..., jnp.newaxis] @@ -538,11 +558,13 @@ def poly_root( if sort or distinct: r = jnp.sort(r, axis=-1) - return _filter_distinct(r, sentinel, eps) if distinct else r + r = _filter_distinct(r, sentinel, eps) if distinct else r + assert r.shape[-1] == num_coef - 1 + return r -def _root_cubic(a, b, c, d, sentinel, eps, distinct): - """Return r such that a r³ + b r² + c r + d = 0, assuming real coef and roots.""" +def _root_cubic(C, sentinel, eps, distinct): + """Return real cubic root assuming real coefficients.""" # numerical.recipes/book.html, page 228 def irreducible(Q, R, b, mask): @@ -583,23 +605,36 @@ def root(b, c, d): reducible(Q, R, b), ) + a = C[..., 0] + b = C[..., 1] + c = C[..., 2] + d = C[..., 3] return jnp.where( - # Tests catch failure here if eps < 1e-12 for 64 bit jax. + # Tests catch failure here if eps < 1e-12 for 64 bit precision. jnp.expand_dims(jnp.abs(a) <= eps, axis=-1), - _concat_sentinel(_root_quadratic(b, c, d, sentinel, eps, distinct), sentinel), + _concat_sentinel( + _root_quadratic( + C=C[..., 1:], sentinel=sentinel, eps=eps, distinct=distinct + ), + sentinel, + ), root(b, c, d), ) -def _root_quadratic(a, b, c, sentinel, eps, distinct): - """Return r such that a r² + b r + c = 0, assuming real coefficients and roots.""" +def _root_quadratic(C, sentinel, eps, distinct): + """Return real quadratic root assuming real coefficients.""" # numerical.recipes/book.html, page 227 + a = C[..., 0] + b = C[..., 1] + c = C[..., 2] + discriminant = b**2 - 4 * a * c q = -0.5 * (b + jnp.sign(b) * jnp.sqrt(jnp.abs(discriminant))) r1 = jnp.where( discriminant < 0, sentinel, - safediv(q, a, _root_linear(b, c, sentinel, eps)), + safediv(q, a, _root_linear(C=C[..., 1:], sentinel=sentinel, eps=eps)), ) r2 = jnp.where( # more robust to remove repeated roots with discriminant @@ -610,21 +645,14 @@ def _root_quadratic(a, b, c, sentinel, eps, distinct): return jnp.stack([r1, r2], axis=-1) -def _root_linear(a, b, sentinel, eps, distinct=False): - """Return r such that a r + b = 0.""" +def _root_linear(C, sentinel, eps, distinct=False): + """Return real linear root assuming real coefficients.""" + a = C[..., 0] + b = C[..., 1] return safediv(-b, a, jnp.where(jnp.abs(b) <= eps, 0, sentinel)) def _concat_sentinel(r, sentinel, num=1): - """Concat ``sentinel`` ``num`` times to ``r`` on last axis.""" + """Append ``sentinel`` ``num`` times to ``r`` on last axis.""" sent = jnp.broadcast_to(sentinel, (*r.shape[:-1], num)) return jnp.append(r, sent, axis=-1) - - -def _filter_distinct(r, sentinel, eps): - """Set all but one of matching adjacent elements in ``r`` to ``sentinel``.""" - # eps needs to be low enough that close distinct roots do not get removed. - # Otherwise, algorithms relying on continuity will fail. - mask = jnp.isclose(jnp.diff(r, axis=-1, prepend=sentinel), 0, atol=eps) - r = jnp.where(mask, sentinel, r) - return r diff --git a/desc/integrals/quad_utils.py b/desc/integrals/quad_utils.py index d1f66057da..692149e84e 100644 --- a/desc/integrals/quad_utils.py +++ b/desc/integrals/quad_utils.py @@ -19,7 +19,7 @@ def bijection_from_disc(x, a, b): def grad_bijection_from_disc(a, b): - """Gradient of affine bijection from disc.""" + """Gradient wrt ``x`` of ``bijection_from_disc``.""" dy_dx = 0.5 * (b - a) return dy_dx @@ -151,7 +151,7 @@ def leggauss_lob(deg, interior_only=False): Number of quadrature points. interior_only : bool Whether to exclude the points and weights at -1 and 1; - useful if f(-1) = f(1) = 0. If ``True``, then ``deg`` points are still + useful if f(-1) = f(1) = 0. If true, then ``deg`` points are still returned; these are the interior points for lobatto quadrature of ``deg+2``. Returns @@ -213,10 +213,9 @@ def get_quadrature(quad, automorphism): x, w = quad assert x.ndim == w.ndim == 1 if automorphism is not None: - # Apply automorphisms to supress singularities. auto, grad_auto = automorphism w = w * grad_auto(x) - # Recall bijection_from_disc(auto(x), ζ_b₁, ζ_b₂) = ζ. + # Recall bijection_from_disc(auto(x), ζ₁, ζ₂) = ζ. x = auto(x) return x, w diff --git a/desc/integrals/singularities.py b/desc/integrals/singularities.py index 3730c172af..ab2371a839 100644 --- a/desc/integrals/singularities.py +++ b/desc/integrals/singularities.py @@ -9,10 +9,9 @@ from desc.backend import fori_loop, jnp, put, vmap from desc.basis import DoubleFourierSeries from desc.compute.geom_utils import rpz2xyz, rpz2xyz_vec, xyz2rpz_vec -from desc.compute.utils import safediv, safenorm from desc.grid import LinearGrid from desc.io import IOAble -from desc.utils import isalmostequal, islinspaced +from desc.utils import isalmostequal, islinspaced, safediv, safenorm def _get_quadrature_nodes(q): diff --git a/desc/io/optimizable_io.py b/desc/io/optimizable_io.py index 554cdac070..e15a21756e 100644 --- a/desc/io/optimizable_io.py +++ b/desc/io/optimizable_io.py @@ -169,16 +169,17 @@ class IOAble(ABC, metaclass=_CombinedMeta): """Abstract Base Class for savable and loadable objects. Objects inheriting from this class can be saved and loaded via hdf5 or pickle. - To save properly, each object should have an attribute `_io_attrs_` which + To save properly, each object should have an attribute ``_io_attrs_`` which is a list of strings of the object attributes or properties that should be saved and loaded. - For saved objects to be loaded correctly, the __init__ method of any custom - types being saved should only assign attributes that are listed in `_io_attrs_`. + For saved objects to be loaded correctly, the ``__init__`` method of any custom + types being saved should only assign attributes that are listed in ``_io_attrs_``. Other attributes or other initialization should be done in a separate - `set_up()` method that can be called during __init__. The loading process - will involve creating an empty object, bypassing init, then setting any `_io_attrs_` - of the object, then calling `_set_up()` without any arguments, if it exists. + ``set_up()`` method that can be called during ``__init__``. The loading process + will involve creating an empty object, bypassing init, then setting any + ``_io_attrs_`` of the object, then calling ``_set_up()`` without any arguments, + if it exists. """ diff --git a/desc/magnetic_fields/__init__.py b/desc/magnetic_fields/__init__.py index 0a8f18abd8..173b04e7ee 100644 --- a/desc/magnetic_fields/__init__.py +++ b/desc/magnetic_fields/__init__.py @@ -9,6 +9,7 @@ SplineMagneticField, SumMagneticField, ToroidalMagneticField, + VectorPotentialField, VerticalMagneticField, _MagneticField, field_line_integrate, diff --git a/desc/magnetic_fields/_core.py b/desc/magnetic_fields/_core.py index 7b32a1217a..760f4e372b 100644 --- a/desc/magnetic_fields/_core.py +++ b/desc/magnetic_fields/_core.py @@ -15,7 +15,7 @@ DoubleFourierSeries, ) from desc.compute import compute as compute_fun -from desc.compute import rpz2xyz, rpz2xyz_vec, xyz2rpz +from desc.compute import rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec from desc.compute.utils import get_params, get_transforms from desc.derivatives import Derivative from desc.equilibrium import EquilibriaFamily, Equilibrium @@ -62,6 +62,40 @@ def body(i, B): return 1e-7 * fori_loop(0, J.shape[0], body, B) +def biot_savart_general_vector_potential(re, rs, J, dV): + """Biot-Savart law for arbitrary sources for vector potential. + + Parameters + ---------- + re : ndarray, shape(n_eval_pts, 3) + evaluation points to evaluate B at, in cartesian. + rs : ndarray, shape(n_src_pts, 3) + source points for current density J, in cartesian. + J : ndarray, shape(n_src_pts, 3) + current density vector at source points, in cartesian. + dV : ndarray, shape(n_src_pts) + volume element at source points + + Returns + ------- + A : ndarray, shape(n,3) + magnetic vector potential in cartesian components at specified points + """ + re, rs, J, dV = map(jnp.asarray, (re, rs, J, dV)) + assert J.shape == rs.shape + JdV = J * dV[:, None] + A = jnp.zeros_like(re) + + def body(i, A): + r = re - rs[i, :] + num = JdV[i, :] + den = jnp.linalg.norm(r, axis=-1) + A = A + jnp.where(den[:, None] == 0, 0, num / den[:, None]) + return A + + return 1e-7 * fori_loop(0, J.shape[0], body, A) + + def read_BNORM_file(fname, surface, eval_grid=None, scale_by_curpol=True): """Read BNORM-style .txt file containing Bnormal Fourier coefficients. @@ -193,6 +227,8 @@ def compute_magnetic_field( source_grid : Grid, int or None or array-like, optional Grid used to discretize MagneticField object if calculating B from Biot-Savart. Should NOT include endpoint at 2pi. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid Returns ------- @@ -205,6 +241,33 @@ def __call__(self, grid, params=None, basis="rpz"): """Compute magnetic field at a set of points.""" return self.compute_magnetic_field(grid, params, basis) + @abstractmethod + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize MagneticField object if calculating A from + Biot-Savart. Should NOT include endpoint at 2pi. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + def compute_Bnormal( self, surface, @@ -410,6 +473,7 @@ def save_mgrid( nR=101, nZ=101, nphi=90, + save_vector_potential=True, ): """Save the magnetic field to an mgrid NetCDF file in "raw" format. @@ -431,6 +495,9 @@ def save_mgrid( Number of grid points in the Z coordinate (default = 101). nphi : int, optional Number of grid points in the toroidal angle (default = 90). + save_vector_potential : bool, optional + Whether or not to save the magnetic vector potential to the mgrid + file, in addition to the magnetic field. Defaults to True. Returns ------- @@ -451,6 +518,15 @@ def save_mgrid( B_phi = field[:, 1].reshape(nphi, nZ, nR) B_Z = field[:, 2].reshape(nphi, nZ, nR) + # evaluate magnetic vector potential on grid + if save_vector_potential: + field = self.compute_magnetic_vector_potential(grid, basis="rpz") + A_R = field[:, 0].reshape(nphi, nZ, nR) + A_phi = field[:, 1].reshape(nphi, nZ, nR) + A_Z = field[:, 2].reshape(nphi, nZ, nR) + else: + A_R = None + # write mgrid file file = Dataset(path, mode="w", format="NETCDF3_64BIT_OFFSET") @@ -537,6 +613,28 @@ def save_mgrid( ) bz_001[:] = B_Z + if save_vector_potential: + ar_001 = file.createVariable("ar_001", np.float64, ("phi", "zee", "rad")) + ar_001.long_name = ( + "A_R = radial component of magnetic vector potential " + "in lab frame (T/m)." + ) + ar_001[:] = A_R + + ap_001 = file.createVariable("ap_001", np.float64, ("phi", "zee", "rad")) + ap_001.long_name = ( + "A_phi = toroidal component of magnetic vector potential " + "in lab frame (T/m)." + ) + ap_001[:] = A_phi + + az_001 = file.createVariable("az_001", np.float64, ("phi", "zee", "rad")) + az_001.long_name = ( + "A_Z = vertical component of magnetic vector potential " + "in lab frame (T/m)." + ) + az_001[:] = A_Z + file.close() @@ -618,6 +716,33 @@ def compute_magnetic_field( B = rpz2xyz_vec(B, phi=coords[:, 1]) return B + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + raise NotImplementedError( + "MagneticFieldFromUser does not have vector potential calculation " + "implemented." + ) + class ScaledMagneticField(_MagneticField, Optimizable): """Magnetic field scaled by a scalar value. @@ -703,6 +828,35 @@ def compute_magnetic_field( coords, params, basis, source_grid ) + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize MagneticField object if calculating A from + Biot-Savart. Should NOT include endpoint at 2pi. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + A : ndarray, shape(N,3) + scaled magnetic vector potential at specified points + + """ + return self._scale * self._field.compute_magnetic_vector_potential( + coords, params, basis, source_grid + ) + class SumMagneticField(_MagneticField, MutableSequence, OptimizableCollection): """Sum of two or more magnetic field sources. @@ -724,10 +878,16 @@ def __init__(self, *fields): ) self._fields = fields - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. Parameters ---------- @@ -742,6 +902,9 @@ def compute_magnetic_field( Biot-Savart. Should NOT include endpoint at 2pi. transforms : dict of Transform Transforms for R, Z, lambda, etc. Default is to build from source_grid + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- @@ -749,6 +912,11 @@ def compute_magnetic_field( scaled magnetic field at specified points """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) if params is None: params = [None] * len(self._fields) if isinstance(params, dict): @@ -770,13 +938,74 @@ def compute_magnetic_field( # zip does not terminate early transforms = transforms * len(self._fields) - B = 0 + op = {"B": "compute_magnetic_field", "A": "compute_magnetic_vector_potential"}[ + compute_A_or_B + ] + + AB = 0 for i, (field, g, tr) in enumerate(zip(self._fields, source_grid, transforms)): - B += field.compute_magnetic_field( + AB += getattr(field, op)( coords, params[i % len(params)], basis, source_grid=g, transforms=tr ) + return AB - return B + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize MagneticField object if calculating B from + Biot-Savart. Should NOT include endpoint at 2pi. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + field : ndarray, shape(N,3) + sum magnetic field at specified points + + """ + return self._compute_A_or_B( + coords, params, basis, source_grid, transforms, compute_A_or_B="B" + ) + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize MagneticField object if calculating A from + Biot-Savart. Should NOT include endpoint at 2pi. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + A : ndarray, shape(N,3) + sum magnetic vector potential at specified points + + """ + return self._compute_A_or_B( + coords, params, basis, source_grid, transforms, compute_A_or_B="A" + ) # dunder methods required by MutableSequence def __getitem__(self, i): @@ -886,10 +1115,54 @@ def compute_magnetic_field( return B + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + The vector potential is specified assuming the Coulomb Gauge. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for R0 and B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + params = setdefault(params, {}) + B0 = params.get("B0", self.B0) + R0 = params.get("R0", self.R0) + + assert basis.lower() in ["rpz", "xyz"] + coords = jnp.atleast_2d(jnp.asarray(coords)) + if basis == "xyz": + coords = xyz2rpz(coords) + az = -B0 * R0 * jnp.log(coords[:, 0]) + arp = jnp.zeros_like(az) + A = jnp.array([arp, arp, az]).T + # b/c it only has a nonzero z component, no need + # to switch bases back if xyz is given + return A + class VerticalMagneticField(_MagneticField, Optimizable): """Uniform magnetic field purely in the vertical (Z) direction. + The vector potential is specified assuming the Coulomb Gauge. + Parameters ---------- B0 : float @@ -940,18 +1213,63 @@ def compute_magnetic_field( params = setdefault(params, {}) B0 = params.get("B0", self.B0) - assert basis.lower() in ["rpz", "xyz"] coords = jnp.atleast_2d(jnp.asarray(coords)) - if basis == "xyz": - coords = xyz2rpz(coords) bz = B0 * jnp.ones_like(coords[:, 2]) brp = jnp.zeros_like(bz) B = jnp.array([brp, brp, bz]).T - if basis == "xyz": - B = rpz2xyz_vec(B, phi=coords[:, 1]) + # b/c it only has a nonzero z component, no need + # to switch bases back if xyz is given return B + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + The vector potential is specified assuming the Coulomb Gauge. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + params = setdefault(params, {}) + B0 = params.get("B0", self.B0) + + coords = jnp.atleast_2d(jnp.asarray(coords)) + + if basis == "xyz": + coords_xyz = coords + coords_rpz = xyz2rpz(coords) + else: + coords_rpz = coords + coords_xyz = rpz2xyz(coords) + ax = B0 / 2 * coords_xyz[:, 1] + ay = -B0 / 2 * coords_xyz[:, 0] + + az = jnp.zeros_like(ax) + A = jnp.array([ax, ay, az]).T + if basis == "rpz": + A = xyz2rpz_vec(A, phi=coords_rpz[:, 1]) + + return A + class PoloidalMagneticField(_MagneticField, Optimizable): """Pure poloidal magnetic field (ie in theta direction). @@ -1062,6 +1380,36 @@ def compute_magnetic_field( return B + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + raise NotImplementedError( + "PoloidalMagneticField has nonzero divergence, therefore it can't be " + "represented with a vector potential." + ) + class SplineMagneticField(_MagneticField, Optimizable): """Magnetic field from precomputed values on a grid. @@ -1080,6 +1428,12 @@ class SplineMagneticField(_MagneticField, Optimizable): toroidal magnetic field on grid BZ : array-like, shape(NR,Nphi,NZ,Ngroups) vertical magnetic field on grid + AR : array-like, shape(NR,Nphi,NZ,Ngroups) + radial magnetic vector potential on grid, optional + aphi : array-like, shape(NR,Nphi,NZ,Ngroups) + toroidal magnetic vector potential on grid, optional + AZ : array-like, shape(NR,Nphi,NZ,Ngroups) + vertical magnetic vector potential on grid, optional currents : array-like, shape(Ngroups) Currents or scaling factors for each field group. NFP : int, optional @@ -1098,6 +1452,9 @@ class SplineMagneticField(_MagneticField, Optimizable): "_BR", "_Bphi", "_BZ", + "_AR", + "_Aphi", + "_AZ", "_method", "_extrap", "_derivs", @@ -1110,7 +1467,20 @@ class SplineMagneticField(_MagneticField, Optimizable): _static_attrs = ["_extrap", "_period"] def __init__( - self, R, phi, Z, BR, Bphi, BZ, currents=1.0, NFP=1, method="cubic", extrap=False + self, + R, + phi, + Z, + BR, + Bphi, + BZ, + AR=None, + Aphi=None, + AZ=None, + currents=1.0, + NFP=1, + method="cubic", + extrap=False, ): R, phi, Z, currents = map( lambda x: jnp.atleast_1d(jnp.asarray(x)), (R, phi, Z, currents) @@ -1152,6 +1522,17 @@ def _atleast_4d(x): self._derivs["BR"] = self._approx_derivs(self._BR) self._derivs["Bphi"] = self._approx_derivs(self._Bphi) self._derivs["BZ"] = self._approx_derivs(self._BZ) + if AR is not None and Aphi is not None and AZ is not None: + AR, Aphi, AZ = map(_atleast_4d, (AR, Aphi, AZ)) + assert AR.shape == Aphi.shape == AZ.shape == shape + self._AR = AR + self._Aphi = Aphi + self._AZ = AZ + self._derivs["AR"] = self._approx_derivs(self._AR) + self._derivs["Aphi"] = self._approx_derivs(self._Aphi) + self._derivs["AZ"] = self._approx_derivs(self._AZ) + else: + self._AR = self._Aphi = self._AZ = None @property def NFP(self): @@ -1190,10 +1571,16 @@ def _approx_derivs(self, Bi): tempdict[key] = val[:, 0, :] return tempdict - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or magnetic vector potential at a set of points. Parameters ---------- @@ -1208,107 +1595,185 @@ def compute_magnetic_field( transforms : dict of Transform Transforms for R, Z, lambda, etc. Default is to build from source_grid Unused by this MagneticField class. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- field : ndarray, shape(N,3) - magnetic field at specified points, in cylindrical form [BR, Bphi,BZ] + magnetic field or vector potential at specified points, + in cylindrical form [BR, Bphi,BZ] """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) + errorif( + compute_A_or_B == "A" and self._AR is None, + ValueError, + "Cannot calculate vector potential" + " as no vector potential spline values exist.", + ) assert basis.lower() in ["rpz", "xyz"] currents = self.currents if params is None else params["currents"] coords = jnp.atleast_2d(jnp.asarray(coords)) if basis == "xyz": coords = xyz2rpz(coords) Rq, phiq, Zq = coords.T + if compute_A_or_B == "B": + A_or_B_R = self._BR + A_or_B_phi = self._Bphi + A_or_B_Z = self._BZ + elif compute_A_or_B == "A": + A_or_B_R = self._AR + A_or_B_phi = self._Aphi + A_or_B_Z = self._AZ + if self._axisym: - BRq = interp2d( + ABRq = interp2d( Rq, Zq, self._R, self._Z, - self._BR[:, 0, :], + A_or_B_R[:, 0, :], self._method, (0, 0), self._extrap, (None, None), - **self._derivs["BR"], + **self._derivs[compute_A_or_B + "R"], ) - Bphiq = interp2d( + ABphiq = interp2d( Rq, Zq, self._R, self._Z, - self._Bphi[:, 0, :], + A_or_B_phi[:, 0, :], self._method, (0, 0), self._extrap, (None, None), - **self._derivs["Bphi"], + **self._derivs[compute_A_or_B + "phi"], ) - BZq = interp2d( + ABZq = interp2d( Rq, Zq, self._R, self._Z, - self._BZ[:, 0, :], + A_or_B_Z[:, 0, :], self._method, (0, 0), self._extrap, (None, None), - **self._derivs["BZ"], + **self._derivs[compute_A_or_B + "Z"], ) else: - BRq = interp3d( + ABRq = interp3d( Rq, phiq, Zq, self._R, self._phi, self._Z, - self._BR, + A_or_B_R, self._method, (0, 0, 0), self._extrap, (None, 2 * np.pi / self.NFP, None), - **self._derivs["BR"], + **self._derivs[compute_A_or_B + "R"], ) - Bphiq = interp3d( + ABphiq = interp3d( Rq, phiq, Zq, self._R, self._phi, self._Z, - self._Bphi, + A_or_B_phi, self._method, (0, 0, 0), self._extrap, (None, 2 * np.pi / self.NFP, None), - **self._derivs["Bphi"], + **self._derivs[compute_A_or_B + "phi"], ) - BZq = interp3d( + ABZq = interp3d( Rq, phiq, Zq, self._R, self._phi, self._Z, - self._BZ, + A_or_B_Z, self._method, (0, 0, 0), self._extrap, (None, 2 * np.pi / self.NFP, None), - **self._derivs["BZ"], + **self._derivs[compute_A_or_B + "Z"], ) - # BRq etc shape(nq, ngroups) - B = jnp.stack([BRq, Bphiq, BZq], axis=1) - # B shape(nq, 3, ngroups) - B = jnp.sum(B * currents, axis=-1) + # ABRq etc shape(nq, ngroups) + AB = jnp.stack([ABRq, ABphiq, ABZq], axis=1) + # AB shape(nq, 3, ngroups) + AB = jnp.sum(AB * currents, axis=-1) if basis == "xyz": - B = rpz2xyz_vec(B, phi=coords[:, 1]) - return B + AB = rpz2xyz_vec(AB, phi=coords[:, 1]) + return AB + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points, in cylindrical form [BR, Bphi,BZ] + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") @classmethod def from_mgrid(cls, mgrid_file, extcur=None, method="cubic", extrap=False): @@ -1366,8 +1831,40 @@ def from_mgrid(cls, mgrid_file, extcur=None, method="cubic", extrap=False): bp = np.moveaxis(bp, (0, 1, 2), (1, 2, 0)) bz = np.moveaxis(bz, (0, 1, 2), (1, 2, 0)) + # sum magnetic vector potentials from each coil + ar = np.zeros([kp, jz, ir, nextcur]) + ap = np.zeros([kp, jz, ir, nextcur]) + az = np.zeros([kp, jz, ir, nextcur]) + try: + for i in range(nextcur): + coil_id = "%03d" % (i + 1,) + ar[:, :, :, i] += mgrid["ar_" + coil_id][ + () + ] # A_R radial mag. vec. potential + ap[:, :, :, i] += mgrid["ap_" + coil_id][ + () + ] # A_phi toroidal mag. vec. potential + az[:, :, :, i] += mgrid["az_" + coil_id][ + () + ] # A_Z vertical mag. vec. potential + + # shift axes to correct order + ar = np.moveaxis(ar, (0, 1, 2), (1, 2, 0)) + ap = np.moveaxis(ap, (0, 1, 2), (1, 2, 0)) + az = np.moveaxis(az, (0, 1, 2), (1, 2, 0)) + except IndexError: + warnif( + True, + UserWarning, + "mgrid does not appear to contain vector potential information." + " Vector potential will not be computable.", + ) + ar = ap = az = None + mgrid.close() - return cls(Rgrid, pgrid, Zgrid, br, bp, bz, extcur, nfp, method, extrap) + return cls( + Rgrid, pgrid, Zgrid, br, bp, bz, ar, ap, az, extcur, nfp, method, extrap + ) @classmethod def from_field( @@ -1397,6 +1894,15 @@ def from_field( shp = rr.shape coords = np.array([rr.flatten(), pp.flatten(), zz.flatten()]).T BR, BP, BZ = field.compute_magnetic_field(coords, params, basis="rpz").T + try: + AR, AP, AZ = field.compute_magnetic_vector_potential( + coords, params, basis="rpz" + ).T + AR = AR.reshape(shp) + AP = AP.reshape(shp) + AZ = AZ.reshape(shp) + except NotImplementedError: + AR = AP = AZ = None return cls( R, phi, @@ -1404,6 +1910,9 @@ def from_field( BR.reshape(shp), BP.reshape(shp), BZ.reshape(shp), + AR=AR, + Aphi=AP, + AZ=AZ, currents=1.0, NFP=NFP, method=method, @@ -1474,6 +1983,187 @@ def compute_magnetic_field( B = rpz2xyz_vec(B, phi=coords[:, 1]) return B + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + raise NotImplementedError( + "ScalarPotentialField does not have vector potential calculation " + "implemented." + ) + + +class VectorPotentialField(_MagneticField): + """Magnetic field due to a vector magnetic potential in cylindrical coordinates. + + Parameters + ---------- + potential : callable + function to compute the vector potential. Should have a signature of + the form potential(R,phi,Z,*params) -> ndarray. + R,phi,Z are arrays of cylindrical coordinates. + params : dict, optional + default parameters to pass to potential function + + """ + + def __init__(self, potential, params=None): + self._potential = potential + self._params = params + + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", + ): + """Compute magnetic field or vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points + + """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) + assert basis.lower() in ["rpz", "xyz"] + coords = jnp.atleast_2d(jnp.asarray(coords)) + coords = coords.astype(float) + if basis == "xyz": + coords = xyz2rpz(coords) + + if params is None: + params = self._params + r, p, z = coords.T + + if compute_A_or_B == "B": + funR = lambda x: self._potential(x, p, z, **params) + funP = lambda x: self._potential(r, x, z, **params) + funZ = lambda x: self._potential(r, p, x, **params) + + ap = self._potential(r, p, z, **params)[:, 1] + + # these are the gradients of each component of A + dAdr = Derivative.compute_jvp(funR, 0, (jnp.ones_like(r),), r) + dAdp = Derivative.compute_jvp(funP, 0, (jnp.ones_like(p),), p) + dAdz = Derivative.compute_jvp(funZ, 0, (jnp.ones_like(z),), z) + + # form the B components with the appropriate combinations + B = jnp.array( + [ + dAdp[:, 2] / r - dAdz[:, 1], + dAdz[:, 0] - dAdr[:, 2], + dAdr[:, 1] + (ap - dAdp[:, 0]) / r, + ] + ).T + if basis == "xyz": + B = rpz2xyz_vec(B, phi=coords[:, 1]) + return B + elif compute_A_or_B == "A": + A = self._potential(r, p, z, **params) + if basis == "xyz": + A = rpz2xyz_vec(A, phi=coords[:, 1]) + return A + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") + def field_line_integrate( r0, diff --git a/desc/magnetic_fields/_current_potential.py b/desc/magnetic_fields/_current_potential.py index ab8a42d909..a8759155ee 100644 --- a/desc/magnetic_fields/_current_potential.py +++ b/desc/magnetic_fields/_current_potential.py @@ -11,7 +11,11 @@ from desc.optimizable import Optimizable, optimizable_parameter from desc.utils import copy_coeffs, errorif, setdefault, warnif -from ._core import _MagneticField, biot_savart_general +from ._core import ( + _MagneticField, + biot_savart_general, + biot_savart_general_vector_potential, +) class CurrentPotentialField(_MagneticField, FourierRZToroidalSurface): @@ -177,10 +181,16 @@ def save(self, file_name, file_format=None, file_mode="w"): " as the potential function cannot be serialized." ) - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. Parameters ---------- @@ -194,11 +204,14 @@ def compute_magnetic_field( Source grid upon which to evaluate the surface current density K. transforms : dict of Transform Transforms for R, Z, lambda, etc. Default is to build from source_grid + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- field : ndarray, shape(N,3) - magnetic field at specified points + magnetic field or vector potential at specified points """ source_grid = source_grid or LinearGrid( @@ -206,15 +219,70 @@ def compute_magnetic_field( N=30 + 2 * self.N, NFP=self.NFP, ) - return _compute_magnetic_field_from_CurrentPotentialField( + return _compute_A_or_B_from_CurrentPotentialField( field=self, coords=coords, params=params, basis=basis, source_grid=source_grid, transforms=transforms, + compute_A_or_B=compute_A_or_B, ) + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Source grid upon which to evaluate the surface current density K. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + This assumes the Coulomb gauge. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Source grid upon which to evaluate the surface current density K. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + A : ndarray, shape(N,3) + Magnetic vector potential at specified points. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") + @classmethod def from_surface( cls, @@ -496,10 +564,16 @@ def change_Phi_resolution(self, M=None, N=None, NFP=None, sym_Phi=None): NFP=NFP ) # make sure surface and Phi basis NFP are the same - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. Parameters ---------- @@ -513,11 +587,14 @@ def compute_magnetic_field( Source grid upon which to evaluate the surface current density K. transforms : dict of Transform Transforms for R, Z, lambda, etc. Default is to build from source_grid + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- field : ndarray, shape(N,3) - magnetic field at specified points + magnetic field or vector potential at specified points """ source_grid = source_grid or LinearGrid( @@ -525,15 +602,70 @@ def compute_magnetic_field( N=30 + 2 * max(self.N, self.N_Phi), NFP=self.NFP, ) - return _compute_magnetic_field_from_CurrentPotentialField( + return _compute_A_or_B_from_CurrentPotentialField( field=self, coords=coords, params=params, basis=basis, source_grid=source_grid, transforms=transforms, + compute_A_or_B=compute_A_or_B, ) + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Source grid upon which to evaluate the surface current density K. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + This assumes the Coulomb gauge. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Source grid upon which to evaluate the surface current density K. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + A : ndarray, shape(N,3) + Magnetic vector potential at specified points. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") + @classmethod def from_surface( cls, @@ -613,10 +745,16 @@ def from_surface( ) -def _compute_magnetic_field_from_CurrentPotentialField( - field, coords, source_grid, params=None, basis="rpz", transforms=None +def _compute_A_or_B_from_CurrentPotentialField( + field, + coords, + source_grid, + params=None, + basis="rpz", + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. Parameters ---------- @@ -631,25 +769,36 @@ def _compute_magnetic_field_from_CurrentPotentialField( should include the potential basis : {"rpz", "xyz"} basis for input coordinates and returned magnetic field + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- field : ndarray, shape(N,3) - magnetic field at specified points + magnetic field or vector potential at specified points """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) assert basis.lower() in ["rpz", "xyz"] coords = jnp.atleast_2d(jnp.asarray(coords)) if basis == "rpz": coords = rpz2xyz(coords) - + op = {"B": biot_savart_general, "A": biot_savart_general_vector_potential}[ + compute_A_or_B + ] # compute surface current, and store grid quantities # needed for integration in class if not params or not transforms: data = field.compute( ["K", "x"], grid=source_grid, + basis="rpz", params=params, transforms=transforms, jitable=True, @@ -680,7 +829,7 @@ def nfp_loop(j, f): rs = jnp.vstack((_rs[:, 0], phi, _rs[:, 2])).T rs = rpz2xyz(rs) K = rpz2xyz_vec(_K, phi=phi) - fj = biot_savart_general( + fj = op( coords, rs, K, diff --git a/desc/objectives/_coils.py b/desc/objectives/_coils.py index 5e961523e5..04f2954910 100644 --- a/desc/objectives/_coils.py +++ b/desc/objectives/_coils.py @@ -12,10 +12,9 @@ ) from desc.compute import get_profiles, get_transforms, rpz2xyz from desc.compute.utils import _compute as compute_fun -from desc.compute.utils import safenorm from desc.grid import LinearGrid, _Grid from desc.integrals import compute_B_plasma -from desc.utils import Timer, errorif, warnif +from desc.utils import Timer, errorif, safenorm, warnif from .normalization import compute_scaling_factors from .objective_funs import _Objective @@ -124,6 +123,12 @@ def _prune_coilset_tree(coilset): # get individual coils from coilset coils, structure = tree_flatten(coil, is_leaf=_is_single_coil) + for c in coils: + errorif( + not isinstance(c, _Coil), + TypeError, + f"Expected object of type Coil, got {type(c)}", + ) self._num_coils = len(coils) # map grid to list of length coils @@ -1304,6 +1309,14 @@ class ToroidalFlux(_Objective): by making the coil currents zero. Instead, this objective ensures the coils create the necessary toroidal flux for the equilibrium field. + Will try to use the vector potential method to calculate the toroidal flux + (Φ = ∮ 𝐀 ⋅ 𝐝𝐥 over the perimeter of a constant zeta plane) + instead of the brute force method using the magnetic field + (Φ = ∯ 𝐁 ⋅ 𝐝𝐒 over a constant zeta XS). The vector potential method + is much more efficient, however not every ``MagneticField`` object + has a vector potential available to compute, so in those cases + the magnetic field method is used. + Parameters ---------- eq : Equilibrium @@ -1349,6 +1362,7 @@ class ToroidalFlux(_Objective): name : str, optional Name of the objective function. + """ _coordinates = "rtz" @@ -1376,6 +1390,7 @@ def __init__( self._field_grid = field_grid self._eval_grid = eval_grid self._eq = eq + # TODO: add eq_fixed option so this can be used in single stage super().__init__( things=[field], @@ -1401,9 +1416,17 @@ def build(self, use_jit=True, verbose=1): """ eq = self._eq + self._use_vector_potential = True + try: + self._field.compute_magnetic_vector_potential([0, 0, 0]) + except (NotImplementedError, ValueError): + self._use_vector_potential = False if self._eval_grid is None: eval_grid = LinearGrid( - L=eq.L_grid, M=eq.M_grid, zeta=jnp.array(0.0), NFP=eq.NFP + L=eq.L_grid if not self._use_vector_potential else 0, + M=eq.M_grid, + zeta=jnp.array(0.0), + NFP=eq.NFP, ) self._eval_grid = eval_grid eval_grid = self._eval_grid @@ -1438,10 +1461,12 @@ def build(self, use_jit=True, verbose=1): if verbose > 0: print("Precomputing transforms") timer.start("Precomputing transforms") - - data = eq.compute( - ["R", "phi", "Z", "|e_rho x e_theta|", "n_zeta"], grid=eval_grid - ) + data_keys = ["R", "phi", "Z"] + if self._use_vector_potential: + data_keys += ["e_theta"] + else: + data_keys += ["|e_rho x e_theta|", "n_zeta"] + data = eq.compute(data_keys, grid=eval_grid) plasma_coords = jnp.array([data["R"], data["phi"], data["Z"]]).T @@ -1483,22 +1508,32 @@ def compute(self, field_params=None, constants=None): data = constants["equil_data"] plasma_coords = constants["plasma_coords"] - - B = constants["field"].compute_magnetic_field( - plasma_coords, - basis="rpz", - source_grid=constants["field_grid"], - params=field_params, - ) grid = constants["eval_grid"] - B_dot_n_zeta = jnp.sum(B * data["n_zeta"], axis=1) + if self._use_vector_potential: + A = constants["field"].compute_magnetic_vector_potential( + plasma_coords, + basis="rpz", + source_grid=constants["field_grid"], + params=field_params, + ) - Psi = jnp.sum( - grid.spacing[:, 0] - * grid.spacing[:, 1] - * data["|e_rho x e_theta|"] - * B_dot_n_zeta - ) + A_dot_e_theta = jnp.sum(A * data["e_theta"], axis=1) + Psi = jnp.sum(grid.spacing[:, 1] * A_dot_e_theta) + else: + B = constants["field"].compute_magnetic_field( + plasma_coords, + basis="rpz", + source_grid=constants["field_grid"], + params=field_params, + ) + + B_dot_n_zeta = jnp.sum(B * data["n_zeta"], axis=1) + Psi = jnp.sum( + grid.spacing[:, 0] + * grid.spacing[:, 1] + * data["|e_rho x e_theta|"] + * B_dot_n_zeta + ) return Psi diff --git a/desc/objectives/_equilibrium.py b/desc/objectives/_equilibrium.py index 624fd99023..dc2f4bbb22 100644 --- a/desc/objectives/_equilibrium.py +++ b/desc/objectives/_equilibrium.py @@ -557,7 +557,7 @@ class HelicalForceBalance(_Objective): _equilibrium = True _coordinates = "rtz" _units = "(N)" - _print_value_fmt = "Helical force error: {:10.3e}, " + _print_value_fmt = "Helical force error: " def __init__( self, diff --git a/desc/objectives/_geometry.py b/desc/objectives/_geometry.py index 57b1eebe46..e405609c79 100644 --- a/desc/objectives/_geometry.py +++ b/desc/objectives/_geometry.py @@ -5,9 +5,8 @@ from desc.backend import jnp, vmap from desc.compute import get_profiles, get_transforms, rpz2xyz, xyz2rpz from desc.compute.utils import _compute as compute_fun -from desc.compute.utils import safenorm from desc.grid import LinearGrid, QuadratureGrid -from desc.utils import Timer, errorif, parse_argname_change, warnif +from desc.utils import Timer, errorif, parse_argname_change, safenorm, warnif from .normalization import compute_scaling_factors from .objective_funs import _Objective diff --git a/desc/objectives/_omnigenity.py b/desc/objectives/_omnigenity.py index 05f08356c0..1eb7a8da6e 100644 --- a/desc/objectives/_omnigenity.py +++ b/desc/objectives/_omnigenity.py @@ -47,7 +47,7 @@ class QuasisymmetryBoozer(_Objective): reverse mode and forward over reverse mode respectively. grid : Grid, optional Collocation grid containing the nodes to evaluate at. - Must be a LinearGrid with a single flux surface and sym=False. + Must be a LinearGrid with sym=False. Defaults to ``LinearGrid(M=M_booz, N=N_booz)``. helicity : tuple, optional Type of quasi-symmetry (M, N). Default = quasi-axisymmetry (1, 0). @@ -122,12 +122,6 @@ def build(self, use_jit=True, verbose=1): grid = self._grid errorif(grid.sym, ValueError, "QuasisymmetryBoozer grid must be non-symmetric") - errorif( - grid.num_rho != 1, - ValueError, - "QuasisymmetryBoozer grid must be on a single surface. " - "To target multiple surfaces, use multiple objectives.", - ) warnif( grid.num_theta < 2 * eq.M, RuntimeWarning, @@ -195,7 +189,7 @@ def compute(self, params, constants=None): Returns ------- f : ndarray - Quasi-symmetry flux function error at each node (T^3). + Symmetry breaking harmonics of B (T). """ if constants is None: @@ -207,8 +201,11 @@ def compute(self, params, constants=None): transforms=constants["transforms"], profiles=constants["profiles"], ) - B_mn = constants["matrix"] @ data["|B|_mn"] - return B_mn[constants["idx"]] + B_mn = data["|B|_mn"].reshape((constants["transforms"]["grid"].num_rho, -1)) + B_mn = constants["matrix"] @ B_mn.T + # output order = (rho, mn).flatten(), ie all the surfaces concatenated + # one after the other + return B_mn[constants["idx"]].T.flatten() @property def helicity(self): diff --git a/desc/objectives/_power_balance.py b/desc/objectives/_power_balance.py index 74b679ee2a..299b248358 100644 --- a/desc/objectives/_power_balance.py +++ b/desc/objectives/_power_balance.py @@ -61,7 +61,7 @@ class FusionPower(_Objective): _scalar = True _units = "(W)" - _print_value_fmt = "Fusion power: {:10.3e} " + _print_value_fmt = "Fusion power: " def __init__( self, @@ -246,7 +246,7 @@ class HeatingPowerISS04(_Objective): _scalar = True _units = "(W)" - _print_value_fmt = "Heating power: {:10.3e} " + _print_value_fmt = "Heating power: " def __init__( self, diff --git a/desc/plotting.py b/desc/plotting.py index f43f408af4..b3c8fdeb17 100644 --- a/desc/plotting.py +++ b/desc/plotting.py @@ -971,9 +971,9 @@ def plot_3d( if grid.num_rho == 1: n1, n2 = grid.num_theta, grid.num_zeta if not grid.nodes[-1][2] == 2 * np.pi: - p1, p2 = True, False + p1, p2 = False, False else: - p1, p2 = True, True + p1, p2 = False, True elif grid.num_theta == 1: n1, n2 = grid.num_rho, grid.num_zeta p1, p2 = False, True @@ -1352,10 +1352,9 @@ def plot_section( phi = np.atleast_1d(phi) nphi = len(phi) if grid is None: - nfp = eq.NFP grid_kwargs = { "L": 25, - "NFP": nfp, + "NFP": 1, "axis": False, "theta": np.linspace(0, 2 * np.pi, 91, endpoint=True), "zeta": phi, @@ -1610,9 +1609,14 @@ def plot_surfaces(eq, rho=8, theta=8, phi=None, ax=None, return_data=False, **kw phi = np.atleast_1d(phi) nphi = len(phi) + # do not need NFP supplied to these grids as + # the above logic takes care of the correct phi range + # if defaults are requested. Setting NFP here instead + # can create reshaping issues when phi is supplied and gets + # truncated by 2pi/NFP. See PR #1204 grid_kwargs = { "rho": rho, - "NFP": nfp, + "NFP": 1, "theta": np.linspace(0, 2 * np.pi, NT, endpoint=True), "zeta": phi, } @@ -1631,7 +1635,7 @@ def plot_surfaces(eq, rho=8, theta=8, phi=None, ax=None, return_data=False, **kw ) grid_kwargs = { "rho": np.linspace(0, 1, NR), - "NFP": nfp, + "NFP": 1, "theta": theta, "zeta": phi, } @@ -1960,7 +1964,7 @@ def plot_boundary(eq, phi=None, plot_axis=True, ax=None, return_data=False, **kw plot_axis = plot_axis and eq.L > 0 rho = np.array([0.0, 1.0]) if plot_axis else np.array([1.0]) - grid_kwargs = {"NFP": eq.NFP, "rho": rho, "theta": 100, "zeta": phi} + grid_kwargs = {"NFP": 1, "rho": rho, "theta": 100, "zeta": phi} grid = _get_grid(**grid_kwargs) nr, nt, nz = grid.num_rho, grid.num_theta, grid.num_zeta grid = Grid( @@ -2030,6 +2034,9 @@ def plot_boundaries( ): """Plot stellarator boundaries at multiple toroidal coordinates. + NOTE: If attempting to plot objects with differing NFP, `phi` must + be given explicitly. + Parameters ---------- eqs : array-like of Equilibrium, Surface or EquilibriaFamily @@ -2085,7 +2092,21 @@ def plot_boundaries( fig, ax = plot_boundaries((eq1, eq2, eq3)) """ + # if NFPs are not all equal, means there are + # objects with differing NFPs, which it is not clear + # how to choose the phis for by default, so we will throw an error + # unless phi was given. phi = parse_argname_change(phi, kwargs, "zeta", "phi") + errorif( + not np.allclose([thing.NFP for thing in eqs], eqs[0].NFP) and phi is None, + ValueError, + "supplied objects must have the same number of field periods, " + "or if there are differing field periods, `phi` must be given explicitly." + f" Instead, supplied objects have NFPs {[t.NFP for t in eqs]}." + " If attempting to plot an axisymmetric object with non-axisymmetric objects," + " you must use the `change_resolution` method to make the axisymmetric " + "object have the same NFP as the non-axisymmetric objects.", + ) figsize = kwargs.pop("figsize", None) cmap = kwargs.pop("cmap", "rainbow") @@ -2129,7 +2150,7 @@ def plot_boundaries( plot_axis_i = plot_axis and eqs[i].L > 0 rho = np.array([0.0, 1.0]) if plot_axis_i else np.array([1.0]) - grid_kwargs = {"NFP": eqs[i].NFP, "theta": 100, "zeta": phi, "rho": rho} + grid_kwargs = {"NFP": 1, "theta": 100, "zeta": phi, "rho": rho} grid = _get_grid(**grid_kwargs) nr, nt, nz = grid.num_rho, grid.num_theta, grid.num_zeta grid = Grid( @@ -2198,6 +2219,9 @@ def plot_comparison( ): """Plot comparison between flux surfaces of multiple equilibria. + NOTE: If attempting to plot objects with differing NFP, `phi` must + be given explicitly. + Parameters ---------- eqs : array-like of Equilibrium or EquilibriaFamily @@ -2266,7 +2290,21 @@ def plot_comparison( ) """ + # if NFPs are not all equal, means there are + # objects with differing NFPs, which it is not clear + # how to choose the phis for by default, so we will throw an error + # unless phi was given. phi = parse_argname_change(phi, kwargs, "zeta", "phi") + errorif( + not np.allclose([thing.NFP for thing in eqs], eqs[0].NFP) and phi is None, + ValueError, + "supplied objects must have the same number of field periods, " + "or if there are differing field periods, `phi` must be given explicitly." + f" Instead, supplied objects have NFPs {[t.NFP for t in eqs]}." + " If attempting to plot an axisymmetric object with non-axisymmetric objects," + " you must use the `change_resolution` method to make the axisymmetric " + "object have the same NFP as the non-axisymmetric objects.", + ) color = parse_argname_change(color, kwargs, "colors", "color") ls = parse_argname_change(ls, kwargs, "linestyles", "ls") lw = parse_argname_change(lw, kwargs, "lws", "lw") @@ -2576,7 +2614,7 @@ def plot_boozer_modes( # noqa: C901 elif np.isscalar(rho) and rho > 1: rho = np.linspace(1, 0, num=rho, endpoint=False) - B_mn = np.array([[]]) + rho = np.sort(rho) M_booz = kwargs.pop("M_booz", 2 * eq.M) N_booz = kwargs.pop("N_booz", 2 * eq.N) linestyle = kwargs.pop("ls", "-") @@ -2594,16 +2632,15 @@ def plot_boozer_modes( # noqa: C901 else: matrix, modes = ptolemy_linear_transform(basis.modes) - for i, r in enumerate(rho): - grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=np.array(r)) - transforms = get_transforms( - "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute("|B|_mn", grid=grid, transforms=transforms) - b_mn = np.atleast_2d(matrix @ data["|B|_mn"]) - B_mn = np.vstack((B_mn, b_mn)) if B_mn.size else b_mn + grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=rho) + transforms = get_transforms( + "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data = eq.compute("|B|_mn", grid=grid, transforms=transforms) + B_mn = data["|B|_mn"].reshape((len(rho), -1)) + B_mn = np.atleast_2d(matrix @ B_mn.T).T zidx = np.where((modes[:, 1:] == np.array([[0, 0]])).all(axis=1))[0] if norm: @@ -2972,6 +3009,7 @@ def plot_qs_error( # noqa: 16 fxn too complex rho = np.linspace(1, 0, num=20, endpoint=False) elif np.isscalar(rho) and rho > 1: rho = np.linspace(1, 0, num=rho, endpoint=False) + rho = np.sort(rho) fig, ax = _format_ax(ax, figsize=kwargs.pop("figsize", None)) @@ -2989,119 +3027,92 @@ def plot_qs_error( # noqa: 16 fxn too complex R0 = data["R0"] B0 = np.mean(data["|B|"] * data["sqrt(g)"]) / np.mean(data["sqrt(g)"]) - f_B = np.array([]) - f_C = np.array([]) - f_T = np.array([]) - plot_data = {} - for i, r in enumerate(rho): - grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=np.array(r)) - if fB: - transforms = get_transforms( - "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz - ) - if i == 0: # only need to do this once for the first rho surface - matrix, modes, idx = ptolemy_linear_transform( - transforms["B"].basis.modes, - helicity=helicity, - NFP=transforms["B"].basis.NFP, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute( - ["|B|_mn", "B modes"], grid=grid, transforms=transforms - ) - B_mn = matrix @ data["|B|_mn"] - f_b = np.sqrt(np.sum(B_mn[idx] ** 2)) / np.sqrt(np.sum(B_mn**2)) - f_B = np.append(f_B, f_b) - if fC: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute("f_C", grid=grid, helicity=helicity) - f_c = ( - np.mean(np.abs(data["f_C"]) * data["sqrt(g)"]) - / np.mean(data["sqrt(g)"]) - / B0**3 - ) - f_C = np.append(f_C, f_c) - if fT: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute("f_T", grid=grid) - f_t = ( - np.mean(np.abs(data["f_T"]) * data["sqrt(g)"]) - / np.mean(data["sqrt(g)"]) - * R0**2 - / B0**4 - ) - f_T = np.append(f_T, f_t) + plot_data = {"rho": rho} - plot_data["f_B"] = f_B - plot_data["f_C"] = f_C - plot_data["f_T"] = f_T - plot_data["rho"] = rho + grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=rho) + names = [] + if fB: + names += ["|B|_mn"] + transforms = get_transforms( + "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz + ) + matrix, modes, idx = ptolemy_linear_transform( + transforms["B"].basis.modes, + helicity=helicity, + NFP=transforms["B"].basis.NFP, + ) + if fC or fT: + names += ["sqrt(g)"] + if fC: + names += ["f_C"] + if fT: + names += ["f_T"] - if log: - if fB: - ax.semilogy( - rho, - f_B, - ls=ls[0 % len(ls)], - c=colors[0 % len(colors)], - marker=markers[0 % len(markers)], - label=labels[0 % len(labels)], - lw=lw[0 % len(lw)], - ) - if fC: - ax.semilogy( - rho, - f_C, - ls=ls[1 % len(ls)], - c=colors[1 % len(colors)], - marker=markers[1 % len(markers)], - label=labels[1 % len(labels)], - lw=lw[1 % len(lw)], - ) - if fT: - ax.semilogy( - rho, - f_T, - ls=ls[2 % len(ls)], - c=colors[2 % len(colors)], - marker=markers[2 % len(markers)], - label=labels[2 % len(labels)], - lw=lw[2 % len(lw)], - ) - else: - if fB: - ax.plot( - rho, - f_B, - ls=ls[0 % len(ls)], - c=colors[0 % len(colors)], - marker=markers[0 % len(markers)], - label=labels[0 % len(labels)], - lw=lw[0 % len(lw)], - ) - if fC: - ax.plot( - rho, - f_C, - ls=ls[1 % len(ls)], - c=colors[1 % len(colors)], - marker=markers[1 % len(markers)], - label=labels[1 % len(labels)], - lw=lw[1 % len(lw)], - ) - if fT: - ax.plot( - rho, - f_T, - ls=ls[2 % len(ls)], - c=colors[2 % len(colors)], - marker=markers[2 % len(markers)], - label=labels[2 % len(labels)], - lw=lw[2 % len(lw)], - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data = eq.compute( + names, grid=grid, M_booz=M_booz, N_booz=N_booz, helicity=helicity + ) + + if fB: + B_mn = data["|B|_mn"].reshape((len(rho), -1)) + B_mn = (matrix @ B_mn.T).T + f_B = np.sqrt(np.sum(B_mn[:, idx] ** 2, axis=-1)) / np.sqrt( + np.sum(B_mn**2, axis=-1) + ) + plot_data["f_B"] = f_B + if fC: + sqrtg = grid.meshgrid_reshape(data["sqrt(g)"], "rtz") + f_C = grid.meshgrid_reshape(data["f_C"], "rtz") + f_C = ( + np.mean(np.abs(f_C) * sqrtg, axis=(1, 2)) + / np.mean(sqrtg, axis=(1, 2)) + / B0**3 + ) + plot_data["f_C"] = f_C + if fT: + sqrtg = grid.meshgrid_reshape(data["sqrt(g)"], "rtz") + f_T = grid.meshgrid_reshape(data["f_T"], "rtz") + f_T = ( + np.mean(np.abs(f_T) * sqrtg, axis=(1, 2)) + / np.mean(sqrtg, axis=(1, 2)) + * R0**2 + / B0**4 + ) + plot_data["f_T"] = f_T + + plot_op = ax.semilogy if log else ax.plot + + if fB: + plot_op( + rho, + f_B, + ls=ls[0 % len(ls)], + c=colors[0 % len(colors)], + marker=markers[0 % len(markers)], + label=labels[0 % len(labels)], + lw=lw[0 % len(lw)], + ) + if fC: + plot_op( + rho, + f_C, + ls=ls[1 % len(ls)], + c=colors[1 % len(colors)], + marker=markers[1 % len(markers)], + label=labels[1 % len(labels)], + lw=lw[1 % len(lw)], + ) + if fT: + plot_op( + rho, + f_T, + ls=ls[2 % len(ls)], + c=colors[2 % len(colors)], + marker=markers[2 % len(markers)], + label=labels[2 % len(labels)], + lw=lw[2 % len(lw)], + ) ax.set_xlabel(_AXIS_LABELS_RTZ[0], fontsize=xlabel_fontsize) if ylabel: diff --git a/desc/utils.py b/desc/utils.py index 27b5fa79ad..72dd10f975 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -693,13 +693,9 @@ def broadcast_tree(tree_in, tree_out, dtype=int): @partial(jnp.vectorize, signature="(m),(m)->(n)", excluded={"size", "fill_value"}) -def take_mask(a, mask, size=None, fill_value=None): +def take_mask(a, mask, /, *, size=None, fill_value=None): """JIT compilable method to return ``a[mask][:size]`` padded by ``fill_value``. - Warnings - -------- - The parameters ``size`` and ``fill_value`` must be specified as keyword arguments. - Parameters ---------- a : jnp.ndarray @@ -741,34 +737,193 @@ def flatten_matrix(y): # TODO: Eventually remove and use numpy's stuff. # https://github.com/numpy/numpy/issues/25805 -def atleast_nd(ndmin, *arys): +def atleast_nd(ndmin, ary): """Adds dimensions to front if necessary.""" - if ndmin == 1: - return jnp.atleast_1d(*arys) - if ndmin == 2: - return jnp.atleast_2d(*arys) - tup = tuple(jnp.array(ary, ndmin=ndmin) for ary in arys) - if len(tup) == 1: - tup = tup[0] - return tup - - -def atleast_3d_mid(*arys): - """Like np.atleast3d but if adds dim at axis 1 for 2d arrays.""" - arys = jnp.atleast_2d(*arys) - tup = tuple(ary[:, jnp.newaxis] if ary.ndim == 2 else ary for ary in arys) - if len(tup) == 1: - tup = tup[0] - return tup - - -def atleast_2d_end(*arys): - """Like np.atleast2d but if adds dim at axis 1 for 1d arrays.""" - arys = jnp.atleast_1d(*arys) - tup = tuple(ary[:, jnp.newaxis] if ary.ndim == 1 else ary for ary in arys) - if len(tup) == 1: - tup = tup[0] - return tup + return jnp.array(ary, ndmin=ndmin) if jnp.ndim(ary) < ndmin else ary PRINT_WIDTH = 60 # current longest name is BootstrapRedlConsistency with pre-text + + +def dot(a, b, axis=-1): + """Batched vector dot product. + + Parameters + ---------- + a : array-like + First array of vectors. + b : array-like + Second array of vectors. + axis : int + Axis along which vectors are stored. + + Returns + ------- + y : array-like + y = sum(a*b, axis=axis) + + """ + return jnp.sum(a * b, axis=axis, keepdims=False) + + +def cross(a, b, axis=-1): + """Batched vector cross product. + + Parameters + ---------- + a : array-like + First array of vectors. + b : array-like + Second array of vectors. + axis : int + Axis along which vectors are stored. + + Returns + ------- + y : array-like + y = a x b + + """ + return jnp.cross(a, b, axis=axis) + + +def safenorm(x, ord=None, axis=None, fill=0, threshold=0): + """Like jnp.linalg.norm, but without nan gradient at x=0. + + Parameters + ---------- + x : ndarray + Vector or array to norm. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of norm. + axis : {None, int, 2-tuple of ints}, optional + Axis to take norm along. + fill : float, ndarray, optional + Value to return where x is zero. + threshold : float >= 0 + How small is x allowed to be. + + """ + is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) + y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero + n = jnp.linalg.norm(y, ord=ord, axis=axis) + n = jnp.where(is_zero.squeeze(), fill, n) # replace norm with zero if is_zero + return n + + +def safenormalize(x, ord=None, axis=None, fill=0, threshold=0): + """Normalize a vector to unit length, but without nan gradient at x=0. + + Parameters + ---------- + x : ndarray + Vector or array to norm. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of norm. + axis : {None, int, 2-tuple of ints}, optional + Axis to take norm along. + fill : float, ndarray, optional + Value to return where x is zero. + threshold : float >= 0 + How small is x allowed to be. + + """ + is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) + y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero + n = safenorm(x, ord, axis, fill, threshold) * jnp.ones_like(x) + # return unit vector with equal components if norm <= threshold + return jnp.where(n <= threshold, jnp.ones_like(y) / jnp.sqrt(y.size), y / n) + + +def safediv(a, b, fill=0, threshold=0): + """Divide a/b with guards for division by zero. + + Parameters + ---------- + a, b : ndarray + Numerator and denominator. + fill : float, ndarray, optional + Value to return where b is zero. + threshold : float >= 0 + How small is b allowed to be. + """ + mask = jnp.abs(b) <= threshold + num = jnp.where(mask, fill, a) + den = jnp.where(mask, 1, b) + return num / den + + +def cumtrapz(y, x=None, dx=1.0, axis=-1, initial=None): + """Cumulatively integrate y(x) using the composite trapezoidal rule. + + Taken from SciPy, but changed NumPy references to JAX.NumPy: + https://github.com/scipy/scipy/blob/v1.10.1/scipy/integrate/_quadrature.py + + Parameters + ---------- + y : array_like + Values to integrate. + x : array_like, optional + The coordinate to integrate along. If None (default), use spacing `dx` + between consecutive elements in `y`. + dx : float, optional + Spacing between elements of `y`. Only used if `x` is None. + axis : int, optional + Specifies the axis to cumulate. Default is -1 (last axis). + initial : scalar, optional + If given, insert this value at the beginning of the returned result. + Typically, this value should be 0. Default is None, which means no + value at ``x[0]`` is returned and `res` has one element less than `y` + along the axis of integration. + + Returns + ------- + res : ndarray + The result of cumulative integration of `y` along `axis`. + If `initial` is None, the shape is such that the axis of integration + has one less value than `y`. If `initial` is given, the shape is equal + to that of `y`. + + """ + y = jnp.asarray(y) + if x is None: + d = dx + else: + x = jnp.asarray(x) + if x.ndim == 1: + d = jnp.diff(x) + # reshape to correct shape + shape = [1] * y.ndim + shape[axis] = -1 + d = d.reshape(shape) + elif len(x.shape) != len(y.shape): + raise ValueError("If given, shape of x must be 1-D or the " "same as y.") + else: + d = jnp.diff(x, axis=axis) + + if d.shape[axis] != y.shape[axis] - 1: + raise ValueError( + "If given, length of x along axis must be the " "same as y." + ) + + def tupleset(t, i, value): + l = list(t) + l[i] = value + return tuple(l) + + nd = len(y.shape) + slice1 = tupleset((slice(None),) * nd, axis, slice(1, None)) + slice2 = tupleset((slice(None),) * nd, axis, slice(None, -1)) + res = jnp.cumsum(d * (y[slice1] + y[slice2]) / 2.0, axis=axis) + + if initial is not None: + if not jnp.isscalar(initial): + raise ValueError("`initial` parameter should be a scalar.") + + shape = list(res.shape) + shape[axis] = 1 + res = jnp.concatenate( + [jnp.full(shape, initial, dtype=res.dtype), res], axis=axis + ) + + return res diff --git a/desc/vmec.py b/desc/vmec.py index fc6fc5498f..17e7bf3b30 100644 --- a/desc/vmec.py +++ b/desc/vmec.py @@ -25,7 +25,7 @@ from desc.objectives.utils import factorize_linear_constraints from desc.profiles import PowerSeriesProfile, SplineProfile from desc.transform import Transform -from desc.utils import Timer +from desc.utils import Timer, warnif from desc.vmec_utils import ( fourier_to_zernike, ptolemy_identity_fwd, @@ -158,7 +158,7 @@ def load( zax_cs = file.variables["zaxis_cs"][:].filled() try: rax_cs = file.variables["raxis_cs"][:].filled() - rax_cc = file.variables["zaxis_cc"][:].filled() + zax_cc = file.variables["zaxis_cc"][:].filled() except KeyError: rax_cs = np.zeros_like(rax_cc) zax_cc = np.zeros_like(zax_cs) @@ -208,7 +208,9 @@ def load( return eq @classmethod - def save(cls, eq, path, surfs=128, verbose=1): # noqa: C901 - FIXME - simplify + def save( # noqa: C901 - FIXME - simplify + cls, eq, path, surfs=128, verbose=1, M_nyq=None, N_nyq=None + ): """Save an Equilibrium as a netCDF file in the VMEC format. Parameters @@ -224,6 +226,10 @@ def save(cls, eq, path, surfs=128, verbose=1): # noqa: C901 - FIXME - simplify * 0: no output * 1: status of quantities computed * 2: as above plus timing information + M_nyq, N_nyq: int + The max poloidal and toroidal modenumber to use in the + Nyquist spectrum that the derived quantities are Fourier + fit with. Defaults to M+4 and N+2. Returns ------- @@ -242,8 +248,14 @@ def save(cls, eq, path, surfs=128, verbose=1): # noqa: C901 - FIXME - simplify NFP = eq.NFP M = eq.M N = eq.N - M_nyq = M + 4 - N_nyq = N + 2 if N > 0 else 0 + M_nyq = M + 4 if M_nyq is None else M_nyq + warnif( + N_nyq is not None and int(N) == 0, + UserWarning, + "Passed in N_nyq but equilibrium is axisymmetric, setting N_nyq to zero", + ) + N_nyq = N + 2 if N_nyq is None else N_nyq + N_nyq = 0 if int(N) == 0 else N_nyq # VMEC radial coordinate: s = rho^2 = Psi / Psi(LCFS) s_full = np.linspace(0, 1, surfs) @@ -807,6 +819,14 @@ def save(cls, eq, path, surfs=128, verbose=1): # noqa: C901 - FIXME - simplify lmnc.long_name = "cos(m*t-n*p) component of lambda, on half mesh" lmnc.units = "rad" l1 = np.ones_like(eq.L_lmn) + # should negate lambda coefs bc theta_DESC + lambda = theta_PEST, + # since we are reversing the theta direction (and the theta_PEST direction), + # so -theta_PEST = -theta_DESC - lambda, so the negative of lambda is what + # should be saved, so that would be negating all of eq.L_lmn + # BUT since we are also reversing the poloidal angle direction, which + # would negate only the coeffs of L_lmn corresponding to m<0 + # (sin theta modes in DESC), the effective result is to only + # negate the cos(theta) (m>0) lambda modes l1[eq.L_basis.modes[:, 1] >= 0] *= -1 m, n, x_mn = zernike_to_fourier(l1 * eq.L_lmn, basis=eq.L_basis, rho=r_half) xm, xn, s, c = ptolemy_identity_rev(m, n, x_mn) @@ -823,7 +843,7 @@ def save(cls, eq, path, surfs=128, verbose=1): # noqa: C901 - FIXME - simplify sin_basis = DoubleFourierSeries(M=M_nyq, N=N_nyq, NFP=NFP, sym="sin") cos_basis = DoubleFourierSeries(M=M_nyq, N=N_nyq, NFP=NFP, sym="cos") - full_basis = DoubleFourierSeries(M=M_nyq, N=N_nyq, NFP=NFP, sym=None) + full_basis = DoubleFourierSeries(M=M_nyq, N=N_nyq, NFP=NFP, sym=False) if eq.sym: sin_transform = Transform( grid=grid_lcfs, basis=sin_basis, build=False, build_pinv=True @@ -932,7 +952,7 @@ def fullfit(x): if eq.sym: x_mn[i, :] = cosfit(data[i, :]) else: - x_mn[i, :] = full_transform.fit(data[i, :]) + x_mn[i, :] = fullfit(data[i, :]) xm, xn, s, c = ptolemy_identity_rev(m, n, x_mn) bmnc[0, :] = 0 bmnc[1:, :] = c @@ -975,7 +995,7 @@ def fullfit(x): if eq.sym: x_mn[i, :] = cosfit(data[i, :]) else: - x_mn[i, :] = full_transform.fit(data[i, :]) + x_mn[i, :] = fullfit(data[i, :]) xm, xn, s, c = ptolemy_identity_rev(m, n, x_mn) bsupumnc[0, :] = 0 bsupumnc[1:, :] = -c # negative sign for negative Jacobian @@ -1018,7 +1038,7 @@ def fullfit(x): if eq.sym: x_mn[i, :] = cosfit(data[i, :]) else: - x_mn[i, :] = full_transform.fit(data[i, :]) + x_mn[i, :] = fullfit(data[i, :]) xm, xn, s, c = ptolemy_identity_rev(m, n, x_mn) bsupvmnc[0, :] = 0 bsupvmnc[1:, :] = c @@ -1641,13 +1661,15 @@ def vmec_interpolate(Cmn, Smn, xm, xn, theta, phi, s=None, si=None, sym=True): return C + S @classmethod - def compute_theta_coords(cls, lmns, xm, xn, s, theta_star, zeta, si=None): + def compute_theta_coords( + cls, lmns, xm, xn, s, theta_star, zeta, si=None, lmnc=None + ): """Find theta such that theta + lambda(theta) == theta_star. Parameters ---------- lmns : array-like - fourier coefficients for lambda + sin(mt-nz) Fourier coefficients for lambda xm : array-like poloidal mode numbers xn : array-like @@ -1662,6 +1684,8 @@ def compute_theta_coords(cls, lmns, xm, xn, s, theta_star, zeta, si=None): si : ndarray values of radial coordinates where lmns are defined. Defaults to linearly spaced on half grid between (0,1) + lmnc : array-like, optional + cos(mt-nz) Fourier coefficients for lambda Returns ------- @@ -1672,19 +1696,30 @@ def compute_theta_coords(cls, lmns, xm, xn, s, theta_star, zeta, si=None): if si is None: si = np.linspace(0, 1, lmns.shape[0]) si[1:] = si[0:-1] + 0.5 / (lmns.shape[0] - 1) - lmbda_mn = interpolate.CubicSpline(si, lmns) + lmbda_mns = interpolate.CubicSpline(si, lmns) + if lmnc is None: + lmbda_mnc = lambda s: 0 + else: + lmbda_mnc = interpolate.CubicSpline(si, lmnc) # Note: theta* (also known as vartheta) is the poloidal straight field line # angle in PEST-like flux coordinates def root_fun(theta): lmbda = np.sum( - lmbda_mn(s) + lmbda_mns(s) * np.sin( xm[np.newaxis] * theta[:, np.newaxis] - xn[np.newaxis] * zeta[:, np.newaxis] ), axis=-1, + ) + np.sum( + lmbda_mnc(s) + * np.cos( + xm[np.newaxis] * theta[:, np.newaxis] + - xn[np.newaxis] * zeta[:, np.newaxis] + ), + axis=-1, ) theta_star_k = theta + lmbda # theta* = theta + lambda err = theta_star - theta_star_k # FIXME: mod by 2pi @@ -1782,6 +1817,8 @@ def compute_coord_surfaces(cls, equil, vmec_data, Nr=10, Nt=8, Nz=None, **kwargs t_nodes = t_grid.nodes t_nodes[:, 0] = t_nodes[:, 0] ** 2 + sym = "lmnc" not in vmec_data.keys() + v_nodes = cls.compute_theta_coords( vmec_data["lmns"], vmec_data["xm"], @@ -1789,29 +1826,71 @@ def compute_coord_surfaces(cls, equil, vmec_data, Nr=10, Nt=8, Nz=None, **kwargs t_nodes[:, 0], t_nodes[:, 1], t_nodes[:, 2], + lmnc=vmec_data["lmnc"] if not sym else None, ) t_nodes[:, 1] = v_nodes + if sym: + Rr_vmec, Zr_vmec = cls.vmec_interpolate( + vmec_data["rmnc"], + vmec_data["zmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=r_nodes[:, 1], + phi=r_nodes[:, 2], + s=r_nodes[:, 0], + ) - Rr_vmec, Zr_vmec = cls.vmec_interpolate( - vmec_data["rmnc"], - vmec_data["zmns"], - vmec_data["xm"], - vmec_data["xn"], - theta=r_nodes[:, 1], - phi=r_nodes[:, 2], - s=r_nodes[:, 0], - ) - - Rv_vmec, Zv_vmec = cls.vmec_interpolate( - vmec_data["rmnc"], - vmec_data["zmns"], - vmec_data["xm"], - vmec_data["xn"], - theta=t_nodes[:, 1], - phi=t_nodes[:, 2], - s=t_nodes[:, 0], - ) + Rv_vmec, Zv_vmec = cls.vmec_interpolate( + vmec_data["rmnc"], + vmec_data["zmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=t_nodes[:, 1], + phi=t_nodes[:, 2], + s=t_nodes[:, 0], + ) + else: + Rr_vmec = cls.vmec_interpolate( + vmec_data["rmnc"], + vmec_data["rmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=r_nodes[:, 1], + phi=r_nodes[:, 2], + s=r_nodes[:, 0], + sym=False, + ) + Zr_vmec = cls.vmec_interpolate( + vmec_data["zmnc"], + vmec_data["zmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=r_nodes[:, 1], + phi=r_nodes[:, 2], + s=r_nodes[:, 0], + sym=False, + ) + Rv_vmec = cls.vmec_interpolate( + vmec_data["rmnc"], + vmec_data["rmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=t_nodes[:, 1], + phi=t_nodes[:, 2], + s=t_nodes[:, 0], + sym=False, + ) + Zv_vmec = cls.vmec_interpolate( + vmec_data["zmnc"], + vmec_data["zmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=t_nodes[:, 1], + phi=t_nodes[:, 2], + s=t_nodes[:, 0], + sym=False, + ) coords = { "Rr_desc": Rr_desc, diff --git a/tests/baseline/test_bounce1d_checks.png b/tests/baseline/test_bounce1d_checks.png new file mode 100644 index 0000000000..51e5a4d94f Binary files /dev/null and b/tests/baseline/test_bounce1d_checks.png differ diff --git a/tests/baseline/test_plot_comparison_different_NFPs.png b/tests/baseline/test_plot_comparison_different_NFPs.png new file mode 100644 index 0000000000..96a140648f Binary files /dev/null and b/tests/baseline/test_plot_comparison_different_NFPs.png differ diff --git a/tests/benchmarks/compare_bench_results.py b/tests/benchmarks/compare_bench_results.py index 09fc580e22..ab56816153 100644 --- a/tests/benchmarks/compare_bench_results.py +++ b/tests/benchmarks/compare_bench_results.py @@ -8,60 +8,87 @@ cwd = os.getcwd() data = {} -master_idx = 0 -latest_idx = 0 +master_idx = [] +latest_idx = [] commit_ind = 0 -for diret in os.walk(cwd + "/compare_results"): - files = diret[2] - timing_file_exists = False - - for filename in files: - if filename.find("json") != -1: # check if json output file is present - try: - filepath = os.path.join(diret[0], filename) - with open(filepath) as f: - print(filepath) - curr_data = json.load(f) - commit_id = curr_data["commit_info"]["id"][0:7] - data[commit_id] = curr_data - if filepath.find("master") != -1: - master_idx = commit_ind - elif filepath.find("Latest_Commit") != -1: - latest_idx = commit_ind - commit_ind += 1 - except Exception as e: - print(e) - continue - +folder_names = [] + +for root1, dirs1, files1 in os.walk(cwd): + for dir_name in dirs1: + if dir_name == "compare_results" or dir_name.startswith("benchmark_artifact"): + print("Including folder: " + dir_name) + # "compare_results" is the folder containing the benchmark results from this + # job "benchmark_artifact" is the folder containing the benchmark results + # from other jobs if in future we change the Python version of the + # benchmarks, we will need to update this + # "/Linux-CPython--64bit" + files2walk = ( + os.walk(cwd + "/" + dir_name) + if dir_name == "compare_results" + else os.walk(cwd + "/" + dir_name + "/Linux-CPython-3.9-64bit") + ) + for root, dirs, files in files2walk: + for filename in files: + if ( + filename.find("json") != -1 + ): # check if json output file is present + try: + filepath = os.path.join(root, filename) + with open(filepath) as f: + curr_data = json.load(f) + commit_id = curr_data["commit_info"]["id"][0:7] + data[commit_ind] = curr_data["benchmarks"] + if filepath.find("master") != -1: + master_idx.append(commit_ind) + elif filepath.find("Latest_Commit") != -1: + latest_idx.append(commit_ind) + commit_ind += 1 + except Exception as e: + print(e) + continue # need arrays of size [ num benchmarks x num commits ] # one for mean one for stddev # number of benchmark cases -num_benchmarks = len(data[list(data.keys())[0]]["benchmarks"]) -num_commits = len(list(data.keys())) +num_benchmarks = 0 +# sum number of benchmarks splitted into different jobs +for split in master_idx: + num_benchmarks += len(data[split]) +num_commits = 2 + times = np.zeros([num_benchmarks, num_commits]) stddevs = np.zeros([num_benchmarks, num_commits]) commit_ids = [] test_names = [None] * num_benchmarks -for id_num, commit_id in enumerate(data.keys()): - commit_ids.append(commit_id) - for i, test in enumerate(data[commit_id]["benchmarks"]): +id_num = 0 +for i in master_idx: + for test in data[i]: t_mean = test["stats"]["median"] t_stddev = test["stats"]["iqr"] - times[i, id_num] = t_mean - stddevs[i, id_num] = t_stddev - test_names[i] = test["name"] - + times[id_num, 0] = t_mean + stddevs[id_num, 0] = t_stddev + test_names[id_num] = test["name"] + id_num = id_num + 1 + +id_num = 0 +for i in latest_idx: + for test in data[i]: + t_mean = test["stats"]["median"] + t_stddev = test["stats"]["iqr"] + times[id_num, 1] = t_mean + stddevs[id_num, 1] = t_stddev + test_names[id_num] = test["name"] + id_num = id_num + 1 # we say a slowdown/speedup has occurred if the mean time difference is greater than # n_sigma * (stdev of time delta) significance = 3 # n_sigmas of normal distribution, ie z score of 3 colors = [" "] * num_benchmarks # g if faster, w if similar, r if slower -delta_times_ms = times[:, latest_idx] - times[:, master_idx] -delta_stds_ms = np.sqrt(stddevs[:, latest_idx] ** 2 + stddevs[:, master_idx] ** 2) -delta_times_pct = delta_times_ms / times[:, master_idx] * 100 -delta_stds_pct = delta_stds_ms / times[:, master_idx] * 100 +delta_times_ms = times[:, 1] - times[:, 0] +delta_stds_ms = np.sqrt(stddevs[:, 1] ** 2 + stddevs[:, 0] ** 2) +delta_times_pct = delta_times_ms / times[:, 0] * 100 +delta_stds_pct = delta_stds_ms / times[:, 0] * 100 for i, (pct, spct) in enumerate(zip(delta_times_pct, delta_stds_pct)): if pct > 0 and pct > significance * spct: colors[i] = "-" # this will make the line red @@ -72,8 +99,6 @@ # now make the commit message, save as a txt file # benchmark_name dt(%) dt(s) t_new(s) t_old(s) -print(latest_idx) -print(master_idx) commit_msg_lines = [ "```diff", f"| {'benchmark_name':^38} | {'dt(%)':^22} | {'dt(s)':^22} |" @@ -88,8 +113,8 @@ line = f"{colors[i]:>1}{test_names[i]:<39} |" line += f" {f'{dpct:+6.2f} +/- {sdpct:4.2f}':^22} |" line += f" {f'{dt:+.2e} +/- {sdt:.2e}':^22} |" - line += f" {f'{times[i, latest_idx]:.2e} +/- {stddevs[i, latest_idx]:.1e}':^22} |" - line += f" {f'{times[i, master_idx]:.2e} +/- {stddevs[i, master_idx]:.1e}':^22} |" + line += f" {f'{times[i, 1]:.2e} +/- {stddevs[i, 1]:.1e}':^22} |" + line += f" {f'{times[i, 0]:.2e} +/- {stddevs[i, 0]:.1e}':^22} |" commit_msg_lines.append(line) diff --git a/tests/conftest.py b/tests/conftest.py index 873d2c3f0a..ccab0e07a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -335,3 +335,22 @@ def VMEC_save(SOLOVEV, tmpdir_factory): ) desc = Dataset(str(SOLOVEV["desc_nc_path"]), mode="r") return vmec, desc + + +@pytest.fixture(scope="session") +def VMEC_save_asym(tmpdir_factory): + """Save an asymmetric equilibrium in VMEC netcdf format for comparison.""" + tmpdir = tmpdir_factory.mktemp("asym_wout") + filename = tmpdir.join("wout_HELIO_asym_desc.nc") + vmec = Dataset("./tests/inputs/wout_HELIOTRON_asym_NTHETA50_NZETA100.nc", mode="r") + eq = Equilibrium.load("./tests/inputs/HELIO_asym.h5") + VMECIO.save( + eq, + filename, + surfs=vmec.variables["ns"][:], + verbose=0, + M_nyq=round(np.max(vmec.variables["xm_nyq"][:])), + N_nyq=round(np.max(vmec.variables["xn_nyq"][:]) / eq.NFP), + ) + desc = Dataset(filename, mode="r") + return vmec, desc, eq diff --git a/tests/inputs/HELIO_asym.h5 b/tests/inputs/HELIO_asym.h5 new file mode 100644 index 0000000000..c66a6cb100 Binary files /dev/null and b/tests/inputs/HELIO_asym.h5 differ diff --git a/tests/inputs/master_compute_data_rpz.pkl b/tests/inputs/master_compute_data_rpz.pkl index 7591194c45..d72778328e 100644 Binary files a/tests/inputs/master_compute_data_rpz.pkl and b/tests/inputs/master_compute_data_rpz.pkl differ diff --git a/tests/inputs/wout_HELIOTRON_asym_NTHETA50_NZETA100.nc b/tests/inputs/wout_HELIOTRON_asym_NTHETA50_NZETA100.nc new file mode 100644 index 0000000000..cc51c535a3 Binary files /dev/null and b/tests/inputs/wout_HELIOTRON_asym_NTHETA50_NZETA100.nc differ diff --git a/tests/test_axis_limits.py b/tests/test_axis_limits.py index 925cdfc83b..e204dc423d 100644 --- a/tests/test_axis_limits.py +++ b/tests/test_axis_limits.py @@ -12,22 +12,25 @@ import pytest from desc.compute import data_index -from desc.compute.utils import _grow_seeds, dot +from desc.compute.utils import _grow_seeds from desc.equilibrium import Equilibrium from desc.examples import get from desc.grid import LinearGrid from desc.integrals import surface_integrals_map from desc.objectives import GenericObjective, ObjectiveFunction +from desc.utils import dot # Unless mentioned in the source code of the compute function, the assumptions # made to compute the magnetic axis limit can be reduced to assuming that these # functions tend toward zero as the magnetic axis is approached and that # d²ψ/(dρ)² and 𝜕√𝑔/𝜕𝜌 are both finite nonzero at the magnetic axis. # Also, dⁿψ/(dρ)ⁿ for n > 3 is assumed zero everywhere. -zero_limits = {"rho", "psi", "psi_r", "e_theta", "sqrt(g)", "B_t"} +zero_limits = {"rho", "psi", "psi_r", "psi_rrr", "e_theta", "sqrt(g)", "B_t"} + # These compute quantities require kinetic profiles, which are not defined for all # configurations (giving NaN values) not_continuous_limits = {"current Redl", "P_ISS04", "P_fusion", ""} + not_finite_limits = { "D_Mercier", "D_geodesic", diff --git a/tests/test_coils.py b/tests/test_coils.py index 704ad5f761..71127660da 100644 --- a/tests/test_coils.py +++ b/tests/test_coils.py @@ -4,7 +4,9 @@ import numpy as np import pytest +import scipy +from desc.backend import jnp from desc.coils import ( CoilSet, FourierPlanarCoil, @@ -13,12 +15,13 @@ MixedCoilSet, SplineXYZCoil, ) -from desc.compute import get_params, get_transforms, xyz2rpz, xyz2rpz_vec +from desc.compute import get_params, get_transforms, rpz2xyz, xyz2rpz, xyz2rpz_vec from desc.examples import get -from desc.geometry import FourierRZCurve, FourierRZToroidalSurface +from desc.geometry import FourierRZCurve, FourierRZToroidalSurface, FourierXYZCurve from desc.grid import Grid, LinearGrid from desc.io import load from desc.magnetic_fields import SumMagneticField, VerticalMagneticField +from desc.utils import dot class TestCoil: @@ -149,6 +152,198 @@ def test_biot_savart_all_coils(self): B_true_rpz_phi, B_rpz, rtol=1e-3, atol=1e-10, err_msg="Using FourierRZCoil" ) + @pytest.mark.unit + def test_biot_savart_vector_potential_all_coils(self): + """Test biot-savart vec potential implementation against analytic formula.""" + coil_grid = LinearGrid(zeta=100, endpoint=False) + + R = 2 + y = 1 + I = 1e7 + + A_true = np.atleast_2d([0, 0, 0]) + grid_xyz = np.atleast_2d([10, y, 0]) + grid_rpz = xyz2rpz(grid_xyz) + + def test(coil, grid_xyz, grid_rpz): + A_xyz = coil.compute_magnetic_vector_potential( + grid_xyz, basis="xyz", source_grid=coil_grid + ) + A_rpz = coil.compute_magnetic_vector_potential( + grid_rpz, basis="rpz", source_grid=coil_grid + ) + np.testing.assert_allclose( + A_true, A_xyz, rtol=1e-3, atol=1e-10, err_msg=f"Using {coil}" + ) + np.testing.assert_allclose( + A_true, A_rpz, rtol=1e-3, atol=1e-10, err_msg=f"Using {coil}" + ) + np.testing.assert_allclose( + A_true, A_rpz, rtol=1e-3, atol=1e-10, err_msg=f"Using {coil}" + ) + + # FourierXYZCoil + coil = FourierXYZCoil(I) + test(coil, grid_xyz, grid_rpz) + + # SplineXYZCoil + x = coil.compute("x", grid=coil_grid, basis="xyz")["x"] + coil = SplineXYZCoil(I, X=x[:, 0], Y=x[:, 1], Z=x[:, 2]) + test(coil, grid_xyz, grid_rpz) + + # FourierPlanarCoil + coil = FourierPlanarCoil(I) + test(coil, grid_xyz, grid_rpz) + + grid_xyz = np.atleast_2d([0, 0, y]) + grid_rpz = xyz2rpz(grid_xyz) + + # FourierRZCoil + coil = FourierRZCoil(I, R_n=np.array([R]), modes_R=np.array([0])) + test(coil, grid_xyz, grid_rpz) + # test in a CoilSet + coil2 = CoilSet(coil) + test(coil2, grid_xyz, grid_rpz) + # test in a MixedCoilSet + coil3 = MixedCoilSet(coil2, coil, check_intersection=False) + coil3[1].current = 0 + test(coil3, grid_xyz, grid_rpz) + + @pytest.mark.unit + def test_biot_savart_vector_potential_integral_all_coils(self): + """Test analytic expression of flux integral for all coils.""" + # taken from analytic benchmark in + # "A Magnetic Diagnostic Code for 3D Fusion Equilibria", Lazerson 2013 + # find flux for concentric loops of varying radii to a circular coil + + coil_grid = LinearGrid(zeta=1000, endpoint=False) + + R = 1 + I = 1e7 + + # analytic eqn for "A_phi" (phi is in dl direction for loop) + def _A_analytic(r): + # elliptic integral arguments must be k^2, not k, + # error in original paper and apparently in Jackson EM book too. + theta = np.pi / 2 + arg = R**2 + r**2 + 2 * r * R * np.sin(theta) + term_1_num = 4.0e-7 * I * R + term_1_den = np.sqrt(arg) + k_sqd = 4 * r * R * np.sin(theta) / arg + term_2_num = (2 - k_sqd) * scipy.special.ellipk( + k_sqd + ) - 2 * scipy.special.ellipe(k_sqd) + term_2_den = k_sqd + return term_1_num * term_2_num / term_1_den / term_2_den + + # we only evaluate it at theta=np.pi/2 (b/c it is in spherical coords) + rs = np.linspace(0.1, 3, 10, endpoint=True) + N = 200 + curve_grid = LinearGrid(zeta=N) + + def test( + coil, grid_xyz, grid_rpz, A_true_rpz, correct_flux, rtol=1e-10, atol=1e-12 + ): + """Test that we compute the correct flux for the given coil.""" + A_xyz = coil.compute_magnetic_vector_potential( + grid_xyz, basis="xyz", source_grid=coil_grid + ) + A_rpz = coil.compute_magnetic_vector_potential( + grid_rpz, basis="rpz", source_grid=coil_grid + ) + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose( + correct_flux, flux_xyz, rtol=rtol, err_msg=f"Using {coil}" + ) + np.testing.assert_allclose( + correct_flux, flux_rpz, rtol=rtol, err_msg=f"Using {coil}" + ) + np.testing.assert_allclose( + A_true_rpz, + A_rpz, + rtol=rtol, + atol=atol, + err_msg=f"Using {coil}", + ) + + for r in rs: + # A_phi is constant around the loop (no phi dependence) + A_true_phi = _A_analytic(r) * np.ones(N) + A_true_rpz = np.vstack( + (np.zeros_like(A_true_phi), A_true_phi, np.zeros_like(A_true_phi)) + ).T + correct_flux = np.sum(r * A_true_phi * 2 * np.pi / N) + + curve = FourierXYZCurve( + X_n=[-r, 0, 0], Y_n=[0, 0, r], Z_n=[0, 0, 0] + ) # flux loop to integrate A over + + curve_data = curve.compute(["x", "x_s"], grid=curve_grid, basis="xyz") + curve_data_rpz = curve.compute(["x", "x_s"], grid=curve_grid, basis="rpz") + + grid_rpz = np.vstack( + [ + curve_data_rpz["x"][:, 0], + curve_data_rpz["x"][:, 1], + curve_data_rpz["x"][:, 2], + ] + ).T + grid_xyz = rpz2xyz(grid_rpz) + # FourierXYZCoil + coil = FourierXYZCoil(I, X_n=[-R, 0, 0], Y_n=[0, 0, R], Z_n=[0, 0, 0]) + test( + coil, + grid_xyz, + grid_rpz, + A_true_rpz, + correct_flux, + rtol=1e-8, + atol=1e-12, + ) + + # SplineXYZCoil + x = coil.compute("x", grid=coil_grid, basis="xyz")["x"] + coil = SplineXYZCoil(I, X=x[:, 0], Y=x[:, 1], Z=x[:, 2]) + test( + coil, + grid_xyz, + grid_rpz, + A_true_rpz, + correct_flux, + rtol=1e-4, + atol=1e-12, + ) + + # FourierPlanarCoil + coil = FourierPlanarCoil(I, center=[0, 0, 0], normal=[0, 0, -1], r_n=R) + test( + coil, + grid_xyz, + grid_rpz, + A_true_rpz, + correct_flux, + rtol=1e-8, + atol=1e-12, + ) + + # FourierRZCoil + coil = FourierRZCoil(I, R_n=np.array([R]), modes_R=np.array([0])) + test( + coil, + grid_xyz, + grid_rpz, + A_true_rpz, + correct_flux, + rtol=1e-8, + atol=1e-12, + ) + @pytest.mark.unit def test_properties(self): """Test getting/setting attributes for Coil class.""" diff --git a/tests/test_compute_everything.py b/tests/test_compute_everything.py index aff0345e8f..308138b26e 100644 --- a/tests/test_compute_everything.py +++ b/tests/test_compute_everything.py @@ -80,8 +80,21 @@ def _compare_against_rpz(p, data, data_rpz, coordinate_conversion_func): def test_compute_everything(): """Test that the computations on this branch agree with those on master. - Also make sure we can compute everything without errors. Computed quantities - are both in "rpz" and "xyz" basis. + Also make sure we can compute everything without errors. + + Notes + ----- + This test will fail if the benchmark file has been updated on both + the local and upstream branches and git cannot resolve the merge + conflict. In that case, please regenerate the benchmark file. + Here are instructions for convenience. + + 1. Prepend true to the line near the end of this test. + ``if True or (not error_rpz and update_master_data_rpz):`` + 2. Run pytest -k test_compute_everything + 3. Revert 1. + 4. git add tests/inputs/master_compute_data_rpz.pkl + """ elliptic_cross_section_with_torsion = { "R_lmn": [10, 1, 0.2], diff --git a/tests/test_compute_funs.py b/tests/test_compute_funs.py index 43c3d81449..9a9216cc8e 100644 --- a/tests/test_compute_funs.py +++ b/tests/test_compute_funs.py @@ -5,12 +5,12 @@ from scipy.signal import convolve2d from desc.compute import rpz2xyz_vec -from desc.compute.utils import dot from desc.equilibrium import Equilibrium from desc.examples import get from desc.geometry import FourierRZToroidalSurface from desc.grid import LinearGrid from desc.io import load +from desc.utils import dot # convolve kernel is reverse of FD coeffs FD_COEF_1_2 = np.array([-1 / 2, 0, 1 / 2])[::-1] @@ -1134,6 +1134,24 @@ def test_boozer_transform(): ) +@pytest.mark.unit +def test_boozer_transform_multiple_surfaces(): + """Test that computing over multiple surfaces is the same as over 1 at a time.""" + eq = get("HELIOTRON") + grid1 = LinearGrid(rho=0.6, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP) + grid2 = LinearGrid(rho=0.8, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP) + grid3 = LinearGrid(rho=np.array([0.6, 0.8]), M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP) + data1 = eq.compute("|B|_mn", grid=grid1, M_booz=eq.M, N_booz=eq.N) + data2 = eq.compute("|B|_mn", grid=grid2, M_booz=eq.M, N_booz=eq.N) + data3 = eq.compute("|B|_mn", grid=grid3, M_booz=eq.M, N_booz=eq.N) + np.testing.assert_allclose( + data1["|B|_mn"], data3["|B|_mn"].reshape((grid3.num_rho, -1))[0] + ) + np.testing.assert_allclose( + data2["|B|_mn"], data3["|B|_mn"].reshape((grid3.num_rho, -1))[1] + ) + + @pytest.mark.unit def test_compute_averages(): """Test that computing averages uses the correct grid.""" diff --git a/tests/test_examples.py b/tests/test_examples.py index 6bae9dd16b..a03e364c35 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1077,9 +1077,12 @@ def test_freeb_axisym(): -6.588300858364606e04, -3.560589388468855e05, ] - ext_field = SplineMagneticField.from_mgrid( - r"tests/inputs/mgrid_solovev.nc", extcur=extcur - ) + with pytest.warns(UserWarning, match="Vector potential"): + # the mgrid file does not have the vector potential + # saved so we will ignore the thrown warning + ext_field = SplineMagneticField.from_mgrid( + r"tests/inputs/mgrid_solovev.nc", extcur=extcur + ) pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) diff --git a/tests/test_integrals.py b/tests/test_integrals.py index d2f5950804..ca4e7e6d91 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -15,7 +15,6 @@ from desc.backend import jnp from desc.basis import FourierZernikeBasis -from desc.compute.utils import dot, safediv from desc.equilibrium import Equilibrium from desc.equilibrium.coords import get_rtz_grid, map_coordinates from desc.examples import get @@ -40,15 +39,15 @@ _get_extrema, bounce_points, get_alpha, - get_pitch, + get_pitch_inv, interp_to_argmin, interp_to_argmin_hard, - plot_ppoly, ) from desc.integrals.interp_utils import fourier_pts from desc.integrals.quad_utils import ( automorphism_sin, bijection_from_disc, + get_quadrature, grad_automorphism_sin, grad_bijection_from_disc, leggauss_lob, @@ -57,6 +56,7 @@ from desc.integrals.singularities import _get_quadrature_nodes from desc.integrals.surface_integral import _get_grid_surface from desc.transform import Transform +from desc.utils import dot, safediv class TestSurfaceIntegral: @@ -735,14 +735,16 @@ def filter(z1, z2): @pytest.mark.unit def test_z1_first(self): - """Test that bounce points are computed correctly.""" + """Case where straight line through first two intersects is in epigraph.""" start = np.pi / 3 end = 6 * np.pi knots = np.linspace(start, end, 5) B = CubicHermiteSpline(knots, np.cos(knots), -np.sin(knots)) - pitch = 2.0 - intersect = B.solve(1 / pitch, extrapolate=False) - z1, z2 = bounce_points(pitch, knots, B.c, B.derivative().c, check=True) + pitch_inv = 0.5 + intersect = B.solve(pitch_inv, extrapolate=False) + z1, z2 = bounce_points( + pitch_inv, knots, B.c.T, B.derivative().c.T, check=True, include_knots=True + ) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size np.testing.assert_allclose(z1, intersect[0::2]) @@ -750,14 +752,16 @@ def test_z1_first(self): @pytest.mark.unit def test_z2_first(self): - """Test that bounce points are computed correctly.""" + """Case where straight line through first two intersects is in hypograph.""" start = -3 * np.pi end = -start k = np.linspace(start, end, 5) B = CubicHermiteSpline(k, np.cos(k), -np.sin(k)) - pitch = 2.0 - intersect = B.solve(1 / pitch, extrapolate=False) - z1, z2 = bounce_points(pitch, k, B.c, B.derivative().c, check=True) + pitch_inv = 0.5 + intersect = B.solve(pitch_inv, extrapolate=False) + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, B.derivative().c.T, check=True, include_knots=True + ) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size np.testing.assert_allclose(z1, intersect[1:-1:2]) @@ -765,7 +769,9 @@ def test_z2_first(self): @pytest.mark.unit def test_z1_before_extrema(self): - """Test that bounce points are computed correctly.""" + """Case where local maximum is the shared intersect between two wells.""" + # To make sure both regions in epigraph left and right of extrema are + # integrated over. start = -np.pi end = -2 * start k = np.linspace(start, end, 5) @@ -773,11 +779,13 @@ def test_z1_before_extrema(self): k, np.cos(k) + 2 * np.sin(-2 * k), -np.sin(k) - 4 * np.cos(-2 * k) ) dB_dz = B.derivative() - pitch = 1 / B(dB_dz.roots(extrapolate=False))[3] + 1e-13 - z1, z2 = bounce_points(pitch, k, B.c, dB_dz.c, check=True) + pitch_inv = B(dB_dz.roots(extrapolate=False))[3] - 1e-13 + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size - intersect = B.solve(1 / pitch, extrapolate=False) + intersect = B.solve(pitch_inv, extrapolate=False) np.testing.assert_allclose(z1[1], 1.982767, rtol=1e-6) np.testing.assert_allclose(z1, intersect[[1, 2]], rtol=1e-6) # intersect array could not resolve double root as single at index 2,3 @@ -786,7 +794,9 @@ def test_z1_before_extrema(self): @pytest.mark.unit def test_z2_before_extrema(self): - """Test that bounce points are computed correctly.""" + """Case where local minimum is the shared intersect between two wells.""" + # To make sure both regions in hypograph left and right of extrema are not + # integrated over. start = -1.2 * np.pi end = -2 * start k = np.linspace(start, end, 7) @@ -796,17 +806,20 @@ def test_z2_before_extrema(self): -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 4, ) dB_dz = B.derivative() - pitch = 1 / B(dB_dz.roots(extrapolate=False))[2] - z1, z2 = bounce_points(pitch, k, B.c, dB_dz.c, check=True) + pitch_inv = B(dB_dz.roots(extrapolate=False))[2] + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size - intersect = B.solve(1 / pitch, extrapolate=False) + intersect = B.solve(pitch_inv, extrapolate=False) np.testing.assert_allclose(z1, intersect[[0, -2]]) np.testing.assert_allclose(z2, intersect[[1, -1]]) @pytest.mark.unit def test_extrema_first_and_before_z1(self): - """Test that bounce points are computed correctly.""" + """Case where first intersect is extrema and second enters epigraph.""" + # To make sure we don't perform integral between first pair of intersects. start = -1.2 * np.pi end = -2 * start k = np.linspace(start, end, 7) @@ -816,14 +829,19 @@ def test_extrema_first_and_before_z1(self): -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 20, ) dB_dz = B.derivative() - pitch = 1 / B(dB_dz.roots(extrapolate=False))[2] - 1e-13 + pitch_inv = B(dB_dz.roots(extrapolate=False))[2] + 1e-13 z1, z2 = bounce_points( - pitch, k[2:], B.c[:, 2:], dB_dz.c[:, 2:], check=True, plot=False + pitch_inv, + k[2:], + B.c[:, 2:].T, + dB_dz.c[:, 2:].T, + check=True, + start=k[2], + include_knots=True, ) - plot_ppoly(B, z1=z1, z2=z2, k=1 / pitch, start=k[2]) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size - intersect = B.solve(1 / pitch, extrapolate=False) + intersect = B.solve(pitch_inv, extrapolate=False) np.testing.assert_allclose(z1[0], 0.835319, rtol=1e-6) intersect = intersect[intersect >= k[2]] np.testing.assert_allclose(z1, intersect[[0, 2, 4]], rtol=1e-6) @@ -831,7 +849,8 @@ def test_extrema_first_and_before_z1(self): @pytest.mark.unit def test_extrema_first_and_before_z2(self): - """Test that bounce points are computed correctly.""" + """Case where first intersect is extrema and second exits epigraph.""" + # To make sure we do perform integral between first pair of intersects. start = -1.2 * np.pi end = -2 * start + 1 k = np.linspace(start, end, 7) @@ -841,12 +860,14 @@ def test_extrema_first_and_before_z2(self): -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 10, ) dB_dz = B.derivative() - pitch = 1 / B(dB_dz.roots(extrapolate=False))[1] + 1e-13 - z1, z2 = bounce_points(pitch, k, B.c, dB_dz.c, check=True) + pitch_inv = B(dB_dz.roots(extrapolate=False))[1] - 1e-13 + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size # Our routine correctly detects intersection, while scipy, jnp.root fails. - intersect = B.solve(1 / pitch, extrapolate=False) + intersect = B.solve(pitch_inv, extrapolate=False) np.testing.assert_allclose(z1[0], -0.671904, rtol=1e-6) np.testing.assert_allclose(z1, intersect[[0, 3, 5]], rtol=1e-5) # intersect array could not resolve double root as single at index 0,1 @@ -863,7 +884,7 @@ def test_get_extrema(self): k, np.cos(k) + 2 * np.sin(-2 * k), -np.sin(k) - 4 * np.cos(-2 * k) ) dB_dz = B.derivative() - ext, B_ext = _get_extrema(k, B.c, dB_dz.c) + ext, B_ext = _get_extrema(k, B.c.T, dB_dz.c.T) mask = ~np.isnan(ext) ext, B_ext = ext[mask], B_ext[mask] idx = np.argsort(ext) @@ -904,27 +925,26 @@ class TestBounce1DQuadrature: ], ) def test_bounce_quadrature(self, is_strong, quad, automorphism): - """Test bounce integral matches singular elliptic integrals.""" + """Test quadrature matches singular (strong and weak) elliptic integrals.""" p = 1e-4 m = 1 - p # Some prime number that doesn't appear anywhere in calculation. - # Ensures no lucky cancellation occurs from this test case since otherwise - # (z2 - z1) / pi = pi / (z2 - z1) which could mask errors since pi - # appears often in transformations. + # Ensures no lucky cancellation occurs from ζ₂ − ζ₁ / π = π / (ζ₂ − ζ₁) + # which could mask errors since π appears often in transformations. v = 7 z1 = -np.pi / 2 * v z2 = -z1 knots = np.linspace(z1, z2, 50) - pitch = 1 + 50 * jnp.finfo(jnp.array(1.0).dtype).eps + pitch_inv = 1 - 50 * jnp.finfo(jnp.array(1.0).dtype).eps b = np.clip(np.sin(knots / v) ** 2, 1e-7, 1) db = np.sin(2 * knots / v) / v data = {"B^zeta": b, "B^zeta_z|r,a": db, "|B|": b, "|B|_z|r,a": db} if is_strong: - integrand = lambda B, pitch: 1 / jnp.sqrt(1 - pitch * m * B) + integrand = lambda B, pitch: 1 / jnp.sqrt(1 - m * pitch * B) truth = v * 2 * ellipkm1(p) else: - integrand = lambda B, pitch: jnp.sqrt(1 - pitch * m * B) + integrand = lambda B, pitch: jnp.sqrt(1 - m * pitch * B) truth = v * 2 * ellipe(m) kwargs = {} if automorphism != "default": @@ -936,9 +956,9 @@ def test_bounce_quadrature(self, is_strong, quad, automorphism): check=True, **kwargs, ) - result = bounce.integrate(pitch, integrand, [], check=True) + result = bounce.integrate(integrand, pitch_inv, check=True, plot=True) assert np.count_nonzero(result) == 1 - np.testing.assert_allclose(np.sum(result), truth, rtol=1e-4) + np.testing.assert_allclose(result.sum(), truth, rtol=1e-4) @staticmethod @partial(np.vectorize, excluded={0}) @@ -952,17 +972,24 @@ def _fixed_elliptic(integrand, k, deg): k = np.atleast_1d(k) a = np.zeros_like(k) b = 2 * np.arcsin(k) - x, w = leggauss(deg) - w = w * grad_automorphism_sin(x) - x = automorphism_sin(x) + x, w = get_quadrature(leggauss(deg), (automorphism_sin, grad_automorphism_sin)) Z = bijection_from_disc(x, a[..., np.newaxis], b[..., np.newaxis]) k = k[..., np.newaxis] - quad = np.dot(integrand(Z, k), w) * grad_bijection_from_disc(a, b) + quad = integrand(Z, k).dot(w) * grad_bijection_from_disc(a, b) return quad + # TODO: add the analytical test that converts incomplete elliptic integrals to + # complete ones using the Reciprocal Modulus transformation + # https://dlmf.nist.gov/19.7#E4. @staticmethod def elliptic_incomplete(k2): - """Calculate elliptic integrals for bounce averaged binormal drift.""" + """Calculate elliptic integrals for bounce averaged binormal drift. + + The test is nice because it is independent of all the bounce integrals + and splines. One can test performance of different quadrature methods + by using that method in the ``_fixed_elliptic`` method above. + + """ K_integrand = lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * (k / 4) E_integrand = lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) / (k * 4) # Scipy's elliptic integrals are broken. @@ -1017,7 +1044,7 @@ def elliptic_incomplete(k2): TestBounce1DQuadrature._fixed_elliptic( lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.cos(Z), k, - deg=10, + deg=11, ), ) np.testing.assert_allclose( @@ -1032,8 +1059,18 @@ def elliptic_incomplete(k2): class TestBounce1D: """Test bounce integration with one-dimensional local spline methods.""" + @staticmethod + def _example_numerator(g_zz, B, pitch): + f = (1 - 0.5 * pitch * B) * g_zz + return safediv(f, jnp.sqrt(jnp.abs(1 - pitch * B))) + + @staticmethod + def _example_denominator(B, pitch): + return safediv(1, jnp.sqrt(jnp.abs(1 - pitch * B))) + @pytest.mark.unit - def test_integrate_checks(self): + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d * 4) + def test_bounce1d_checks(self): """Test that all the internal correctness checks pass for real example.""" # noqa: D202 # Suppose we want to compute a bounce average of the function @@ -1042,55 +1079,67 @@ def test_integrate_checks(self): # coordinates. This is defined as # [∫ f(ℓ) / √(1 − λ|B|) dℓ] / [∫ 1 / √(1 − λ|B|) dℓ] - def numerator(g_zz, B, pitch): - f = (1 - pitch * B / 2) * g_zz - return safediv(f, jnp.sqrt(jnp.abs(1 - pitch * B))) - - def denominator(B, pitch): - return safediv(1, jnp.sqrt(jnp.abs(1 - pitch * B))) - - # Pick flux surfaces, field lines, and how far to follow the field line - # in Clebsch-Type field-line coordinates ρ, α, ζ. + # 1. Define python functions for the integrands. We do that above. + # 2. Pick flux surfaces, field lines, and how far to follow the field + # line in Clebsch coordinates ρ, α, ζ. rho = np.linspace(0.1, 1, 6) - alpha = np.array([0]) + alpha = np.array([0, 0.5]) zeta = np.linspace(-2 * np.pi, 2 * np.pi, 200) eq = get("HELIOTRON") - # Convert above coordinates to DESC computational coordinates. + # 3. Convert above coordinates to DESC computational coordinates. grid = get_rtz_grid( eq, rho, alpha, zeta, coordinates="raz", period=(np.inf, 2 * np.pi, np.inf) ) + # 4. Compute input data. data = eq.compute( - Bounce1D.required_names() + ["min_tz |B|", "max_tz |B|", "g_zz"], grid=grid + Bounce1D.required_names + ["min_tz |B|", "max_tz |B|", "g_zz"], grid=grid + ) + # 5. Make the bounce integration operator. + bounce = Bounce1D( + grid.source_grid, + data, + quad=leggauss(3), # not checking quadrature accuracy in this test + check=True, ) - bounce = Bounce1D(grid.source_grid, data, quad=leggauss(3), check=True) - pitch = get_pitch( + pitch_inv = bounce.get_pitch_inv( grid.compress(data["min_tz |B|"]), grid.compress(data["max_tz |B|"]), 10 ) num = bounce.integrate( - pitch, - numerator, - Bounce1D.reshape_data(grid.source_grid, data["g_zz"]), + integrand=TestBounce1D._example_numerator, + pitch_inv=pitch_inv, + f=Bounce1D.reshape_data(grid.source_grid, data["g_zz"]), check=True, ) - den = bounce.integrate(pitch, denominator, [], check=True) + den = bounce.integrate( + integrand=TestBounce1D._example_denominator, + pitch_inv=pitch_inv, + check=True, + batch=False, + ) avg = safediv(num, den) - - # Sum all bounce integrals across each particular field line. - avg = np.sum(avg, axis=-1) - assert np.isfinite(avg).all() - # Group the averages by field line. - avg = avg.reshape(pitch.shape[0], rho.size, alpha.size) - # The sum stored at index i, j - i, j = 0, 0 - print(avg[:, i, j]) - # is the summed bounce average among wells along the field line with nodes - # given in Clebsch-Type field-line coordinates ρ, α, ζ - nodes = grid.source_grid.meshgrid_reshape(grid.source_grid.nodes, "raz") - print(nodes[i, j]) - # for the pitch values stored in - pitch = pitch.reshape(pitch.shape[0], rho.size, alpha.size) - print(pitch[:, i, j]) + assert np.isfinite(avg).all() and np.count_nonzero(avg) + + # 6. Basic manipulation of the output. + # Sum all bounce averages across a particular field line, for every field line. + result = avg.sum(axis=-1) + # Group the result by pitch and flux surface. + result = result.reshape(alpha.size, rho.size, pitch_inv.shape[-1]) + # The result stored at + m, l, p = 0, 1, 3 + print("Result(α, ρ, λ):", result[m, l, p]) + # corresponds to the 1/λ value + print("1/λ(α, ρ):", pitch_inv[l, p]) + # for the Clebsch-type field line coordinates + nodes = grid.source_grid.meshgrid_reshape(grid.source_grid.nodes[:, :2], "arz") + print("(α, ρ):", nodes[m, l, 0]) + + # 7. Optionally check for correctness of bounce points + bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False) + + # 8. Plotting + fig, ax = bounce.plot(m, l, pitch_inv[l], include_legend=False, show=False) + return fig @pytest.mark.unit @pytest.mark.parametrize("func", [interp_to_argmin, interp_to_argmin_hard]) @@ -1122,21 +1171,22 @@ def dg_dz(z): "|B|_z|r,a": dg_dz(zeta), }, ) - np.testing.assert_allclose(bounce._zeta, zeta) + z1 = np.array(0, ndmin=4) + z2 = np.array(2 * np.pi, ndmin=4) argmin = 5.61719 - np.testing.assert_allclose( - h(argmin), - func( - h(zeta), - z1=np.array(0, ndmin=3), - z2=np.array(2 * np.pi, ndmin=3), - knots=zeta, - g=bounce._B, - dg_dz=bounce._dB_dz, - ), - rtol=1e-3, - ) - + h_min = h(argmin) + result = func( + h=h(zeta), + z1=z1, + z2=z2, + knots=zeta, + g=bounce.B, + dg_dz=bounce._dB_dz, + ) + assert result.shape == z1.shape + np.testing.assert_allclose(h_min, result, rtol=1e-3) + + # TODO: stellarator geometry test with ripples @staticmethod def drift_analytic(data): """Compute analytic approximation for bounce-averaged binormal drift. @@ -1150,9 +1200,9 @@ def drift_analytic(data): Numerically computed ``data["cvdrift"]` and ``data["gbdrift"]`` normalized by some scale factors for this unit test. These should be fed to the bounce integration as input. - pitch : jnp.ndarray + pitch_inv : jnp.ndarray Shape (P, ). - Pitch values used. + 1/λ values used. """ B = data["|B|"] / data["Bref"] @@ -1216,12 +1266,14 @@ def drift_analytic(data): np.testing.assert_allclose(gbdrift, gbdrift_analytic_low_order, atol=1e-2) np.testing.assert_allclose(cvdrift, cvdrift_analytic_low_order, atol=2e-2) - pitch = get_pitch(np.min(B), np.max(B), 100)[1:] - k2 = 0.5 * ((1 - pitch * B0) / (epsilon * pitch * B0) + 1) + # Exclude singularity not captured by analytic approximation for pitch near + # the maximum |B|. (This is captured by the numerical integration). + pitch_inv = get_pitch_inv(np.min(B), np.max(B), 100)[:-1] + k2 = 0.5 * ((1 - B0 / pitch_inv) / (epsilon * B0 / pitch_inv) + 1) I_0, I_1, I_2, I_3, I_4, I_5, I_6, I_7 = ( TestBounce1DQuadrature.elliptic_incomplete(k2) ) - y = np.sqrt(2 * epsilon * pitch * B0) + y = np.sqrt(2 * epsilon * B0 / pitch_inv) I_0, I_2, I_4, I_6 = map(lambda I: I / y, (I_0, I_2, I_4, I_6)) I_1, I_3, I_5, I_7 = map(lambda I: I * y, (I_1, I_3, I_5, I_7)) @@ -1237,7 +1289,7 @@ def drift_analytic(data): ) / G0 drift_analytic_den = I_0 / G0 drift_analytic = drift_analytic_num / drift_analytic_den - return drift_analytic, cvdrift, gbdrift, pitch + return drift_analytic, cvdrift, gbdrift, pitch_inv @staticmethod def drift_num_integrand(cvdrift, gbdrift, B, pitch): @@ -1276,7 +1328,7 @@ def test_binormal_drift_bounce1d(self): iota=iota, ) data = eq.compute( - Bounce1D.required_names() + Bounce1D.required_names + [ "cvdrift", "gbdrift", @@ -1301,7 +1353,7 @@ def test_binormal_drift_bounce1d(self): data["shear"] = grid.compress(data["shear"]) # Compute analytic approximation. - drift_analytic, cvdrift, gbdrift, pitch = TestBounce1D.drift_analytic(data) + drift_analytic, cvdrift, gbdrift, pitch_inv = TestBounce1D.drift_analytic(data) # Compute numerical result. bounce = Bounce1D( grid.source_grid, @@ -1311,17 +1363,19 @@ def test_binormal_drift_bounce1d(self): Lref=data["a"], check=True, ) + bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False) + f = Bounce1D.reshape_data(grid.source_grid, cvdrift, gbdrift) drift_numerical_num = bounce.integrate( - pitch=pitch[:, np.newaxis], integrand=TestBounce1D.drift_num_integrand, + pitch_inv=pitch_inv, f=f, num_well=1, check=True, ) drift_numerical_den = bounce.integrate( - pitch=pitch[:, np.newaxis], integrand=TestBounce1D.drift_den_integrand, + pitch_inv=pitch_inv, num_well=1, weight=np.ones(zeta.size), check=True, @@ -1333,7 +1387,7 @@ def test_binormal_drift_bounce1d(self): drift_numerical, drift_analytic, atol=5e-3, rtol=5e-2 ) - self._test_bounce_autodiff( + TestBounce1D._test_bounce_autodiff( bounce, TestBounce1D.drift_num_integrand, f=f, @@ -1341,32 +1395,76 @@ def test_binormal_drift_bounce1d(self): ) fig, ax = plt.subplots() - ax.plot(1 / pitch, drift_analytic) - ax.plot(1 / pitch, drift_numerical) + ax.plot(pitch_inv, drift_analytic) + ax.plot(pitch_inv, drift_numerical) return fig @staticmethod def _test_bounce_autodiff(bounce, integrand, **kwargs): - """Make sure reverse mode AD works correctly on this algorithm.""" + """Make sure reverse mode AD works correctly on this algorithm. + + Non-differentiable operations (e.g. ``take_mask``) are used in computation. + See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html + and https://jax.readthedocs.io/en/latest/faq.html# + why-are-gradients-zero-for-functions-based-on-sort-order. + + If the AD tool works properly, then these operations should be assigned + zero gradients while the gradients wrt parameters of our physics computations + accumulate correctly. Less mature AD tools may have subtle bugs that cause + the gradients to not accumulate correctly. (There's a few + GitHub issues that JAX has fixed related to this in the past.) + + This test first confirms the gradients computed by reverse mode AD matches + the analytic approximation of the true gradient. Then we confirm that the + partial gradients wrt the integrand and bounce points are correct. + + Apply the Leibniz integral rule + https://en.wikipedia.org/wiki/Leibniz_integral_rule, with + the label w summing over the magnetic wells: + + ∂_λ ∑_w ∫_ζ₁^ζ₂ f dζ (λ) = ∑_w [ + ∫_ζ₁^ζ₂ (∂f/∂λ)(λ) dζ + + f(λ,ζ₂) (∂ζ₂/∂λ)(λ) + - f(λ,ζ₁) (∂ζ₁/∂λ)(λ) + ] + where (∂ζ₁/∂λ)(λ) = -λ² / (∂|B|/∂ζ|ρ,α)(ζ₁) + (∂ζ₂/∂λ)(λ) = -λ² / (∂|B|/∂ζ|ρ,α)(ζ₂) + + All terms in these expressions are known analytically. + If we wanted, it's simple to check explicitly that AD takes each derivative + correctly because |w| = 1 is constant and our tokamak has symmetry + (∂|B|/∂ζ|ρ,α)(ζ₁) = - (∂|B|/∂ζ|ρ,α)(ζ₂). + + After confirming the left hand side is correct, we just check that derivative + wrt bounce points of the right hand side doesn't vanish due to some zero + gradient issue mentioned above. + + """ def integrand_grad(*args, **kwargs2): - return jnp.vectorize( + grad_fun = jnp.vectorize( grad(integrand, -1), signature="()," * len(kwargs["f"]) + "(),()->()" - )(*args, *kwargs2.values()) + ) + return grad_fun(*args, *kwargs2.values()) def fun1(pitch): - return jnp.sum(bounce.integrate(pitch, integrand, check=False, **kwargs)) + return bounce.integrate(integrand, 1 / pitch, check=False, **kwargs).sum() def fun2(pitch): - return jnp.sum( - bounce.integrate(pitch, integrand_grad, check=True, **kwargs) - ) + return bounce.integrate( + integrand_grad, 1 / pitch, check=True, **kwargs + ).sum() pitch = 1.0 - truth = 650 # Extrapolated from plot. - assert np.isclose(grad(fun1)(pitch), truth, rtol=1e-3) - # Make sure bounce points get differentiated too. - assert np.isclose(fun2(pitch), -131750, rtol=1e-1) + # can easily obtain from math or just extrapolate from analytic expression plot + analytic_approximation_of_gradient = 650 + np.testing.assert_allclose( + grad(fun1)(pitch), analytic_approximation_of_gradient, rtol=1e-3 + ) + # It is expected that this is much larger because the integrand is singular + # wrt λ but the boundary derivative: f(λ,ζ₂) (∂ζ₂/∂λ)(λ) - f(λ,ζ₁) (∂ζ₁/∂λ)(λ). + # smooths out because the bounce points ζ₁ and ζ₂ are smooth functions of λ. + np.testing.assert_allclose(fun2(pitch), -131750, rtol=1e-1) class TestBounce2DPoints: @@ -1451,7 +1549,7 @@ def test_fourier_chebyshev(self, rho=1, M=8, N=32, f=lambda B, pitch: B * pitch) fb = Bounce2D( grid, data, M, N, desc_from_clebsch, check=True, warn=False ) # TODO check true - pitch = get_pitch( + pitch = get_pitch_inv( grid.compress(data["min_tz |B|"]), grid.compress(data["max_tz |B|"]), 10 ) result = fb.integrate(f, [], pitch) # noqa: F841 diff --git a/tests/test_interp_utils.py b/tests/test_interp_utils.py index 250ca42a8e..47b4e6226b 100644 --- a/tests/test_interp_utils.py +++ b/tests/test_interp_utils.py @@ -10,6 +10,7 @@ chebpts2, chebval, ) +from numpy.polynomial.polynomial import polyvander from scipy.fft import dct as sdct from scipy.fft import idct as sidct @@ -22,36 +23,46 @@ interp_dct, interp_rfft, interp_rfft2, - poly_root, polyder_vec, + polyroot_vec, polyval_vec, ) from desc.integrals.quad_utils import bijection_to_disc class TestPolyUtils: - """Test polynomial stuff used for local spline interpolation.""" + """Test polynomial utilities used for local spline interpolation in integrals.""" @pytest.mark.unit - def test_poly_root(self): + def test_polyroot_vec(self): """Test vectorized computation of cubic polynomial exact roots.""" - cubic = 4 - c = np.arange(-24, 24).reshape(cubic, 6, -1) * np.pi - # make sure broadcasting won't hide error in implementation + c = np.arange(-24, 24).reshape(4, 6, -1).transpose(-1, 1, 0) + # Ensure broadcasting won't hide error in implementation. assert np.unique(c.shape).size == c.ndim - constant = np.broadcast_to(np.arange(c.shape[-1]), c.shape[1:]) - constant = np.stack([constant, constant]) - root = poly_root(c, constant, sort=True) - - for i in range(constant.shape[0]): - for j in range(c.shape[1]): - for k in range(c.shape[2]): - d = c[-1, j, k] - constant[i, j, k] - np.testing.assert_allclose( - actual=root[i, j, k], - desired=np.sort(np.roots([*c[:-1, j, k], d])), - ) + k = np.broadcast_to(np.arange(c.shape[-2]), c.shape[:-1]) + # Now increase dimension so that shapes still broadcast, but stuff like + # ``c[...,-1]-=k`` is not allowed because it grows the dimension of ``c``. + # This is needed functionality in ``polyroot_vec`` that requires an awkward + # loop to obtain if using jnp.vectorize. + k = np.stack([k, k * 2 + 1]) + r = polyroot_vec(c, k, sort=True) + + for i in range(k.shape[0]): + d = c.copy() + d[..., -1] -= k[i] + # np.roots cannot be vectorized because it strips leading zeros and + # output shape is therefore dynamic. + for idx in np.ndindex(d.shape[:-1]): + np.testing.assert_allclose( + r[(i, *idx)], + np.sort(np.roots(d[idx])), + err_msg=f"Eigenvalue branch of polyroot_vec failed at {i, *idx}.", + ) + + # Now test analytic formula branch, Ensure it filters distinct roots, + # and ensure zero coefficients don't bust computation due to singularities + # in analytic formulae which are not present in iterative eigenvalue scheme. c = np.array( [ [1, 0, 0, 0], @@ -63,58 +74,55 @@ def test_poly_root(self): [0, -6, 11, -2], ] ) - root = poly_root(c.T, sort=True, distinct=True) + r = polyroot_vec(c, sort=True, distinct=True) for j in range(c.shape[0]): - unique_roots = np.unique(np.roots(c[j])) + root = r[j][~np.isnan(r[j])] + unique_root = np.unique(np.roots(c[j])) + assert root.size == unique_root.size np.testing.assert_allclose( - actual=root[j][~np.isnan(root[j])], desired=unique_roots, err_msg=str(j) + root, + unique_root, + err_msg=f"Analytic branch of polyroot_vec failed at {j}.", ) c = np.array([0, 1, -1, -8, 12]) - root = poly_root(c, sort=True, distinct=True) - root = root[~np.isnan(root)] - unique_root = np.unique(np.roots(c)) - assert root.size == unique_root.size - np.testing.assert_allclose(root, unique_root) + r = polyroot_vec(c, sort=True, distinct=True) + r = r[~np.isnan(r)] + unique_r = np.unique(np.roots(c)) + assert r.size == unique_r.size + np.testing.assert_allclose(r, unique_r) @pytest.mark.unit def test_polyder_vec(self): """Test vectorized computation of polynomial derivative.""" - quintic = 6 - c = np.arange(-18, 18).reshape(quintic, 3, -1) * np.pi - # make sure broadcasting won't hide error in implementation + c = np.arange(-18, 18).reshape(3, -1, 6) + # Ensure broadcasting won't hide error in implementation. assert np.unique(c.shape).size == c.ndim - derivative = polyder_vec(c) - desired = np.vectorize(np.polyder, signature="(m)->(n)")(c.T).T - np.testing.assert_allclose(derivative, desired) + np.testing.assert_allclose( + polyder_vec(c), + np.vectorize(np.polyder, signature="(m)->(n)")(c), + ) @pytest.mark.unit def test_polyval_vec(self): """Test vectorized computation of polynomial evaluation.""" def test(x, c): - val = polyval_vec(x=x, c=c) - c = np.moveaxis(c, 0, -1) - x = x[..., np.newaxis] + # Ensure broadcasting won't hide error in implementation. + assert np.unique(x.shape).size == x.ndim + assert np.unique(c.shape).size == c.ndim np.testing.assert_allclose( - val, - np.vectorize(np.polyval, signature="(m),(n)->(n)")(c, x).squeeze( - axis=-1 - ), + polyval_vec(x=x, c=c), + np.sum(polyvander(x, c.shape[-1] - 1) * c[..., ::-1], axis=-1), ) - quartic = 5 - c = np.arange(-60, 60).reshape(quartic, 3, -1) * np.pi - # make sure broadcasting won't hide error in implementation - assert np.unique(c.shape).size == c.ndim - x = np.linspace(0, 20, c.shape[1] * c.shape[2]).reshape(c.shape[1], c.shape[2]) + c = np.arange(-60, 60).reshape(-1, 5, 3) + x = np.linspace(0, 20, np.prod(c.shape[:-1])).reshape(c.shape[:-1]) test(x, c) x = np.stack([x, x * 2], axis=0) x = np.stack([x, x * 2, x * 3, x * 4], axis=0) - # make sure broadcasting won't hide error in implementation - assert np.unique(x.shape).size == x.ndim - assert c.shape[1:] == x.shape[x.ndim - (c.ndim - 1) :] - assert np.unique((c.shape[0],) + x.shape[c.ndim - 1 :]).size == x.ndim - 1 + assert c.shape[:-1] == x.shape[x.ndim - (c.ndim - 1) :] + assert np.unique((c.shape[-1],) + x.shape[c.ndim - 1 :]).size == x.ndim - 1 test(x, c) diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index 06e7b83800..86f4174547 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -6,10 +6,11 @@ from desc.backend import jit, jnp from desc.basis import DoubleFourierSeries -from desc.compute import rpz2xyz_vec, xyz2rpz_vec +from desc.compute import rpz2xyz, rpz2xyz_vec, xyz2rpz_vec from desc.compute.utils import get_params, get_transforms +from desc.derivatives import FiniteDiffDerivative as Derivative from desc.examples import get -from desc.geometry import FourierRZToroidalSurface +from desc.geometry import FourierRZToroidalSurface, FourierXYZCurve from desc.grid import LinearGrid from desc.io import load from desc.magnetic_fields import ( @@ -22,11 +23,13 @@ ScalarPotentialField, SplineMagneticField, ToroidalMagneticField, + VectorPotentialField, VerticalMagneticField, field_line_integrate, read_BNORM_file, ) from desc.magnetic_fields._dommaschk import CD_m_k, CN_m_k +from desc.utils import dot def phi_lm(R, phi, Z, a, m): @@ -59,8 +62,41 @@ def test_basic_fields(self): tfield = ToroidalMagneticField(2, 1) vfield = VerticalMagneticField(1) pfield = PoloidalMagneticField(2, 1, 2) + + def tfield_A(R, phi, Z, B0=2, R0=1): + az = -B0 * R0 * jnp.log(R) + arp = jnp.zeros_like(az) + A = jnp.array([arp, arp, az]).T + return A + + tfield_from_A = VectorPotentialField(tfield_A, params={"B0": 2, "R0": 1}) + + def vfield_A(R, phi, Z, B0=None): + coords_rpz = jnp.vstack([R, phi, Z]).T + coords_xyz = rpz2xyz(coords_rpz) + ax = B0 / 2 * coords_xyz[:, 1] + ay = -B0 / 2 * coords_xyz[:, 0] + + az = jnp.zeros_like(ax) + A = jnp.array([ax, -ay, az]).T + A = xyz2rpz_vec(A, phi=coords_rpz[:, 1]) + return A + + vfield_params = {"B0": 1} + vfield_from_A = VectorPotentialField(vfield_A, params=vfield_params) + np.testing.assert_allclose(tfield([1, 0, 0]), [[0, 2, 0]]) np.testing.assert_allclose((4 * tfield)([2, 0, 0]), [[0, 4, 0]]) + np.testing.assert_allclose(tfield_from_A([1, 0, 0]), [[0, 2, 0]]) + np.testing.assert_allclose( + tfield_A(1, 0, 0), + tfield_from_A.compute_magnetic_vector_potential([1, 0, 0]).squeeze(), + ) + np.testing.assert_allclose( + vfield_A(1, 0, 0, **vfield_params), + vfield_from_A.compute_magnetic_vector_potential([1, 0, 0]), + ) + np.testing.assert_allclose((tfield + vfield)([1, 0, 0]), [[0, 2, 1]]) np.testing.assert_allclose( (tfield + vfield - pfield)([1, 0, 0.1]), [[0.4, 2, 1]] @@ -104,17 +140,40 @@ def test_combined_fields(self): assert scaled_field.B0 == 2 assert scaled_field.scale == 3.1 np.testing.assert_allclose(scaled_field([1.0, 0, 0]), np.array([[0, 6.2, 0]])) + np.testing.assert_allclose( + scaled_field.compute_magnetic_vector_potential([2.0, 0, 0]), + np.array([[0, 0, -3.1 * 2 * 1 * np.log(2)]]), + ) + scaled_field.R0 = 1.3 scaled_field.scale = 1.0 np.testing.assert_allclose(scaled_field([1.3, 0, 0]), np.array([[0, 2, 0]])) + np.testing.assert_allclose( + scaled_field.compute_magnetic_vector_potential([2.0, 0, 0]), + np.array([[0, 0, -2 * 1.3 * np.log(2)]]), + ) assert scaled_field.optimizable_params == ["B0", "R0", "scale"] assert hasattr(scaled_field, "B0") sum_field = vfield + pfield + tfield + sum_field_tv = vfield + tfield # to test A since pfield does not have A assert len(sum_field) == 3 + assert len(sum_field_tv) == 2 + np.testing.assert_allclose( sum_field([1.3, 0, 0.0]), [[0.0, 2, 3.2 + 2 * 1.2 * 0.3]] ) + + tfield_A = np.array([[0, 0, -tfield.B0 * tfield.R0 * np.log(tfield.R0)]]) + x = tfield.R0 * np.cos(np.pi / 4) + y = tfield.R0 * np.sin(np.pi / 4) + vfield_A = np.array([[vfield.B0 * y, -vfield.B0 * x, 0]]) / 2 + + np.testing.assert_allclose( + sum_field_tv.compute_magnetic_vector_potential([x, y, 0.0], basis="xyz"), + tfield_A + vfield_A, + ) + assert sum_field.optimizable_params == [ ["B0"], ["B0", "R0", "iota"], @@ -304,6 +363,87 @@ def test_current_potential_field(self): with pytest.raises(AssertionError): field.potential_dzeta = 1 + @pytest.mark.unit + def test_current_potential_vector_potential(self): + """Test current potential field vector potential against analytic result.""" + R0 = 10 + a = 1 + surface = FourierRZToroidalSurface( + R_lmn=jnp.array([R0, a]), + Z_lmn=jnp.array([0, -a]), + modes_R=jnp.array([[0, 0], [1, 0]]), + modes_Z=jnp.array([[0, 0], [-1, 0]]), + NFP=10, + ) + # make a current potential corresponding a purely poloidal current + G = 100 # net poloidal current + potential = lambda theta, zeta, G: G * zeta / 2 / jnp.pi + potential_dtheta = lambda theta, zeta, G: jnp.zeros_like(theta) + potential_dzeta = lambda theta, zeta, G: G * jnp.ones_like(theta) / 2 / jnp.pi + + params = {"G": -G} + + field = CurrentPotentialField( + potential, + R_lmn=surface.R_lmn, + Z_lmn=surface.Z_lmn, + modes_R=surface._R_basis.modes[:, 1:], + modes_Z=surface._Z_basis.modes[:, 1:], + params=params, + potential_dtheta=potential_dtheta, + potential_dzeta=potential_dzeta, + NFP=surface.NFP, + ) + # test the loop integral of A around a curve encompassing the torus + # against the analytic result for flux in an ideal toroidal solenoid + prefactors = mu_0 * G / 2 / jnp.pi + correct_flux = -2 * np.pi * prefactors * (np.sqrt(R0**2 - a**2) - R0) + + curve = FourierXYZCurve() # curve to integrate A over + curve_grid = LinearGrid(zeta=20) + curve_data = curve.compute(["x", "x_s"], grid=curve_grid, basis="xyz") + curve_data_rpz = curve.compute(["x", "x_s"], grid=curve_grid, basis="rpz") + + surface_grid = LinearGrid(M=60, N=60, NFP=10) + + A_xyz = field.compute_magnetic_vector_potential( + curve_data["x"], basis="xyz", source_grid=surface_grid + ) + A_rpz = field.compute_magnetic_vector_potential( + curve_data_rpz["x"], basis="rpz", source_grid=surface_grid + ) + + # integrate to get the flux + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose(correct_flux, flux_xyz, rtol=1e-8) + np.testing.assert_allclose(correct_flux, flux_rpz, rtol=1e-8) + + field.params["G"] = -2 * field.params["G"] + + A_xyz = field.compute_magnetic_vector_potential( + curve_data["x"], basis="xyz", source_grid=surface_grid + ) + A_rpz = field.compute_magnetic_vector_potential( + curve_data_rpz["x"], basis="rpz", source_grid=surface_grid + ) + + # integrate to get the flux + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose(-2 * correct_flux, flux_xyz, rtol=1e-8) + np.testing.assert_allclose(-2 * correct_flux, flux_rpz, rtol=1e-8) + @pytest.mark.unit def test_fourier_current_potential_field(self): """Test Fourier current potential magnetic field against analytic result.""" @@ -416,6 +556,124 @@ def test_fourier_current_potential_field(self): atol=1e-16, ) + @pytest.mark.unit + def test_fourier_current_potential_vector_potential(self): + """Test Fourier current potential vector potential against analytic result.""" + R0 = 10 + a = 1 + surface = FourierRZToroidalSurface( + R_lmn=jnp.array([R0, a]), + Z_lmn=jnp.array([0, -a]), + modes_R=jnp.array([[0, 0], [1, 0]]), + modes_Z=jnp.array([[0, 0], [-1, 0]]), + NFP=10, + ) + + basis = DoubleFourierSeries(M=2, N=2, sym="sin") + phi_mn = np.ones((basis.num_modes,)) + # make a current potential corresponding a purely poloidal current + G = 100 # net poloidal current + + # test the loop integral of A around a curve encompassing the torus + # against the analytic result for flux in an ideal toroidal solenoid + ## expression for flux inside of toroidal solenoid of radius a + prefactors = mu_0 * G / 2 / jnp.pi + correct_flux = -2 * np.pi * prefactors * (np.sqrt(R0**2 - a**2) - R0) + + curve = FourierXYZCurve() # curve to integrate A over + curve_grid = LinearGrid(zeta=20) + curve_data = curve.compute(["x", "x_s"], grid=curve_grid) + curve_data_rpz = curve.compute(["x", "x_s"], grid=curve_grid, basis="rpz") + + field = FourierCurrentPotentialField( + Phi_mn=phi_mn, + modes_Phi=basis.modes[:, 1:], + I=0, + G=-G, # to get a positive B_phi, we must put G negative + # since -G is the net poloidal current on the surface + # ( with G=-(net_current) meaning that we have net_current + # flowing poloidally (in clockwise direction) around torus) + sym_Phi="sin", + R_lmn=surface.R_lmn, + Z_lmn=surface.Z_lmn, + modes_R=surface._R_basis.modes[:, 1:], + modes_Z=surface._Z_basis.modes[:, 1:], + NFP=10, + ) + surface_grid = LinearGrid(M=60, N=60, NFP=10) + + phi_mn = np.zeros((basis.num_modes,)) + + field.Phi_mn = phi_mn + + field.change_resolution(3, 3) + field.change_Phi_resolution(2, 2) + + A_xyz = field.compute_magnetic_vector_potential( + curve_data["x"], basis="xyz", source_grid=surface_grid + ) + A_rpz = field.compute_magnetic_vector_potential( + curve_data_rpz["x"], basis="rpz", source_grid=surface_grid + ) + + # integrate to get the flux + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose(correct_flux, flux_xyz, rtol=1e-8) + np.testing.assert_allclose(correct_flux, flux_rpz, rtol=1e-8) + + field.G = -2 * field.G + field.I = 0 + + A_xyz = field.compute_magnetic_vector_potential( + curve_data["x"], basis="xyz", source_grid=surface_grid + ) + A_rpz = field.compute_magnetic_vector_potential( + curve_data_rpz["x"], basis="rpz", source_grid=surface_grid + ) + + # integrate to get the flux + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose(-2 * correct_flux, flux_xyz, rtol=1e-8) + np.testing.assert_allclose(-2 * correct_flux, flux_rpz, rtol=1e-8) + + field = FourierCurrentPotentialField.from_surface( + surface=surface, + Phi_mn=phi_mn, + modes_Phi=basis.modes[:, 1:], + I=0, + G=-G, + ) + + A_xyz = field.compute_magnetic_vector_potential( + curve_data["x"], basis="xyz", source_grid=surface_grid + ) + A_rpz = field.compute_magnetic_vector_potential( + curve_data_rpz["x"], basis="rpz", source_grid=surface_grid + ) + + # integrate to get the flux + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose(correct_flux, flux_xyz, rtol=1e-8) + np.testing.assert_allclose(correct_flux, flux_rpz, rtol=1e-8) + @pytest.mark.unit def test_fourier_current_potential_field_symmetry(self): """Test Fourier current potential magnetic field Phi symmetry logic.""" @@ -644,7 +902,7 @@ def test_init_Phi_mn_fourier_current_field(self): @pytest.mark.slow @pytest.mark.unit - def test_spline_field(self): + def test_spline_field(self, tmpdir_factory): """Test accuracy of spline magnetic field.""" field1 = ScalarPotentialField(phi_lm, args) R = np.linspace(0.5, 1.5, 20) @@ -659,10 +917,65 @@ def test_spline_field(self): extcur = [4700.0, 1000.0] mgrid = "tests/inputs/mgrid_test.nc" field3 = SplineMagneticField.from_mgrid(mgrid, extcur) + # test saving and loading from mgrid + tmpdir = tmpdir_factory.mktemp("spline_mgrid_with_A") + path = tmpdir.join("spline_mgrid_with_A.nc") + field3.save_mgrid( + path, + Rmin=np.min(field3._R), + Rmax=np.max(field3._R), + Zmin=np.min(field3._Z), + Zmax=np.max(field3._Z), + nR=field3._R.size, + nZ=field3._Z.size, + nphi=field3._phi.size, + ) + # no need for extcur b/c is saved in "raw" format, no need to scale again + field4 = SplineMagneticField.from_mgrid(path) + attrs_4d = ["_AR", "_Aphi", "_AZ", "_BR", "_Bphi", "_BZ"] + for attr in attrs_4d: + np.testing.assert_allclose( + (getattr(field3, attr) * np.array(extcur)).sum(axis=-1), + getattr(field4, attr).squeeze(), + err_msg=attr, + ) + attrs_3d = ["_R", "_phi", "_Z"] + for attr in attrs_3d: + np.testing.assert_allclose(getattr(field3, attr), getattr(field4, attr)) + + r = 0.70 + p = 0 + z = 0 + # use finite diff derivatives to check A accuracy + tfield_A = lambda R, phi, Z: field3.compute_magnetic_vector_potential( + jnp.vstack([R, phi, Z]).T + ) + funR = lambda x: tfield_A(x, p, z) + funP = lambda x: tfield_A(r, x, z) + funZ = lambda x: tfield_A(r, p, x) + + ap = tfield_A(r, p, z)[:, 1] + + # these are the gradients of each component of A + dAdr = Derivative.compute_jvp(funR, 0, (jnp.ones_like(r),), r) + dAdp = Derivative.compute_jvp(funP, 0, (jnp.ones_like(p),), p) + dAdz = Derivative.compute_jvp(funZ, 0, (jnp.ones_like(z),), z) + + # form the B components with the appropriate combinations + B2 = jnp.array( + [ + dAdp[:, 2] / r - dAdz[:, 1], + dAdz[:, 0] - dAdr[:, 2], + dAdr[:, 1] + (ap - dAdp[:, 0]) / r, + ] + ).T np.testing.assert_allclose( field3([0.70, 0, 0]), np.array([[0, -0.671, 0.0858]]), rtol=1e-3, atol=1e-8 ) + + np.testing.assert_allclose(field3([0.70, 0, 0]), B2, rtol=1e-3, atol=5e-3) + field3.currents *= 2 np.testing.assert_allclose( field3([0.70, 0, 0]), @@ -697,14 +1010,20 @@ def test_spline_field_axisym(self): -2.430716e04, -2.380229e04, ] - field = SplineMagneticField.from_mgrid( - "tests/inputs/mgrid_d3d.nc", extcur=extcur - ) + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + field = SplineMagneticField.from_mgrid( + "tests/inputs/mgrid_d3d.nc", extcur=extcur + ) # make sure field is invariant to shift in phi B1 = field.compute_magnetic_field(np.array([1.75, 0.0, 0.0])) B2 = field.compute_magnetic_field(np.array([1.75, 1.0, 0.0])) np.testing.assert_allclose(B1, B2) + # test the error when no vec pot values exist + with pytest.raises(ValueError, match="no vector potential"): + field.compute_magnetic_vector_potential(np.array([1.75, 0.0, 0.0])) + @pytest.mark.unit def test_field_line_integrate(self): """Test field line integration.""" @@ -842,8 +1161,15 @@ def test_mgrid_io(self, tmpdir_factory): Rmax = 7 Zmin = -2 Zmax = 2 - save_field.save_mgrid(path, Rmin, Rmax, Zmin, Zmax) - load_field = SplineMagneticField.from_mgrid(path) + with pytest.raises(NotImplementedError): + # Raises error because poloidal field has no vector potential + # and so cannot save the vector potential + save_field.save_mgrid(path, Rmin, Rmax, Zmin, Zmax) + save_field.save_mgrid(path, Rmin, Rmax, Zmin, Zmax, save_vector_potential=False) + with pytest.warns(UserWarning): + # user warning because saved mgrid has no vector potential + # and so cannot load the vector potential + load_field = SplineMagneticField.from_mgrid(path) # check that the fields are the same num_nodes = 50 diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 82f0dd337a..4711907ab9 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -24,12 +24,13 @@ from desc.compute import get_transforms from desc.equilibrium import Equilibrium from desc.examples import get -from desc.geometry import FourierRZToroidalSurface, FourierXYZCurve +from desc.geometry import FourierPlanarCurve, FourierRZToroidalSurface, FourierXYZCurve from desc.grid import ConcentricGrid, LinearGrid, QuadratureGrid from desc.io import load from desc.magnetic_fields import ( FourierCurrentPotentialField, OmnigenousField, + PoloidalMagneticField, SplineMagneticField, ToroidalMagneticField, VerticalMagneticField, @@ -367,6 +368,51 @@ def test_qh_boozer(self): # should have the same values up until then np.testing.assert_allclose(f[idx_f][:120], B_mn[idx_B][:120]) + @pytest.mark.unit + def test_qh_boozer_multiple_surfaces(self): + """Test for computing Boozer error on multiple surfaces.""" + eq = get("WISTELL-A") # WISTELL-A is optimized for QH symmetry + helicity = (1, -eq.NFP) + M_booz = eq.M + N_booz = eq.N + grid1 = LinearGrid(rho=0.5, M=2 * eq.M, N=2 * eq.N, NFP=eq.NFP, sym=False) + grid2 = LinearGrid(rho=1.0, M=2 * eq.M, N=2 * eq.N, NFP=eq.NFP, sym=False) + grid3 = LinearGrid( + rho=np.array([0.5, 1.0]), M=2 * eq.M, N=2 * eq.N, NFP=eq.NFP, sym=False + ) + + obj1 = QuasisymmetryBoozer( + helicity=helicity, + M_booz=M_booz, + N_booz=N_booz, + grid=grid1, + normalize=False, + eq=eq, + ) + obj2 = QuasisymmetryBoozer( + helicity=helicity, + M_booz=M_booz, + N_booz=N_booz, + grid=grid2, + normalize=False, + eq=eq, + ) + obj3 = QuasisymmetryBoozer( + helicity=helicity, + M_booz=M_booz, + N_booz=N_booz, + grid=grid3, + normalize=False, + eq=eq, + ) + obj1.build() + obj2.build() + obj3.build() + f1 = obj1.compute_unscaled(*obj1.xs(eq)) + f2 = obj2.compute_unscaled(*obj2.xs(eq)) + f3 = obj3.compute_unscaled(*obj3.xs(eq)) + np.testing.assert_allclose(f3, np.concatenate([f1, f2]), atol=1e-14) + @pytest.mark.unit def test_qs_twoterm(self): """Test calculation of two term QS metric.""" @@ -441,11 +487,6 @@ def test_qs_boozer_grids(self): with pytest.raises(ValueError): QuasisymmetryBoozer(eq=eq, grid=grid).build() - # multiple flux surfaces - grid = LinearGrid(M=eq.M, N=eq.N, NFP=eq.NFP, rho=[0.25, 0.5, 0.75, 1]) - with pytest.raises(ValueError): - QuasisymmetryBoozer(eq=eq, grid=grid).build() - @pytest.mark.unit def test_mercier_stability(self): """Test calculation of mercier stability criteria.""" @@ -869,6 +910,13 @@ def test(coil, grid=None): test(mixed_coils) test(nested_coils, grid=grid) + def test_coil_type_error(self): + """Tests error when objective is not passed a coil.""" + curve = FourierPlanarCurve(r_n=2, basis="rpz") + obj = CoilLength(curve) + with pytest.raises(TypeError): + obj.build() + @pytest.mark.unit def test_coil_min_distance(self): """Tests minimum distance between coils in a coilset.""" @@ -1113,10 +1161,14 @@ def test_quadratic_flux(self): @pytest.mark.unit def test_toroidal_flux(self): """Test calculation of toroidal flux from coils.""" - grid1 = LinearGrid(L=10, M=10, zeta=np.array(0.0)) + grid1 = LinearGrid(L=0, M=40, zeta=np.array(0.0)) def test(eq, field, correct_value, rtol=1e-14, grid=None): - obj = ToroidalFlux(eq=eq, field=field, eval_grid=grid) + obj = ToroidalFlux( + eq=eq, + field=field, + eval_grid=grid, + ) obj.build(verbose=2) torflux = obj.compute_unscaled(*obj.xs(field)) np.testing.assert_allclose(torflux, correct_value, rtol=rtol) @@ -1126,22 +1178,20 @@ def test(eq, field, correct_value, rtol=1e-14, grid=None): field = ToroidalMagneticField(B0=1, R0=1) # calc field Psi - data = eq.compute(["R", "phi", "Z", "|e_rho x e_theta|", "n_zeta"], grid=grid1) - field_B = field.compute_magnetic_field( + data = eq.compute(["R", "phi", "Z", "e_theta"], grid=grid1) + field_A = field.compute_magnetic_vector_potential( np.vstack([data["R"], data["phi"], data["Z"]]).T ) - B_dot_n_zeta = jnp.sum(field_B * data["n_zeta"], axis=1) + A_dot_e_theta = jnp.sum(field_A * data["e_theta"], axis=1) - psi_from_field = np.sum( - grid1.spacing[:, 0] - * grid1.spacing[:, 1] - * data["|e_rho x e_theta|"] - * B_dot_n_zeta - ) - eq.change_resolution(L_grid=10, M_grid=10) + psi_from_field = np.sum(grid1.spacing[:, 1] * A_dot_e_theta) + eq.change_resolution(L_grid=20, M_grid=20) test(eq, field, psi_from_field) + test(eq, field, psi_from_field, rtol=1e-3) + # test on field with no vector potential + test(eq, PoloidalMagneticField(1, 1, 1), 0.0) @pytest.mark.unit def test_signed_plasma_vessel_distance(self): @@ -2228,7 +2278,9 @@ def test_compute_scalar_resolution_heating_power(self): @pytest.mark.regression def test_compute_scalar_resolution_boundary_error(self): """BoundaryError.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2254,7 +2306,9 @@ def test_compute_scalar_resolution_boundary_error(self): @pytest.mark.regression def test_compute_scalar_resolution_vacuum_boundary_error(self): """VacuumBoundaryError.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2281,7 +2335,8 @@ def test_compute_scalar_resolution_vacuum_boundary_error(self): @pytest.mark.regression def test_compute_scalar_resolution_quadratic_flux(self): """VacuumBoundaryError.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2305,7 +2360,25 @@ def test_compute_scalar_resolution_quadratic_flux(self): np.testing.assert_allclose(f, f[-1], rtol=5e-2) @pytest.mark.regression - def test_compute_scalar_resolution_toroidal_flux(self): + def test_compute_scalar_resolution_toroidal_flux_A(self): + """ToroidalFlux.""" + ext_field = ToroidalMagneticField(1, 1) + eq = get("precise_QA") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(4, 4, 4, 8, 8, 8) + + f = np.zeros_like(self.res_array, dtype=float) + for i, res in enumerate(self.res_array): + eq.change_resolution( + L_grid=int(eq.L * res), M_grid=int(eq.M * res), N_grid=int(eq.N * res) + ) + obj = ObjectiveFunction(ToroidalFlux(eq, ext_field), use_jit=False) + obj.build(verbose=0) + f[i] = obj.compute_scalar(obj.x()) + np.testing.assert_allclose(f, f[-1], rtol=5e-2) + + @pytest.mark.regression + def test_compute_scalar_resolution_toroidal_flux_B(self): """ToroidalFlux.""" ext_field = ToroidalMagneticField(1, 1) eq = get("precise_QA") @@ -2579,7 +2652,9 @@ def test_objective_no_nangrad_heating_power(self): @pytest.mark.unit def test_objective_no_nangrad_boundary_error(self): """BoundaryError.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2600,7 +2675,9 @@ def test_objective_no_nangrad_boundary_error(self): @pytest.mark.unit def test_objective_no_nangrad_vacuum_boundary_error(self): """VacuumBoundaryError.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2623,7 +2700,9 @@ def test_objective_no_nangrad_vacuum_boundary_error(self): @pytest.mark.unit def test_objective_no_nangrad_quadratic_flux(self): """QuadraticFlux.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2654,7 +2733,12 @@ def test_objective_no_nangrad_toroidal_flux(self): obj = ObjectiveFunction(ToroidalFlux(eq, ext_field), use_jit=False) obj.build() g = obj.grad(obj.x(ext_field)) - assert not np.any(np.isnan(g)), "toroidal flux" + assert not np.any(np.isnan(g)), "toroidal flux A" + + obj = ObjectiveFunction(ToroidalFlux(eq, ext_field), use_jit=False) + obj.build() + g = obj.grad(obj.x(ext_field)) + assert not np.any(np.isnan(g)), "toroidal flux B" @pytest.mark.unit @pytest.mark.parametrize( diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 16b89c1543..6048a7d601 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -513,7 +513,14 @@ def test_plot_boundaries(self): eq1 = get("SOLOVEV") eq2 = get("DSHAPE") eq3 = get("W7-X") - fig, ax, data = plot_boundaries((eq1, eq2, eq3), return_data=True) + eq4 = get("ESTELL") + with pytest.raises(ValueError, match="differing field periods"): + fig, ax = plot_boundaries([eq3, eq4], theta=0) + fig, ax, data = plot_boundaries( + (eq1, eq2, eq3), + phi=np.linspace(0, 2 * np.pi / eq3.NFP, 4, endpoint=False), + return_data=True, + ) assert "R" in data.keys() assert "Z" in data.keys() assert len(data["R"]) == 3 @@ -550,6 +557,22 @@ def test_plot_comparison_no_theta(self): fig, ax = plot_comparison(eqf, theta=0) return fig + @pytest.mark.unit + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_2d) + def test_plot_comparison_different_NFPs(self): + """Test plotting comparison of flux surfaces with differing NFPs.""" + eq = get("SOLOVEV") + eq_nonax = get("HELIOTRON") + eq_nonax2 = get("ESTELL") + with pytest.raises(ValueError, match="differing field periods"): + fig, ax = plot_comparison([eq_nonax, eq_nonax2], theta=0) + fig, ax = plot_comparison( + [eq, eq_nonax], + phi=np.linspace(0, 2 * np.pi / eq_nonax.NFP, 6, endpoint=False), + theta=0, + ) + return fig + class TestPlotGrid: """Tests for the plot_grid function.""" diff --git a/tests/test_quad_utils.py b/tests/test_quad_utils.py index a23b81c8d8..5a7c3d00e7 100644 --- a/tests/test_quad_utils.py +++ b/tests/test_quad_utils.py @@ -2,7 +2,9 @@ import numpy as np import pytest +from jax import grad +from desc.backend import jnp from desc.integrals.quad_utils import ( automorphism_arcsin, automorphism_sin, @@ -20,7 +22,7 @@ @pytest.mark.unit def test_composite_linspace(): - """Test this utility function useful for Newton-Cotes integration over pitch.""" + """Test this utility function which is used for integration over pitch.""" B_min_tz = np.array([0.1, 0.2]) B_max_tz = np.array([1, 3]) breaks = np.linspace(B_min_tz, B_max_tz, num=5) @@ -91,3 +93,11 @@ def test_leggauss_lobatto(): np.testing.assert_allclose(x, [-1, -np.sqrt(3 / 7), 0, np.sqrt(3 / 7), 1]) np.testing.assert_allclose(w, [1 / 10, 49 / 90, 32 / 45, 49 / 90, 1 / 10]) np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1])) + + def fun(a): + x, w = leggauss_lob(a.size) + return jnp.dot(x * a, w) + + # make sure differentiable + # https://github.com/PlasmaControl/DESC/pull/854#discussion_r1733323161 + assert np.isfinite(grad(fun)(jnp.arange(10) * np.pi)).all() diff --git a/tests/test_vmec.py b/tests/test_vmec.py index d7ae22f2b4..0fef594b3c 100644 --- a/tests/test_vmec.py +++ b/tests/test_vmec.py @@ -368,14 +368,6 @@ def test_axis_surf_after_load(): f.close() -@pytest.mark.unit -def test_vmec_save_asym(TmpDir): - """Tests that saving a non-symmetric equilibrium runs without errors.""" - output_path = str(TmpDir.join("output.nc")) - eq = Equilibrium(L=2, M=2, N=2, NFP=3, pressure=np.array([[2, 0]]), sym=False) - VMECIO.save(eq, output_path) - - @pytest.mark.unit def test_vmec_save_kinetic(TmpDir): """Tests that saving an equilibrium with kinetic profiles runs without errors.""" @@ -874,6 +866,369 @@ def test_vmec_save_2(VMEC_save): np.testing.assert_allclose(currv_vmec, currv_desc, rtol=1e-2) +@pytest.mark.regression +@pytest.mark.slow +def test_vmec_save_asym(VMEC_save_asym): + """Tests that saving in NetCDF format agrees with VMEC.""" + vmec, desc, eq = VMEC_save_asym + # first, compare some quantities which don't require calculation + assert vmec.variables["version_"][:] == desc.variables["version_"][:] + assert vmec.variables["mgrid_mode"][:] == desc.variables["mgrid_mode"][:] + assert np.all( + np.char.compare_chararrays( + vmec.variables["mgrid_file"][:], + desc.variables["mgrid_file"][:], + "==", + False, + ) + ) + assert vmec.variables["ier_flag"][:] == desc.variables["ier_flag"][:] + assert ( + vmec.variables["lfreeb__logical__"][:] == desc.variables["lfreeb__logical__"][:] + ) + assert ( + vmec.variables["lrecon__logical__"][:] == desc.variables["lrecon__logical__"][:] + ) + assert vmec.variables["lrfp__logical__"][:] == desc.variables["lrfp__logical__"][:] + assert ( + vmec.variables["lasym__logical__"][:] == desc.variables["lasym__logical__"][:] + ) + assert vmec.variables["nfp"][:] == desc.variables["nfp"][:] + assert vmec.variables["ns"][:] == desc.variables["ns"][:] + assert vmec.variables["mpol"][:] == desc.variables["mpol"][:] + assert vmec.variables["ntor"][:] == desc.variables["ntor"][:] + assert vmec.variables["mnmax"][:] == desc.variables["mnmax"][:] + np.testing.assert_allclose(vmec.variables["xm"][:], desc.variables["xm"][:]) + np.testing.assert_allclose(vmec.variables["xn"][:], desc.variables["xn"][:]) + assert vmec.variables["mnmax_nyq"][:] == desc.variables["mnmax_nyq"][:] + np.testing.assert_allclose(vmec.variables["xm_nyq"][:], desc.variables["xm_nyq"][:]) + np.testing.assert_allclose(vmec.variables["xn_nyq"][:], desc.variables["xn_nyq"][:]) + assert vmec.variables["signgs"][:] == desc.variables["signgs"][:] + assert vmec.variables["gamma"][:] == desc.variables["gamma"][:] + assert vmec.variables["nextcur"][:] == desc.variables["nextcur"][:] + assert np.all( + np.char.compare_chararrays( + vmec.variables["pmass_type"][:], + desc.variables["pmass_type"][:], + "==", + False, + ) + ) + assert np.all( + np.char.compare_chararrays( + vmec.variables["piota_type"][:], + desc.variables["piota_type"][:], + "==", + False, + ) + ) + assert np.all( + np.char.compare_chararrays( + vmec.variables["pcurr_type"][:], + desc.variables["pcurr_type"][:], + "==", + False, + ) + ) + np.testing.assert_allclose( + vmec.variables["am"][:], desc.variables["am"][:], atol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["ai"][:], desc.variables["ai"][:], atol=1e-8 + ) + np.testing.assert_allclose( + vmec.variables["ac"][:], desc.variables["ac"][:], atol=3e-5 + ) + np.testing.assert_allclose( + vmec.variables["presf"][:], desc.variables["presf"][:], atol=2e-5 + ) + np.testing.assert_allclose(vmec.variables["pres"][:], desc.variables["pres"][:]) + np.testing.assert_allclose(vmec.variables["mass"][:], desc.variables["mass"][:]) + np.testing.assert_allclose( + vmec.variables["iotaf"][:], desc.variables["iotaf"][:], rtol=5e-4 + ) + np.testing.assert_allclose( + vmec.variables["q_factor"][:], desc.variables["q_factor"][:], rtol=5e-4 + ) + np.testing.assert_allclose( + vmec.variables["iotas"][:], desc.variables["iotas"][:], rtol=5e-4 + ) + np.testing.assert_allclose(vmec.variables["phi"][:], desc.variables["phi"][:]) + np.testing.assert_allclose(vmec.variables["phipf"][:], desc.variables["phipf"][:]) + np.testing.assert_allclose(vmec.variables["phips"][:], desc.variables["phips"][:]) + np.testing.assert_allclose( + vmec.variables["chi"][:], desc.variables["chi"][:], atol=3e-5, rtol=1e-3 + ) + np.testing.assert_allclose( + vmec.variables["chipf"][:], desc.variables["chipf"][:], atol=3e-5, rtol=1e-3 + ) + np.testing.assert_allclose( + vmec.variables["Rmajor_p"][:], desc.variables["Rmajor_p"][:] + ) + np.testing.assert_allclose( + vmec.variables["Aminor_p"][:], desc.variables["Aminor_p"][:] + ) + np.testing.assert_allclose(vmec.variables["aspect"][:], desc.variables["aspect"][:]) + np.testing.assert_allclose( + vmec.variables["volume_p"][:], desc.variables["volume_p"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["volavgB"][:], desc.variables["volavgB"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["betatotal"][:], desc.variables["betatotal"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["betapol"][:], desc.variables["betapol"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["betator"][:], desc.variables["betator"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["ctor"][:], + desc.variables["ctor"][:], + atol=1e-9, # it is a zero current solve + ) + np.testing.assert_allclose( + vmec.variables["rbtor"][:], desc.variables["rbtor"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["rbtor0"][:], desc.variables["rbtor0"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["b0"][:], desc.variables["b0"][:], rtol=4e-3 + ) + np.testing.assert_allclose( + vmec.variables["buco"][20:100], desc.variables["buco"][20:100], atol=1e-15 + ) + np.testing.assert_allclose( + vmec.variables["bvco"][20:100], desc.variables["bvco"][20:100], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["vp"][20:100], desc.variables["vp"][20:100], rtol=3e-4 + ) + np.testing.assert_allclose( + vmec.variables["bdotb"][20:100], desc.variables["bdotb"][20:100], rtol=3e-4 + ) + np.testing.assert_allclose( + vmec.variables["jdotb"][20:100], + desc.variables["jdotb"][20:100], + atol=4e-3, # nearly zero bc is vacuum + ) + np.testing.assert_allclose( + vmec.variables["jcuru"][20:100], desc.variables["jcuru"][20:100], atol=2 + ) + np.testing.assert_allclose( + vmec.variables["jcurv"][20:100], desc.variables["jcurv"][20:100], rtol=2 + ) + np.testing.assert_allclose( + vmec.variables["DShear"][20:100], desc.variables["DShear"][20:100], rtol=3e-2 + ) + np.testing.assert_allclose( + vmec.variables["DCurr"][20:100], + desc.variables["DCurr"][20:100], + atol=1e-4, # nearly zero bc vacuum + ) + np.testing.assert_allclose( + vmec.variables["DWell"][20:100], desc.variables["DWell"][20:100], rtol=1e-2 + ) + np.testing.assert_allclose( + vmec.variables["DGeod"][20:100], + desc.variables["DGeod"][20:100], + atol=4e-3, + rtol=1e-2, + ) + + # the Mercier stability is pretty off, + # but these are not exactly similar solutions to eachother + np.testing.assert_allclose( + vmec.variables["DMerc"][20:100], desc.variables["DMerc"][20:100], atol=4e-3 + ) + np.testing.assert_allclose( + vmec.variables["raxis_cc"][:], + desc.variables["raxis_cc"][:], + rtol=5e-5, + atol=4e-3, + ) + np.testing.assert_allclose( + vmec.variables["zaxis_cs"][:], + desc.variables["zaxis_cs"][:], + rtol=5e-5, + atol=1e-3, + ) + np.testing.assert_allclose( + vmec.variables["rmin_surf"][:], desc.variables["rmin_surf"][:], rtol=5e-3 + ) + np.testing.assert_allclose( + vmec.variables["rmax_surf"][:], desc.variables["rmax_surf"][:], rtol=5e-3 + ) + np.testing.assert_allclose( + vmec.variables["zmax_surf"][:], desc.variables["zmax_surf"][:], rtol=5e-3 + ) + np.testing.assert_allclose( + vmec.variables["beta_vol"][:], desc.variables["beta_vol"][:], rtol=5e-5 + ) + np.testing.assert_allclose( + vmec.variables["betaxis"][:], desc.variables["betaxis"][:], rtol=5e-5 + ) + # Next, calculate some quantities and compare + # the DESC wout -> DESC (should be very close) + # and the DESC wout -> VMEC wout (should be approximately close) + vol_grid = LinearGrid( + rho=np.sqrt( + abs( + vmec.variables["phi"][:].filled() + / np.max(np.abs(vmec.variables["phi"][:].filled())) + ) + )[10::10], + M=15, + N=15, + NFP=eq.NFP, + axis=False, + sym=False, + ) + bdry_grid = LinearGrid(rho=1.0, M=15, N=15, NFP=eq.NFP, axis=False, sym=False) + + def test( + nc_str, + desc_str, + negate_DESC_quant=False, + use_nyq=True, + convert_sqrt_g_or_B_rho=False, + atol_desc_desc_wout=5e-5, + rtol_desc_desc_wout=1e-5, + atol_vmec_desc_wout=1e-5, + rtol_vmec_desc_wout=1e-2, + grid=vol_grid, + ): + """Helper fxn to evaluate Fourier series from wout and compare to DESC.""" + xm = desc.variables["xm_nyq"][:] if use_nyq else desc.variables["xm"][:] + xn = desc.variables["xn_nyq"][:] if use_nyq else desc.variables["xn"][:] + + si = abs(vmec.variables["phi"][:] / np.max(np.abs(vmec.variables["phi"][:]))) + rho = grid.nodes[:, 0] + s = rho**2 + # some quantities must be negated before comparison bc + # they are negative in the wout i.e. B^theta + negate = -1 if negate_DESC_quant else 1 + + quant_from_desc_wout = VMECIO.vmec_interpolate( + desc.variables[nc_str + "c"][:], + desc.variables[nc_str + "s"][:], + xm, + xn, + theta=-grid.nodes[:, 1], # -theta bc when we save wout we reverse theta + phi=grid.nodes[:, 2], + s=s, + sym=False, + si=si, + ) + + quant_from_vmec_wout = VMECIO.vmec_interpolate( + vmec.variables[nc_str + "c"][:], + vmec.variables[nc_str + "s"][:], + xm, + xn, + # pi - theta bc VMEC, when it gets a CW angle bdry, + # changes poloidal angle to theta -> pi-theta + theta=np.pi - grid.nodes[:, 1], + phi=grid.nodes[:, 2], + s=s, + sym=False, + si=si, + ) + + data = eq.compute(["rho", "sqrt(g)", desc_str], grid=grid) + # convert sqrt(g) or B_rho->B_psi if needed + quant_desc = ( + data[desc_str] / 2 / data["rho"] + if convert_sqrt_g_or_B_rho + else data[desc_str] + ) + + # add sqrt(g) factor if currents being compared + quant_desc = ( + quant_desc * abs(data["sqrt(g)"]) / 2 / data["rho"] + if "J" in desc_str + else quant_desc + ) + + np.testing.assert_allclose( + negate * quant_desc, + quant_from_desc_wout, + atol=atol_desc_desc_wout, + rtol=rtol_desc_desc_wout, + ) + np.testing.assert_allclose( + quant_from_desc_wout, + quant_from_vmec_wout, + atol=atol_vmec_desc_wout, + rtol=rtol_vmec_desc_wout, + ) + + # R & Z & lambda + test("rmn", "R", use_nyq=False) + test("zmn", "Z", use_nyq=False, atol_vmec_desc_wout=4e-2) + + # |B| + test("bmn", "|B|", rtol_desc_desc_wout=7e-4) + + # B^zeta + test("bsupvmn", "B^zeta") # ,rtol_desc_desc_wout=6e-5) + + # B_zeta + test("bsubvmn", "B_zeta", rtol_desc_desc_wout=3e-4) + + # hard to compare to VMEC for the currents, since + # VMEC F error is worse and equilibria are not exactly similar + # just compare back to DESC + test("currumn", "J^theta", atol_vmec_desc_wout=1e4) + test("currvmn", "J^zeta", negate_DESC_quant=True, atol_vmec_desc_wout=1e5) + + # can only compare lambda, sqrt(g) B_psi B^theta and B_theta at bdry + test( + "lmn", + "lambda", + use_nyq=False, + negate_DESC_quant=True, + grid=bdry_grid, + atol_desc_desc_wout=4e-4, + atol_vmec_desc_wout=5e-2, + ) + test( + "gmn", + "sqrt(g)", + convert_sqrt_g_or_B_rho=True, + negate_DESC_quant=True, + grid=bdry_grid, + rtol_desc_desc_wout=5e-4, + rtol_vmec_desc_wout=4e-2, + ) + test( + "bsupumn", + "B^theta", + negate_DESC_quant=True, + grid=bdry_grid, + atol_vmec_desc_wout=6e-4, + ) + test( + "bsubumn", + "B_theta", + negate_DESC_quant=True, + grid=bdry_grid, + atol_desc_desc_wout=1e-4, + atol_vmec_desc_wout=4e-4, + ) + test( + "bsubsmn", + "B_rho", + grid=bdry_grid, + convert_sqrt_g_or_B_rho=True, + rtol_vmec_desc_wout=6e-2, + atol_vmec_desc_wout=9e-3, + ) + + @pytest.mark.unit @pytest.mark.mpl_image_compare(tolerance=1) def test_plot_vmec_comparison():