Skip to content

Commit

Permalink
Merge pull request #87 from e3nn/so3
Browse files Browse the repository at this point in the history
Don't pre-multiply by angle measure.
  • Loading branch information
ameya98 authored Dec 2, 2024
2 parents 058758a + 8f22a3a commit 9a82808
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
21 changes: 10 additions & 11 deletions e3nn_jax/_src/s2grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,9 @@ def __repr__(self) -> str:
else:
return f"SphericalSignal({self.grid_values})"

def __mul__(self, scalar: Union[float, "SphericalSignal"]) -> "SphericalSignal":
def __mul__(self, other: Union[float, "SphericalSignal"]) -> "SphericalSignal":
"""Multiply SphericalSignal by a scalar."""
if isinstance(scalar, SphericalSignal):
other = scalar
if isinstance(other, SphericalSignal):
if self.quadrature != other.quadrature:
raise ValueError(
"Multiplication of SphericalSignals with different quadrature is not supported."
Expand All @@ -231,22 +230,22 @@ def __mul__(self, scalar: Union[float, "SphericalSignal"]) -> "SphericalSignal":
p_arg=self.p_arg,
)

if isinstance(scalar, e3nn.IrrepsArray):
if scalar.irreps != e3nn.Irreps("0e"):
if isinstance(other, e3nn.IrrepsArray):
if other.irreps != e3nn.Irreps("0e"):
raise ValueError("Scalar must be a 0e IrrepsArray.")
scalar = scalar.array[..., 0]
other = other.array[..., 0]

scalar = jnp.asarray(scalar)[..., None, None]
other = jnp.asarray(other)[..., None, None]
return SphericalSignal(
self.grid_values * scalar,
self.grid_values * other,
self.quadrature,
p_val=self.p_val,
p_arg=self.p_arg,
)

def __rmul__(self, scalar: float) -> "SphericalSignal":
"""Multiply SphericalSignal by a scalar."""
return self * scalar
def __rmul__(self, other: Union[float, "SphericalSignal"]) -> "SphericalSignal":
"""Multiply SphericalSignal by a compatible SphericalSignal or scalar."""
return self * other

def __truediv__(self, scalar: float) -> "SphericalSignal":
"""Divide SphericalSignal by a scalar."""
Expand Down
24 changes: 20 additions & 4 deletions e3nn_jax/_src/so3grid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Tuple
from typing import Callable, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -107,16 +107,32 @@ def from_function(
batch_dims = fs.shape[:-3]
assert fs.shape == (*batch_dims, res_theta, res_beta, res_alpha)

# Account for angle-dependency in Haar measure.
fs = fs * (1 - jnp.cos(angles))[..., None, None]
s2_signals = s2_signals.replace_values(fs)
assert s2_signals.shape == (*batch_dims, res_theta, res_beta, res_alpha)
return SO3Signal(s2_signals)

def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
if isinstance(other, float):
return SO3Signal(self.s2_signals * other)

if self.shape != other.shape:
raise ValueError(
f"Shapes of the two signals do not match: {self.shape} != {other.shape}"
)
return SO3Signal(self.s2_signals * other.s2_signals)

def __truediv__(self, other: float) -> "SO3Signal":
return self * (1 / other)

def integrate_over_angles(self) -> SphericalSignal:
"""Integrate the signal over the angles in the axis-angle parametrization."""
# Account for angle-dependency in Haar measure.
grid_values = self.s2_signals.grid_values * (1 - jnp.cos(self.grid_theta))[..., None, None]

# Trapezoidal rule for integration.
delta_theta = self.grid_theta[1] - self.grid_theta[0]
return self.s2_signals.replace_values(
grid_values=jnp.sum(self.s2_signals.grid_values, axis=-3) * delta_theta
grid_values=jnp.sum(grid_values, axis=-3) * delta_theta
)

def integrate(self) -> SphericalSignal:
Expand Down

0 comments on commit 9a82808

Please sign in to comment.