diff --git a/s2fft/transforms/c_backend_spherical.py b/s2fft/transforms/c_backend_spherical.py index 2fddd2a8..92631844 100644 --- a/s2fft/transforms/c_backend_spherical.py +++ b/s2fft/transforms/c_backend_spherical.py @@ -81,7 +81,7 @@ def _ssht_inverse_bwd(res, f): _, L, spin, reality, ssht_sampling, _ssht_backend = res sampling_str = ["MW", "MWSS", "DH", "GL"] _backend = "SSHT" if _ssht_backend == 0 else "ducc0" - if ssht_sampling < 2: + if ssht_sampling < 2: # MW or MWSS sampling flm = np.conj( pyssht.inverse_adjoint( np.conj(f), @@ -92,7 +92,7 @@ def _ssht_inverse_bwd(res, f): backend=_backend, ) ) - else: + else: # DH or GL samping quad_weights = quadrature_jax.quad_weights_transform( L, sampling_str[ssht_sampling].lower() ) @@ -192,7 +192,7 @@ def _ssht_forward_bwd(res, flm): _backend = "SSHT" if _ssht_backend == 0 else "ducc0" flm_1d = samples.flm_2d_to_1d(flm, L) - if ssht_sampling < 2: + if ssht_sampling < 2: # MW or MWSS sampling f = np.conj( pyssht.forward_adjoint( np.conj(flm_1d), @@ -203,7 +203,7 @@ def _ssht_forward_bwd(res, flm): backend=_backend, ) ) - else: + else: # DH or GL sampling quad_weights = quadrature_jax.quad_weights_transform( L, sampling_str[ssht_sampling].lower() ) diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index e44fa8d6..1b453f1d 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -78,8 +78,10 @@ def inverse( between devices is noticable, however as L increases one will asymptotically recover acceleration by the number of devices. """ - if spin >= 8: + + if spin >= 8 and method in ["numpy", "jax"]: raise Warning("Recursive transform may provide lower precision beyond spin ~ 8") + if method == "numpy": return inverse_numpy(flm, L, spin, nside, sampling, reality, precomps, L_lower) elif method == "jax": @@ -130,8 +132,6 @@ def inverse_numpy( sampling (str, optional): Sampling scheme. Supported sampling schemes include {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". - reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -229,8 +229,6 @@ def inverse_jax( sampling (str, optional): Sampling scheme. Supported sampling schemes include {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". - reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -354,7 +352,8 @@ def forward( sampling (str, optional): Sampling scheme. Supported sampling schemes include {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". + method (str, optional): Execution mode in {"numpy", "jax", "jax_ssht", "jax_healpy"}. + Defaults to "numpy". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to @@ -391,8 +390,10 @@ def forward( between devices is noticable, however as L increases one will asymptotically recover acceleration by the number of devices. """ - if spin >= 8: + + if spin >= 8 and method in ["numpy", "jax"]: raise Warning("Recursive transform may provide lower precision beyond spin ~ 8") + if method == "numpy": return forward_numpy(f, L, spin, nside, sampling, reality, precomps, L_lower) elif method == "jax": @@ -443,8 +444,6 @@ def forward_numpy( sampling (str, optional): Sampling scheme. Supported sampling schemes include {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". - reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -570,8 +569,6 @@ def forward_jax( sampling (str, optional): Sampling scheme. Supported sampling schemes include {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". - reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. diff --git a/s2fft/transforms/wigner.py b/s2fft/transforms/wigner.py index 4fd3937d..6f979027 100644 --- a/s2fft/transforms/wigner.py +++ b/s2fft/transforms/wigner.py @@ -90,8 +90,10 @@ def inverse( [1] McEwen, Jason D. and Yves Wiaux. “A Novel Sampling Theorem on the Sphere.” IEEE Transactions on Signal Processing 59 (2011): 5876-5887. """ - if N >= 8: + + if N >= 8 and method in ["numpy", "jax"]: raise Warning("Recursive transform may provide lower precision beyond N ~ 8") + if method == "numpy": return inverse_numpy(flmn, L, N, nside, sampling, reality, precomps, L_lower) elif method == "jax": @@ -144,8 +146,6 @@ def inverse_numpy( sampling (str, optional): Sampling scheme. Supported sampling schemes include {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". - reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -505,8 +505,10 @@ def forward( [1] McEwen, Jason D. and Yves Wiaux. “A Novel Sampling Theorem on the Sphere.” IEEE Transactions on Signal Processing 59 (2011): 5876-5887. """ - if N >= 8: + + if N >= 8 and method in ["numpy", "jax"]: raise Warning("Recursive transform may provide lower precision beyond N ~ 8") + if method == "numpy": return forward_numpy(f, L, N, nside, sampling, reality, precomps, L_lower) elif method == "jax": @@ -559,8 +561,6 @@ def forward_numpy( sampling (str, optional): Sampling scheme. Supported sampling schemes include {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". - reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False. @@ -655,8 +655,6 @@ def forward_jax( sampling (str, optional): Sampling scheme. Supported sampling schemes include {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". - method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". - reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to False.