Skip to content

Commit

Permalink
address JDM review, switch from FFT to DFT
Browse files Browse the repository at this point in the history
  • Loading branch information
CosmoMatt committed Nov 11, 2024
1 parent 0a6828d commit 4665eec
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 44 deletions.
93 changes: 54 additions & 39 deletions s2fft/precompute_transforms/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ def wigner_subset_to_s2(
Transforms an arbitrary subset of Wigner coefficients onto a subset of spin signals
on the sphere.
This function takes a collection of spin spherical harmonic coefficients each with
a different (though not necessarily unique) spin and maps them back to their
corresponding pixel-space representations. Following this operation one may
liftn this collection of spin signals to a signal on SO(3) by exploiting the
correct Mackey functions.
This function takes a collection of spin spherical harmonic coefficients each with
a different (though not necessarily unique) spin and maps them back to their
corresponding pixel-space representations.
Args:
flmn (np.ndarray): Collection of spin spherical harmonic coefficients
Expand All @@ -33,15 +31,25 @@ def wigner_subset_to_s2(
sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss"}. Defaults to "mw".
Raises:
ValueError: If sampling scheme is not recognised.
ValueError: If the number of spins does not match the number of Wigner coefficients.
Returns:
np.ndarray: A collection of spin signals with shape :math:`[batch, n_s, n_\theta, n_\phi, channels]`.
np.ndarray: A collection of spin signals with
shape :math:`[batch, n_s, n_\theta, n_\phi, channels]`.
"""
if sampling.lower() not in ["mw", "mwss"]:
raise ValueError(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

if flmn.shape[1] != spins.shape[0]:
raise ValueError(
f"Number of spins specified {spins.shape[0]} does not match the number of Wigner coefficients {flmn.shape[1]}"
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, _ = DW

Expand Down Expand Up @@ -93,9 +101,7 @@ def wigner_subset_to_s2_jax(
This function takes a collection of spin spherical harmonic coefficients each with
a different (though not necessarily unique) spin and maps them back to their
corresponding pixel-space representations. Following this operation one may
liftn this collection of spin signals to a signal on SO(3) by exploiting the
correct Mackey functions.
corresponding pixel-space representations.
Args:
flmn (jnp.ndarray): Collection of spin spherical harmonic coefficients
Expand All @@ -107,6 +113,10 @@ def wigner_subset_to_s2_jax(
sampling (str, optional): Sampling scheme. Supported sampling schemes include
{"mw", "mwss"}. Defaults to "mw".
Raises:
ValueError: If sampling scheme is not recognised.
ValueError: If the number of spins does not match the number of Wigner coefficients.
Returns:
jnp.ndarray: A collection of spin signals with shape :math:`[batch, n_s, n_\theta, n_\phi, channels]`.
Expand All @@ -116,6 +126,11 @@ def wigner_subset_to_s2_jax(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

if flmn.shape[1] != spins.shape[0]:
raise ValueError(
f"Number of spins specified {spins.shape[0]} does not match the number of Wigner coefficients {flmn.shape[1]}"
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, _ = DW

Expand Down Expand Up @@ -167,10 +182,10 @@ def so3_to_wigner_subset(
Transforms a signal on the rotation group to an arbitrary subset of its Wigner
coefficients.
This function takes a signal on the rotation group SO(3) and computes a subset of
spin spherical harmonic coefficients corresponding to slices across the requested
spin numbers. These spin numbers can be arbitrarily chosen such that their absolute
value is less than or equal to the azimuthal band-limit :math:`N\leq L`.
This function takes a signal on the rotation group SO(3) and computes a subset of
spin spherical harmonic coefficients corresponding to slices across the requested
spin numbers. These spin numbers can be arbitrarily chosen such that their absolute
value is less than or equal to the azimuthal band-limit :math:`N\leq L`.
Args:
f (np.ndarray): Signal on the rotation group with shape :math:`[batch, n_\gamma, n_\theta,n_\phi, channels]`.
Expand All @@ -186,12 +201,11 @@ def so3_to_wigner_subset(
np.ndarray: Collection of spin spherical harmonic coefficients with shape :math:`[batch, n_s, L, 2L-1, channels]`.
"""
# COMPUTE FFT OVER GAMMA
x = np.fft.fft(f, axis=-4, norm="forward")
x = np.fft.fftshift(x, axes=-4)

# EXTRACT REQUESTED SPIN COMPONENTS
x = x[:, N - 1 - spins]
# COMPUTE DFT OVER GAMMA SUBSET
e = np.exp(
-2j * np.pi * np.einsum("g,n->gn", np.arange(f.shape[1]) / f.shape[1], -spins)
)
x = np.einsum("bgtpc,gn->bntpc", f, e) / f.shape[1]

return s2_to_wigner_subset(x, spins, DW, L, sampling)

Expand All @@ -209,10 +223,10 @@ def so3_to_wigner_subset_jax(
Transforms a signal on the rotation group to an arbitrary subset of its Wigner
coefficients (JAX).
This function takes a signal on the rotation group SO(3) and computes a subset of
spin spherical harmonic coefficients corresponding to slices across the requested
spin numbers. These spin numbers can be arbitrarily chosen such that their absolute
value is less than or equal to the azimuthal band-limit :math:`N\leq L`.
This function takes a signal on the rotation group SO(3) and computes a subset of
spin spherical harmonic coefficients corresponding to slices across the requested
spin numbers. These spin numbers can be arbitrarily chosen such that their absolute
value is less than or equal to the azimuthal band-limit :math:`N\leq L`.
Args:
f (jnp.ndarray): Signal on the rotation group with shape :math:`[batch, n_\gamma, n_\theta,n_\phi, channels]`.
Expand All @@ -229,12 +243,13 @@ def so3_to_wigner_subset_jax(
with shape :math:`[batch, n_s, L, 2L-1, channels]`.
"""
# COMPUTE FFT OVER GAMMA
x = jnp.fft.fft(f, axis=-4, norm="forward")
x = jnp.fft.fftshift(x, axes=-4)

# EXTRACT REQUESTED SPIN COMPONENTS
x = x[:, N - 1 - spins]
# COMPUTE DFT OVER GAMMA SUBSET
e = jnp.exp(
-2j
* jnp.pi
* jnp.einsum("g,n->gn", jnp.arange(f.shape[1]) / f.shape[1], -spins)
)
x = jnp.einsum("bgtpc,gn->bntpc", f, e) / f.shape[1]

return s2_to_wigner_subset_jax(x, spins, DW, L, sampling)

Expand All @@ -250,11 +265,11 @@ def s2_to_wigner_subset(
Transforms from a collection of arbitrary spin signals on the sphere to the
corresponding collection of their harmonic coefficients.
This function takes a multimodal collection of spin spherical harmonic signals
on the sphere and transforms them into their spin spherical harmonic coefficients.
These cofficients may then be combined into a subset of Wigner coefficients for
downstream analysis. In this way one may combine input features across a variety
of spins into a unified representation.
This function takes a multimodal collection of spin spherical harmonic signals
on the sphere and transforms them into their spin spherical harmonic coefficients.
These coefficients may then be combined into a subset of Wigner coefficients for
downstream analysis. In this way one may combine input features across a variety
of spins into a unified representation.
Args:
fs (np.ndarray): Collection of spin signal maps on the sphere with shape :math:`[batch, n_s, n_\theta,n_\phi, channels]`.
Expand Down Expand Up @@ -336,11 +351,11 @@ def s2_to_wigner_subset_jax(
Transforms from a collection of arbitrary spin signals on the sphere to the
corresponding collection of their harmonic coefficients (JAX).
This function takes a multimodal collection of spin spherical harmonic signals
on the sphere and transforms them into their spin spherical harmonic coefficients.
These cofficients may then be combined into a subset of Wigner coefficients for
downstream analysis. In this way one may combine input features across a variety
of spins into a unified representation.
This function takes a multimodal collection of spin spherical harmonic signals
on the sphere and transforms them into their spin spherical harmonic coefficients.
These coefficients may then be combined into a subset of Wigner coefficients for
downstream analysis. In this way one may combine input features across a variety
of spins into a unified representation.
Args:
fs (jnp.ndarray): Collection of spin signal maps on the sphere with shape :math:`[batch, n_s, n_\theta,n_\phi, channels]`.
Expand Down
6 changes: 1 addition & 5 deletions tests/test_lifting_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,7 @@ def test_custom_forward_from_so3(
spins = -np.arange(-N + 1, N)

# FUNCTION SWITCH
func = (
ops.so3_to_wigner_subset_jax
if method == "jax"
else ops.so3_to_wigner_subset_jax
)
func = ops.so3_to_wigner_subset_jax if method == "jax" else ops.so3_to_wigner_subset

# CREATE CORRECT SHAPE (BATCH: 1, CHANNELS: 1)
f = f.reshape((1,) + f.shape + (1,))
Expand Down

0 comments on commit 4665eec

Please sign in to comment.