From ef33ae95123e2714f5568073e9e662f8877d7d32 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 22 Jan 2024 15:45:56 +0100 Subject: [PATCH] jax.numpy.ndarray -> jax.Array --- e3nn_jax/_src/basic.py | 4 +- e3nn_jax/_src/irreps.py | 50 ++-- e3nn_jax/_src/irreps_array.py | 10 +- e3nn_jax/_src/mlp_flax.py | 2 +- e3nn_jax/_src/mlp_haiku.py | 2 +- e3nn_jax/_src/radius_graph.py | 8 +- e3nn_jax/_src/rotation.py | 280 +++++++++--------- e3nn_jax/_src/s2grid.py | 32 +- e3nn_jax/_src/scatter.py | 28 +- e3nn_jax/_src/spherical_harmonics/__init__.py | 6 +- .../_src/symmetric_tensor_product_haiku.py | 4 +- 11 files changed, 213 insertions(+), 213 deletions(-) diff --git a/e3nn_jax/_src/basic.py b/e3nn_jax/_src/basic.py index efbc16f..393b8b9 100644 --- a/e3nn_jax/_src/basic.py +++ b/e3nn_jax/_src/basic.py @@ -21,7 +21,7 @@ def from_chunks( Args: irreps (Irreps): irreps - chunks (list of optional `jax.numpy.ndarray`): list of arrays + chunks (list of optional `jax.Array`): list of arrays leading_shape (tuple of int): leading shape of the arrays (without the irreps) Returns: @@ -82,7 +82,7 @@ def as_irreps_array(array: Union[jax.Array, e3nn.IrrepsArray], *, backend=None): """Convert an array to an IrrepsArray. Args: - array (jax.numpy.ndarray or IrrepsArray): array to convert + array (jax.Array or IrrepsArray): array to convert Returns: IrrepsArray diff --git a/e3nn_jax/_src/irreps.py b/e3nn_jax/_src/irreps.py index 07c1500..f3e00f4 100644 --- a/e3nn_jax/_src/irreps.py +++ b/e3nn_jax/_src/irreps.py @@ -119,12 +119,12 @@ def D_from_log_coordinates(self, log_coordinates, k=0): (matrix) Representation of :math:`O(3)`. :math:`D` is the representation of :math:`SO(3)`. Args: - log_coordinates (`jax.numpy.ndarray`): of shape :math:`(..., 3)` - k (optional `jax.numpy.ndarray`): of shape :math:`(...)` + log_coordinates (`jax.Array`): of shape :math:`(..., 3)` + k (optional `jax.Array`): of shape :math:`(...)` How many times the parity is applied. Returns: - `jax.numpy.ndarray`: of shape :math:`(..., 2l+1, 2l+1)` + `jax.Array`: of shape :math:`(..., 2l+1, 2l+1)` See Also: Irreps.D_from_log_coordinates @@ -144,17 +144,17 @@ def D_from_angles(self, alpha, beta, gamma, k=0): (matrix) Representation of :math:`O(3)`. :math:`D` is the representation of :math:`SO(3)`. Args: - alpha (`jax.numpy.ndarray`): of shape :math:`(...)` + alpha (`jax.Array`): of shape :math:`(...)` Rotation :math:`\alpha` around Y axis, applied third. - beta (`jax.numpy.ndarray`): of shape :math:`(...)` + beta (`jax.Array`): of shape :math:`(...)` Rotation :math:`\beta` around X axis, applied second. - gamma (`jax.numpy.ndarray`): of shape :math:`(...)` + gamma (`jax.Array`): of shape :math:`(...)` Rotation :math:`\gamma` around Y axis, applied first. - k (optional `jax.numpy.ndarray`): of shape :math:`(...)` + k (optional `jax.Array`): of shape :math:`(...)` How many times the parity is applied. Returns: - `jax.numpy.ndarray`: of shape :math:`(..., 2l+1, 2l+1)` + `jax.Array`: of shape :math:`(..., 2l+1, 2l+1)` See Also: Irreps.D_from_angles @@ -196,11 +196,11 @@ def D_from_quaternion(self, q, k=0): r"""Matrix of the representation, see `Irrep.D_from_angles`. Args: - q (`jax.numpy.ndarray`): shape :math:`(..., 4)` - k (optional `jax.numpy.ndarray`): shape :math:`(...)` + q (`jax.Array`): shape :math:`(..., 4)` + k (optional `jax.Array`): shape :math:`(...)` Returns: - `jax.numpy.ndarray`: shape :math:`(..., 2l+1, 2l+1)` + `jax.Array`: shape :math:`(..., 2l+1, 2l+1)` """ return self.D_from_angles(*quaternion_to_angles(q), k) @@ -208,11 +208,11 @@ def D_from_matrix(self, R): r"""Matrix of the representation. Args: - R (`jax.numpy.ndarray`): array of shape :math:`(..., 3, 3)` - k (`jax.numpy.ndarray`, optional): array of shape :math:`(...)` + R (`jax.Array`): array of shape :math:`(..., 3, 3)` + k (`jax.Array`, optional): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 2l+1, 2l+1)` + `jax.Array`: array of shape :math:`(..., 2l+1, 2l+1)` Examples: >>> m = Irrep(1, -1).D_from_matrix(-jnp.eye(3)) @@ -238,7 +238,7 @@ def generators(self): r"""Generators of the representation of :math:`SO(3)`. Returns: - `jax.numpy.ndarray`: array of shape :math:`(3, 2l+1, 2l+1)` + `jax.Array`: array of shape :math:`(3, 2l+1, 2l+1)` See Also: `generators` @@ -868,11 +868,11 @@ def D_from_log_coordinates(self, log_coordinates, k=0): r"""Matrix of the representation. Args: - log_coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - k (`jax.numpy.ndarray`, optional): array of shape :math:`(...)` + log_coordinates (`jax.Array`): array of shape :math:`(..., 3)` + k (`jax.Array`, optional): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` + `jax.Array`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` """ return jax.scipy.linalg.block_diag( *[ @@ -892,7 +892,7 @@ def D_from_angles(self, alpha, beta, gamma, k=0): k (int): parity operation Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` + `jax.Array`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` """ return jax.scipy.linalg.block_diag( *[ @@ -906,11 +906,11 @@ def D_from_quaternion(self, q, k=0): r"""Matrix of the representation. Args: - q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` - k (`jax.numpy.ndarray`, optional): array of shape :math:`(...)` + q (`jax.Array`): array of shape :math:`(..., 4)` + k (`jax.Array`, optional): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` + `jax.Array`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` """ return self.D_from_angles(*quaternion_to_angles(q), k) @@ -918,10 +918,10 @@ def D_from_matrix(self, R): r"""Matrix of the representation. Args: - R (`jax.numpy.ndarray`): array of shape :math:`(..., 3, 3)` + R (`jax.Array`): array of shape :math:`(..., 3, 3)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` + `jax.Array`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` """ d = jnp.sign(jnp.linalg.det(R)) R = d[..., None, None] * R @@ -937,7 +937,7 @@ def generators(self) -> jax.Array: r"""Generators of the representation. Returns: - `jax.numpy.ndarray`: array of shape :math:`(3, \mathrm{dim}, \mathrm{dim})` + `jax.Array`: array of shape :math:`(3, \mathrm{dim}, \mathrm{dim})` """ return jax.vmap(jax.scipy.linalg.block_diag)( *[ir.generators() for mul, ir in self for _ in range(mul)] diff --git a/e3nn_jax/_src/irreps_array.py b/e3nn_jax/_src/irreps_array.py index 07945bd..4becb10 100644 --- a/e3nn_jax/_src/irreps_array.py +++ b/e3nn_jax/_src/irreps_array.py @@ -51,7 +51,7 @@ class IrrepsArray: Args: irreps (Irreps): representation of the data - array (`jax.numpy.ndarray`): the data, an array of shape ``(..., irreps.dim)`` + array (`jax.Array`): the data, an array of shape ``(..., irreps.dim)`` zero_flags (tuple of bool, optional): whether each chunk of the data is zero Examples: @@ -962,7 +962,7 @@ def transform_by_log_coordinates( r"""Rotate data by a rotation given by log coordinates. Args: - log_coordinates (`jax.numpy.ndarray`): log coordinates + log_coordinates (`jax.Array`): log coordinates k (int): parity operation Returns: @@ -1039,7 +1039,7 @@ def transform_by_quaternion(self, q: jax.Array, k: int = 0) -> "IrrepsArray": r"""Rotate data by a rotation given by a quaternion. Args: - q (`jax.numpy.ndarray`): quaternion + q (`jax.Array`): quaternion k (int): parity operation Returns: @@ -1055,7 +1055,7 @@ def transform_by_axis_angle( r"""Rotate data by a rotation given by an axis and an angle. Args: - axis (`jax.numpy.ndarray`): axis + axis (`jax.Array`): axis angle (float): angle (in radians) k (int): parity operation @@ -1070,7 +1070,7 @@ def transform_by_matrix(self, R: jax.Array) -> "IrrepsArray": r"""Rotate data by a rotation given by a matrix. Args: - R (`jax.numpy.ndarray`): rotation matrix + R (`jax.Array`): rotation matrix Returns: `IrrepsArray`: rotated data diff --git a/e3nn_jax/_src/mlp_flax.py b/e3nn_jax/_src/mlp_flax.py index b00c721..2fe7503 100644 --- a/e3nn_jax/_src/mlp_flax.py +++ b/e3nn_jax/_src/mlp_flax.py @@ -33,7 +33,7 @@ def __call__( ) -> Union[jax.Array, e3nn.IrrepsArray]: """Evaluate the MLP - Input and output are either `jax.numpy.ndarray` or `IrrepsArray`. + Input and output are either `jax.Array` or `IrrepsArray`. If the input is a `IrrepsArray`, it must contain only scalars. Args: diff --git a/e3nn_jax/_src/mlp_haiku.py b/e3nn_jax/_src/mlp_haiku.py index 34f502e..2793b00 100644 --- a/e3nn_jax/_src/mlp_haiku.py +++ b/e3nn_jax/_src/mlp_haiku.py @@ -59,7 +59,7 @@ def __call__( ) -> Union[jax.Array, e3nn.IrrepsArray]: """Evaluate the MLP - Input and output are either `jax.numpy.ndarray` or `IrrepsArray`. + Input and output are either `jax.Array` or `IrrepsArray`. If the input is a `IrrepsArray`, it must contain only scalars. Args: diff --git a/e3nn_jax/_src/radius_graph.py b/e3nn_jax/_src/radius_graph.py index b3bf551..43fa7fb 100644 --- a/e3nn_jax/_src/radius_graph.py +++ b/e3nn_jax/_src/radius_graph.py @@ -19,17 +19,17 @@ def radius_graph( r"""Try to use ``matscipy.neighbours.neighbour_list`` instead. Args: - pos (`jax.numpy.ndarray`): array of shape ``(n, 3)`` + pos (`jax.Array`): array of shape ``(n, 3)`` r_max (float): - batch (`jax.numpy.ndarray`): indices + batch (`jax.Array`): indices size (int): size of the output loop (bool): whether to include self-loops Returns: (tuple): tuple containing: - jax.numpy.ndarray: source indices - jax.numpy.ndarray: destination indices + jax.Array: source indices + jax.Array: destination indices Examples: >>> key = jax.random.PRNGKey(0) diff --git a/e3nn_jax/_src/rotation.py b/e3nn_jax/_src/rotation.py index 9fb0ff1..929953a 100644 --- a/e3nn_jax/_src/rotation.py +++ b/e3nn_jax/_src/rotation.py @@ -12,7 +12,7 @@ def rand_matrix(key, shape=(), dtype=jnp.float32): shape: a tuple of nonnegative integers representing the result shape. Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` + `jax.Array`: array of shape :math:`(..., 3, 3)` """ return angles_to_matrix(*rand_angles(key, shape, dtype=dtype)) @@ -21,10 +21,10 @@ def rotation_angle_from_matrix(R): r"""Angle of rotation from a rotation matrix. Args: - m (`jax.numpy.ndarray`): array of shape :math:`(..., 3, 3)` + m (`jax.Array`): array of shape :math:`(..., 3, 3)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(...)` + `jax.Array`: array of shape :math:`(...)` """ trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] return jnp.arccos(jnp.clip((trace - 1.0) / 2.0, -1.0, 1.0)) @@ -42,9 +42,9 @@ def identity_angles(shape=(), dtype=jnp.float32): Returns: (tuple): tuple containing: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` """ return jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), jnp.zeros(shape, dtype) @@ -59,9 +59,9 @@ def rand_angles(key, shape=(), dtype=jnp.float32): Returns: (tuple): tuple containing: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` """ x, y, z = jax.random.uniform(key, (3,) + shape, dtype=dtype) return 2 * jnp.pi * x, jnp.arccos(2 * z - 1), 2 * jnp.pi * y @@ -73,19 +73,19 @@ def compose_angles(a1, b1, c1, a2, b2, c2): Computes :math:`(a, b, c)` such that :math:`R(a, b, c) = R(a_1, b_1, c_1) \circ R(a_2, b_2, c_2)` Args: - alpha1 (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta1 (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma1 (`jax.numpy.ndarray`): array of shape :math:`(...)` - alpha2 (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta2 (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma2 (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha1 (`jax.Array`): array of shape :math:`(...)` + beta1 (`jax.Array`): array of shape :math:`(...)` + gamma1 (`jax.Array`): array of shape :math:`(...)` + alpha2 (`jax.Array`): array of shape :math:`(...)` + beta2 (`jax.Array`): array of shape :math:`(...)` + gamma2 (`jax.Array`): array of shape :math:`(...)` Returns: (tuple): tuple containing: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` """ a1, b1, c1, a2, b2, c2 = jnp.broadcast_arrays(a1, b1, c1, a2, b2, c2) @@ -96,16 +96,16 @@ def inverse_angles(a, b, c): r"""Angles of the inverse rotation. Args: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` Returns: (tuple): tuple containing: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` """ return -c, -b, -a @@ -114,15 +114,15 @@ def rotation_angle_from_angles(a1, b1, c1, a2, b2, c2): r"""Angle of rotation from two triplets of angles. Args: - alpha1 (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta1 (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma1 (`jax.numpy.ndarray`): array of shape :math:`(...)` - alpha2 (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta2 (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma2 (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha1 (`jax.Array`): array of shape :math:`(...)` + beta1 (`jax.Array`): array of shape :math:`(...)` + gamma1 (`jax.Array`): array of shape :math:`(...)` + alpha2 (`jax.Array`): array of shape :math:`(...)` + beta2 (`jax.Array`): array of shape :math:`(...)` + gamma2 (`jax.Array`): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(...)` + `jax.Array`: array of shape :math:`(...)` """ R1 = angles_to_matrix(a1, b1, c1) R2 = angles_to_matrix(a2, b2, c2) @@ -140,7 +140,7 @@ def identity_quaternion(shape=(), dtype=jnp.float32): shape: a tuple of nonnegative integers representing the result shape. Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 4)` + `jax.Array`: array of shape :math:`(..., 4)` """ q = jnp.zeros(shape + (4,), dtype=dtype) return q.at[..., 0].set(1) # or -1... @@ -154,7 +154,7 @@ def rand_quaternion(key, shape=(), dtype=jnp.float32): shape: a tuple of nonnegative integers representing the result shape. Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 4)` + `jax.Array`: array of shape :math:`(..., 4)` """ return angles_to_quaternion(*rand_angles(key, shape, dtype)) @@ -163,11 +163,11 @@ def compose_quaternion(q1, q2): r"""Compose two quaternions: :math:`q_1 \circ q_2`. Args: - q1 (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` - q2 (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` + q1 (`jax.Array`): array of shape :math:`(..., 4)` + q2 (`jax.Array`): array of shape :math:`(..., 4)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 4)` + `jax.Array`: array of shape :math:`(..., 4)` """ q1, q2 = jnp.broadcast_arrays(q1, q2) return jnp.stack( @@ -199,10 +199,10 @@ def inverse_quaternion(q): Works only for unit quaternions. Args: - q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` + q (`jax.Array`): array of shape :math:`(..., 4)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 4)` + `jax.Array`: array of shape :math:`(..., 4)` """ return q.at[..., 1:].multiply(-1) @@ -211,11 +211,11 @@ def rotation_angle_from_quaternion(q1, q2): r"""Rotation angle between two quaternions. Args: - q1 (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` - q2 (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` + q1 (`jax.Array`): array of shape :math:`(..., 4)` + q2 (`jax.Array`): array of shape :math:`(..., 4)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(...)` + `jax.Array`: array of shape :math:`(...)` """ q1, q2 = jnp.broadcast_arrays(q1, q2) dot = jnp.sum(q1 * q2, axis=-1) @@ -236,8 +236,8 @@ def rand_axis_angle(key, shape=(), dtype=jnp.float32): Returns: (tuple): tuple containing: - axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis (`jax.Array`): array of shape :math:`(..., 3)` + angle (`jax.Array`): array of shape :math:`(...)` """ return angles_to_axis_angle(*rand_angles(key, shape, dtype)) @@ -246,16 +246,16 @@ def compose_axis_angle(axis1, angle1, axis2, angle2): r"""Compose :math:`(\vec x_1, \alpha_1)` with :math:`(\vec x_2, \alpha_2)`. Args: - axis1 (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle1 (`jax.numpy.ndarray`): array of shape :math:`(...)` - axis2 (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle2 (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis1 (`jax.Array`): array of shape :math:`(..., 3)` + angle1 (`jax.Array`): array of shape :math:`(...)` + axis2 (`jax.Array`): array of shape :math:`(..., 3)` + angle2 (`jax.Array`): array of shape :math:`(...)` Returns: (tuple): tuple containing: - axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis (`jax.Array`): array of shape :math:`(..., 3)` + angle (`jax.Array`): array of shape :math:`(...)` """ return quaternion_to_axis_angle( compose_quaternion( @@ -269,13 +269,13 @@ def rotation_angle_from_axis_angle(axis1, angle1, axis2, angle2): r"""Rotation angle between :math:`(\vec x_1, \alpha_1)` and :math:`(\vec x_2, \alpha_2)`. Args: - axis1 (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle1 (`jax.numpy.ndarray`): array of shape :math:`(...)` - axis2 (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle2 (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis1 (`jax.Array`): array of shape :math:`(..., 3)` + angle1 (`jax.Array`): array of shape :math:`(...)` + axis2 (`jax.Array`): array of shape :math:`(..., 3)` + angle2 (`jax.Array`): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(...)` + `jax.Array`: array of shape :math:`(...)` """ return rotation_angle_from_quaternion( axis_angle_to_quaternion(axis1, angle1), @@ -293,7 +293,7 @@ def identity_log_coordinates(shape=(), dtype=jnp.float32): shape: a tuple of nonnegative integers representing the result shape. Returns: - log coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log coordinates (`jax.Array`): array of shape :math:`(..., 3)` """ return jnp.zeros(shape + (3,), dtype=dtype) @@ -306,7 +306,7 @@ def rand_log_coordinates(key, shape=(), dtype=jnp.float32): shape: a tuple of nonnegative integers representing the result shape. Returns: - log coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log coordinates (`jax.Array`): array of shape :math:`(..., 3)` """ return axis_angle_to_log_coordinates(*rand_axis_angle(key, shape, dtype)) @@ -315,11 +315,11 @@ def compose_log_coordinates(log1, log2): r"""Compose :math:`\vec \alpha_1` with :math:`\vec \alpha_2`. Args: - log1 (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - log2 (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log1 (`jax.Array`): array of shape :math:`(..., 3)` + log2 (`jax.Array`): array of shape :math:`(..., 3)` Returns: - log coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log coordinates (`jax.Array`): array of shape :math:`(..., 3)` """ return quaternion_to_log_coordinates( compose_quaternion( @@ -332,10 +332,10 @@ def inverse_log_coordinates(log): r"""Inverse of log coordinates. Args: - log (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log (`jax.Array`): array of shape :math:`(..., 3)` Returns: - log coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log coordinates (`jax.Array`): array of shape :math:`(..., 3)` """ return -log @@ -344,11 +344,11 @@ def rotation_angle_from_log_coordinates(log1, log2): r"""Rotation angle between a pair of log coordinates. Args: - log1 (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - log2 (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log1 (`jax.Array`): array of shape :math:`(..., 3)` + log2 (`jax.Array`): array of shape :math:`(..., 3)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(...)` + `jax.Array`: array of shape :math:`(...)` """ return rotation_angle_from_quaternion( log_coordinates_to_quaternion(log1), @@ -363,10 +363,10 @@ def matrix_x(angle): r"""Matrix of rotation around X axis. Args: - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + angle (`jax.Array`): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` + `jax.Array`: array of shape :math:`(..., 3, 3)` """ c = jnp.cos(angle) s = jnp.sin(angle) @@ -386,10 +386,10 @@ def matrix_y(angle): r"""Matrix of rotation around Y axis. Args: - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + angle (`jax.Array`): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` + `jax.Array`: array of shape :math:`(..., 3, 3)` """ c = jnp.cos(angle) s = jnp.sin(angle) @@ -409,10 +409,10 @@ def matrix_z(angle): r"""Matrix of rotation around Z axis. Args: - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + angle (`jax.Array`): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` + `jax.Array`: array of shape :math:`(..., 3, 3)` """ c = jnp.cos(angle) s = jnp.sin(angle) @@ -432,12 +432,12 @@ def angles_to_matrix(alpha, beta, gamma): r"""Conversion from angles to matrix. Args: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` + `jax.Array`: array of shape :math:`(..., 3, 3)` """ alpha, beta, gamma = jnp.broadcast_arrays(alpha, beta, gamma) return matrix_y(alpha) @ matrix_x(beta) @ matrix_y(gamma) @@ -448,14 +448,14 @@ def matrix_to_angles(R): Warning: this function is not differentiable at rotation angles :math:`\pi`. Args: - R (`jax.numpy.ndarray`): array of shape :math:`(..., 3, 3)` + R (`jax.Array`): array of shape :math:`(..., 3, 3)` Returns: (tuple): tuple containing: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` """ # assert jnp.allclose(jnp.linalg.det(R), 1) x = R @ jnp.array([0.0, 1.0, 0.0], dtype=R.dtype) @@ -469,12 +469,12 @@ def angles_to_quaternion(alpha, beta, gamma): r"""Conversion from angles to quaternion. Args: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` Returns: - q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` + q (`jax.Array`): array of shape :math:`(..., 4)` """ alpha, beta, gamma = jnp.broadcast_arrays(alpha, beta, gamma) qa = axis_angle_to_quaternion(jnp.array([0.0, 1.0, 0.0], alpha.dtype), alpha) @@ -487,10 +487,10 @@ def matrix_to_quaternion(R): r"""Conversion from matrix :math:`R` to quaternion :math:`q`. Args: - R (`jax.numpy.ndarray`): array of shape :math:`(..., 3, 3)` + R (`jax.Array`): array of shape :math:`(..., 3, 3)` Returns: - q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` + q (`jax.Array`): array of shape :math:`(..., 4)` """ return axis_angle_to_quaternion(*matrix_to_axis_angle(R)) @@ -499,11 +499,11 @@ def axis_angle_to_quaternion(xyz, angle): r"""Conversion from axis-angle to quaternion. Args: - axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis (`jax.Array`): array of shape :math:`(..., 3)` + angle (`jax.Array`): array of shape :math:`(...)` Returns: - q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` + q (`jax.Array`): array of shape :math:`(..., 4)` """ angle = jnp.asarray(angle) xyz, angle = jnp.broadcast_arrays(xyz, angle[..., None]) @@ -517,13 +517,13 @@ def quaternion_to_axis_angle(q): r"""Conversion from quaternion to axis-angle. Args: - q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` + q (`jax.Array`): array of shape :math:`(..., 4)` Returns: (tuple): tuple containing: - axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis (`jax.Array`): array of shape :math:`(..., 3)` + angle (`jax.Array`): array of shape :math:`(...)` """ angle = 2 * jnp.arccos(jnp.clip(q[..., 0], -1, 1)) axis = _normalize(q[..., 1:]) @@ -541,13 +541,13 @@ def matrix_to_axis_angle(R): Warning: this function is not differentiable at rotation angles :math:`\pi`. Args: - R (`jax.numpy.ndarray`): array of shape :math:`(..., 3, 3)` + R (`jax.Array`): array of shape :math:`(..., 3, 3)` Returns: (tuple): tuple containing: - axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis (`jax.Array`): array of shape :math:`(..., 3)` + angle (`jax.Array`): array of shape :math:`(...)` """ # assert jnp.allclose(jnp.linalg.det(R), 1) angle = rotation_angle_from_matrix(R) @@ -567,15 +567,15 @@ def angles_to_axis_angle(alpha, beta, gamma): r"""Conversion from angles to axis-angle. Args: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` Returns: (tuple): tuple containing: - axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis (`jax.Array`): array of shape :math:`(..., 3)` + angle (`jax.Array`): array of shape :math:`(...)` """ return matrix_to_axis_angle(angles_to_matrix(alpha, beta, gamma)) @@ -584,11 +584,11 @@ def axis_angle_to_matrix(axis, angle): r"""Conversion from axis-angle to matrix. Args: - axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis (`jax.Array`): array of shape :math:`(..., 3)` + angle (`jax.Array`): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` + `jax.Array`: array of shape :math:`(..., 3, 3)` """ angle = jnp.asarray(angle) axis, angle = jnp.broadcast_arrays(axis, angle[..., None]) @@ -602,10 +602,10 @@ def quaternion_to_matrix(q): r"""Conversion from quaternion to matrix. Args: - q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` + q (`jax.Array`): array of shape :math:`(..., 4)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` + `jax.Array`: array of shape :math:`(..., 3, 3)` """ return axis_angle_to_matrix(*quaternion_to_axis_angle(q)) @@ -614,14 +614,14 @@ def quaternion_to_angles(q): r"""Conversion from quaternion to angles. Args: - q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` + q (`jax.Array`): array of shape :math:`(..., 4)` Returns: (tuple): tuple containing: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` """ return matrix_to_angles(quaternion_to_matrix(q)) @@ -630,15 +630,15 @@ def axis_angle_to_angles(axis, angle): r"""Conversion from axis-angle to angles. Args: - axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis (`jax.Array`): array of shape :math:`(..., 3)` + angle (`jax.Array`): array of shape :math:`(...)` Returns: (tuple): tuple containing: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` """ return matrix_to_angles(axis_angle_to_matrix(axis, angle)) @@ -647,10 +647,10 @@ def log_coordinates_to_matrix(log_coordinates: jax.Array) -> jax.Array: r"""Conversion from log coordinates to matrix. Args: - log_coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log_coordinates (`jax.Array`): array of shape :math:`(..., 3)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3, 3)` + `jax.Array`: array of shape :math:`(..., 3, 3)` """ shape = log_coordinates.shape[:-1] log_coordinates = log_coordinates.reshape(-1, 3) @@ -672,13 +672,13 @@ def log_coordinates_to_axis_angle(log_coordinates): r"""Conversion from log coordinates to axis-angle. Args: - log_coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log_coordinates (`jax.Array`): array of shape :math:`(..., 3)` Returns: (tuple): tuple containing: - axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis (`jax.Array`): array of shape :math:`(..., 3)` + angle (`jax.Array`): array of shape :math:`(...)` """ n2 = jnp.sum(log_coordinates**2, axis=-1) n2_ = jnp.where(n2 > 0.0, n2, 1.0) @@ -691,10 +691,10 @@ def log_coordinates_to_quaternion(log_coordinates): r"""Conversion from log coordinates to quaternion. Args: - log_coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log_coordinates (`jax.Array`): array of shape :math:`(..., 3)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 4)` + `jax.Array`: array of shape :math:`(..., 4)` """ return axis_angle_to_quaternion(*log_coordinates_to_axis_angle(log_coordinates)) @@ -703,14 +703,14 @@ def log_coordinates_to_angles(log_coordinates): r"""Conversion from log coordinates to angles. Args: - log_coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + log_coordinates (`jax.Array`): array of shape :math:`(..., 3)` Returns: (tuple): tuple containing: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` """ return matrix_to_angles(log_coordinates_to_matrix(log_coordinates)) @@ -719,11 +719,11 @@ def axis_angle_to_log_coordinates(axis, angle): r"""Conversion from axis-angle to log coordinates. Args: - axis (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` - angle (`jax.numpy.ndarray`): array of shape :math:`(...)` + axis (`jax.Array`): array of shape :math:`(..., 3)` + angle (`jax.Array`): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3)` + `jax.Array`: array of shape :math:`(..., 3)` """ angle = jnp.asarray(angle) axis, angle = jnp.broadcast_arrays(axis, angle[..., None]) @@ -734,10 +734,10 @@ def matrix_to_log_coordinates(R): r"""Conversion from matrix to log coordinates. Args: - R (`jax.numpy.ndarray`): array of shape :math:`(..., 3, 3)` + R (`jax.Array`): array of shape :math:`(..., 3, 3)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3)` + `jax.Array`: array of shape :math:`(..., 3)` """ return axis_angle_to_log_coordinates(*matrix_to_axis_angle(R)) @@ -746,12 +746,12 @@ def angles_to_log_coordinates(alpha, beta, gamma): r"""Conversion from angles to log coordinates. Args: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` - gamma (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` + gamma (`jax.Array`): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3)` + `jax.Array`: array of shape :math:`(..., 3)` """ return axis_angle_to_log_coordinates(*angles_to_axis_angle(alpha, beta, gamma)) @@ -760,10 +760,10 @@ def quaternion_to_log_coordinates(q): r"""Conversion from quaternion to log coordinates. Args: - q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)` + q (`jax.Array`): array of shape :math:`(..., 4)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3)` + `jax.Array`: array of shape :math:`(..., 3)` """ return axis_angle_to_log_coordinates(*quaternion_to_axis_angle(q)) @@ -775,11 +775,11 @@ def angles_to_xyz(alpha, beta): r"""Convert :math:`(\alpha, \beta)` into a point :math:`(x, y, z)` on the sphere. Args: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` Returns: - `jax.numpy.ndarray`: array of shape :math:`(..., 3)` + `jax.Array`: array of shape :math:`(..., 3)` Examples: >>> angles_to_xyz(1.7, 0.0) + 0.0 @@ -804,13 +804,13 @@ def xyz_to_angles(xyz): \beta = \arccos(y) Args: - xyz (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` + xyz (`jax.Array`): array of shape :math:`(..., 3)` Returns: (tuple): tuple containing: - alpha (`jax.numpy.ndarray`): array of shape :math:`(...)` - beta (`jax.numpy.ndarray`): array of shape :math:`(...)` + alpha (`jax.Array`): array of shape :math:`(...)` + beta (`jax.Array`): array of shape :math:`(...)` """ xyz = _normalize(xyz) xyz = jnp.clip(xyz, -1, 1) diff --git a/e3nn_jax/_src/s2grid.py b/e3nn_jax/_src/s2grid.py index 8297490..aa94bba 100644 --- a/e3nn_jax/_src/s2grid.py +++ b/e3nn_jax/_src/s2grid.py @@ -491,8 +491,8 @@ def pad_to_plot( rescales the surface so that the maximum amplitude is equal to the radius Returns: - r (`jax.numpy.ndarray`): vectors on the sphere, shape ``(res_beta + 2, res_alpha + 1, 3)`` - f (`jax.numpy.ndarray`): padded signal, shape ``(res_beta + 2, res_alpha + 1)`` + r (`jax.Array`): vectors on the sphere, shape ``(res_beta + 2, res_alpha + 1, 3)`` + f (`jax.Array`): padded signal, shape ``(res_beta + 2, res_alpha + 1)`` """ f, y, alpha = self.grid_values, self.grid_y, self.grid_alpha assert f.ndim == 2 and f.shape == ( @@ -602,13 +602,13 @@ def sample(self, key: jax.Array) -> Tuple[jax.Array, jax.Array]: The probability distribution does not need to be normalized. Args: - key (`jax.numpy.ndarray`): random key + key (`jax.Array`): random key Returns: (tuple): tuple containing: - beta_index (`jax.numpy.ndarray`): index of the sampled beta - alpha_index (`jax.numpy.ndarray`): index of the sampled alpha + beta_index (`jax.Array`): index of the sampled beta + alpha_index (`jax.Array`): index of the sampled alpha Examples: @@ -694,7 +694,7 @@ def s2_dirac( The integral of the Dirac delta is 1. Args: - position (`jax.numpy.ndarray` or `IrrepsArray`): position of the delta, shape ``(3,)``. + position (`jax.Array` or `IrrepsArray`): position of the delta, shape ``(3,)``. It will be normalized to have a norm of 1. lmax (int): maximum degree of the spherical harmonics expansion @@ -1246,7 +1246,7 @@ def to_s2point( Args: coeffs (`IrrepsArray`): coefficient array of shape ``(*shape1, irreps)`` - point (`jax.numpy.ndarray`): point on the sphere of shape ``(*shape2, 3)`` + point (`jax.Array`): point on the sphere of shape ``(*shape2, 3)`` normalization ({'norm', 'component', 'integral'}): normalization of the basis Returns: @@ -1401,11 +1401,11 @@ def _spherical_harmonics_s2grid( Returns: (tuple): tuple containing: - y (`jax.numpy.ndarray`): array of shape ``(res_beta)`` - alphas (`jax.numpy.ndarray`): array of shape ``(res_alpha)`` - sh_y (`jax.numpy.ndarray`): array of shape ``(res_beta, lmax + 1, lmax + 1)`` - sh_alpha (`jax.numpy.ndarray`): array of shape ``(res_alpha, 2 * lmax + 1)`` - qw (`jax.numpy.ndarray`): array of shape ``(res_beta)`` + y (`jax.Array`): array of shape ``(res_beta)`` + alphas (`jax.Array`): array of shape ``(res_alpha)`` + sh_y (`jax.Array`): array of shape ``(res_beta, lmax + 1, lmax + 1)`` + sh_alpha (`jax.Array`): array of shape ``(res_alpha, 2 * lmax + 1)`` + qw (`jax.Array`): array of shape ``(res_beta)`` """ y, alphas, qw = _s2grid(res_beta, res_alpha, quadrature) y, alphas, qw = jax.tree_util.tree_map( @@ -1500,10 +1500,10 @@ def _normalization( def _rfft(x: jax.Array, l: int) -> jax.Array: r"""Real fourier transform Args: - x (`jax.numpy.ndarray`): input array of shape ``(..., res_beta, res_alpha)`` + x (`jax.Array`): input array of shape ``(..., res_beta, res_alpha)`` l (int): value of `l` for which the transform is being run Returns: - `jax.numpy.ndarray`: transformed values - array of shape ``(..., res_beta, 2*l+1)`` + `jax.Array`: transformed values - array of shape ``(..., res_beta, 2*l+1)`` """ x_reshaped = x.reshape((-1, x.shape[-1])) x_transformed_c = jnp.fft.rfft(x_reshaped) # (..., 2*l+1) @@ -1521,10 +1521,10 @@ def _rfft(x: jax.Array, l: int) -> jax.Array: def _irfft(x: jax.Array, res: int) -> jax.Array: r"""Inverse of the real fourier transform Args: - x (`jax.numpy.ndarray`): array of shape ``(..., 2*l + 1)`` + x (`jax.Array`): array of shape ``(..., 2*l + 1)`` res (int): output resolution, has to be an odd number Returns: - `jax.numpy.ndarray`: positions on the sphere, array of shape ``(..., res)`` + `jax.Array`: positions on the sphere, array of shape ``(..., res)`` """ assert res % 2 == 1 diff --git a/e3nn_jax/_src/scatter.py b/e3nn_jax/_src/scatter.py index cf48d83..0a33740 100644 --- a/e3nn_jax/_src/scatter.py +++ b/e3nn_jax/_src/scatter.py @@ -11,10 +11,10 @@ def _distinct_but_small(x: jax.Array) -> jax.Array: """Maps the input to the integers 0, 1, 2, ..., n-1, where n is the number of distinct elements in x. Args: - x (`jax.numpy.ndarray`): array of integers + x (`jax.Array`): array of integers Returns: - `jax.numpy.ndarray`: array of integers of same size + `jax.Array`: array of integers of same size """ shape = x.shape x = jnp.ravel(x) @@ -44,15 +44,15 @@ def scatter_sum( output[i] = sum(data[sum(nel[:i]):sum(nel[:i+1])]) Args: - data (`jax.numpy.ndarray` or `IrrepsArray`): array of shape ``(n1,..nd, ...)`` - dst (optional, `jax.numpy.ndarray`): array of shape ``(n1,..nd)``. If not specified, ``nel`` must be specified. - nel (optional, `jax.numpy.ndarray`): array of shape ``(output_size,)``. If not specified, ``dst`` must be specified. + data (`jax.Array` or `IrrepsArray`): array of shape ``(n1,..nd, ...)`` + dst (optional, `jax.Array`): array of shape ``(n1,..nd)``. If not specified, ``nel`` must be specified. + nel (optional, `jax.Array`): array of shape ``(output_size,)``. If not specified, ``dst`` must be specified. output_size (optional, int): size of output array. If not specified, ``nel`` must be specified or ``map_back`` must be ``True``. map_back (bool): whether to map back to the input position Returns: - `jax.numpy.ndarray` or `IrrepsArray`: output array of shape ``(output_size, ...)`` + `jax.Array` or `IrrepsArray`: output array of shape ``(output_size, ...)`` """ return _scatter_op( "sum", @@ -87,15 +87,15 @@ def scatter_mean( output[i] = sum(data[sum(nel[:i]):sum(nel[:i+1])]) / nel[i] Args: - data (`jax.numpy.ndarray` or `IrrepsArray`): array of shape ``(n1,..nd, ...)`` - dst (optional, `jax.numpy.ndarray`): array of shape ``(n1,..nd)``. If not specified, ``nel`` must be specified. - nel (optional, `jax.numpy.ndarray`): array of shape ``(output_size,)``. If not specified, ``dst`` must be specified. + data (`jax.Array` or `IrrepsArray`): array of shape ``(n1,..nd, ...)`` + dst (optional, `jax.Array`): array of shape ``(n1,..nd)``. If not specified, ``nel`` must be specified. + nel (optional, `jax.Array`): array of shape ``(output_size,)``. If not specified, ``dst`` must be specified. output_size (optional, int): size of output array. If not specified, ``nel`` must be specified or ``map_back`` must be ``True``. map_back (bool): whether to map back to the input position Returns: - `jax.numpy.ndarray` or `IrrepsArray`: output array of shape ``(output_size, ...)`` + `jax.Array` or `IrrepsArray`: output array of shape ``(output_size, ...)`` """ if map_back and nel is not None: assert dst is None @@ -178,16 +178,16 @@ def scatter_max( output[i] = max(initial, *data[sum(nel[:i]):sum(nel[:i+1])]) Args: - data (`jax.numpy.ndarray` or `IrrepsArray`): array of shape ``(n, ...)`` - dst (optional, `jax.numpy.ndarray`): array of shape ``(n,)``. If not specified, ``nel`` must be specified. - nel (optional, `jax.numpy.ndarray`): array of shape ``(output_size,)``. If not specified, ``dst`` must be specified. + data (`jax.Array` or `IrrepsArray`): array of shape ``(n, ...)`` + dst (optional, `jax.Array`): array of shape ``(n,)``. If not specified, ``nel`` must be specified. + nel (optional, `jax.Array`): array of shape ``(output_size,)``. If not specified, ``dst`` must be specified. initial (float): initial value to compare to output_size (optional, int): size of output array. If not specified, ``nel`` must be specified or ``map_back`` must be ``True``. map_back (bool): whether to map back to the input position Returns: - `jax.numpy.ndarray` or `IrrepsArray`: output array of shape ``(output_size, ...)`` + `jax.Array` or `IrrepsArray`: output array of shape ``(output_size, ...)`` """ if isinstance(data, e3nn.IrrepsArray): if not data.irreps.is_scalar(): diff --git a/e3nn_jax/_src/spherical_harmonics/__init__.py b/e3nn_jax/_src/spherical_harmonics/__init__.py index dc1758e..16b1708 100644 --- a/e3nn_jax/_src/spherical_harmonics/__init__.py +++ b/e3nn_jax/_src/spherical_harmonics/__init__.py @@ -24,13 +24,13 @@ def sh( Args: irreps_out (`Irreps` or int or Sequence[int]): the output irreps - input (`jax.numpy.ndarray`): cartesian coordinates, shape (..., 3) + input (`jax.Array`): cartesian coordinates, shape (..., 3) normalize (bool): if True, the polynomials are restricted to the sphere normalization (str): normalization of the constant :math:`\text{cste}`. Default is 'component' algorithm (Tuple[str]): algorithm to use for the computation. (legendre|recursive, dense|sparse, [custom_jvp]) Returns: - `jax.numpy.ndarray`: polynomials of the spherical harmonics + `jax.Array`: polynomials of the spherical harmonics """ input = e3nn.IrrepsArray("1e", input) return spherical_harmonics( @@ -89,7 +89,7 @@ def spherical_harmonics( Args: irreps_out (`Irreps` or list of int or int): output irreps - input (`IrrepsArray` or `jax.numpy.ndarray`): cartesian coordinates + input (`IrrepsArray` or `jax.Array`): cartesian coordinates normalize (bool): if True, the polynomials are restricted to the sphere normalization (str): normalization of the constant :math:`\text{cste}`. Default is 'component' algorithm (Tuple[str]): algorithm to use for the computation. (legendre|recursive, dense|sparse, [custom_jvp]) diff --git a/e3nn_jax/_src/symmetric_tensor_product_haiku.py b/e3nn_jax/_src/symmetric_tensor_product_haiku.py index 4b59cdd..b3b230a 100644 --- a/e3nn_jax/_src/symmetric_tensor_product_haiku.py +++ b/e3nn_jax/_src/symmetric_tensor_product_haiku.py @@ -31,7 +31,7 @@ class SymmetricTensorProduct(hk.Module): orders (tuple of int): orders of the tensor product keep_irrep_out (optional, set of Irrep): irreps to keep in the output get_parameter (optional, callable): function to get the parameters, by default it uses ``hk.get_parameter`` - it should have the signature ``get_parameter(name, shape) -> ndarray`` and return a normal distribution + it should have the signature ``get_parameter(name, shape) -> Array`` and return a normal distribution with variance 1 """ @@ -95,7 +95,7 @@ def fn(x: e3nn.IrrepsArray): if order in self.orders: for (mul, ir_out), u in zip(U.irreps, U.chunks): - # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim] + # u: Array [(irreps_x.dim)^order, multiplicity, ir_out.dim] u = ( u / u.shape[-2] ) # normalize both U and the contraction with w