Skip to content

Commit

Permalink
Handle edge cases properly
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFabisch committed Nov 21, 2024
1 parent 8835162 commit 64b6c17
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
31 changes: 31 additions & 0 deletions pytransform3d/test/test_transform_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,3 +644,34 @@ def test_temporal_transform_manager_incorrect_frame():
tm.get_transform("B", "W")
with pytest.raises(ValueError):
tm.get_transform("A", "B")


def test_temporal_transform_manager_out_of_bounds():
duration = 10.0 # [s]
sample_period = 0.5 # [s]
velocity_x = 1 # [m/s]
time_A, pq_arr_A = create_sinusoidal_movement(
duration, sample_period, velocity_x, y_start_offset=0.0, start_time=0.0
)
transform_WA = NumpyTimeseriesTransform(time_A, pq_arr_A)

time_B, pq_arr_B = create_sinusoidal_movement(
duration, sample_period, velocity_x, y_start_offset=2.0, start_time=0.1
)
transform_WB = NumpyTimeseriesTransform(time_B, pq_arr_B)

tm = TemporalTransformManager()
tm.add_transform("A", "W", transform_WA)
tm.add_transform("B", "W", transform_WB)

assert min(time_A) == 0.0
assert min(time_B) == 0.1
A2B_at_start_time = tm.get_transform_at_time("A", "B", 0.0)
A2B_before_start_time = tm.get_transform_at_time("A", "B", -0.1)
assert_array_almost_equal(A2B_at_start_time, A2B_before_start_time)

assert max(time_A) == 9.5
assert max(time_B) == 9.6
A2B_at_end_time = tm.get_transform_at_time("A", "B", 9.6)
A2B_after_end_time = tm.get_transform_at_time("A", "B", 10.0)
assert_array_almost_equal(A2B_at_end_time, A2B_after_end_time)
43 changes: 30 additions & 13 deletions pytransform3d/transform_manager/_temporal_transform_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def __init__(self, time, pqs):
if self._pqs.shape[1] != 7:
raise ValueError("`pqs` matrix shall have 7 columns.")

self._min_time = min(self.time)
self._max_time = max(self.time)

def as_matrix(self, query_time):
"""Get transformation matrix at given time.
Expand All @@ -113,7 +116,7 @@ def as_matrix(self, query_time):
Returns
-------
A2B_t : array, shape (4, 4) or (..., 4, 4)
Homogeneous transformation matrix at given time. . or times
Homogeneous transformation matrix / matrices at given time / times.
"""
pq = self._interpolate_pq_using_sclerp(query_time)
transforms = transforms_from_pqs(pq)
Expand All @@ -136,27 +139,37 @@ def _interpolate_pq_using_sclerp(self, query_time):
min_index = 0
max_index = self.time.shape[0] - 2
idxs_timestep_earlier_wrt_query_time = np.clip(
idxs_timestep_earlier_wrt_query_time,
min_index,
max_index
)
idxs_timestep_earlier_wrt_query_time, min_index, max_index)
idxs_timestep_later_wrt_query_time = \
idxs_timestep_earlier_wrt_query_time + 1
before_start = query_time_arr <= self._min_time
idxs_timestep_later_wrt_query_time[
before_start] = idxs_timestep_earlier_wrt_query_time[before_start]
after_end = query_time_arr >= self._max_time
idxs_timestep_earlier_wrt_query_time[
after_end] = idxs_timestep_later_wrt_query_time[after_end]

# dual quaternion from preceding sample
t_prev = self.time[idxs_timestep_earlier_wrt_query_time]
pq_prev = self._pqs[idxs_timestep_earlier_wrt_query_time, :]
pq_prev = self._pqs[idxs_timestep_earlier_wrt_query_time]
dq_prev = dual_quaternions_from_pqs(pq_prev)

# dual quaternion from successive sample
t_next = self.time[idxs_timestep_earlier_wrt_query_time + 1]
pq_next = self._pqs[idxs_timestep_earlier_wrt_query_time + 1, :]
t_next = self.time[idxs_timestep_later_wrt_query_time]
pq_next = self._pqs[idxs_timestep_later_wrt_query_time]
dq_next = dual_quaternions_from_pqs(pq_next)

# since sclerp works with relative (0-1) positions
rel_delta_t = (query_time - t_prev) / (t_next - t_prev)
# scale t, since sclerp works with relative times t in [0, 1]
rel_delta_t = np.empty_like(query_time_arr)
edge_case = t_prev == t_next
rel_delta_t[edge_case] = 0.0
interpolation_case = ~edge_case
rel_delta_t[interpolation_case] = (
query_time[interpolation_case] - t_prev[interpolation_case]
) / (t_next[interpolation_case] - t_prev[interpolation_case])
dqs_interpolated = dual_quaternions_sclerp(
dq_prev, dq_next, rel_delta_t)
res = pqs_from_dual_quaternions(dqs_interpolated)
return res
return pqs_from_dual_quaternions(dqs_interpolated)


class TemporalTransformManager(TransformGraphBase):
Expand Down Expand Up @@ -210,7 +223,9 @@ def get_transform_at_time(self, from_frame, to_frame, time):
Name of the frame in which the transformation is defined
time : Union[float, array-like shape (...)]
Time or times at which we request the transformation.
Time or times at which we request the transformation. If the query
time is out of bounds, it will be clipped to either the first or
last available time.
Returns
-------
Expand All @@ -236,6 +251,8 @@ def get_transform(self, from_frame, to_frame):
"""Request a transformation.
The internal current_time will be used for time based transformations.
If the query time is out of bounds, it will be clipped to either the
first or the last available time.
Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class StaticTransform(TimeVaryingTransform):
class NumpyTimeseriesTransform(TimeVaryingTransform):
time: np.ndarray
_pqs: np.ndarray
_min_time: float
_max_time: float

def __init__(self, time: npt.ArrayLike, pqs: npt.ArrayLike): ...

Expand Down

0 comments on commit 64b6c17

Please sign in to comment.