Skip to content

Commit

Permalink
fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
CosmoMatt committed Oct 25, 2024
1 parent e08fa78 commit 0a6828d
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions s2fft/precompute_transforms/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def wigner_subset_to_s2(
"""
if sampling.lower() not in ["mw", "mwss"]:
raise ValueError(f"Fourier-Wigner algorithm does not support {sampling} sampling.")
raise ValueError(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, _ = DW
Expand All @@ -55,7 +57,9 @@ def wigner_subset_to_s2(
n = -spins

# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
x = np.zeros((flmn.shape[0], n_dim, xnlm_size, xnlm_size, flmn.shape[-1]), dtype=flmn.dtype)
x = np.zeros(
(flmn.shape[0], n_dim, xnlm_size, xnlm_size, flmn.shape[-1]), dtype=flmn.dtype
)
x[:, :, m_offset:, m_offset:, :] = np.einsum(
"bnlmc,lam,lan,l->bnamc",
flmn,
Expand All @@ -65,7 +69,9 @@ def wigner_subset_to_s2(
)

# APPLY SIGN FUNCTION AND PHASE SHIFT
x = np.einsum("bnamc,m,n,a->bnamc", x, 1j ** (-m), 1j ** (n), np.exp(1j * m * theta0))
x = np.einsum(
"bnamc,m,n,a->bnamc", x, 1j ** (-m), 1j ** (n), np.exp(1j * m * theta0)
)

# IFFT OVER THETA AND PHI
x = np.fft.ifftshift(x, axes=(-3, -2))
Expand Down Expand Up @@ -106,7 +112,9 @@ def wigner_subset_to_s2_jax(
"""
if sampling.lower() not in ["mw", "mwss"]:
raise ValueError(f"Fourier-Wigner algorithm does not support {sampling} sampling.")
raise ValueError(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, _ = DW
Expand All @@ -123,7 +131,9 @@ def wigner_subset_to_s2_jax(
n = -spins

# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
x = jnp.zeros((flmn.shape[0], n_dim, xnlm_size, xnlm_size, flmn.shape[-1]), dtype=flmn.dtype)
x = jnp.zeros(
(flmn.shape[0], n_dim, xnlm_size, xnlm_size, flmn.shape[-1]), dtype=flmn.dtype
)
x = x.at[:, :, m_offset:, m_offset:, :].set(
jnp.einsum(
"bnlmc,lam,lan,l->bnamc",
Expand All @@ -135,7 +145,9 @@ def wigner_subset_to_s2_jax(
)

# APPLY SIGN FUNCTION AND PHASE SHIFT
x = jnp.einsum("bnamc,m,n,a->bnamc", x, 1j ** (-m), 1j ** (n), jnp.exp(1j * m * theta0))
x = jnp.einsum(
"bnamc,m,n,a->bnamc", x, 1j ** (-m), 1j ** (n), jnp.exp(1j * m * theta0)
)

# IFFT OVER THETA AND PHI
x = jnp.fft.ifftshift(x, axes=(-3, -2))
Expand Down Expand Up @@ -258,7 +270,9 @@ def s2_to_wigner_subset(
"""
if sampling.lower() not in ["mw", "mwss"]:
raise ValueError(f"Fourier-Wigner algorithm does not support {sampling} sampling.")
raise ValueError(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, Quads = DW
Expand Down Expand Up @@ -342,7 +356,9 @@ def s2_to_wigner_subset_jax(
"""
if sampling.lower() not in ["mw", "mwss"]:
raise ValueError(f"Fourier-Wigner algorithm does not support {sampling} sampling.")
raise ValueError(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, Quads = DW
Expand Down

0 comments on commit 0a6828d

Please sign in to comment.