Skip to content

Commit

Permalink
Added callback for validating parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
ndelima-ekumen committed Sep 20, 2024
1 parent 83a9be1 commit d72c3eb
Showing 1 changed file with 52 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
QoSProfile,
ReliabilityPolicy,
)
from rclpy.parameter import Parameter
from rcl_interfaces.msg import SetParametersResult
from sensor_msgs.msg import CompressedImage, Image
from std_msgs.msg import Header
from vision_msgs.msg import (
Expand Down Expand Up @@ -77,8 +79,12 @@ class FruitDetectionNode(Node):
RECT_COLOR = (0, 0, 255)
SCORE_THRESHOLD = 0.7
LOGGING_THROTTLE = 1
MINIMUM_BOX_SIZE_X = 50
MINIMUM_BOX_SIZE_Y = 50
DEFAULT_BBOX_SIZE_X = 50
DEFAULT_BBOX_SIZE_Y = 50
MINIMUM_BBOX_SIZE_X = 0
MINIMUM_BBOX_SIZE_Y = 0
MAXIMUM_BBOX_SIZE_X = 640
MAXIMUM_BBOX_SIZE_Y = 480

def __init__(self) -> None:
"""Initialize the node."""
Expand All @@ -90,13 +96,15 @@ def __init__(self) -> None:
)
self.declare_parameter(
"bbox_min_x",
FruitDetectionNode.MINIMUM_BOX_SIZE_X,
FruitDetectionNode.DEFAULT_BBOX_SIZE_X,
)
self.declare_parameter(
"bbox_min_y",
FruitDetectionNode.MINIMUM_BOX_SIZE_Y,
FruitDetectionNode.DEFAULT_BBOX_SIZE_Y,
)

self.add_on_set_parameters_callback(self.validate_parameters)

self.__model_path = (
self.get_parameter("model_path").get_parameter_value().string_value
)
Expand Down Expand Up @@ -143,6 +151,34 @@ def __init__(self) -> None:
self.get_logger().info("ingestion;inference;plot;detection;publish;")
self.ingest_transform = get_transform()

def validate_parameters(self, params):
"""
Validate parameter changes.
:param params: list of parameters.
:return: SetParametersResult.
"""
parameters_are_valid = True
for param in params:
if param.name == "bbox_min_x":
if param.type_ != Parameter.Type.INTEGER or not (
FruitDetectionNode.MINIMUM_BBOX_SIZE_X
<= param.value
<= FruitDetectionNode.MAXIMUM_BBOX_SIZE_X
):
parameters_are_valid = False
break
elif param.name == "bbox_min_y":
if param.type_ != Parameter.Type.INTEGER or not (
FruitDetectionNode.MINIMUM_BBOX_SIZE_Y
<= param.value
<= FruitDetectionNode.MAXIMUM_BBOX_SIZE_Y
):
parameters_are_valid = False
break
return SetParametersResult(successful=parameters_are_valid)

def load_model(self):
"""Load the torch model."""
self.model = fasterrcnn_resnet50_fpn(weights=None)
Expand Down Expand Up @@ -202,16 +238,22 @@ def score_frame(self, frame):
for i, (box, score, label) in enumerate(
zip(output["boxes"], output["scores"], output["labels"])
):
bbox_min_x = (
self.get_parameter("bbox_min_x")
.get_parameter_value()
.integer_value # Minimum bbox x size
)
bbox_min_y = (
self.get_parameter("bbox_min_y")
.get_parameter_value()
.integer_value # Minimum bbox y size
)
if (
score >= FruitDetectionNode.SCORE_THRESHOLD
and self.bbox_has_minimum_size(
box,
self.get_parameter("bbox_min_x")
.get_parameter_value()
.integer_value,
self.get_parameter("bbox_min_y")
.get_parameter_value()
.integer_value,
bbox_min_x,
bbox_min_y,
)
):
results.append(
Expand Down

0 comments on commit d72c3eb

Please sign in to comment.