Skip to content

Commit

Permalink
Merge pull request #200 from astro-informatics/feature/acceleration
Browse files Browse the repository at this point in the history
Feature/acceleration
  • Loading branch information
CosmoMatt authored Apr 15, 2024
2 parents da9eaf0 + 7755673 commit bc7cbd8
Show file tree
Hide file tree
Showing 12 changed files with 498 additions and 342 deletions.
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
author = "Matthew Price, Jason McEwen, Matthew Graham, Sofia Miñano, Devaraj Gopinathan"

# The short X.Y version
version = "1.1.0"
version = "1.1.1"
# The full version, including alpha/beta/rc tags
release = "1.1.0"
release = "1.1.1"


# -- General configuration ---------------------------------------------------
Expand Down
39 changes: 30 additions & 9 deletions s2fft/precompute_transforms/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,25 @@ def inverse_transform_jax(
)
)
ftm *= (-1) ** spin
if reality:
ftm = ftm.at[:, m_offset : m_start_ind + m_offset].set(
jnp.flip(jnp.conj(ftm[:, m_start_ind + m_offset + 1 :]), axis=-1)
)

if sampling.lower() == "healpix":
if reality:
ftm = ftm.at[:, m_offset : m_start_ind + m_offset].set(
jnp.flip(jnp.conj(ftm[:, m_start_ind + m_offset + 1 :]), axis=-1)
)
f = hp.healpix_ifft(ftm, L, nside, "jax", reality)

else:
f = jnp.conj(jnp.fft.ifftshift(ftm, axes=-1))
f = jnp.conj(jnp.fft.fft(f, axis=-1, norm="backward"))
if reality:
f = jnp.fft.irfft(
ftm[:, m_start_ind + m_offset :],
samples.nphi_equiang(L, sampling),
axis=-1,
norm="forward",
)
else:
f = jnp.fft.ifftshift(ftm, axes=-1)
f = jnp.fft.ifft(f, axis=-1, norm="forward")

return jnp.real(f) if reality else f


Expand Down Expand Up @@ -247,11 +255,24 @@ def inverse_transform_torch(
)

if sampling.lower() == "healpix":
if reality:
ftm[:, m_offset : m_start_ind + m_offset] = torch.flip(
torch.conj(ftm[:, m_start_ind + m_offset + 1 :]), dims=[-1]
)
f = hp.healpix_ifft(ftm, L, nside, "torch", reality)

else:
f = torch.conj(torch.fft.ifftshift(ftm, dim=[-1]))
f = torch.conj(torch.fft.fft(f, axis=-1, norm="backward"))
if reality:
f = torch.fft.irfft(
ftm[:, m_start_ind + m_offset :],
samples.nphi_equiang(L, sampling),
axis=-1,
norm="forward",
)
else:
f = torch.fft.ifftshift(ftm, dim=[-1])
f = torch.fft.ifft(f, axis=-1, norm="forward")

return f.real if reality else f


Expand Down
1 change: 1 addition & 0 deletions s2fft/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import s2_samples
from . import so3_samples
from . import reindex
186 changes: 186 additions & 0 deletions s2fft/sampling/reindex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from jax import jit
import jax.numpy as jnp
from functools import partial


@partial(jit, static_argnums=(1))
def flm_1d_to_2d_fast(flm_1d: jnp.ndarray, L: int) -> jnp.ndarray:
r"""Convert from 1D indexed harmnonic coefficients to 2D indexed coefficients (JAX).
Note:
Storage conventions for harmonic coefficients :math:`flm_{(\ell,m)}`, for
e.g. :math:`L = 3`, are as follows.
.. math::
\text{ 2D data format}:
\begin{bmatrix}
0 & 0 & flm_{(0,0)} & 0 & 0 \\
0 & flm_{(1,-1)} & flm_{(1,0)} & flm_{(1,1)} & 0 \\
flm_{(2,-2)} & flm_{(2,-1)} & flm_{(2,0)} & flm_{(2,1)} & flm_{(2,2)}
\end{bmatrix}
.. math::
\text{1D data format}: [flm_{0,0}, flm_{1,-1}, flm_{1,0}, flm_{1,1}, \dots]
Args:
flm_1d (jnp.ndarray): 1D indexed harmonic coefficients.
L (int): Harmonic band-limit.
Returns:
jnp.ndarray: 2D indexed harmonic coefficients.
"""
flm_2d = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128)
els = jnp.arange(L)
offset = els**2 + els
for el in range(L):
m_array = jnp.arange(-el, el + 1)
flm_2d = flm_2d.at[el, L - 1 + m_array].set(flm_1d[offset[el] + m_array])
return flm_2d


@partial(jit, static_argnums=(1))
def flm_2d_to_1d_fast(flm_2d: jnp.ndarray, L: int) -> jnp.ndarray:
r"""Convert from 2D indexed harmonic coefficients to 1D indexed coefficients (JAX).
Note:
Storage conventions for harmonic coefficients :math:`flm_{(\ell,m)}`, for
e.g. :math:`L = 3`, are as follows.
.. math::
\text{ 2D data format}:
\begin{bmatrix}
0 & 0 & flm_{(0,0)} & 0 & 0 \\
0 & flm_{(1,-1)} & flm_{(1,0)} & flm_{(1,1)} & 0 \\
flm_{(2,-2)} & flm_{(2,-1)} & flm_{(2,0)} & flm_{(2,1)} & flm_{(2,2)}
\end{bmatrix}
.. math::
\text{1D data format}: [flm_{0,0}, flm_{1,-1}, flm_{1,0}, flm_{1,1}, \dots]
Args:
flm_2d (jnp.ndarray): 2D indexed harmonic coefficients.
L (int): Harmonic band-limit.
Returns:
jnp.ndarray: 1D indexed harmonic coefficients.
"""
flm_1d = jnp.zeros(L**2, dtype=jnp.complex128)
els = jnp.arange(L)
offset = els**2 + els
for el in range(L):
m_array = jnp.arange(-el, el + 1)
flm_1d = flm_1d.at[offset[el] + m_array].set(flm_2d[el, L - 1 + m_array])
return flm_1d


