Skip to content

Commit

Permalink
run linting on test scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
CosmoMatt committed Oct 10, 2024
1 parent ba58339 commit 3551865
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 133 deletions.
76 changes: 19 additions & 57 deletions s2fft/precompute_transforms/construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def spin_spherical_kernel(
if recursion.lower() == "auto":
# This mode automatically determines which recursion is best suited for the
# current parameter configuration.
recursion = (
"risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen"
)
recursion = "risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen"

dl = []
m_start_ind = L - 1 if reality else 0
Expand Down Expand Up @@ -111,13 +109,9 @@ def spin_spherical_kernel(
# - The complexity of this approach is O(L^4).
# - This approach is stable for arbitrary abs(spins) <= L.
if sampling.lower() in ["healpix", "gl"]:
delta = np.zeros(
(len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64
)
delta = np.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64)
for el in range(L):
delta = recursions.risbo.compute_full_vectorised(
delta, thetas, L, el
)
delta = recursions.risbo.compute_full_vectorised(delta, thetas, L, el)
dl[:, el] = delta[:, m_start_ind:, L - 1 - spin]

# MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
Expand All @@ -144,19 +138,13 @@ def spin_spherical_kernel(
delta[:, L - 1 - spin],
1j ** (-spin - m_value[m_start_ind:]),
)
temp = np.einsum(
"am,a->am", temp, np.exp(1j * m_value * thetas[0])
)
temp = np.fft.irfft(
temp[L - 1 :], n=nsamps, axis=0, norm="forward"
)
temp = np.einsum("am,a->am", temp, np.exp(1j * m_value * thetas[0]))
temp = np.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")

dl[:, el] = temp[: len(thetas)]

# Fold in normalisation to avoid recomputation at run-time.
dl = np.einsum(
"tlm,l->tlm", dl, np.sqrt((2 * np.arange(L) + 1) / (4 * np.pi))
)
dl = np.einsum("tlm,l->tlm", dl, np.sqrt((2 * np.arange(L) + 1) / (4 * np.pi)))

else:
raise ValueError(f"Recursion method {recursion} not recognised.")

Check warning on line 150 in s2fft/precompute_transforms/construct.py

View check run for this annotation

Codecov / codecov/patch

s2fft/precompute_transforms/construct.py#L150

Added line #L150 was not covered by tests
Expand Down Expand Up @@ -234,9 +222,7 @@ def spin_spherical_kernel_jax(
if recursion.lower() == "auto":
# This mode automatically determines which recursion is best suited for the
# current parameter configuration.
recursion = (
"risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen"
)
recursion = "risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen"

dl = []
m_start_ind = L - 1 if reality else 0
Expand Down Expand Up @@ -283,9 +269,7 @@ def spin_spherical_kernel_jax(
# - The complexity of this approach is O(L^4).
# - This approach is stable for arbitrary abs(spins) <= L.
if sampling.lower() in ["healpix", "gl"]:
delta = jnp.zeros(
(len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64
)
delta = jnp.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64)
vfunc = jax.vmap(
recursions.risbo_jax.compute_full, in_axes=(0, 0, None, None)
)
Expand All @@ -309,32 +293,24 @@ def spin_spherical_kernel_jax(

# Calculate the Fourier coefficients of the Wigner d-functions, delta(pi/2).
for el in range(L):
delta = recursions.risbo_jax.compute_full(
delta, jnp.pi / 2, L, el
)
delta = recursions.risbo_jax.compute_full(delta, jnp.pi / 2, L, el)
m_value = jnp.arange(-L + 1, L)
temp = jnp.einsum(
"am,a,m->am",
delta[:, m_start_ind:],
delta[:, L - 1 - spin],
1j ** (-spin - m_value[m_start_ind:]),
)
temp = jnp.einsum(
"am,a->am", temp, jnp.exp(1j * m_value * thetas[0])
)
temp = jnp.fft.irfft(
temp[L - 1 :], n=nsamps, axis=0, norm="forward"
)
temp = jnp.einsum("am,a->am", temp, jnp.exp(1j * m_value * thetas[0]))
temp = jnp.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")

dl = dl.at[:, el].set(temp[: len(thetas)])

else:
raise ValueError(f"Recursion method {recursion} not recognised.")

Check warning on line 310 in s2fft/precompute_transforms/construct.py

View check run for this annotation

Codecov / codecov/patch

s2fft/precompute_transforms/construct.py#L310

Added line #L310 was not covered by tests

# Fold in normalisation to avoid recomputation at run-time.
dl = jnp.einsum(
"tlm,l->tlm", dl, jnp.sqrt((2 * jnp.arange(L) + 1) / (4 * jnp.pi))
)
dl = jnp.einsum("tlm,l->tlm", dl, jnp.sqrt((2 * jnp.arange(L) + 1) / (4 * jnp.pi)))

# Fold in quadrature to avoid recomputation at run-time.
if forward:
Expand Down Expand Up @@ -433,9 +409,7 @@ def wigner_kernel(
if mode.lower() == "direct":
delta = np.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64)
for el in range(L):
delta = recursions.risbo.compute_full_vectorised(
delta, thetas, L, el
)
delta = recursions.risbo.compute_full_vectorised(delta, thetas, L, el)
dl[:, :, el] = np.moveaxis(delta, -1, 0)[L - 1 + n]

# MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
Expand Down Expand Up @@ -464,9 +438,7 @@ def wigner_kernel(
1j ** (-m_value),
1j ** (n),
)
temp = np.einsum(
"amn,a->amn", temp, np.exp(1j * m_value * thetas[0])
)
temp = np.einsum("amn,a->amn", temp, np.exp(1j * m_value * thetas[0]))
temp = np.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
dl[:, :, el] = np.moveaxis(temp[: len(thetas)], -1, 0)

Expand Down Expand Up @@ -574,12 +546,8 @@ def wigner_kernel_jax(
# - The complexity of this approach is ALWAYS O(L^4).
# - This approach is stable for arbitrary abs(spins) <= L.
if mode.lower() == "direct":
delta = jnp.zeros(
(len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64
)
vfunc = jax.vmap(
recursions.risbo_jax.compute_full, in_axes=(0, 0, None, None)
)
delta = jnp.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64)
vfunc = jax.vmap(recursions.risbo_jax.compute_full, in_axes=(0, 0, None, None))
for el in range(L):
delta = vfunc(delta, thetas, L, el)
dl = dl.at[:, :, el].set(jnp.moveaxis(delta, -1, 0)[L - 1 + n])
Expand Down Expand Up @@ -610,12 +578,8 @@ def wigner_kernel_jax(
1j ** (-m_value),
1j ** (n),
)
temp = jnp.einsum(
"amn,a->amn", temp, jnp.exp(1j * m_value * thetas[0])
)
temp = jnp.fft.irfft(
temp[L - 1 :], n=nsamps, axis=0, norm="forward"
)
temp = jnp.einsum("amn,a->amn", temp, jnp.exp(1j * m_value * thetas[0]))
temp = jnp.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
dl = dl.at[:, :, el].set(jnp.moveaxis(temp[: len(thetas)], -1, 0))

else:
Expand Down Expand Up @@ -646,9 +610,7 @@ def wigner_kernel_jax(
return dl


def healpix_phase_shifts(
L: int, nside: int, forward: bool = False
) -> np.ndarray:
def healpix_phase_shifts(L: int, nside: int, forward: bool = False) -> np.ndarray:
r"""
Generates a phase shift vector for HEALPix for all :math:`\theta` rings.
Expand Down
34 changes: 9 additions & 25 deletions s2fft/precompute_transforms/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,9 @@ def inverse(
if method == "numpy":
return inverse_transform(flm, kernel, L, sampling, reality, spin, nside)
elif method == "jax":
return inverse_transform_jax(
flm, kernel, L, sampling, reality, spin, nside
)
return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside)
elif method == "torch":
return inverse_transform_torch(
flm, kernel, L, sampling, reality, spin, nside
)
return inverse_transform_torch(flm, kernel, L, sampling, reality, spin, nside)
else:
raise ValueError(f"Method {method} not recognised.")

Expand Down Expand Up @@ -193,9 +189,7 @@ def inverse_transform_jax(
if sampling.lower() == "healpix":
if reality:
ftm = ftm.at[:, m_offset : m_start_ind + m_offset].set(
jnp.flip(
jnp.conj(ftm[:, m_start_ind + m_offset + 1 :]), axis=-1
)
jnp.flip(jnp.conj(ftm[:, m_start_ind + m_offset + 1 :]), axis=-1)
)
f = hp.healpix_ifft(ftm, L, nside, "jax", reality)

Expand Down Expand Up @@ -252,9 +246,7 @@ def inverse_transform_torch(
m_offset = 1 if sampling in ["mwss", "healpix"] else 0
m_start_ind = L - 1 if reality else 0

ftm = torch.zeros(
samples.ftm_shape(L, sampling, nside), dtype=torch.complex128
)
ftm = torch.zeros(samples.ftm_shape(L, sampling, nside), dtype=torch.complex128)
if sampling.lower() == "healpix":
ftm[:, m_start_ind + m_offset :] += torch.einsum(
"...tlm, ...lm -> ...tm", kernel, flm[:, m_start_ind:]
Expand Down Expand Up @@ -348,13 +340,9 @@ def forward(
if method == "numpy":
return forward_transform(f, kernel, L, sampling, reality, spin, nside)
elif method == "jax":
return forward_transform_jax(
f, kernel, L, sampling, reality, spin, nside
)
return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside)
elif method == "torch":
return forward_transform_torch(
f, kernel, L, sampling, reality, spin, nside
)
return forward_transform_torch(f, kernel, L, sampling, reality, spin, nside)
else:
raise ValueError(f"Method {method} not recognised.")

Expand Down Expand Up @@ -495,8 +483,7 @@ def forward_transform_jax(
if reality:
flm = flm.at[:, :m_start_ind].set(
jnp.flip(
(-1) ** (jnp.arange(1, L) % 2)
* jnp.conj(flm[:, m_start_ind + 1 :]),
(-1) ** (jnp.arange(1, L) % 2) * jnp.conj(flm[:, m_start_ind + 1 :]),
axis=-1,
)
)
Expand Down Expand Up @@ -564,9 +551,7 @@ def forward_transform_torch(

flm = torch.zeros(samples.flm_shape(L), dtype=torch.complex128)
if sampling.lower() == "healpix":
flm[:, m_start_ind:] = torch.einsum(
"...tlm, ...tm -> ...lm", kernel, ftm
)
flm[:, m_start_ind:] = torch.einsum("...tlm, ...tm -> ...lm", kernel, ftm)
else:
flm[:, m_start_ind:].real = torch.einsum(
"...tlm, ...tm -> ...lm", kernel, ftm.real
Expand All @@ -577,8 +562,7 @@ def forward_transform_torch(

if reality:
flm[:, :m_start_ind] = torch.flip(
(-1) ** (torch.arange(1, L) % 2)
* torch.conj(flm[:, m_start_ind + 1 :]),
(-1) ** (torch.arange(1, L) % 2) * torch.conj(flm[:, m_start_ind + 1 :]),
dims=[-1],
)

Expand Down
4 changes: 1 addition & 3 deletions s2fft/recursions/risbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:

ddj = dd[i, k] / j

dl[k - el + L - 1, i - el + L - 1] += (
sqrt_jmi * sqrt_jmk * ddj * coshb
)
dl[k - el + L - 1, i - el + L - 1] += sqrt_jmi * sqrt_jmk * ddj * coshb
dl[k - el + L - 1, i + 1 - el + L - 1] -= (
sqrt_ip1 * sqrt_jmk * ddj * sinhb
)
Expand Down
22 changes: 7 additions & 15 deletions s2fft/recursions/risbo_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,21 @@ def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray:
dlj = dl[k - (el - 1) + L - 1][:, i - (el - 1) + L - 1]

dd = dd.at[:j, :j].add(
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True)
* dlj
* coshb
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True) * dlj * coshb
)
dd = dd.at[:j, 1 : j + 1].add(
jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True)
* dlj
* sinhb
jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True) * dlj * sinhb
)
dd = dd.at[1 : j + 1, :j].add(
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True)
* dlj
* sinhb
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True) * dlj * sinhb
)
dd = dd.at[1 : j + 1, 1 : j + 1].add(
jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True)
* dlj
* coshb
jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True) * dlj * coshb
)

dl = dl.at[
-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1
].multiply(0.0)
dl = dl.at[-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1].multiply(
0.0
)

j = 2 * el
i = jnp.arange(j)
Expand Down
15 changes: 6 additions & 9 deletions tests/test_spherical_precompute.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import pyssht as ssht
import pytest
import torch
from s2fft.precompute_transforms.spherical import inverse, forward
from s2fft.precompute_transforms import construct as c

from s2fft.base_transforms import spherical as base
import pyssht as ssht
from s2fft.precompute_transforms import construct as c
from s2fft.precompute_transforms.spherical import forward, inverse
from s2fft.sampling import s2_samples as samples

L_to_test = [12]
Expand Down Expand Up @@ -47,9 +48,7 @@ def test_transform_inverse(
if method.lower() == "jax"
else c.spin_spherical_kernel
)
kernel = kfunc(
L, spin, reality, sampling, forward=False, recursion=recursion
)
kernel = kfunc(L, spin, reality, sampling, forward=False, recursion=recursion)

tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12
if method.lower() == "torch":
Expand Down Expand Up @@ -178,9 +177,7 @@ def test_transform_forward(
if method.lower() == "jax"
else c.spin_spherical_kernel
)
kernel = kfunc(
L, spin, reality, sampling, forward=True, recursion=recursion
)
kernel = kfunc(L, spin, reality, sampling, forward=True, recursion=recursion)

tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12
if method.lower() == "torch":
Expand Down
11 changes: 5 additions & 6 deletions tests/test_wigner_precompute.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import numpy as np
import pytest
import so3
import torch
from s2fft.precompute_transforms.wigner import inverse, forward
from s2fft.precompute_transforms import construct as c

from s2fft.base_transforms import wigner as base
from s2fft.precompute_transforms import construct as c
from s2fft.precompute_transforms.wigner import forward, inverse
from s2fft.sampling import so3_samples as samples
import so3

L_to_test = [6]
N_to_test = [2, 6]
Expand Down Expand Up @@ -172,9 +173,7 @@ def test_inverse_wigner_transform_healpix(
method,
nside,
)
np.testing.assert_allclose(
np.real(f), np.real(f_check), atol=1e-5, rtol=1e-5
)
np.testing.assert_allclose(np.real(f), np.real(f_check), atol=1e-5, rtol=1e-5)

# Test Gradients
flmn_grad_test = torch.from_numpy(flmn)
Expand Down
Loading

0 comments on commit 3551865

Please sign in to comment.