Skip to content

Commit

Permalink
add where
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jan 15, 2024
1 parent 2ee804e commit f3ea865
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `e3nn.where` function

### Changed
- replace `jnp.ndarray` by `jax.Array`

Expand Down
3 changes: 3 additions & 0 deletions docs/api/irreps_array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ IrrepsArray


.. autofunction:: e3nn_jax.sum


.. autofunction:: e3nn_jax.where
3 changes: 3 additions & 0 deletions e3nn_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
normal,
dot,
cross,
where,
)
from e3nn_jax._src.basic import sum_ as sum
from e3nn_jax._src.spherical_harmonics import spherical_harmonics, sh, legendre
Expand Down Expand Up @@ -115,6 +116,7 @@
)
from e3nn_jax._src.utils.vmap import vmap


# make submodules flax and haiku available
from e3nn_jax import flax, haiku, equinox
from e3nn_jax import utils
Expand Down Expand Up @@ -187,6 +189,7 @@
"normal",
"dot",
"cross",
"where",
"sum",
"spherical_harmonics",
"sh",
Expand Down
43 changes: 43 additions & 0 deletions e3nn_jax/_src/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,46 @@ def normal(
return e3nn.from_chunks(irreps, list, leading_shape, dtype)
else:
raise ValueError("Normalization needs to be 'norm' or 'component'")


def where(mask: jax.Array, x: e3nn.IrrepsArray, y: e3nn.IrrepsArray):
"""Selects elements from `x` or `y`, depending on `mask`.
Equivalent to:
>>> mask = jnp.array([True, False])
>>> x = e3nn.IrrepsArray("0e", jnp.array([[1.0], [2.0]]))
>>> y = e3nn.zeros_like(x)
>>> e3nn.IrrepsArray("0e", jnp.where(mask[..., None], x.array, y.array))
Args:
mask: Boolean array of shape `(...)`.
x: IrrepsArray of shape `(..., irreps.dim)`.
y: IrrepsArray of shape `(..., irreps.dim)`.
Returns:
IrrepsArray of shape `(..., irreps.dim)`.
"""
mask = jnp.asarray(mask)
x = e3nn.as_irreps_array(x)
y = e3nn.as_irreps_array(y)

if x.irreps != y.irreps:
raise ValueError(f"e3nn.where: x.irreps ({x.irreps}) != y.irreps ({y.irreps})")

array = jnp.where(mask[..., None], x.array, y.array)

def f(x: Optional[jax.Array], y: Optional[jax.Array]) -> Optional[jax.Array]:
if x is None and y is None:
return None
elif x is None:
return jnp.where(mask[..., None, None], 0.0, y)
elif y is None:
return jnp.where(mask[..., None, None], x, 0.0)
else:
return jnp.where(mask[..., None, None], x, y)

chunks = [f(x, y) for x, y in zip(x.chunks, y.chunks)]

zero_flags = [x and y for x, y in zip(x.zero_flags, y.zero_flags)]

return e3nn.IrrepsArray(x.irreps, array, zero_flags=zero_flags, chunks=chunks)
5 changes: 1 addition & 4 deletions e3nn_jax/_src/irreps_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,7 @@ def chunks(self) -> List[Optional[jax.Array]]:
if self._chunks is None:
jnp = _infer_backend(self.array)
leading_shape = self.array.shape[:-1]
if self.zero_flags is None:
zeros = [False] * len(self.irreps)
else:
zeros = self.zero_flags
zeros = self.zero_flags

if len(self.irreps) == 1:
mul, ir = self.irreps[0]
Expand Down
12 changes: 12 additions & 0 deletions tests/_src/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,15 @@ def test_stack2():
assert y.irreps == "0e + 1e"
assert_array_equals_chunks(y)
assert y.zero_flags == (True, False)


def test_where():
mask = jnp.array([True, False])
x = e3nn.IrrepsArray("0e", jnp.array([[1.0], [2.0]]))
y = e3nn.zeros_like(x)

A = e3nn.IrrepsArray("0e", jnp.where(mask[..., None], x.array, y.array))
B = e3nn.where(mask, x, y)

assert A.irreps == B.irreps
np.testing.assert_allclose(A.array, B.array)

0 comments on commit f3ea865

Please sign in to comment.