Skip to content

Commit

Permalink
Correct some docstrings; only warn for high spin for certain methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonmcewen committed Apr 8, 2024
1 parent 4c54ede commit 5ea9c8d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 23 deletions.
8 changes: 4 additions & 4 deletions s2fft/transforms/c_backend_spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _ssht_inverse_bwd(res, f):
_, L, spin, reality, ssht_sampling, _ssht_backend = res
sampling_str = ["MW", "MWSS", "DH", "GL"]
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
if ssht_sampling < 2:
if ssht_sampling < 2: # MW or MWSS sampling
flm = np.conj(
pyssht.inverse_adjoint(
np.conj(f),
Expand All @@ -92,7 +92,7 @@ def _ssht_inverse_bwd(res, f):
backend=_backend,
)
)
else:
else: # DH or GL samping
quad_weights = quadrature_jax.quad_weights_transform(
L, sampling_str[ssht_sampling].lower()
)
Expand Down Expand Up @@ -192,7 +192,7 @@ def _ssht_forward_bwd(res, flm):
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
flm_1d = samples.flm_2d_to_1d(flm, L)

if ssht_sampling < 2:
if ssht_sampling < 2: # MW or MWSS sampling
f = np.conj(
pyssht.forward_adjoint(
np.conj(flm_1d),
Expand All @@ -203,7 +203,7 @@ def _ssht_forward_bwd(res, flm):
backend=_backend,
)
)
else:
else: # DH or GL sampling
quad_weights = quadrature_jax.quad_weights_transform(
L, sampling_str[ssht_sampling].lower()
)
Expand Down
19 changes: 8 additions & 11 deletions s2fft/transforms/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ def inverse(
between devices is noticable, however as L increases one will asymptotically
recover acceleration by the number of devices.
"""
if spin >= 8:

if spin >= 8 and method in ["numpy", "jax"]:
raise Warning("Recursive transform may provide lower precision beyond spin ~ 8")

if method == "numpy":
return inverse_numpy(flm, L, spin, nside, sampling, reality, precomps, L_lower)
elif method == "jax":
Expand Down Expand Up @@ -130,8 +132,6 @@ def inverse_numpy(
sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy".
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs. Defaults to
False.
Expand Down Expand Up @@ -229,8 +229,6 @@ def inverse_jax(
sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy".
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs. Defaults to
False.
Expand Down Expand Up @@ -354,7 +352,8 @@ def forward(
sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy".
method (str, optional): Execution mode in {"numpy", "jax", "jax_ssht", "jax_healpy"}.
Defaults to "numpy".
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs. Defaults to
Expand Down Expand Up @@ -391,8 +390,10 @@ def forward(
between devices is noticable, however as L increases one will asymptotically
recover acceleration by the number of devices.
"""
if spin >= 8:

if spin >= 8 and method in ["numpy", "jax"]:
raise Warning("Recursive transform may provide lower precision beyond spin ~ 8")

if method == "numpy":
return forward_numpy(f, L, spin, nside, sampling, reality, precomps, L_lower)
elif method == "jax":
Expand Down Expand Up @@ -443,8 +444,6 @@ def forward_numpy(
sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy".
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs. Defaults to
False.
Expand Down Expand Up @@ -570,8 +569,6 @@ def forward_jax(
sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy".
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs. Defaults to
False.
Expand Down
14 changes: 6 additions & 8 deletions s2fft/transforms/wigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def inverse(
[1] McEwen, Jason D. and Yves Wiaux. “A Novel Sampling Theorem on the Sphere.”
IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
"""
if N >= 8:

if N >= 8 and method in ["numpy", "jax"]:
raise Warning("Recursive transform may provide lower precision beyond N ~ 8")

if method == "numpy":
return inverse_numpy(flmn, L, N, nside, sampling, reality, precomps, L_lower)
elif method == "jax":
Expand Down Expand Up @@ -144,8 +146,6 @@ def inverse_numpy(
sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy".
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs. Defaults to
False.
Expand Down Expand Up @@ -505,8 +505,10 @@ def forward(
[1] McEwen, Jason D. and Yves Wiaux. “A Novel Sampling Theorem on the Sphere.”
IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
"""
if N >= 8:

if N >= 8 and method in ["numpy", "jax"]:
raise Warning("Recursive transform may provide lower precision beyond N ~ 8")

if method == "numpy":
return forward_numpy(f, L, N, nside, sampling, reality, precomps, L_lower)
elif method == "jax":
Expand Down Expand Up @@ -559,8 +561,6 @@ def forward_numpy(
sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy".
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs. Defaults to
False.
Expand Down Expand Up @@ -655,8 +655,6 @@ def forward_jax(
sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy".
reality (bool, optional): Whether the signal on the sphere is real. If so,
conjugate symmetry is exploited to reduce computational costs. Defaults to
False.
Expand Down

0 comments on commit 5ea9c8d

Please sign in to comment.