From 3551865e38d45f947e3bd7cf92437466fe6cc6b0 Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Thu, 10 Oct 2024 11:14:16 +0100 Subject: [PATCH] run linting on test scripts --- s2fft/precompute_transforms/construct.py | 76 ++++++------------------ s2fft/precompute_transforms/spherical.py | 34 +++-------- s2fft/recursions/risbo.py | 4 +- s2fft/recursions/risbo_jax.py | 22 +++---- tests/test_spherical_precompute.py | 15 ++--- tests/test_wigner_precompute.py | 11 ++-- tests/test_wigner_recursions.py | 28 ++++----- 7 files changed, 57 insertions(+), 133 deletions(-) diff --git a/s2fft/precompute_transforms/construct.py b/s2fft/precompute_transforms/construct.py index 3d23cc8..f74a00d 100644 --- a/s2fft/precompute_transforms/construct.py +++ b/s2fft/precompute_transforms/construct.py @@ -74,9 +74,7 @@ def spin_spherical_kernel( if recursion.lower() == "auto": # This mode automatically determines which recursion is best suited for the # current parameter configuration. - recursion = ( - "risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen" - ) + recursion = "risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen" dl = [] m_start_ind = L - 1 if reality else 0 @@ -111,13 +109,9 @@ def spin_spherical_kernel( # - The complexity of this approach is O(L^4). # - This approach is stable for arbitrary abs(spins) <= L. if sampling.lower() in ["healpix", "gl"]: - delta = np.zeros( - (len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64 - ) + delta = np.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64) for el in range(L): - delta = recursions.risbo.compute_full_vectorised( - delta, thetas, L, el - ) + delta = recursions.risbo.compute_full_vectorised(delta, thetas, L, el) dl[:, el] = delta[:, m_start_ind:, L - 1 - spin] # MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated @@ -144,19 +138,13 @@ def spin_spherical_kernel( delta[:, L - 1 - spin], 1j ** (-spin - m_value[m_start_ind:]), ) - temp = np.einsum( - "am,a->am", temp, np.exp(1j * m_value * thetas[0]) - ) - temp = np.fft.irfft( - temp[L - 1 :], n=nsamps, axis=0, norm="forward" - ) + temp = np.einsum("am,a->am", temp, np.exp(1j * m_value * thetas[0])) + temp = np.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward") dl[:, el] = temp[: len(thetas)] # Fold in normalisation to avoid recomputation at run-time. - dl = np.einsum( - "tlm,l->tlm", dl, np.sqrt((2 * np.arange(L) + 1) / (4 * np.pi)) - ) + dl = np.einsum("tlm,l->tlm", dl, np.sqrt((2 * np.arange(L) + 1) / (4 * np.pi))) else: raise ValueError(f"Recursion method {recursion} not recognised.") @@ -234,9 +222,7 @@ def spin_spherical_kernel_jax( if recursion.lower() == "auto": # This mode automatically determines which recursion is best suited for the # current parameter configuration. - recursion = ( - "risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen" - ) + recursion = "risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen" dl = [] m_start_ind = L - 1 if reality else 0 @@ -283,9 +269,7 @@ def spin_spherical_kernel_jax( # - The complexity of this approach is O(L^4). # - This approach is stable for arbitrary abs(spins) <= L. if sampling.lower() in ["healpix", "gl"]: - delta = jnp.zeros( - (len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64 - ) + delta = jnp.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64) vfunc = jax.vmap( recursions.risbo_jax.compute_full, in_axes=(0, 0, None, None) ) @@ -309,9 +293,7 @@ def spin_spherical_kernel_jax( # Calculate the Fourier coefficients of the Wigner d-functions, delta(pi/2). for el in range(L): - delta = recursions.risbo_jax.compute_full( - delta, jnp.pi / 2, L, el - ) + delta = recursions.risbo_jax.compute_full(delta, jnp.pi / 2, L, el) m_value = jnp.arange(-L + 1, L) temp = jnp.einsum( "am,a,m->am", @@ -319,12 +301,8 @@ def spin_spherical_kernel_jax( delta[:, L - 1 - spin], 1j ** (-spin - m_value[m_start_ind:]), ) - temp = jnp.einsum( - "am,a->am", temp, jnp.exp(1j * m_value * thetas[0]) - ) - temp = jnp.fft.irfft( - temp[L - 1 :], n=nsamps, axis=0, norm="forward" - ) + temp = jnp.einsum("am,a->am", temp, jnp.exp(1j * m_value * thetas[0])) + temp = jnp.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward") dl = dl.at[:, el].set(temp[: len(thetas)]) @@ -332,9 +310,7 @@ def spin_spherical_kernel_jax( raise ValueError(f"Recursion method {recursion} not recognised.") # Fold in normalisation to avoid recomputation at run-time. - dl = jnp.einsum( - "tlm,l->tlm", dl, jnp.sqrt((2 * jnp.arange(L) + 1) / (4 * jnp.pi)) - ) + dl = jnp.einsum("tlm,l->tlm", dl, jnp.sqrt((2 * jnp.arange(L) + 1) / (4 * jnp.pi))) # Fold in quadrature to avoid recomputation at run-time. if forward: @@ -433,9 +409,7 @@ def wigner_kernel( if mode.lower() == "direct": delta = np.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64) for el in range(L): - delta = recursions.risbo.compute_full_vectorised( - delta, thetas, L, el - ) + delta = recursions.risbo.compute_full_vectorised(delta, thetas, L, el) dl[:, :, el] = np.moveaxis(delta, -1, 0)[L - 1 + n] # MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated @@ -464,9 +438,7 @@ def wigner_kernel( 1j ** (-m_value), 1j ** (n), ) - temp = np.einsum( - "amn,a->amn", temp, np.exp(1j * m_value * thetas[0]) - ) + temp = np.einsum("amn,a->amn", temp, np.exp(1j * m_value * thetas[0])) temp = np.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward") dl[:, :, el] = np.moveaxis(temp[: len(thetas)], -1, 0) @@ -574,12 +546,8 @@ def wigner_kernel_jax( # - The complexity of this approach is ALWAYS O(L^4). # - This approach is stable for arbitrary abs(spins) <= L. if mode.lower() == "direct": - delta = jnp.zeros( - (len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64 - ) - vfunc = jax.vmap( - recursions.risbo_jax.compute_full, in_axes=(0, 0, None, None) - ) + delta = jnp.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64) + vfunc = jax.vmap(recursions.risbo_jax.compute_full, in_axes=(0, 0, None, None)) for el in range(L): delta = vfunc(delta, thetas, L, el) dl = dl.at[:, :, el].set(jnp.moveaxis(delta, -1, 0)[L - 1 + n]) @@ -610,12 +578,8 @@ def wigner_kernel_jax( 1j ** (-m_value), 1j ** (n), ) - temp = jnp.einsum( - "amn,a->amn", temp, jnp.exp(1j * m_value * thetas[0]) - ) - temp = jnp.fft.irfft( - temp[L - 1 :], n=nsamps, axis=0, norm="forward" - ) + temp = jnp.einsum("amn,a->amn", temp, jnp.exp(1j * m_value * thetas[0])) + temp = jnp.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward") dl = dl.at[:, :, el].set(jnp.moveaxis(temp[: len(thetas)], -1, 0)) else: @@ -646,9 +610,7 @@ def wigner_kernel_jax( return dl -def healpix_phase_shifts( - L: int, nside: int, forward: bool = False -) -> np.ndarray: +def healpix_phase_shifts(L: int, nside: int, forward: bool = False) -> np.ndarray: r""" Generates a phase shift vector for HEALPix for all :math:`\theta` rings. diff --git a/s2fft/precompute_transforms/spherical.py b/s2fft/precompute_transforms/spherical.py index dc26a0e..c6fe4a0 100644 --- a/s2fft/precompute_transforms/spherical.py +++ b/s2fft/precompute_transforms/spherical.py @@ -65,13 +65,9 @@ def inverse( if method == "numpy": return inverse_transform(flm, kernel, L, sampling, reality, spin, nside) elif method == "jax": - return inverse_transform_jax( - flm, kernel, L, sampling, reality, spin, nside - ) + return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside) elif method == "torch": - return inverse_transform_torch( - flm, kernel, L, sampling, reality, spin, nside - ) + return inverse_transform_torch(flm, kernel, L, sampling, reality, spin, nside) else: raise ValueError(f"Method {method} not recognised.") @@ -193,9 +189,7 @@ def inverse_transform_jax( if sampling.lower() == "healpix": if reality: ftm = ftm.at[:, m_offset : m_start_ind + m_offset].set( - jnp.flip( - jnp.conj(ftm[:, m_start_ind + m_offset + 1 :]), axis=-1 - ) + jnp.flip(jnp.conj(ftm[:, m_start_ind + m_offset + 1 :]), axis=-1) ) f = hp.healpix_ifft(ftm, L, nside, "jax", reality) @@ -252,9 +246,7 @@ def inverse_transform_torch( m_offset = 1 if sampling in ["mwss", "healpix"] else 0 m_start_ind = L - 1 if reality else 0 - ftm = torch.zeros( - samples.ftm_shape(L, sampling, nside), dtype=torch.complex128 - ) + ftm = torch.zeros(samples.ftm_shape(L, sampling, nside), dtype=torch.complex128) if sampling.lower() == "healpix": ftm[:, m_start_ind + m_offset :] += torch.einsum( "...tlm, ...lm -> ...tm", kernel, flm[:, m_start_ind:] @@ -348,13 +340,9 @@ def forward( if method == "numpy": return forward_transform(f, kernel, L, sampling, reality, spin, nside) elif method == "jax": - return forward_transform_jax( - f, kernel, L, sampling, reality, spin, nside - ) + return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside) elif method == "torch": - return forward_transform_torch( - f, kernel, L, sampling, reality, spin, nside - ) + return forward_transform_torch(f, kernel, L, sampling, reality, spin, nside) else: raise ValueError(f"Method {method} not recognised.") @@ -495,8 +483,7 @@ def forward_transform_jax( if reality: flm = flm.at[:, :m_start_ind].set( jnp.flip( - (-1) ** (jnp.arange(1, L) % 2) - * jnp.conj(flm[:, m_start_ind + 1 :]), + (-1) ** (jnp.arange(1, L) % 2) * jnp.conj(flm[:, m_start_ind + 1 :]), axis=-1, ) ) @@ -564,9 +551,7 @@ def forward_transform_torch( flm = torch.zeros(samples.flm_shape(L), dtype=torch.complex128) if sampling.lower() == "healpix": - flm[:, m_start_ind:] = torch.einsum( - "...tlm, ...tm -> ...lm", kernel, ftm - ) + flm[:, m_start_ind:] = torch.einsum("...tlm, ...tm -> ...lm", kernel, ftm) else: flm[:, m_start_ind:].real = torch.einsum( "...tlm, ...tm -> ...lm", kernel, ftm.real @@ -577,8 +562,7 @@ def forward_transform_torch( if reality: flm[:, :m_start_ind] = torch.flip( - (-1) ** (torch.arange(1, L) % 2) - * torch.conj(flm[:, m_start_ind + 1 :]), + (-1) ** (torch.arange(1, L) % 2) * torch.conj(flm[:, m_start_ind + 1 :]), dims=[-1], ) diff --git a/s2fft/recursions/risbo.py b/s2fft/recursions/risbo.py index 74738e3..843e65d 100644 --- a/s2fft/recursions/risbo.py +++ b/s2fft/recursions/risbo.py @@ -89,9 +89,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray: ddj = dd[i, k] / j - dl[k - el + L - 1, i - el + L - 1] += ( - sqrt_jmi * sqrt_jmk * ddj * coshb - ) + dl[k - el + L - 1, i - el + L - 1] += sqrt_jmi * sqrt_jmk * ddj * coshb dl[k - el + L - 1, i + 1 - el + L - 1] -= ( sqrt_ip1 * sqrt_jmk * ddj * sinhb ) diff --git a/s2fft/recursions/risbo_jax.py b/s2fft/recursions/risbo_jax.py index 6f2c82b..911db3f 100644 --- a/s2fft/recursions/risbo_jax.py +++ b/s2fft/recursions/risbo_jax.py @@ -68,29 +68,21 @@ def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray: dlj = dl[k - (el - 1) + L - 1][:, i - (el - 1) + L - 1] dd = dd.at[:j, :j].add( - jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True) - * dlj - * coshb + jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True) * dlj * coshb ) dd = dd.at[:j, 1 : j + 1].add( - jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True) - * dlj - * sinhb + jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True) * dlj * sinhb ) dd = dd.at[1 : j + 1, :j].add( - jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True) - * dlj - * sinhb + jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True) * dlj * sinhb ) dd = dd.at[1 : j + 1, 1 : j + 1].add( - jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True) - * dlj - * coshb + jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True) * dlj * coshb ) - dl = dl.at[ - -el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1 - ].multiply(0.0) + dl = dl.at[-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1].multiply( + 0.0 + ) j = 2 * el i = jnp.arange(j) diff --git a/tests/test_spherical_precompute.py b/tests/test_spherical_precompute.py index 537d865..7672331 100644 --- a/tests/test_spherical_precompute.py +++ b/tests/test_spherical_precompute.py @@ -1,10 +1,11 @@ import numpy as np +import pyssht as ssht import pytest import torch -from s2fft.precompute_transforms.spherical import inverse, forward -from s2fft.precompute_transforms import construct as c + from s2fft.base_transforms import spherical as base -import pyssht as ssht +from s2fft.precompute_transforms import construct as c +from s2fft.precompute_transforms.spherical import forward, inverse from s2fft.sampling import s2_samples as samples L_to_test = [12] @@ -47,9 +48,7 @@ def test_transform_inverse( if method.lower() == "jax" else c.spin_spherical_kernel ) - kernel = kfunc( - L, spin, reality, sampling, forward=False, recursion=recursion - ) + kernel = kfunc(L, spin, reality, sampling, forward=False, recursion=recursion) tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12 if method.lower() == "torch": @@ -178,9 +177,7 @@ def test_transform_forward( if method.lower() == "jax" else c.spin_spherical_kernel ) - kernel = kfunc( - L, spin, reality, sampling, forward=True, recursion=recursion - ) + kernel = kfunc(L, spin, reality, sampling, forward=True, recursion=recursion) tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12 if method.lower() == "torch": diff --git a/tests/test_wigner_precompute.py b/tests/test_wigner_precompute.py index d83de61..ae0a9a4 100644 --- a/tests/test_wigner_precompute.py +++ b/tests/test_wigner_precompute.py @@ -1,11 +1,12 @@ import numpy as np import pytest +import so3 import torch -from s2fft.precompute_transforms.wigner import inverse, forward -from s2fft.precompute_transforms import construct as c + from s2fft.base_transforms import wigner as base +from s2fft.precompute_transforms import construct as c +from s2fft.precompute_transforms.wigner import forward, inverse from s2fft.sampling import so3_samples as samples -import so3 L_to_test = [6] N_to_test = [2, 6] @@ -172,9 +173,7 @@ def test_inverse_wigner_transform_healpix( method, nside, ) - np.testing.assert_allclose( - np.real(f), np.real(f_check), atol=1e-5, rtol=1e-5 - ) + np.testing.assert_allclose(np.real(f), np.real(f_check), atol=1e-5, rtol=1e-5) # Test Gradients flmn_grad_test = torch.from_numpy(flmn) diff --git a/tests/test_wigner_recursions.py b/tests/test_wigner_recursions.py index d8ffce3..305b075 100644 --- a/tests/test_wigner_recursions.py +++ b/tests/test_wigner_recursions.py @@ -102,15 +102,11 @@ def test_trapani_interfaces(): dl_jax = recursions.trapani.init(dl_jax, L, implementation="jax") for el in range(1, L): - dl_loop = recursions.trapani.compute_full( - dl_loop, L, el, implementation="loop" - ) + dl_loop = recursions.trapani.compute_full(dl_loop, L, el, implementation="loop") dl_vect = recursions.trapani.compute_full( dl_vect, L, el, implementation="vectorized" ) - dl_jax = recursions.trapani.compute_full( - dl_jax, L, el, implementation="jax" - ) + dl_jax = recursions.trapani.compute_full(dl_jax, L, el, implementation="jax") np.testing.assert_allclose( dl_loop[ -el + (L - 1) : el + (L - 1) + 1, @@ -134,13 +130,11 @@ def test_trapani_interfaces(): atol=1e-10, ) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError) as _: recursions.trapani.init(dl_loop, L, implementation="unexpected") - with pytest.raises(ValueError) as e: - recursions.trapani.compute_full( - dl_jax, L, el, implementation="unexpected" - ) + with pytest.raises(ValueError) as _: + recursions.trapani.compute_full(dl_jax, L, el, implementation="unexpected") def test_risbo_with_ssht(): @@ -240,10 +234,8 @@ def test_turok_slice_jax_with_ssht(L: int, spin: int, sampling: str): for el in range(L): if el >= np.abs(spin): - print("beta {}, el {}, spin {}".format(beta, el, spin)) - dl_turok = recursions.turok_jax.compute_slice( - beta, el, L, -spin - ) + print(f"beta {beta}, el {el}, spin {spin}") + dl_turok = recursions.turok_jax.compute_slice(beta, el, L, -spin) np.testing.assert_allclose( dl_turok[L - 1 - el : L - 1 + el + 1], @@ -256,11 +248,11 @@ def test_turok_slice_jax_with_ssht(L: int, spin: int, sampling: str): def test_turok_exceptions(): L = 10 - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError) as _: recursions.turok.compute_full(np.pi / 2, L, L) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError) as _: recursions.turok.compute_slice(beta=np.pi / 2, el=L - 1, L=L, mm=L) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError) as _: recursions.turok.compute_slice(beta=np.pi / 2, el=L, L=L, mm=0)