diff --git a/include/common_robotics_utilities/path_processing.hpp b/include/common_robotics_utilities/path_processing.hpp index c34c0a3..4fcd435 100644 --- a/include/common_robotics_utilities/path_processing.hpp +++ b/include/common_robotics_utilities/path_processing.hpp @@ -9,6 +9,7 @@ #include #include +#include #include namespace common_robotics_utilities @@ -188,7 +189,7 @@ inline Container AttemptShortcut( return Container(); } -template> inline Container ShortcutSmoothPath( const Container& path, @@ -205,7 +206,7 @@ inline Container ShortcutSmoothPath( const std::function& state_interpolation_fn, - PRNG& prng) + const utility::UniformUnitRealFunction& uniform_unit_real_fn) { Container current_path = path; uint32_t num_iterations = 0; @@ -216,21 +217,19 @@ inline Container ShortcutSmoothPath( { num_iterations++; // Attempt a shortcut - const int64_t base_index - = std::uniform_int_distribution( - 0, current_path.size() - 1)(prng); + const int64_t base_index = utility::GetUniformRandomIndex( + uniform_unit_real_fn, static_cast(current_path.size())); // Pick an offset fraction - const double offset_fraction - = std::uniform_real_distribution( - -max_shortcut_fraction, max_shortcut_fraction)(prng); + const double offset_fraction = math::Interpolate( + -max_shortcut_fraction, max_shortcut_fraction, uniform_unit_real_fn()); // Compute the offset index const int64_t offset_index = base_index + static_cast(std::floor( static_cast(current_path.size()) * offset_fraction)); // We need to clamp it to the bounds of the current path - const int64_t safe_offset_index - = utility::ClampValue(offset_index, INT64_C(0), - static_cast(current_path.size() - 1)); + const int64_t safe_offset_index = utility::ClampValue( + offset_index, INT64_C(0), + static_cast(current_path.size() - 1)); // Get start & end indices const size_t start_index = static_cast(std::min(base_index, safe_offset_index)); diff --git a/include/common_robotics_utilities/simple_rrt_planner.hpp b/include/common_robotics_utilities/simple_rrt_planner.hpp index 9552973..48e1b55 100644 --- a/include/common_robotics_utilities/simple_rrt_planner.hpp +++ b/include/common_robotics_utilities/simple_rrt_planner.hpp @@ -879,11 +879,12 @@ RRTPlanMultiPath( /// terminated. The provided int64_t values are the current size of the /// start and goal tree, respectively. These may be useful for a /// size-limited planning problem. -/// @param rng a PRNG for use in internal sampling and tree swaps. +/// @param uniform_unit_real_fn Returns a uniformly distributed double from +/// [0.0, 1.0). Used internally for tree sampling and swapping. /// @return paths + statistics where paths is the vector of solution paths /// and statistics is a map of useful statistics collected /// while planning. -template, typename Container=std::vector> inline MultipleSolutionPlanningResults @@ -902,7 +903,7 @@ BiRRTPlanMultiPath( const double tree_sampling_bias, const double p_switch_tree, const BiRRTTerminationCheckFunction& termination_check_fn, - RNG& rng) + const utility::UniformUnitRealFunction& uniform_unit_real_fn) { if ((tree_sampling_bias < 0.0) || (tree_sampling_bias > 1.0)) { @@ -993,13 +994,12 @@ BiRRTPlanMultiPath( TreeType& target_tree = (start_tree_active) ? goal_tree : start_tree; // Select our sampling type const bool sample_from_tree - = (unit_real_distribution(rng) <= tree_sampling_bias); + = (uniform_unit_real_fn() <= tree_sampling_bias); int64_t target_tree_node_index = -1; if (sample_from_tree) { - std::uniform_int_distribution - tree_sampling_distribution(0, target_tree.Size() - 1); - target_tree_node_index = tree_sampling_distribution(rng); + target_tree_node_index = utility::GetUniformRandomIndex( + uniform_unit_real_fn, target_tree.Size()); } // Sample a target state const StateType target_state @@ -1107,7 +1107,7 @@ BiRRTPlanMultiPath( statistics["failed_samples"] += 1.0; } // Decide if we should switch the active tree - if (unit_real_distribution(rng) <= p_switch_tree) + if (uniform_unit_real_fn() <= p_switch_tree) { start_tree_active = !start_tree_active; statistics["active_tree_swaps"] += 1.0; @@ -1285,11 +1285,12 @@ RRTPlanSinglePath( /// terminated. The provided int64_t values are the current size of the /// start and goal tree, respectively. These may be useful for a /// size-limited planning problem. -/// @param rng a PRNG for use in internal sampling and tree swaps. +/// @param uniform_unit_real_fn Returns a uniformly distributed double from +/// [0.0, 1.0). Used internally for tree sampling and swapping. /// @return path + statistics where path is the solution path and /// statistics is a map of useful statistics collected while /// planning. -template, typename Container=std::vector> inline SingleSolutionPlanningResults @@ -1308,7 +1309,7 @@ BiRRTPlanSinglePath( const double tree_sampling_bias, const double p_switch_tree, const BiRRTTerminationCheckFunction& termination_check_fn, - RNG& rng) + const utility::UniformUnitRealFunction& uniform_unit_real_fn) { bool solution_found = false; const GoalBridgeCallbackFunction @@ -1333,11 +1334,11 @@ BiRRTPlanSinglePath( current_goal_tree_size)); }; const auto birrt_result = - BiRRTPlanMultiPath( + BiRRTPlanMultiPath( start_tree, goal_tree, state_sampling_fn, nearest_neighbor_fn, propagation_fn, state_added_callback_fn, states_connected_fn, internal_goal_bridge_callback_fn, tree_sampling_bias, p_switch_tree, - internal_termination_check_fn, rng); + internal_termination_check_fn, uniform_unit_real_fn); if (birrt_result.Paths().size() > 0) { return SingleSolutionPlanningResults( @@ -1449,33 +1450,32 @@ inline BiRRTTerminationCheckFunction MakeBiRRTTimeoutTerminationFunction( /// direction RRT with fixed goal states, you interleave sampling random states /// (accomplished here by calling @param state_sampling_fn) and "sampling" the /// known goal states (here, @param goal_states) with probablity -/// @param goal_bias. This helper function copies the provided -/// @param state_sampling_fn, @param goal_states, and @param goal_bias, but -/// passes @param rng by reference. Thus, the lifetime of @param rng must cover -/// the entire lifetime of the std::function this returns! -template> +/// @param goal_bias. @param uniform_unit_real_fn Returns a uniformly +/// distributed double from [0.0, 1.0). +template> inline SamplingFunction MakeStateAndGoalsSamplingFunction( const std::function& state_sampling_fn, const Container& goal_states, - const double goal_bias, PRNG& rng) + const double goal_bias, + const utility::UniformUnitRealFunction& uniform_unit_real_fn) { class StateAndGoalsSamplingFunction { private: Container goal_samples_; const double goal_bias_ = 0.0; - std::uniform_real_distribution unit_real_dist_; - std::uniform_int_distribution goal_sampling_dist_; + utility::UniformUnitRealFunction uniform_unit_real_fn_; const std::function state_sampling_fn_; public: StateAndGoalsSamplingFunction( - const Container& goal_samples, - const double goal_bias, - const std::function& state_sampling_fn) + const Container& goal_samples, + const double goal_bias, + const utility::UniformUnitRealFunction& uniform_unit_real_fn, + const std::function& state_sampling_fn) : goal_samples_(goal_samples), goal_bias_(goal_bias), - unit_real_dist_(0.0, 1.0), state_sampling_fn_(state_sampling_fn) + uniform_unit_real_fn_(uniform_unit_real_fn), + state_sampling_fn_(state_sampling_fn) { if ((goal_bias_ < 0.0) || (goal_bias_ > 1.0)) { @@ -1486,13 +1486,11 @@ inline SamplingFunction MakeStateAndGoalsSamplingFunction( { throw std::invalid_argument("goal_samples is empty"); } - goal_sampling_dist_ = - std::uniform_int_distribution(0, goal_samples_.size() - 1); } - SampleType Sample(PRNG& rng) + SampleType Sample() { - if (unit_real_dist_(rng) > goal_bias_) + if (uniform_unit_real_fn_() > goal_bias_) { return state_sampling_fn_(); } @@ -1504,18 +1502,19 @@ inline SamplingFunction MakeStateAndGoalsSamplingFunction( } else { - return goal_samples_.at(goal_sampling_dist_(rng)); + return goal_samples_.at(utility::GetUniformRandomIndex( + uniform_unit_real_fn_, goal_samples_.size())); } } } }; StateAndGoalsSamplingFunction sampling_fn_helper( - goal_states, goal_bias, state_sampling_fn); + goal_states, goal_bias, uniform_unit_real_fn, state_sampling_fn); std::function sampling_function - = [sampling_fn_helper, &rng] (void) mutable + = [sampling_fn_helper] (void) mutable { - return sampling_fn_helper.Sample(rng); + return sampling_fn_helper.Sample(); }; return sampling_function; } diff --git a/include/common_robotics_utilities/utility.hpp b/include/common_robotics_utilities/utility.hpp index 32c0856..ba67179 100644 --- a/include/common_robotics_utilities/utility.hpp +++ b/include/common_robotics_utilities/utility.hpp @@ -1,9 +1,11 @@ #pragma once #include +#include #include #include #include +#include #include #include #include @@ -85,6 +87,44 @@ namespace common_robotics_utilities { namespace utility { +/// Signature for function that returns a double uniformly sampled from +/// the interval [0.0, 1.0). +using UniformUnitRealFunction = std::function; + +/// Given a UniformUnitRealFunction @param uniform_unit_real_fn, returns an +/// index in [0, container_size - 1]. +template +SizeType GetUniformRandomIndex( + const UniformUnitRealFunction& uniform_unit_real_fn, + const SizeType container_size) +{ + static_assert( + std::is_integral::value, "SizeType must be an integral type"); + if (container_size < 1) + { + throw std::invalid_argument("container_size must be >= 1"); + } + return static_cast(std::floor( + uniform_unit_real_fn() * static_cast(container_size))); +} + +/// Given a UniformUnitRealFunction @param uniform_unit_real_fn, returns a +/// value in [start, end]. +template +SizeType GetUniformRandomInRange( + const UniformUnitRealFunction& uniform_unit_real_fn, + const SizeType start, const SizeType end) +{ + if (start > end) + { + throw std::invalid_argument("start must be <= end"); + } + const SizeType range = end - start; + const SizeType offset = + GetUniformRandomIndex(uniform_unit_real_fn, range + 1); + return start + offset; +} + template inline T ClampValue(const T& val, const T& min, const T& max) { diff --git a/test/planning_test.cpp b/test/planning_test.cpp index 73bb828..eaf6435 100644 --- a/test/planning_test.cpp +++ b/test/planning_test.cpp @@ -190,11 +190,10 @@ bool CheckEdgeCollisionFree( return true; } -template WaypointVector SmoothWaypoints( const WaypointVector& waypoints, const std::function& check_edge_fn, - PRNG& prng) + const utility::UniformUnitRealFunction& uniform_unit_real_fn) { // Parameters for shortcut smoothing const uint32_t max_iterations = 100; @@ -203,11 +202,11 @@ WaypointVector SmoothWaypoints( const double max_shortcut_fraction = 0.5; const double resample_shortcuts_interval = 0.5; const bool check_for_marginal_shortcuts = false; - return path_processing::ShortcutSmoothPath( + return path_processing::ShortcutSmoothPath( waypoints, max_iterations, max_failed_iterations, max_backtracking_steps, max_shortcut_fraction, resample_shortcuts_interval, check_for_marginal_shortcuts, check_edge_fn, WaypointDistance, - InterpolateWaypoint, prng); + InterpolateWaypoint, uniform_unit_real_fn); } void DrawRoadmap( @@ -256,12 +255,15 @@ WaypointVector GenerateAllPossible8ConnectedChildren( Waypoint(waypoint.first + 1, waypoint.second + 1)}; } -template -Waypoint SampleWaypoint(const TestMap& map, PRNG& rng) +Waypoint SampleWaypoint( + const TestMap& map, + const utility::UniformUnitRealFunction& uniform_unit_real_fn) { - std::uniform_int_distribution row_dist(1, map.rows() - 1); - std::uniform_int_distribution col_dist(1, map.rows() - 1); - return Waypoint(row_dist(rng), col_dist(rng)); + const ssize_t row = utility::GetUniformRandomInRange( + uniform_unit_real_fn, 1, map.rows() - 1); + const ssize_t col = utility::GetUniformRandomInRange( + uniform_unit_real_fn, 1, map.cols() - 1); + return Waypoint(row, col); } GTEST_TEST(PlanningTest, Test) @@ -292,6 +294,12 @@ GTEST_TEST(PlanningTest, Test) const int64_t prng_seed = 42; std::mt19937_64 prng(prng_seed); + std::uniform_real_distribution uniform_unit_dist(0.0, 1.0); + utility::UniformUnitRealFunction uniform_unit_real_fn = [&] () + { + return uniform_unit_dist(prng); + }; + const WaypointVector keypoints = {Waypoint(1, 1), Waypoint(18, 18), Waypoint(7, 13), Waypoint(9, 5)}; @@ -313,7 +321,7 @@ GTEST_TEST(PlanningTest, Test) const std::function state_sampling_fn = [&] (void) { - return SampleWaypoint(test_env, prng); + return SampleWaypoint(test_env, uniform_unit_real_fn); }; // Functions to check planning results @@ -345,7 +353,7 @@ GTEST_TEST(PlanningTest, Test) std::cout << "Checking raw path" << std::endl; check_path(path); const auto smoothed_path = - SmoothWaypoints(path, check_edge_validity_fn, prng); + SmoothWaypoints(path, check_edge_validity_fn, uniform_unit_real_fn); std::cout << "Checking smoothed path" << std::endl; check_path(smoothed_path); const auto resampled_path = ResampleWaypoints(smoothed_path); @@ -533,7 +541,7 @@ GTEST_TEST(PlanningTest, Test) // Query-specific RRT helpers const auto rrt_sample_fn = simple_rrt_planner::MakeStateAndGoalsSamplingFunction( - state_sampling_fn, {goal}, rrt_goal_bias, prng); + state_sampling_fn, {goal}, rrt_goal_bias, uniform_unit_real_fn); const simple_rrt_planner::CheckGoalReachedFunction rrt_goal_reached_fn = [&] (const Waypoint& state) { @@ -582,14 +590,14 @@ GTEST_TEST(PlanningTest, Test) const auto birrt_extent_path = simple_rrt_planner::BiRRTPlanSinglePath< - std::mt19937_64, Waypoint, WaypointPlannerTree, WaypointVector>( + Waypoint, WaypointPlannerTree, WaypointVector>( birrt_extend_start_tree, birrt_extend_goal_tree, state_sampling_fn, birrt_nearest_neighbors_fn, birrt_extend_fn, {}, birrt_states_connected_fn, {}, birrt_tree_sampling_bias, birrt_p_switch_trees, simple_rrt_planner ::MakeBiRRTTimeoutTerminationFunction(rrt_timeout), - prng).Path(); + uniform_unit_real_fn).Path(); check_plan(test_env, {start}, {goal}, birrt_extent_path); // Plan with BiRRT-Connect @@ -602,14 +610,14 @@ GTEST_TEST(PlanningTest, Test) const auto birrt_connect_path = simple_rrt_planner::BiRRTPlanSinglePath< - std::mt19937_64, Waypoint, WaypointPlannerTree, WaypointVector>( + Waypoint, WaypointPlannerTree, WaypointVector>( birrt_connect_start_tree, birrt_connect_goal_tree, state_sampling_fn, birrt_nearest_neighbors_fn, birrt_connect_fn, {}, birrt_states_connected_fn, {}, birrt_tree_sampling_bias, birrt_p_switch_trees, simple_rrt_planner ::MakeBiRRTTimeoutTerminationFunction(rrt_timeout), - prng).Path(); + uniform_unit_real_fn).Path(); check_plan(test_env, {start}, {goal}, birrt_connect_path); } } @@ -641,14 +649,14 @@ GTEST_TEST(PlanningTest, Test) const auto birrt_connect_path = simple_rrt_planner::BiRRTPlanSinglePath< - std::mt19937_64, Waypoint, WaypointPlannerTree, WaypointVector>( + Waypoint, WaypointPlannerTree, WaypointVector>( birrt_connect_start_tree, birrt_connect_goal_tree, state_sampling_fn, birrt_nearest_neighbors_fn, birrt_connect_fn, {}, birrt_states_connected_fn, {}, birrt_tree_sampling_bias, birrt_p_switch_trees, simple_rrt_planner ::MakeBiRRTTimeoutTerminationFunction(rrt_timeout), - prng).Path(); + uniform_unit_real_fn).Path(); check_plan(test_env, starts, goals, birrt_connect_path); // Use one of the trees to check tree serialization & deserialization