diff --git a/deformable_gym/envs/sampler.py b/deformable_gym/envs/sampler.py index c75286e..d33ac89 100644 --- a/deformable_gym/envs/sampler.py +++ b/deformable_gym/envs/sampler.py @@ -94,7 +94,7 @@ def __init__( points_per_axis = [np.linspace( low[i], high[i], n_points_per_axis[i]) for i in range(self.n_dims)] - self.grid = np.array(np.meshgrid(*points_per_axis)).T.reshape(-1, 3) + self.grid = np.array(np.meshgrid(*points_per_axis)).T.reshape(-1, self.n_dims) self.n_samples = len(self.grid) self.n_calls = 0 diff --git a/tests/envs/test_samplers.py b/tests/envs/test_samplers.py index 94790e5..6751370 100644 --- a/tests/envs/test_samplers.py +++ b/tests/envs/test_samplers.py @@ -30,7 +30,7 @@ def uniform_target_pose() -> npt.NDArray: @pytest.fixture def grid_target_pose() -> npt.NDArray: - target = np.array([1, 2, 3]) + target = np.array([1, 2, 3, 4]) return target @@ -60,9 +60,9 @@ def uniform_sampler() -> UniformSampler: @pytest.fixture def grid_sampler() -> GridSampler: return GridSampler( - low=np.array([1, 2, 3]), - high=np.array([2, 3, 4]), - n_points_per_axis=np.array([5, 3, 1]) + low=np.array([1, 2, 3, 4]), + high=np.array([2, 3, 4, 5]), + n_points_per_axis=np.array([5, 3, 1, 1]) )