diff --git a/liegroups/numpy/so3.py b/liegroups/numpy/so3.py index aff0512..e4ecc41 100644 --- a/liegroups/numpy/so3.py +++ b/liegroups/numpy/so3.py @@ -78,9 +78,9 @@ def from_quaternion(cls, quat, ordering='wxyz'): if not np.isclose(np.linalg.norm(quat), 1.): raise ValueError("Quaternion must be unit length") - if ordering is 'xyzw': + if ordering == 'xyzw': qx, qy, qz, qw = quat - elif ordering is 'wxyz': + elif ordering == 'wxyz': qw, qx, qy, qz = quat else: raise ValueError( @@ -295,9 +295,9 @@ def to_quaternion(self, ordering='wxyz'): qz = (R[1, 0] - R[0, 1]) / d # Check ordering last - if ordering is 'xyzw': + if ordering == 'xyzw': quat = np.array([qx, qy, qz, qw]) - elif ordering is 'wxyz': + elif ordering == 'wxyz': quat = np.array([qw, qx, qy, qz]) else: raise ValueError( @@ -384,9 +384,9 @@ class SO3Quaternion(_base.VectorLieGroupBase): dof = 3 def from_array(self, arr, ordering='wxyz'): - if ordering is 'xyzw': + if ordering == 'xyzw': self.data = arr[[3, 0, 1, 2]] - elif ordering is 'wxyz': + elif ordering == 'wxyz': self.data = arr else: raise ValueError( diff --git a/liegroups/torch/so3.py b/liegroups/torch/so3.py index 1495c25..1d317c3 100644 --- a/liegroups/torch/so3.py +++ b/liegroups/torch/so3.py @@ -68,12 +68,12 @@ def from_quaternion(cls, quat, ordering='wxyz'): if not utils.allclose(quat.norm(p=2, dim=1), 1.): raise ValueError("Quaternions must be unit length") - if ordering is 'xyzw': + if ordering == 'xyzw': qx = quat[:, 0] qy = quat[:, 1] qz = quat[:, 2] qw = quat[:, 3] - elif ordering is 'wxyz': + elif ordering == 'wxyz': qw = quat[:, 0] qx = quat[:, 1] qy = quat[:, 2] @@ -359,12 +359,12 @@ def to_quaternion(self, ordering='wxyz'): qz[far_zero_inds] = (R_fz[:, 1, 0] - R_fz[:, 0, 1]) / d # Check ordering last - if ordering is 'xyzw': + if ordering == 'xyzw': quat = torch.cat([qx.unsqueeze_(dim=1), qy.unsqueeze_(dim=1), qz.unsqueeze_(dim=1), qw.unsqueeze_(dim=1)], dim=1).squeeze_() - elif ordering is 'wxyz': + elif ordering == 'wxyz': quat = torch.cat([qw.unsqueeze_(dim=1), qx.unsqueeze_(dim=1), qy.unsqueeze_(dim=1),