Skip to content

Commit

Permalink
Merge pull request #90 from e3nn/so3
Browse files Browse the repository at this point in the history
Fix multiplication.
  • Loading branch information
ameya98 authored Dec 3, 2024
2 parents f9ac41f + 7c0a392 commit 941afa1
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions e3nn_jax/_src/so3grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,14 @@ def from_function(
return SO3Signal(s2_signals)

def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
if isinstance(other, float):
return SO3Signal(self.s2_signals * other)

if self.shape != other.shape:
raise ValueError(
f"Shapes of the two signals do not match: {self.shape} != {other.shape}"
)
return SO3Signal(self.s2_signals * other.s2_signals)
if isinstance(other, SO3Signal):
if self.shape != other.shape:
raise ValueError(
f"Shapes of the two signals do not match: {self.shape} != {other.shape}"
)
return SO3Signal(self.s2_signals * other.s2_signals)

return SO3Signal(self.s2_signals * other)

def __truediv__(self, other: float) -> "SO3Signal":
return self * (1 / other)
Expand Down

0 comments on commit 941afa1

Please sign in to comment.