@partial(jit, static_argnums=(1))
def flm_hp_to_2d_fast(flm_hp: jnp.ndarray, L: int) -> jnp.ndarray:
r"""Converts from HEALPix (healpy) indexed harmonic coefficients to 2D indexed
coefficients (JAX).
Notes:
HEALPix implicitly assumes conjugate symmetry and thus only stores positive `m`
coefficients. Here we unpack that into harmonic coefficients of an
explicitly real signal.
Warning:
Note that the harmonic band-limit `L` differs to the HEALPix `lmax` convention,
where `L = lmax + 1`.
Note:
Storage conventions for harmonic coefficients :math:`f_{(\ell,m)}`, for
e.g. :math:`L = 3`, are as follows.
.. math::
\text{ 2D data format}:
\begin{bmatrix}
0 & 0 & flm_{(0,0)} & 0 & 0 \\
0 & flm_{(1,-1)} & flm_{(1,0)} & flm_{(1,1)} & 0 \\
flm_{(2,-2)} & flm_{(2,-1)} & flm_{(2,0)} & flm_{(2,1)} & flm_{(2,2)}
\end{bmatrix}
.. math::
\text{HEALPix}: [flm_{(0,0)}, \dots, flm_{(2,0)}, flm_{(1,1)}, \dots, flm_{(L-1,1)}, \dots]
Note:
Returns harmonic coefficients of an explicitly real signal.
Args:
flm_hp (jnp.ndarray): HEALPix indexed harmonic coefficients.
L (int): Harmonic band-limit.
Returns:
jnp.ndarray: 2D indexed harmonic coefficients.
"""
flm_2d = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128)

for el in range(L):
flm_2d = flm_2d.at[el, L - 1].set(flm_hp[el])
m_array = jnp.arange(1, el + 1)
hp_idx = m_array * (2 * L - 1 - m_array) // 2 + el
flm_2d = flm_2d.at[el, L - 1 + m_array].set(flm_hp[hp_idx])
flm_2d = flm_2d.at[el, L - 1 - m_array].set(
(-1) ** m_array * jnp.conj(flm_hp[hp_idx])
)

return flm_2d


@partial(jit, static_argnums=(1))
def flm_2d_to_hp_fast(flm_2d: jnp.ndarray, L: int) -> jnp.ndarray:
r"""Converts from 2D indexed harmonic coefficients to HEALPix (healpy) indexed
coefficients (JAX).
Note:
HEALPix implicitly assumes conjugate symmetry and thus only stores positive `m`
coefficients. So this function discards the negative `m` values. This process
is NOT invertible! See the `healpy api docs <https://healpy.readthedocs.io/en/latest/generated/healpy.sphtfunc.alm2map.html>`_
for details on healpy indexing and lengths.
Note:
Storage conventions for harmonic coefficients :math:`f_{(\ell,m)}`, for
e.g. :math:`L = 3`, are as follows.
.. math::
\text{ 2D data format}:
\begin{bmatrix}
0 & 0 & flm_{(0,0)} & 0 & 0 \\
0 & flm_{(1,-1)} & flm_{(1,0)} & flm_{(1,1)} & 0 \\
flm_{(2,-2)} & flm_{(2,-1)} & flm_{(2,0)} & flm_{(2,1)} & flm_{(2,2)}
\end{bmatrix}
.. math::
\text{HEALPix}: [flm_{(0,0)}, \dots, flm_{(2,0)}, flm_{(1,1)}, \dots, flm_{(L-1,1)}, \dots]
Warning:
Returns harmonic coefficients of an explicitly real signal.
Warning:
Note that the harmonic band-limit `L` differs to the HEALPix `lmax` convention,
where `L = lmax + 1`.
Args:
flm_2d (jnp.ndarray): 2D indexed harmonic coefficients.
L (int): Harmonic band-limit.
Returns:
jnp.ndarray: HEALPix indexed harmonic coefficients.
"""
flm_hp = jnp.zeros(int(L * (L + 1) / 2), dtype=jnp.complex128)

for el in range(L):
m_array = jnp.arange(el + 1)
hp_idx = m_array * (2 * L - 1 - m_array) // 2 + el
flm_hp = flm_hp.at[hp_idx].set(flm_2d[el, L - 1 + m_array])

return flm_hp
1 change: 0 additions & 1 deletion s2fft/sampling/s2_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,6 @@ def flm_hp_to_2d(flm_hp: np.ndarray, L: int) -> np.ndarray:
Returns:
np.ndarray: 2D indexed harmonic coefficients.
"""
flm_2d = np.zeros(flm_shape(L), dtype=np.complex128)

Expand Down
Loading

0 comments on commit bc7cbd8

Please sign in to comment.