Skip to content

Commit

Permalink
Merge pull request #238 from astro-informatics/map/risbo-precompute-t…
Browse files Browse the repository at this point in the history
…ransform-memeff

add stable forward/inverse memory efficient Wigner transforms
  • Loading branch information
CosmoMatt authored Oct 24, 2024
2 parents 11f76bf + 6f64ebb commit 5d1d13f
Show file tree
Hide file tree
Showing 7 changed files with 575 additions and 3 deletions.
7 changes: 7 additions & 0 deletions docs/api/precompute_transforms/fourier_wigner.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:html_theme.sidebar_secondary.remove:

**************************
Fourier-Wigner Transform
**************************
.. automodule:: s2fft.precompute_transforms.fourier_wigner
:members:
20 changes: 20 additions & 0 deletions docs/api/precompute_transforms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ Precompute Functions
* - :func:`~s2fft.precompute_transforms.wigner.forward_transform_torch`
- Forward Wigner transform (Torch)

.. list-table:: Fourier-Wigner transforms.
:widths: 25 25
:header-rows: 1

* - Function Name
- Description
* - :func:`~s2fft.precompute_transforms.fourier_wigner.inverse_transform`
- Inverse Wigner transform with Fourier method (NumPy)
* - :func:`~s2fft.precompute_transforms.fourier_wigner.inverse_transform_jax`
- Inverse Wigner transform with Fourier method (JAX)
* - :func:`~s2fft.precompute_transforms.fourier_wigner.forward_transform`
- Forward Wigner transform with Fourier method (NumPy)
* - :func:`~s2fft.precompute_transforms.fourier_wigner.forward_transform_jax`
- Forward Wigner transform with Fourier method (JAX)

.. list-table:: Constructing Kernels for precompute transforms.
:widths: 25 25
:header-rows: 1
Expand All @@ -64,6 +79,10 @@ Precompute Functions
- Builds a kernel including quadrature weights and Wigner-D coefficients for spherical harmonic transform (JAX).
* - :func:`~s2fft.precompute_transforms.construct.wigner_kernel_jax`
- Builds a kernel including quadrature weights and Wigner-D coefficients for Wigner transform (JAX).
* - :func:`~s2fft.precompute_transforms.construct.fourier_wigner_kernel`
- Builds a kernel including quadrature weights and Fourier coefficienfs of Wigner d-functions
* - :func:`~s2fft.precompute_transforms.construct.fourier_wigner_kernel_jax`
- Builds a kernel including quadrature weights and Fourier coefficienfs of Wigner d-functions (JAX).
* - :func:`~s2fft.precompute_transforms.construct.healpix_phase_shifts`
- Builds a vector of corresponding phase shifts for each HEALPix latitudinal ring.

Expand All @@ -76,4 +95,5 @@ Precompute Functions
alt_construct
spin_spherical
wigner
fourier_wigner

2 changes: 1 addition & 1 deletion s2fft/precompute_transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import construct, spherical, wigner
from . import construct, fourier_wigner, spherical, wigner
57 changes: 57 additions & 0 deletions s2fft/precompute_transforms/construct.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Tuple
from warnings import warn

import jax
Expand Down Expand Up @@ -610,6 +611,62 @@ def wigner_kernel_jax(
return dl


def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
weights upsampled for the forward Fourier-Wigner transform.
Args:
L (int): Harmonic band-limit.
Returns:
Tuple[np.ndarray, np.ndarray]: Tuple of delta Fourier coefficients and weights.
"""
# Calculate deltas (np.pi/2 Fourier coefficients of Wigner matrices)
deltas = np.zeros((L, 2 * L - 1, 2 * L - 1), dtype=np.float64)
d = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
for el in range(L):
d = recursions.risbo.compute_full(d, np.pi / 2, L, el)
deltas[el] = d

# Calculate upsampled quadrature weights
w = np.zeros(4 * L - 3, dtype=np.complex128)
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
w[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm)
w = np.fft.ifft(np.fft.ifftshift(w), norm="forward")

return deltas, w


def fourier_wigner_kernel_jax(L: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
weights upsampled for the forward Fourier-Wigner transform (JAX implementation).
Args:
L (int): Harmonic band-limit.
Returns:
Tuple[jnp.ndarray, jnp.ndarray]: Tuple of delta Fourier coefficients and weights.
"""
# Calculate deltas (np.pi/2 Fourier coefficients of Wigner matrices)
deltas = jnp.zeros((L, 2 * L - 1, 2 * L - 1), dtype=jnp.float64)
d = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
for el in range(L):
d = recursions.risbo_jax.compute_full(d, jnp.pi / 2, L, el)
deltas = deltas.at[el].set(d)

# Calculate upsampled quadrature weights
w = jnp.zeros(4 * L - 3, dtype=jnp.complex128)
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
w = w.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm))
w = jnp.fft.ifft(jnp.fft.ifftshift(w), norm="forward")

return deltas, w


def healpix_phase_shifts(L: int, nside: int, forward: bool = False) -> np.ndarray:
r"""
Generates a phase shift vector for HEALPix for all :math:`\theta` rings.
Expand Down
Loading

0 comments on commit 5d1d13f

Please sign in to comment.