diff --git a/docs/api/precompute_transforms/fourier_wigner.rst b/docs/api/precompute_transforms/fourier_wigner.rst new file mode 100644 index 0000000..a45986b --- /dev/null +++ b/docs/api/precompute_transforms/fourier_wigner.rst @@ -0,0 +1,7 @@ +:html_theme.sidebar_secondary.remove: + +************************** +Fourier-Wigner Transform +************************** +.. automodule:: s2fft.precompute_transforms.fourier_wigner + :members: \ No newline at end of file diff --git a/docs/api/precompute_transforms/index.rst b/docs/api/precompute_transforms/index.rst index 19dbe42..100bd2b 100644 --- a/docs/api/precompute_transforms/index.rst +++ b/docs/api/precompute_transforms/index.rst @@ -50,6 +50,21 @@ Precompute Functions * - :func:`~s2fft.precompute_transforms.wigner.forward_transform_torch` - Forward Wigner transform (Torch) +.. list-table:: Fourier-Wigner transforms. + :widths: 25 25 + :header-rows: 1 + + * - Function Name + - Description + * - :func:`~s2fft.precompute_transforms.fourier_wigner.inverse_transform` + - Inverse Wigner transform with Fourier method (NumPy) + * - :func:`~s2fft.precompute_transforms.fourier_wigner.inverse_transform_jax` + - Inverse Wigner transform with Fourier method (JAX) + * - :func:`~s2fft.precompute_transforms.fourier_wigner.forward_transform` + - Forward Wigner transform with Fourier method (NumPy) + * - :func:`~s2fft.precompute_transforms.fourier_wigner.forward_transform_jax` + - Forward Wigner transform with Fourier method (JAX) + .. list-table:: Constructing Kernels for precompute transforms. :widths: 25 25 :header-rows: 1 @@ -64,6 +79,10 @@ Precompute Functions - Builds a kernel including quadrature weights and Wigner-D coefficients for spherical harmonic transform (JAX). * - :func:`~s2fft.precompute_transforms.construct.wigner_kernel_jax` - Builds a kernel including quadrature weights and Wigner-D coefficients for Wigner transform (JAX). + * - :func:`~s2fft.precompute_transforms.construct.fourier_wigner_kernel` + - Builds a kernel including quadrature weights and Fourier coefficienfs of Wigner d-functions + * - :func:`~s2fft.precompute_transforms.construct.fourier_wigner_kernel_jax` + - Builds a kernel including quadrature weights and Fourier coefficienfs of Wigner d-functions (JAX). * - :func:`~s2fft.precompute_transforms.construct.healpix_phase_shifts` - Builds a vector of corresponding phase shifts for each HEALPix latitudinal ring. @@ -76,4 +95,5 @@ Precompute Functions alt_construct spin_spherical wigner + fourier_wigner diff --git a/s2fft/precompute_transforms/__init__.py b/s2fft/precompute_transforms/__init__.py index 28e3184..78d407c 100644 --- a/s2fft/precompute_transforms/__init__.py +++ b/s2fft/precompute_transforms/__init__.py @@ -1 +1 @@ -from . import construct, spherical, wigner +from . import construct, fourier_wigner, spherical, wigner diff --git a/s2fft/precompute_transforms/construct.py b/s2fft/precompute_transforms/construct.py index 9c07ca1..a74a657 100644 --- a/s2fft/precompute_transforms/construct.py +++ b/s2fft/precompute_transforms/construct.py @@ -1,3 +1,4 @@ +from typing import Tuple from warnings import warn import jax @@ -610,6 +611,62 @@ def wigner_kernel_jax( return dl +def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Computes Fourier coefficients of the reduced Wigner d-functions and quadrature + weights upsampled for the forward Fourier-Wigner transform. + + Args: + L (int): Harmonic band-limit. + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple of delta Fourier coefficients and weights. + + """ + # Calculate deltas (np.pi/2 Fourier coefficients of Wigner matrices) + deltas = np.zeros((L, 2 * L - 1, 2 * L - 1), dtype=np.float64) + d = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64) + for el in range(L): + d = recursions.risbo.compute_full(d, np.pi / 2, L, el) + deltas[el] = d + + # Calculate upsampled quadrature weights + w = np.zeros(4 * L - 3, dtype=np.complex128) + for mm in range(-2 * (L - 1), 2 * (L - 1) + 1): + w[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm) + w = np.fft.ifft(np.fft.ifftshift(w), norm="forward") + + return deltas, w + + +def fourier_wigner_kernel_jax(L: int) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Computes Fourier coefficients of the reduced Wigner d-functions and quadrature + weights upsampled for the forward Fourier-Wigner transform (JAX implementation). + + Args: + L (int): Harmonic band-limit. + + Returns: + Tuple[jnp.ndarray, jnp.ndarray]: Tuple of delta Fourier coefficients and weights. + + """ + # Calculate deltas (np.pi/2 Fourier coefficients of Wigner matrices) + deltas = jnp.zeros((L, 2 * L - 1, 2 * L - 1), dtype=jnp.float64) + d = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64) + for el in range(L): + d = recursions.risbo_jax.compute_full(d, jnp.pi / 2, L, el) + deltas = deltas.at[el].set(d) + + # Calculate upsampled quadrature weights + w = jnp.zeros(4 * L - 3, dtype=jnp.complex128) + for mm in range(-2 * (L - 1), 2 * (L - 1) + 1): + w = w.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm)) + w = jnp.fft.ifft(jnp.fft.ifftshift(w), norm="forward") + + return deltas, w + + def healpix_phase_shifts(L: int, nside: int, forward: bool = False) -> np.ndarray: r""" Generates a phase shift vector for HEALPix for all :math:`\theta` rings. diff --git a/s2fft/precompute_transforms/fourier_wigner.py b/s2fft/precompute_transforms/fourier_wigner.py new file mode 100644 index 0000000..b7ed597 --- /dev/null +++ b/s2fft/precompute_transforms/fourier_wigner.py @@ -0,0 +1,344 @@ +from functools import partial + +import jax.numpy as jnp +import numpy as np +from jax import jit + + +def inverse_transform( + flmn: np.ndarray, + DW: np.ndarray, + L: int, + N: int, + reality: bool = False, + sampling: str = "mw", +) -> np.ndarray: + """ + Computes the inverse Wigner transform using the Fourier decomposition algorithm. + + Args: + flmn (np.ndarray): Wigner coefficients. + DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced + Wigner d-functions and the corresponding upsampled quadrature weights. + L (int): Harmonic band-limit. + N (int): Azimuthal band-limit. + reality (bool, optional): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + Defaults to False. + sampling (str, optional): Sampling scheme. Supported sampling schemes include + {"mw", "mwss"}. Defaults to "mw". + + Returns: + np.ndarray: Pixel-space function sampled on the rotation group. + + """ + if sampling.lower() not in ["mw", "mwss"]: + raise ValueError( + f"Fourier-Wigner algorithm does not support {sampling} sampling." + ) + + # EXTRACT VARIOUS PRECOMPUTES + Delta, _ = DW + + # INDEX VALUES + n_start_ind = N - 1 if reality else 0 + n_dim = N if reality else 2 * N - 1 + m_offset = 1 if sampling.lower() == "mwss" else 0 + ntheta = L + 1 if sampling.lower() == "mwss" else L + theta0 = 0 if sampling.lower() == "mwss" else np.pi / (2 * L - 1) + xnlm_size = 2 * L if sampling.lower() == "mwss" else 2 * L - 1 + + # REUSED ARRAYS + m = np.arange(-L + 1 - m_offset, L) + n = np.arange(n_start_ind - N + 1, N) + + # Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2) + x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype) + x[m_offset:, m_offset:] = np.einsum( + "nlm,lam,lan,l->amn", + flmn[n_start_ind:], + Delta, + Delta[:, :, L - 1 + n], + (2 * np.arange(L) + 1) / (8 * np.pi**2), + ) + + # APPLY SIGN FUNCTION AND PHASE SHIFT + x = np.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), np.exp(1j * m * theta0)) + + # PERFORM FFT OVER BETA, GAMMA, ALPHA + if reality: + x = np.fft.ifftshift(x, axes=(1, 2)) + x = np.fft.ifft(x, axis=1, norm="forward")[:, :ntheta] + x = np.fft.ifft(x, axis=2, norm="forward") + return np.fft.irfft(x, 2 * N - 1, axis=0, norm="forward") + else: + x = np.fft.ifftshift(x) + x = np.fft.ifft(x, axis=1, norm="forward")[:, :ntheta] + return np.fft.ifft2(x, axes=(0, 2), norm="forward") + + +@partial(jit, static_argnums=(2, 3, 4, 5)) +def inverse_transform_jax( + flmn: jnp.ndarray, + DW: jnp.ndarray, + L: int, + N: int, + reality: bool = False, + sampling: str = "mw", +) -> jnp.ndarray: + """ + Computes the inverse Wigner transform using the Fourier decomposition algorithm (JAX). + + Args: + flmn (jnp.ndarray): Wigner coefficients. + DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced + Wigner d-functions and the corresponding upsampled quadrature weights. + L (int): Harmonic band-limit. + N (int): Azimuthal band-limit. + reality (bool, optional): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + Defaults to False. + sampling (str, optional): Sampling scheme. Supported sampling schemes include + {"mw", "mwss"}. Defaults to "mw". + + Returns: + jnp.ndarray: Pixel-space function sampled on the rotation group. + + """ + if sampling.lower() not in ["mw", "mwss"]: + raise ValueError( + f"Fourier-Wigner algorithm does not support {sampling} sampling." + ) + + # EXTRACT VARIOUS PRECOMPUTES + Delta, _ = DW + + # INDEX VALUES + n_start_ind = N - 1 if reality else 0 + n_dim = N if reality else 2 * N - 1 + m_offset = 1 if sampling.lower() == "mwss" else 0 + ntheta = L + 1 if sampling.lower() == "mwss" else L + theta0 = 0 if sampling.lower() == "mwss" else jnp.pi / (2 * L - 1) + xnlm_size = 2 * L if sampling.lower() == "mwss" else 2 * L - 1 + + # REUSED ARRAYS + m = jnp.arange(-L + 1 - m_offset, L) + n = jnp.arange(n_start_ind - N + 1, N) + + # Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2) + x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=jnp.complex128) + flmn = jnp.einsum("nlm,l->nlm", flmn, (2 * jnp.arange(L) + 1) / (8 * jnp.pi**2)) + x = x.at[m_offset:, m_offset:].set( + jnp.einsum( + "nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n] + ) + ) + + # APPLY SIGN FUNCTION AND PHASE SHIFT + x = jnp.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), jnp.exp(1j * m * theta0)) + + # PERFORM FFT OVER BETA, GAMMA, ALPHA + if reality: + x = jnp.fft.ifftshift(x, axes=(1, 2)) + x = jnp.fft.ifft(x, axis=1, norm="forward")[:, :ntheta] + x = jnp.fft.ifft(x, axis=2, norm="forward") + return jnp.fft.irfft(x, 2 * N - 1, axis=0, norm="forward") + else: + x = jnp.fft.ifftshift(x) + x = jnp.fft.ifft(x, axis=1, norm="forward")[:, :ntheta] + return jnp.fft.ifft2(x, axes=(0, 2), norm="forward") + + +def forward_transform( + f: np.ndarray, + DW: np.ndarray, + L: int, + N: int, + reality: bool = False, + sampling: str = "mw", +) -> np.ndarray: + """ + Computes the forward Wigner transform using the Fourier decomposition algorithm. + + Args: + f (np.ndarray): Function sampled on the rotation group. + DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced + Wigner d-functions and the corresponding upsampled quadrature weights. + L (int): Harmonic band-limit. + N (int): Azimuthal band-limit. + reality (bool, optional): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + Defaults to False. + sampling (str, optional): Sampling scheme. Supported sampling schemes include + {"mw", "mwss"}. Defaults to "mw". + + Returns: + np.ndarray: Wigner coefficients of function f. + + """ + if sampling.lower() not in ["mw", "mwss"]: + raise ValueError( + f"Fourier-Wigner algorithm does not support {sampling} sampling." + ) + + # EXTRACT VARIOUS PRECOMPUTES + Delta, Quads = DW + + # INDEX VALUES + n_start_ind = N - 1 if reality else 0 + m_offset = 1 if sampling.lower() == "mwss" else 0 + lpad = (L - 2) if sampling.lower() == "mwss" else (L - 1) + + # REUSED ARRAYS + m = np.arange(-L + 1, L) + n = np.arange(n_start_ind - N + 1, N) + + # COMPUTE ALPHA + GAMMA FFT + if reality: + x = np.fft.rfft(np.real(f), axis=0, norm="forward") + x = np.fft.fft(x, axis=2, norm="forward") + x = np.fft.fftshift(x, axes=2)[:, :, m_offset:] + else: + x = np.fft.fft2(f, axes=(0, 2), norm="forward") + x = np.fft.fftshift(x, axes=(0, 2))[:, :, m_offset:] + + # PERIODICALLY EXTEND BETA FROM [0,pi]->[0,2pi) + temp = np.einsum("ntm,m,n->ntm", x, (-1) ** np.abs(m), (-1) ** np.abs(n)) + x = np.concatenate((x, np.flip(temp, axis=1)[:, 1:L]), axis=1) + + # COMPUTE BETA FFT OVER PERIODICALLY EXTENDED FTM + x = np.fft.fft(x, axis=1, norm="forward") + x = np.fft.fftshift(x, axes=1) + + # APPLY PHASE SHIFT + if sampling.lower() == "mw": + x = np.einsum("nbm,b->nbm", x, np.exp(-1j * m * np.pi / (2 * L - 1))) + + # FOURIER UPSAMPLE TO 4L-3 + x = np.pad(x, ((0, 0), (lpad, L - 1), (0, 0))) + x = np.fft.ifftshift(x, axes=1) + x = np.fft.ifft(x, axis=1, norm="forward") + + # PERFORM QUADRATURE CONVOLUTION AS FFT REWEIGHTING IN REAL SPACE + # NB: Our convention here is conjugate to that of SSHT, in which + # the weights are conjugate but applied flipped and therefore are + # equivalent. To avoid flipping here he simply conjugate the weights. + x = np.einsum("nbm,b->nbm", x, Quads) + + # COMPUTE GMM BY FFT + x = np.fft.fft(x, axis=1, norm="forward") + x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] + + # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt + x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n]) + x = np.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n)) + + # SYMMETRY REFLECT FOR N < 0 + if reality: + temp = np.einsum( + "nlm,m,n->nlm", + np.conj(np.flip(x[1:], axis=(-1, -3))), + (-1) ** np.abs(np.arange(-L + 1, L)), + (-1) ** np.abs(np.arange(-N + 1, 0)), + ) + x = np.concatenate((temp, x), axis=0) + + return x * (2.0 * np.pi) ** 2 + + +@partial(jit, static_argnums=(2, 3, 4, 5)) +def forward_transform_jax( + f: jnp.ndarray, + DW: jnp.ndarray, + L: int, + N: int, + reality: bool = False, + sampling: str = "mw", +) -> jnp.ndarray: + """ + Computes the forward Wigner transform using the Fourier decomposition algorithm (JAX). + + Args: + f (jnp.ndarray): Function sampled on the rotation group. + DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced + Wigner d-functions and the corresponding upsampled quadrature weights. + L (int): Harmonic band-limit. + N (int): Azimuthal band-limit. + reality (bool, optional): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + Defaults to False. + sampling (str, optional): Sampling scheme. Supported sampling schemes include + {"mw", "mwss"}. Defaults to "mw". + + Returns: + jnp.ndarray: Wigner coefficients of function f. + + """ + if sampling.lower() not in ["mw", "mwss"]: + raise ValueError( + f"Fourier-Wigner algorithm does not support {sampling} sampling." + ) + + # EXTRACT VARIOUS PRECOMPUTES + Delta, Quads = DW + + # INDEX VALUES + n_start_ind = N - 1 if reality else 0 + m_offset = 1 if sampling.lower() == "mwss" else 0 + lpad = (L - 2) if sampling.lower() == "mwss" else (L - 1) + + # REUSED ARRAYS + m = jnp.arange(-L + 1, L) + n = jnp.arange(n_start_ind - N + 1, N) + + # COMPUTE ALPHA + GAMMA FFT + if reality: + x = jnp.fft.rfft(jnp.real(f), axis=0, norm="forward") + x = jnp.fft.fft(x, axis=2, norm="forward") + x = jnp.fft.fftshift(x, axes=2)[:, :, m_offset:] + else: + x = jnp.fft.fft2(f, axes=(0, 2), norm="forward") + x = jnp.fft.fftshift(x, axes=(0, 2))[:, :, m_offset:] + + # PERIODICALLY EXTEND BETA FROM [0,pi]->[0,2pi) + temp = jnp.einsum("ntm,m,n->ntm", x, (-1) ** jnp.abs(m), (-1) ** jnp.abs(n)) + x = jnp.concatenate((x, jnp.flip(temp, axis=1)[:, 1:L]), axis=1) + + # COMPUTE BETA FFT OVER PERIODICALLY EXTENDED FTM + x = jnp.fft.fft(x, axis=1, norm="forward") + x = jnp.fft.fftshift(x, axes=1) + + # APPLY PHASE SHIFT + if sampling.lower() == "mw": + x = jnp.einsum("nbm,b->nbm", x, jnp.exp(-1j * m * jnp.pi / (2 * L - 1))) + + # FOURIER UPSAMPLE TO 4L-3 + x = jnp.pad(x, ((0, 0), (lpad, L - 1), (0, 0))) + x = jnp.fft.ifftshift(x, axes=1) + x = jnp.fft.ifft(x, axis=1, norm="forward") + + # PERFORM QUADRATURE CONVOLUTION AS FFT REWEIGHTING IN REAL SPACE + # NB: Our convention here is conjugate to that of SSHT, in which + # the weights are conjugate but applied flipped and therefore are + # equivalent. To avoid flipping here he simply conjugate the weights. + x = jnp.einsum("nbm,b->nbm", x, Quads) + + # COMPUTE GMM BY FFT + x = jnp.fft.fft(x, axis=1, norm="forward") + x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2] + + # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt + x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n]) + x = jnp.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n)) + + # SYMMETRY REFLECT FOR N < 0 + if reality: + temp = jnp.einsum( + "nlm,m,n->nlm", + jnp.conj(jnp.flip(x[1:], axis=(-1, -3))), + (-1) ** jnp.abs(jnp.arange(-L + 1, L)), + (-1) ** jnp.abs(jnp.arange(-N + 1, 0)), + ) + x = jnp.concatenate((temp, x), axis=0) + + return x * (2.0 * jnp.pi) ** 2 diff --git a/s2fft/precompute_transforms/wigner.py b/s2fft/precompute_transforms/wigner.py index cd3d4dc..219ce71 100644 --- a/s2fft/precompute_transforms/wigner.py +++ b/s2fft/precompute_transforms/wigner.py @@ -110,7 +110,7 @@ def inverse_transform( fnab = np.zeros(samples.fnab_shape(L, N, sampling, nside), dtype=np.complex128) fnab[n_start_ind:, :, m_offset:] = np.einsum( - "...ntlm, ...nlm -> ...ntm", kernel, flmn[n_start_ind:, :, :] + "...ntlm, ...nlm -> ...ntm", kernel, flmn[n_start_ind:] ) if sampling.lower() in "healpix": @@ -122,7 +122,6 @@ def inverse_transform( return np.fft.irfft(f[n_start_ind:], 2 * N - 1, axis=-2, norm="forward") else: return np.fft.ifft(np.fft.ifftshift(f, axes=-2), axis=-2, norm="forward") - else: if reality: fnab = np.fft.ifft(np.fft.ifftshift(fnab, axes=-1), axis=-1, norm="forward") diff --git a/tests/test_fourier_wigner.py b/tests/test_fourier_wigner.py new file mode 100644 index 0000000..2de9da9 --- /dev/null +++ b/tests/test_fourier_wigner.py @@ -0,0 +1,145 @@ +import numpy as np +import pytest +import so3 + +from s2fft.precompute_transforms import construct as c +from s2fft.precompute_transforms import fourier_wigner as fw +from s2fft.sampling import so3_samples as samples + +# Test cases +L_to_test = [16] +N_to_test = [2, 8, 16] +reality_to_test = [False, True] +sampling_schemes = ["mw", "mwss"] +methods_to_test = ["numpy", "jax"] + +# Test tolerance +atol = 1e-12 + + +@pytest.mark.parametrize("L", L_to_test) +@pytest.mark.parametrize("N", N_to_test) +@pytest.mark.parametrize("sampling", sampling_schemes) +@pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("method", methods_to_test) +def test_inverse_fourier_wigner_transform( + flmn_generator, + s2fft_to_so3_sampling, + L: int, + N: int, + sampling: str, + reality: bool, + method: str, +): + flmn = flmn_generator(L=L, N=N, reality=reality) + + params = so3.create_parameter_dict( + L=L, + N=N, + sampling_scheme_str=s2fft_to_so3_sampling(sampling), + reality=False, + ) + f = so3.inverse(samples.flmn_3d_to_1d(flmn, L, N), params) + + delta = ( + c.fourier_wigner_kernel_jax(L) + if method == "jax" + else c.fourier_wigner_kernel(L) + ) + transform = fw.inverse_transform_jax if method == "jax" else fw.inverse_transform + f_check = transform(flmn, delta, L, N, reality, sampling) + np.testing.assert_allclose(f, f_check.flatten("C"), atol=atol) + + +@pytest.mark.parametrize("L", L_to_test) +@pytest.mark.parametrize("N", N_to_test) +@pytest.mark.parametrize("sampling", sampling_schemes) +@pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("method", methods_to_test) +def test_forward_fourier_wigner_transform( + flmn_generator, + s2fft_to_so3_sampling, + L: int, + N: int, + sampling: str, + reality: bool, + method: str, +): + flmn = flmn_generator(L=L, N=N, reality=reality) + + params = so3.create_parameter_dict( + L=L, + N=N, + sampling_scheme_str=s2fft_to_so3_sampling(sampling), + reality=False, + ) + f = so3.inverse(samples.flmn_3d_to_1d(flmn, L, N), params) + f_3D = f.reshape( + samples._ngamma(N), + samples._nbeta(L, sampling), + samples._nalpha(L, sampling), + ) + flmn = samples.flmn_1d_to_3d(so3.forward(f, params), L, N) + + delta = ( + c.fourier_wigner_kernel_jax(L) + if method == "jax" + else c.fourier_wigner_kernel(L) + ) + transform = fw.forward_transform_jax if method == "jax" else fw.forward_transform + + flmn_check = transform(f_3D, delta, L, N, reality, sampling) + np.testing.assert_allclose(flmn, flmn_check, atol=atol) + + +@pytest.mark.parametrize("L", [8, 16, 32, 64]) +@pytest.mark.parametrize("sampling", sampling_schemes) +@pytest.mark.parametrize("reality", reality_to_test) +def test_inverse_fourier_wigner_transform_high_N( + flmn_generator, s2fft_to_so3_sampling, L: int, sampling: str, reality: bool +): + N = L + flmn = flmn_generator(L=L, N=N, reality=reality) + + params = so3.create_parameter_dict( + L=L, + N=N, + sampling_scheme_str=s2fft_to_so3_sampling(sampling), + reality=False, + ) + f = so3.inverse(samples.flmn_3d_to_1d(flmn, L, N), params) + + f = f.real if reality else f + delta = c.fourier_wigner_kernel(L) + f_check = fw.inverse_transform(flmn, delta, L, N, reality, sampling) + + np.testing.assert_allclose(f, f_check.flatten("C"), atol=atol) + + +@pytest.mark.parametrize("L", [8, 16, 32, 64]) +@pytest.mark.parametrize("sampling", sampling_schemes) +@pytest.mark.parametrize("reality", reality_to_test) +def test_forward_fourier_wigner_transform_high_N( + flmn_generator, s2fft_to_so3_sampling, L: int, sampling: str, reality: bool +): + N = L + flmn = flmn_generator(L=L, N=N, reality=reality) + + params = so3.create_parameter_dict( + L=L, + N=N, + sampling_scheme_str=s2fft_to_so3_sampling(sampling), + reality=False, + ) + + f_1D = so3.inverse(samples.flmn_3d_to_1d(flmn, L, N), params) + f_3D = f_1D.reshape( + samples._ngamma(N), + samples._nbeta(L, sampling), + samples._nalpha(L, sampling), + ) + flmn_so3 = samples.flmn_1d_to_3d(so3.forward(f_1D, params), L, N) + + delta = c.fourier_wigner_kernel_jax(L) + flmn_check = fw.forward_transform_jax(f_3D, delta, L, N, reality, sampling) + np.testing.assert_allclose(flmn_so3, flmn_check, atol=atol)