diff --git a/detection_ws/src/fruit_detection/fruit_detection/fruit_detection_node.py b/detection_ws/src/fruit_detection/fruit_detection/fruit_detection_node.py index 1dddd2f..2c7b514 100644 --- a/detection_ws/src/fruit_detection/fruit_detection/fruit_detection_node.py +++ b/detection_ws/src/fruit_detection/fruit_detection/fruit_detection_node.py @@ -21,6 +21,8 @@ QoSProfile, ReliabilityPolicy, ) +from rclpy.parameter import Parameter +from rcl_interfaces.msg import SetParametersResult from sensor_msgs.msg import CompressedImage, Image from std_msgs.msg import Header from vision_msgs.msg import ( @@ -77,8 +79,12 @@ class FruitDetectionNode(Node): RECT_COLOR = (0, 0, 255) SCORE_THRESHOLD = 0.7 LOGGING_THROTTLE = 1 - MINIMUM_BOX_SIZE_X = 50 - MINIMUM_BOX_SIZE_Y = 50 + DEFAULT_BBOX_SIZE_X = 50 + DEFAULT_BBOX_SIZE_Y = 50 + MINIMUM_BBOX_SIZE_X = 0 + MINIMUM_BBOX_SIZE_Y = 0 + MAXIMUM_BBOX_SIZE_X = 640 + MAXIMUM_BBOX_SIZE_Y = 480 def __init__(self) -> None: """Initialize the node.""" @@ -90,13 +96,15 @@ def __init__(self) -> None: ) self.declare_parameter( "bbox_min_x", - FruitDetectionNode.MINIMUM_BOX_SIZE_X, + FruitDetectionNode.DEFAULT_BBOX_SIZE_X, ) self.declare_parameter( "bbox_min_y", - FruitDetectionNode.MINIMUM_BOX_SIZE_Y, + FruitDetectionNode.DEFAULT_BBOX_SIZE_Y, ) + self.add_on_set_parameters_callback(self.validate_parameters) + self.__model_path = ( self.get_parameter("model_path").get_parameter_value().string_value ) @@ -143,6 +151,34 @@ def __init__(self) -> None: self.get_logger().info("ingestion;inference;plot;detection;publish;") self.ingest_transform = get_transform() + def validate_parameters(self, params): + """ + Validate parameter changes. + + :param params: list of parameters. + + :return: SetParametersResult. + """ + parameters_are_valid = True + for param in params: + if param.name == "bbox_min_x": + if param.type_ != Parameter.Type.INTEGER or not ( + FruitDetectionNode.MINIMUM_BBOX_SIZE_X + <= param.value + <= FruitDetectionNode.MAXIMUM_BBOX_SIZE_X + ): + parameters_are_valid = False + break + elif param.name == "bbox_min_y": + if param.type_ != Parameter.Type.INTEGER or not ( + FruitDetectionNode.MINIMUM_BBOX_SIZE_Y + <= param.value + <= FruitDetectionNode.MAXIMUM_BBOX_SIZE_Y + ): + parameters_are_valid = False + break + return SetParametersResult(successful=parameters_are_valid) + def load_model(self): """Load the torch model.""" self.model = fasterrcnn_resnet50_fpn(weights=None) @@ -202,16 +238,22 @@ def score_frame(self, frame): for i, (box, score, label) in enumerate( zip(output["boxes"], output["scores"], output["labels"]) ): + bbox_min_x = ( + self.get_parameter("bbox_min_x") + .get_parameter_value() + .integer_value # Minimum bbox x size + ) + bbox_min_y = ( + self.get_parameter("bbox_min_y") + .get_parameter_value() + .integer_value # Minimum bbox y size + ) if ( score >= FruitDetectionNode.SCORE_THRESHOLD and self.bbox_has_minimum_size( box, - self.get_parameter("bbox_min_x") - .get_parameter_value() - .integer_value, - self.get_parameter("bbox_min_y") - .get_parameter_value() - .integer_value, + bbox_min_x, + bbox_min_y, ) ): results.append(