Skip to content

Commit

Permalink
Merge pull request #277 from ami-iit/fix_ad_through_axis_angle
Browse files Browse the repository at this point in the history
Fix running AD through `jaxsim.math.Rotation.from_axis_angle`
  • Loading branch information
diegoferigo authored Nov 5, 2024
2 parents b274aab + 73eee01 commit bd59060
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions src/jaxsim/math/rotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jax
import jax.numpy as jnp
import jaxlie

Expand All @@ -8,6 +7,7 @@


class Rotation:

@staticmethod
def x(theta: jtp.Float) -> jtp.Matrix:
"""
Expand All @@ -19,6 +19,7 @@ def x(theta: jtp.Float) -> jtp.Matrix:
Returns:
jtp.Matrix: 3D rotation matrix.
"""

return jaxlie.SO3.from_x_radians(theta=theta).as_matrix()

@staticmethod
Expand All @@ -32,6 +33,7 @@ def y(theta: jtp.Float) -> jtp.Matrix:
Returns:
jtp.Matrix: 3D rotation matrix.
"""

return jaxlie.SO3.from_y_radians(theta=theta).as_matrix()

@staticmethod
Expand All @@ -45,6 +47,7 @@ def z(theta: jtp.Float) -> jtp.Matrix:
Returns:
jtp.Matrix: 3D rotation matrix.
"""

return jaxlie.SO3.from_z_radians(theta=theta).as_matrix()

@staticmethod
Expand All @@ -53,17 +56,18 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
Generate a 3D rotation matrix from an axis-angle representation.
Args:
vector (jtp.Vector): Axis-angle representation as a 3D vector.
vector: Axis-angle representation or the rotation as a 3D vector.
Returns:
jtp.Matrix: 3D rotation matrix.
The SO(3) rotation matrix.
"""

vector = vector.squeeze()
theta = jnp.linalg.norm(vector)

def theta_is_not_zero(theta_and_v: tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix:
theta, v = theta_and_v
def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:

v = axis
theta = jnp.linalg.norm(v)

s = jnp.sin(theta)
c = jnp.cos(theta)
Expand All @@ -77,9 +81,19 @@ def theta_is_not_zero(theta_and_v: tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix:

return R.transpose()

return jax.lax.cond(
pred=(theta == 0.0),
true_fun=lambda operand: jnp.eye(3),
false_fun=theta_is_not_zero,
operand=(theta, vector),
# Use the double-where trick to prevent JAX problems when the
# jax.jit and jax.grad transforms are applied.
return jnp.where(
jnp.linalg.norm(vector) > 0,
theta_is_not_zero(
axis=jnp.where(
jnp.linalg.norm(vector) > 0,
vector,
# The following line is a workaround to prevent division by 0.
# Considering the outer where, this branch is never executed.
jnp.ones(3),
)
),
# Return an identity rotation matrix when the input vector is zero.
jnp.eye(3),
)

0 comments on commit bd59060

Please sign in to comment.