From 9bc256fc99601daa80820c1226dd3acf8d2fbb6b Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Fri, 4 Aug 2023 15:24:46 +0530 Subject: [PATCH 01/32] chore: initial commit --- .../object_detection/faster_rcnn/__init__.py | 22 + .../faster_rcnn/faster_rcnn.py | 449 ++++++++++++++++++ .../faster_rcnn/faster_rcnn_test.py | 0 .../faster_rcnn/feature_pyramid.py | 75 +++ .../object_detection/faster_rcnn/rcnn_head.py | 71 +++ .../object_detection/faster_rcnn/rpn_head.py | 105 ++++ 6 files changed, 722 insertions(+) create mode 100644 keras_cv/models/object_detection/faster_rcnn/__init__.py create mode 100644 keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py create mode 100644 keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py create mode 100644 keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py create mode 100644 keras_cv/models/object_detection/faster_rcnn/rcnn_head.py create mode 100644 keras_cv/models/object_detection/faster_rcnn/rpn_head.py diff --git a/keras_cv/models/object_detection/faster_rcnn/__init__.py b/keras_cv/models/object_detection/faster_rcnn/__init__.py new file mode 100644 index 0000000000..e366f22b29 --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_cv.models.object_detection.faster_rcnn.feature_pyramid import ( + FeaturePyramid, +) +from keras_cv.models.object_detection.faster_rcnn.rcnn_head import ( + RCNNHead, +) +from keras_cv.models.object_detection.faster_rcnn.rpn_head import ( # noqa: E501 + RPNHead, +) \ No newline at end of file diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py new file mode 100644 index 0000000000..f628cab5d2 --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -0,0 +1,449 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(ariG23498): Remove TF import +import tensorflow as tf + +import keras_cv +from keras_cv import bounding_box +from keras_cv import layers as cv_layers +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.bounding_box.converters import _decode_deltas_to_boxes +# from keras_cv.models.backbones.backbone_presets import backbone_presets +# from keras_cv.models.backbones.backbone_presets import ( +# backbone_presets_with_weights, +# ) +from keras_cv.models.object_detection.__internal__ import unpack_input +from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid +from keras_cv.models.object_detection.faster_rcnn import RPNHead +from keras_cv.models.object_detection.faster_rcnn import RCNNHead +# from keras_cv.models.object_detection.retinanet import RetinaNetLabelEncoder +# from keras_cv.models.object_detection.retinanet.retinanet_presets import ( +# retinanet_presets, +# ) +from keras_cv.models.task import Task +# from keras_cv.utils.python_utils import classproperty +from keras_cv.utils.train import get_feature_extractor + +# All the imports from legacy +from keras_cv.bounding_box.utils import _clip_boxes +from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator +from keras_cv.layers.object_detection.box_matcher import BoxMatcher +from keras_cv.layers.object_detection.roi_align import _ROIAligner +from keras_cv.layers.object_detection.roi_generator import ROIGenerator +from keras_cv.layers.object_detection.roi_sampler import _ROISampler +from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder +from keras_cv.models.object_detection import predict_utils + +BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] + + +# TODO(tanzheny): add more configurations +@keras_cv_export("keras_cv.models.FasterRCNN") +class FasterRCNN(Task): + """A Keras model implementing the FasterRCNN architecture. + + Implements the FasterRCNN architecture for object detection. The constructor + requires `num_classes`, `bounding_box_format` and a `backbone`. + + References: + - [FasterRCNN](https://arxiv.org/pdf/1506.01497.pdf) + + Usage: + ```python + retinanet = keras_cv.models.FasterRCNN( + num_classes=20, + bounding_box_format="xywh", + backbone=None, + ) + ``` + + Args: + num_classes: the number of classes in your dataset excluding the + background class. classes should be represented by integers in the + range [0, num_classes). + bounding_box_format: The format of bounding boxes of model output. Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + backbone: Optional `keras.Model`. Must implement the + `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" + and layer names as values. If `None`, defaults to + `keras_cv.models.ResNet50Backbone()`. + anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. It is + used in the model to match ground truth boxes and labels with + anchors, or with region proposals. By default it uses the sizes and + ratios from the paper, that is optimized for image size between + [640, 800]. The users should pass their own anchor generator if the + input image size differs from paper. For now, only anchor generator + with per level dict output is supported, + label_encoder: (Optional) a keras.Layer that accepts an anchors Tensor, + a bounding box Tensor and a bounding box class Tensor to its + `call()` method, and returns RetinaNet training targets. It returns + box and class targets as well as sample weights. + rcnn_head: (Optional) a `keras.layers.Layer` that takes input feature + map and returns a box delta prediction (in reference to rois) and + multi-class prediction (all foreground classes + one background + class). By default it uses the rcnn head from paper, which is 2 FC + layer with 1024 dimension, 1 box regressor and 1 softmax classifier. + prediction_decoder: (Optional) a `keras.layers.Layer` that takes input + box prediction and softmaxed score prediction, and returns NMSed box + prediction, NMSed softmaxed score prediction, NMSed class + prediction, and NMSed valid detection. + """ # noqa: E501 + + def __init__( + self, + backbone, + num_classes, + bounding_box_format, + anchor_generator=None, + label_encoder=None, + rcnn_head=None, + prediction_decoder=None, + **kwargs, + ): + self.bounding_box_format = bounding_box_format + super().__init__(**kwargs) + scales = [2**x for x in [0]] + aspect_ratios = [0.5, 1.0, 2.0] + self.anchor_generator = anchor_generator or AnchorGenerator( + bounding_box_format="yxyx", + sizes={ + "P2": 32.0, + "P3": 64.0, + "P4": 128.0, + "P5": 256.0, + "P6": 512.0, + }, + scales=scales, + aspect_ratios=aspect_ratios, + strides={f"P{i}": 2**i for i in range(2, 7)}, + clip_boxes=True, + ) + self.rpn_head = RPNHead( + num_anchors_per_location=len(scales) * len(aspect_ratios) + ) + self.roi_generator = ROIGenerator( + bounding_box_format="yxyx", + nms_score_threshold_train=float("-inf"), + nms_score_threshold_test=float("-inf"), + ) + self.box_matcher = BoxMatcher( + thresholds=[0.0, 0.5], match_values=[-2, -1, 1] + ) + self.roi_sampler = _ROISampler( + bounding_box_format="yxyx", + roi_matcher=self.box_matcher, + background_class=num_classes, + num_sampled_rois=512, + ) + self.roi_pooler = _ROIAligner(bounding_box_format="yxyx") + self.rcnn_head = rcnn_head or RCNNHead(num_classes) + self.backbone = backbone or keras_cv.models.ResNet50Backbone() + extractor_levels = ["P2", "P3", "P4", "P5"] + extractor_layer_names = [ + self.backbone.pyramid_level_inputs[i] for i in extractor_levels + ] + self.feature_extractor = get_feature_extractor( + self.backbone, extractor_layer_names, extractor_levels + ) + self.feature_pyramid = FeaturePyramid() + self.rpn_labeler = label_encoder or _RpnLabelEncoder( + anchor_format="yxyx", + ground_truth_box_format="yxyx", + positive_threshold=0.7, + negative_threshold=0.3, + samples_per_image=256, + positive_fraction=0.5, + box_variance=BOX_VARIANCE, + ) + self._prediction_decoder = ( + prediction_decoder + or cv_layers.MultiClassNonMaxSuppression( + bounding_box_format=bounding_box_format, + from_logits=False, + max_detections_per_class=10, + max_detections=10, + ) + ) + + def _call_rpn(self, images, anchors, training=None): + image_shape = ops.shape(images[0]) + backbone_outputs = self.feature_extractor(images, training=training) + feature_map = self.feature_pyramid(backbone_outputs, training=training) + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_boxes, rpn_scores = self.rpn_head(feature_map, training=training) + # the decoded format is center_xywh, convert to yxyx + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + rois, _ = self.roi_generator( + decoded_rpn_boxes, rpn_scores, training=training + ) + rois = _clip_boxes(rois, "yxyx", image_shape) + rpn_boxes = ops.concat(tf.nest.flatten(rpn_boxes), axis=1) + rpn_scores = ops.concat(tf.nest.flatten(rpn_scores), axis=1) + return rois, feature_map, rpn_boxes, rpn_scores + + def _call_rcnn(self, rois, feature_map, training=None): + feature_map = self.roi_pooler(feature_map, rois) + # [BS, H*W*K, pool_shape*C] + feature_map = ops.reshape( + feature_map, ops.concat([ops.shape(rois)[:2], [-1]], axis=0) + ) + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + rcnn_box_pred, rcnn_cls_pred = self.rcnn_head( + feature_map, training=training + ) + return rcnn_box_pred, rcnn_cls_pred + + def call(self, images, training=None): + image_shape = ops.shape(images[0]) + anchors = self.anchor_generator(image_shape=image_shape) + rois, feature_map, _, _ = self._call_rpn( + images, anchors, training=training + ) + box_pred, cls_pred = self._call_rcnn( + rois, feature_map, training=training + ) + if not training: + # box_pred is on "center_yxhw" format, convert to target format. + box_pred = _decode_deltas_to_boxes( + anchors=rois, + boxes_delta=box_pred, + anchor_format="yxyx", + box_format=self.bounding_box_format, + variance=[0.1, 0.1, 0.2, 0.2], + ) + + return box_pred, cls_pred + + # TODO(tanzhenyu): Support compile with metrics. + def compile( + self, + box_loss=None, + classification_loss=None, + rpn_box_loss=None, + rpn_classification_loss=None, + weight_decay=0.0001, + loss=None, + **kwargs, + ): + # TODO(tanzhenyu): Add metrics support once COCOMap issue is addressed. + # https://github.com/keras-team/keras-cv/issues/915 + if "metrics" in kwargs.keys(): + raise ValueError( + "`FasterRCNN` does not currently support the use of " + "`metrics` due to performance and distribution concerns. " + "Please use the `PyCOCOCallback` to evaluate COCO metrics." + ) + if loss is not None: + raise ValueError( + "`FasterRCNN` does not accept a `loss` to `compile()`. " + "Instead, please pass `box_loss` and `classification_loss`. " + "`loss` will be ignored during training." + ) + box_loss = _validate_and_get_loss(box_loss, "box_loss") + classification_loss = _validate_and_get_loss( + classification_loss, "classification_loss" + ) + rpn_box_loss = _validate_and_get_loss(rpn_box_loss, "rpn_box_loss") + if rpn_classification_loss == "BinaryCrossentropy": + rpn_classification_loss = keras.losses.BinaryCrossentropy( + from_logits=True, reduction=keras.losses.Reduction.SUM + ) + rpn_classification_loss = _validate_and_get_loss( + rpn_classification_loss, "rpn_cls_loss" + ) + if not rpn_classification_loss.from_logits: + raise ValueError( + "`rpn_classification_loss` must come with `from_logits`=True" + ) + + self.rpn_box_loss = rpn_box_loss + self.rpn_cls_loss = rpn_classification_loss + self.box_loss = box_loss + self.cls_loss = classification_loss + self.weight_decay = weight_decay + losses = { + "box": self.box_loss, + "classification": self.cls_loss, + "rpn_box": self.rpn_box_loss, + "rpn_classification": self.rpn_cls_loss, + } + super().compile(loss=losses, **kwargs) + + def compute_loss(self, images, boxes, classes, training): + local_batch = images.get_shape().as_list()[0] + if tf.distribute.has_strategy(): + num_sync = tf.distribute.get_strategy().num_replicas_in_sync + else: + num_sync = 1 + global_batch = local_batch * num_sync + anchors = self.anchor_generator(image_shape=tuple(images[0].shape)) + ( + rpn_box_targets, + rpn_box_weights, + rpn_cls_targets, + rpn_cls_weights, + ) = self.rpn_labeler( + tf.concat(tf.nest.flatten(anchors), axis=0), boxes, classes + ) + rpn_box_weights /= ( + self.rpn_labeler.samples_per_image * global_batch * 0.25 + ) + rpn_cls_weights /= self.rpn_labeler.samples_per_image * global_batch + rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( + images, anchors, training=training + ) + rois = tf.stop_gradient(rois) + ( + rois, + box_targets, + box_weights, + cls_targets, + cls_weights, + ) = self.roi_sampler(rois, boxes, classes) + box_weights /= self.roi_sampler.num_sampled_rois * global_batch * 0.25 + cls_weights /= self.roi_sampler.num_sampled_rois * global_batch + box_pred, cls_pred = self._call_rcnn( + rois, feature_map, training=training + ) + y_true = { + "rpn_box": rpn_box_targets, + "rpn_classification": rpn_cls_targets, + "box": box_targets, + "classification": cls_targets, + } + y_pred = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + } + weights = { + "rpn_box": rpn_box_weights, + "rpn_classification": rpn_cls_weights, + "box": box_weights, + "classification": cls_weights, + } + return super().compute_loss( + x=images, y=y_true, y_pred=y_pred, sample_weight=weights + ) + + def train_step(self, data): + images, y = unpack_input(data) + + boxes = y["boxes"] + if len(y["classes"].shape) != 2: + raise ValueError( + "Expected 'classes' to be a tf.Tensor of rank 2. " + f"Got y['classes'].shape={y['classes'].shape}." + ) + # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere + classes = tf.expand_dims(y["classes"], axis=-1) + with tf.GradientTape() as tape: + total_loss = self.compute_loss( + images, boxes, classes, training=True + ) + reg_losses = [] + if self.weight_decay: + for var in self.trainable_variables: + if "bn" not in var.name: + reg_losses.append( + self.weight_decay * tf.nn.l2_loss(var) + ) + l2_loss = tf.math.add_n(reg_losses) + total_loss += l2_loss + self.optimizer.minimize(total_loss, self.trainable_variables, tape=tape) + return self.compute_metrics(images, {}, {}, sample_weight={}) + + def test_step(self, data): + images, y = unpack_input(data) + + boxes = y["boxes"] + if len(y["classes"].shape) != 2: + raise ValueError( + "Expected 'classes' to be a tf.Tensor of rank 2. " + f"Got y['classes'].shape={y['classes'].shape}." + ) + classes = tf.expand_dims(y["classes"], axis=-1) + self.compute_loss(images, boxes, classes, training=False) + return self.compute_metrics(images, {}, {}, sample_weight={}) + + def make_predict_function(self, force=False): + return predict_utils.make_predict_function(self, force=force) + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + + def decode_predictions(self, predictions, images): + # no-op if default decoder is used. + box_pred, scores_pred = predictions + box_pred = bounding_box.convert_format( + box_pred, + source=self.bounding_box_format, + target=self.prediction_decoder.bounding_box_format, + images=images, + ) + y_pred = self.prediction_decoder(box_pred, scores_pred[..., :-1]) + box_pred = bounding_box.convert_format( + y_pred["boxes"], + source=self.prediction_decoder.bounding_box_format, + target=self.bounding_box_format, + images=images, + ) + y_pred["boxes"] = box_pred + return y_pred + + def get_config(self): + return { + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "backbone": self.backbone, + "anchor_generator": self.anchor_generator, + "label_encoder": self.rpn_labeler, + "prediction_decoder": self._prediction_decoder, + "feature_pyramid": self.feature_pyramid, + "rcnn_head": self.rcnn_head, + } + + +def _validate_and_get_loss(loss, loss_name): + if isinstance(loss, str): + loss = keras.losses.get(loss) + if loss is None or not isinstance(loss, keras.losses.Loss): + raise ValueError( + f"FasterRCNN only accepts `keras.losses.Loss` for {loss_name}, " + f"got {loss}" + ) + if loss.reduction != keras.losses.Reduction.SUM: + logging.info( + f"FasterRCNN only accepts `SUM` reduction, got {loss.reduction}, " + "automatically converted." + ) + loss.reduction = keras.losses.Reduction.SUM + return loss diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py b/keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py new file mode 100644 index 0000000000..04337b8cbc --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py @@ -0,0 +1,75 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras + + +@keras_cv_export( + "keras_cv.models.faster_rcnn.FeaturePyramid", + package="keras_cv.models.faster_rcnn", +) +class FeaturePyramid(keras.layers.Layer): + """Builds the Feature Pyramid with the feature maps from the backbone.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.conv_c2_1x1 = keras.layers.Conv2D(256, 1, 1, "same") + self.conv_c3_1x1 = keras.layers.Conv2D(256, 1, 1, "same") + self.conv_c4_1x1 = keras.layers.Conv2D(256, 1, 1, "same") + self.conv_c5_1x1 = keras.layers.Conv2D(256, 1, 1, "same") + + self.conv_c2_3x3 = keras.layers.Conv2D(256, 3, 1, "same") + self.conv_c3_3x3 = keras.layers.Conv2D(256, 3, 1, "same") + self.conv_c4_3x3 = keras.layers.Conv2D(256, 3, 1, "same") + self.conv_c5_3x3 = keras.layers.Conv2D(256, 3, 1, "same") + self.conv_c6_3x3 = keras.layers.Conv2D(256, 3, 1, "same") + self.conv_c6_pool = keras.layers.MaxPool2D() + self.upsample_2x = keras.layers.UpSampling2D(2) + + def call(self, inputs, training=None): + c2_output = inputs["P2"] + c3_output = inputs["P3"] + c4_output = inputs["P4"] + c5_output = inputs["P5"] + + c6_output = self.conv_c6_pool(c5_output) + p6_output = c6_output + p5_output = self.conv_c5_1x1(c5_output) + p4_output = self.conv_c4_1x1(c4_output) + p3_output = self.conv_c3_1x1(c3_output) + p2_output = self.conv_c2_1x1(c2_output) + + p4_output = p4_output + self.upsample_2x(p5_output) + p3_output = p3_output + self.upsample_2x(p4_output) + p2_output = p2_output + self.upsample_2x(p3_output) + + p6_output = self.conv_c6_3x3(p6_output) + p5_output = self.conv_c5_3x3(p5_output) + p4_output = self.conv_c4_3x3(p4_output) + p3_output = self.conv_c3_3x3(p3_output) + p2_output = self.conv_c2_3x3(p2_output) + + return { + "P2": p2_output, + "P3": p3_output, + "P4": p4_output, + "P5": p5_output, + "P6": p6_output, + } + + def get_config(self): + config = {} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) \ No newline at end of file diff --git a/keras_cv/models/object_detection/faster_rcnn/rcnn_head.py b/keras_cv/models/object_detection/faster_rcnn/rcnn_head.py new file mode 100644 index 0000000000..f9c4a28822 --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/rcnn_head.py @@ -0,0 +1,71 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras + +@keras_cv_export( + "keras_cv.models.faster_rcnn.RCNNHead", + package="keras_cv.models.faster_rcnn", +) +class RCNNHead(keras.layers.Layer): + def __init__( + self, + num_classes, + conv_dims=[], + fc_dims=[1024, 1024], + **kwargs, + ): + super().__init__(**kwargs) + self.num_classes = num_classes + self.conv_dims = conv_dims + self.fc_dims = fc_dims + self.convs = [] + for conv_dim in conv_dims: + layer = keras.layers.Conv2D( + filters=conv_dim, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + ) + self.convs.append(layer) + self.fcs = [] + for fc_dim in fc_dims: + layer = keras.layers.Dense(units=fc_dim, activation="relu") + self.fcs.append(layer) + self.box_pred = keras.layers.Dense(units=4) + self.cls_score = keras.layers.Dense( + units=num_classes + 1, activation="softmax" + ) + + def call(self, feature_map, training=None): + x = feature_map + for conv in self.convs: + x = conv(x) + for fc in self.fcs: + x = fc(x) + rcnn_boxes = self.box_pred(x) + rcnn_scores = self.cls_score(x) + return rcnn_boxes, rcnn_scores + + def get_config(self): + config = { + "num_classes": self.num_classes, + "conv_dims": self.conv_dims, + "fc_dims": self.fc_dims, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + diff --git a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py new file mode 100644 index 0000000000..16414e4921 --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py @@ -0,0 +1,105 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO @ariG23498 +# Device a way to remove tf import +import tensorflow as tf + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops + +@keras_cv_export( + "keras_cv.models.faster_rcnn.RPNHead", + package="keras_cv.models.faster_rcnn", +) +class RPNHead(keras.layers.Layer): + def __init__( + self, + num_anchors_per_location=3, + **kwargs, + ): + super().__init__(**kwargs) + self.num_anchors = num_anchors_per_location + + def build(self, input_shape): + if isinstance(input_shape, (dict, list, tuple)): + # TODO @ariG23498 + # Device a way to remove tf import + input_shape = tf.nest.flatten(input_shape) + input_shape = input_shape[0] + filters = input_shape[-1] + self.conv = keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + kernel_initializer="truncated_normal", + ) + self.objectness_logits = keras.layers.Conv2D( + filters=self.num_anchors * 1, + kernel_size=1, + strides=1, + padding="same", + kernel_initializer="truncated_normal", + ) + self.anchor_deltas = keras.layers.Conv2D( + filters=self.num_anchors * 4, + kernel_size=1, + strides=1, + padding="same", + kernel_initializer="truncated_normal", + ) + + def call(self, feature_map, training=None): + def call_single_level(f_map): + batch_size = f_map.get_shape().as_list()[0] or tf.shape(f_map)[0] + # [BS, H, W, C] + t = self.conv(f_map) + # [BS, H, W, K] + rpn_scores = self.objectness_logits(t) + # [BS, H, W, K * 4] + rpn_boxes = self.anchor_deltas(t) + # [BS, H*W*K, 4] + rpn_boxes = ops.reshape(rpn_boxes, [batch_size, -1, 4]) + # [BS, H*W*K, 1] + rpn_scores = ops.reshape(rpn_scores, [batch_size, -1, 1]) + return rpn_boxes, rpn_scores + + if not isinstance(feature_map, (dict, list, tuple)): + return call_single_level(feature_map) + elif isinstance(feature_map, (list, tuple)): + rpn_boxes = [] + rpn_scores = [] + for f_map in feature_map: + rpn_box, rpn_score = call_single_level(f_map) + rpn_boxes.append(rpn_box) + rpn_scores.append(rpn_score) + return rpn_boxes, rpn_scores + else: + rpn_boxes = {} + rpn_scores = {} + for lvl, f_map in feature_map.items(): + rpn_box, rpn_score = call_single_level(f_map) + rpn_boxes[lvl] = rpn_box + rpn_scores[lvl] = rpn_score + return rpn_boxes, rpn_scores + + def get_config(self): + config = { + "num_anchors_per_location": self.num_anchors, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) \ No newline at end of file From a8ad7c4e6cad9eeae26075b51d40622027f934c9 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Thu, 17 Aug 2023 17:40:41 +0530 Subject: [PATCH 02/32] review comments --- keras_cv/models/__init__.py | 3 + .../object_detection/faster_rcnn/__init__.py | 5 +- .../faster_rcnn/faster_rcnn.py | 72 ++++++++++++++----- .../object_detection/faster_rcnn/rpn_head.py | 10 +-- 4 files changed, 64 insertions(+), 26 deletions(-) diff --git a/keras_cv/models/__init__.py b/keras_cv/models/__init__.py index 3e5847a346..e2010b251f 100644 --- a/keras_cv/models/__init__.py +++ b/keras_cv/models/__init__.py @@ -123,3 +123,6 @@ from keras_cv.models.segmentation import DeepLabV3Plus from keras_cv.models.stable_diffusion import StableDiffusion from keras_cv.models.stable_diffusion import StableDiffusionV2 +from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import ( + FasterRCNN, +) \ No newline at end of file diff --git a/keras_cv/models/object_detection/faster_rcnn/__init__.py b/keras_cv/models/object_detection/faster_rcnn/__init__.py index e366f22b29..c66ec02fc2 100644 --- a/keras_cv/models/object_detection/faster_rcnn/__init__.py +++ b/keras_cv/models/object_detection/faster_rcnn/__init__.py @@ -17,6 +17,9 @@ from keras_cv.models.object_detection.faster_rcnn.rcnn_head import ( RCNNHead, ) -from keras_cv.models.object_detection.faster_rcnn.rpn_head import ( # noqa: E501 +from keras_cv.models.object_detection.faster_rcnn.rpn_head import ( # noqa: E501 RPNHead, +) +from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import (# noqa: E501 + FasterRCNN, ) \ No newline at end of file diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index f628cab5d2..b847dab38b 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO(ariG23498): Remove TF import -import tensorflow as tf +import tree import keras_cv from keras_cv import bounding_box @@ -62,26 +61,17 @@ class FasterRCNN(Task): References: - [FasterRCNN](https://arxiv.org/pdf/1506.01497.pdf) - Usage: - ```python - retinanet = keras_cv.models.FasterRCNN( - num_classes=20, - bounding_box_format="xywh", - backbone=None, - ) - ``` - Args: + backbone: `keras.Model`. Must implement the + `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" + and layer names as values. If `None`, defaults to + `keras_cv.models.ResNet50Backbone()`. num_classes: the number of classes in your dataset excluding the background class. classes should be represented by integers in the range [0, num_classes). bounding_box_format: The format of bounding boxes of model output. Refer [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) for more details on supported bounding box formats. - backbone: Optional `keras.Model`. Must implement the - `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" - and layer names as values. If `None`, defaults to - `keras_cv.models.ResNet50Backbone()`. anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. It is used in the model to match ground truth boxes and labels with anchors, or with region proposals. By default it uses the sizes and @@ -102,6 +92,44 @@ class FasterRCNN(Task): box prediction and softmaxed score prediction, and returns NMSed box prediction, NMSed softmaxed score prediction, NMSed class prediction, and NMSed valid detection. + + Examples: + + ```python + images = np.ones((1, 512, 512, 3)) + labels = { + "boxes": [ + [ + [0, 0, 100, 100], + [100, 100, 200, 200], + [300, 300, 100, 100], + ] + ], + "classes": [[1, 1, 1]], + } + model = keras_cv.models.FasterRCNN( + num_classes=20, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet50Backbone.from_preset( + "resnet50_imagenet" + ) + ) + + # Evaluate model without box decoding and NMS + model(images) + + # Prediction with box decoding and NMS + model.predict(images) + + # Train model + model.compile( + classification_loss='focal', + box_loss='smoothl1', + optimizer=keras.optimizers.SGD(global_clipnorm=10.0), + jit_compile=False, + ) + model.fit(images, labels) + ``` """ # noqa: E501 def __init__( @@ -198,8 +226,8 @@ def _call_rpn(self, images, anchors, training=None): decoded_rpn_boxes, rpn_scores, training=training ) rois = _clip_boxes(rois, "yxyx", image_shape) - rpn_boxes = ops.concat(tf.nest.flatten(rpn_boxes), axis=1) - rpn_scores = ops.concat(tf.nest.flatten(rpn_scores), axis=1) + rpn_boxes = ops.concat(tree.flatten(rpn_boxes), axis=1) + rpn_scores = ops.concat(tree.flatten(rpn_scores), axis=1) return rois, feature_map, rpn_boxes, rpn_scores def _call_rcnn(self, rois, feature_map, training=None): @@ -304,7 +332,7 @@ def compute_loss(self, images, boxes, classes, training): rpn_cls_targets, rpn_cls_weights, ) = self.rpn_labeler( - tf.concat(tf.nest.flatten(anchors), axis=0), boxes, classes + ops.concat(tree.flatten(anchors), axis=0), boxes, classes ) rpn_box_weights /= ( self.rpn_labeler.samples_per_image * global_batch * 0.25 @@ -431,6 +459,14 @@ def get_config(self): "rcnn_head": self.rcnn_head, } + # def presets(cls): + # return super().presets + + # def presets_with_weights(cls): + # return super().presets_with_weights + + # def backbone_presets(cls): + # return super().backbone_presets def _validate_and_get_loss(loss, loss_name): if isinstance(loss, str): diff --git a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py index 16414e4921..d5f8eebfff 100644 --- a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py +++ b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO @ariG23498 -# Device a way to remove tf import -import tensorflow as tf +import tree from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras @@ -35,9 +33,7 @@ def __init__( def build(self, input_shape): if isinstance(input_shape, (dict, list, tuple)): - # TODO @ariG23498 - # Device a way to remove tf import - input_shape = tf.nest.flatten(input_shape) + input_shape = tree.flatten(input_shape) input_shape = input_shape[0] filters = input_shape[-1] self.conv = keras.layers.Conv2D( @@ -65,7 +61,7 @@ def build(self, input_shape): def call(self, feature_map, training=None): def call_single_level(f_map): - batch_size = f_map.get_shape().as_list()[0] or tf.shape(f_map)[0] + batch_size = f_map.get_shape().as_list()[0] or ops.shape(f_map)[0] # [BS, H, W, C] t = self.conv(f_map) # [BS, H, W, K] From ed3337cd6bc6f0ec26b1dc5a965ef52eff38efa5 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Fri, 18 Aug 2023 19:47:52 +0530 Subject: [PATCH 03/32] chore: train test step modification --- keras_cv/models/__init__.py | 4 +- keras_cv/models/legacy/__init__.py | 3 - .../legacy/object_detection/__init__.py | 13 - .../object_detection/faster_rcnn/__init__.py | 13 - .../faster_rcnn/faster_rcnn.py | 618 ------------------ .../faster_rcnn/faster_rcnn_test.py | 107 --- .../object_detection/faster_rcnn/__init__.py | 13 +- .../faster_rcnn/faster_rcnn.py | 157 ++--- .../faster_rcnn/feature_pyramid.py | 2 +- .../object_detection/faster_rcnn/rcnn_head.py | 2 +- .../object_detection/faster_rcnn/rpn_head.py | 3 +- 11 files changed, 79 insertions(+), 856 deletions(-) delete mode 100644 keras_cv/models/legacy/object_detection/__init__.py delete mode 100644 keras_cv/models/legacy/object_detection/faster_rcnn/__init__.py delete mode 100644 keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn.py delete mode 100644 keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py diff --git a/keras_cv/models/__init__.py b/keras_cv/models/__init__.py index 626dfbac83..bf238d1bcb 100644 --- a/keras_cv/models/__init__.py +++ b/keras_cv/models/__init__.py @@ -131,6 +131,7 @@ ResNetV2Backbone, ) from keras_cv.models.classification.image_classifier import ImageClassifier +from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import FasterRCNN from keras_cv.models.object_detection.retinanet.retinanet import RetinaNet from keras_cv.models.object_detection.yolo_v8.yolo_v8_backbone import ( YOLOV8Backbone, @@ -141,6 +142,3 @@ from keras_cv.models.segmentation import DeepLabV3Plus from keras_cv.models.stable_diffusion import StableDiffusion from keras_cv.models.stable_diffusion import StableDiffusionV2 -from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import ( - FasterRCNN, -) \ No newline at end of file diff --git a/keras_cv/models/legacy/__init__.py b/keras_cv/models/legacy/__init__.py index 20df5826f0..628137597f 100644 --- a/keras_cv/models/legacy/__init__.py +++ b/keras_cv/models/legacy/__init__.py @@ -35,9 +35,6 @@ from keras_cv.models.legacy.mlp_mixer import MLPMixerB16 from keras_cv.models.legacy.mlp_mixer import MLPMixerB32 from keras_cv.models.legacy.mlp_mixer import MLPMixerL16 -from keras_cv.models.legacy.object_detection.faster_rcnn.faster_rcnn import ( - FasterRCNN, -) from keras_cv.models.legacy.regnet import RegNetX002 from keras_cv.models.legacy.regnet import RegNetX004 from keras_cv.models.legacy.regnet import RegNetX006 diff --git a/keras_cv/models/legacy/object_detection/__init__.py b/keras_cv/models/legacy/object_detection/__init__.py deleted file mode 100644 index 65be099991..0000000000 --- a/keras_cv/models/legacy/object_detection/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2022 The KerasCV Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/keras_cv/models/legacy/object_detection/faster_rcnn/__init__.py b/keras_cv/models/legacy/object_detection/faster_rcnn/__init__.py deleted file mode 100644 index 3992ffb59a..0000000000 --- a/keras_cv/models/legacy/object_detection/faster_rcnn/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2023 The KerasCV Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn.py deleted file mode 100644 index f380e3c1ba..0000000000 --- a/keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn.py +++ /dev/null @@ -1,618 +0,0 @@ -# Copyright 2022 The KerasCV Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf -from absl import logging -from tensorflow import keras - -import keras_cv -from keras_cv import bounding_box -from keras_cv import layers as cv_layers -from keras_cv.bounding_box.converters import _decode_deltas_to_boxes -from keras_cv.bounding_box.utils import _clip_boxes -from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator -from keras_cv.layers.object_detection.box_matcher import BoxMatcher -from keras_cv.layers.object_detection.roi_align import _ROIAligner -from keras_cv.layers.object_detection.roi_generator import ROIGenerator -from keras_cv.layers.object_detection.roi_sampler import _ROISampler -from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder -from keras_cv.models.object_detection import predict_utils -from keras_cv.models.object_detection.__internal__ import unpack_input -from keras_cv.utils.train import get_feature_extractor - -BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] - - -class FeaturePyramid(keras.layers.Layer): - """Builds the Feature Pyramid with the feature maps from the backbone.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.conv_c2_1x1 = keras.layers.Conv2D(256, 1, 1, "same") - self.conv_c3_1x1 = keras.layers.Conv2D(256, 1, 1, "same") - self.conv_c4_1x1 = keras.layers.Conv2D(256, 1, 1, "same") - self.conv_c5_1x1 = keras.layers.Conv2D(256, 1, 1, "same") - - self.conv_c2_3x3 = keras.layers.Conv2D(256, 3, 1, "same") - self.conv_c3_3x3 = keras.layers.Conv2D(256, 3, 1, "same") - self.conv_c4_3x3 = keras.layers.Conv2D(256, 3, 1, "same") - self.conv_c5_3x3 = keras.layers.Conv2D(256, 3, 1, "same") - self.conv_c6_3x3 = keras.layers.Conv2D(256, 3, 1, "same") - self.conv_c6_pool = keras.layers.MaxPool2D() - self.upsample_2x = keras.layers.UpSampling2D(2) - - def call(self, inputs, training=None): - c2_output = inputs["P2"] - c3_output = inputs["P3"] - c4_output = inputs["P4"] - c5_output = inputs["P5"] - - c6_output = self.conv_c6_pool(c5_output) - p6_output = c6_output - p5_output = self.conv_c5_1x1(c5_output) - p4_output = self.conv_c4_1x1(c4_output) - p3_output = self.conv_c3_1x1(c3_output) - p2_output = self.conv_c2_1x1(c2_output) - - p4_output = p4_output + self.upsample_2x(p5_output) - p3_output = p3_output + self.upsample_2x(p4_output) - p2_output = p2_output + self.upsample_2x(p3_output) - - p6_output = self.conv_c6_3x3(p6_output) - p5_output = self.conv_c5_3x3(p5_output) - p4_output = self.conv_c4_3x3(p4_output) - p3_output = self.conv_c3_3x3(p3_output) - p2_output = self.conv_c2_3x3(p2_output) - - return { - "P2": p2_output, - "P3": p3_output, - "P4": p4_output, - "P5": p5_output, - "P6": p6_output, - } - - def get_config(self): - config = {} - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) - - -class RPNHead(keras.layers.Layer): - def __init__( - self, - num_anchors_per_location=3, - **kwargs, - ): - super().__init__(**kwargs) - self.num_anchors = num_anchors_per_location - - def build(self, input_shape): - if isinstance(input_shape, (dict, list, tuple)): - input_shape = tf.nest.flatten(input_shape) - input_shape = input_shape[0] - filters = input_shape[-1] - self.conv = keras.layers.Conv2D( - filters=filters, - kernel_size=3, - strides=1, - padding="same", - activation="relu", - kernel_initializer="truncated_normal", - ) - self.objectness_logits = keras.layers.Conv2D( - filters=self.num_anchors * 1, - kernel_size=1, - strides=1, - padding="same", - kernel_initializer="truncated_normal", - ) - self.anchor_deltas = keras.layers.Conv2D( - filters=self.num_anchors * 4, - kernel_size=1, - strides=1, - padding="same", - kernel_initializer="truncated_normal", - ) - - def call(self, feature_map, training=None): - def call_single_level(f_map): - batch_size = f_map.get_shape().as_list()[0] or tf.shape(f_map)[0] - # [BS, H, W, C] - t = self.conv(f_map) - # [BS, H, W, K] - rpn_scores = self.objectness_logits(t) - # [BS, H, W, K * 4] - rpn_boxes = self.anchor_deltas(t) - # [BS, H*W*K, 4] - rpn_boxes = tf.reshape(rpn_boxes, [batch_size, -1, 4]) - # [BS, H*W*K, 1] - rpn_scores = tf.reshape(rpn_scores, [batch_size, -1, 1]) - return rpn_boxes, rpn_scores - - if not isinstance(feature_map, (dict, list, tuple)): - return call_single_level(feature_map) - elif isinstance(feature_map, (list, tuple)): - rpn_boxes = [] - rpn_scores = [] - for f_map in feature_map: - rpn_box, rpn_score = call_single_level(f_map) - rpn_boxes.append(rpn_box) - rpn_scores.append(rpn_score) - return rpn_boxes, rpn_scores - else: - rpn_boxes = {} - rpn_scores = {} - for lvl, f_map in feature_map.items(): - rpn_box, rpn_score = call_single_level(f_map) - rpn_boxes[lvl] = rpn_box - rpn_scores[lvl] = rpn_score - return rpn_boxes, rpn_scores - - def get_config(self): - config = { - "num_anchors_per_location": self.num_anchors, - } - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) - - -# class agnostic regression -class RCNNHead(keras.layers.Layer): - def __init__( - self, - num_classes, - conv_dims=[], - fc_dims=[1024, 1024], - **kwargs, - ): - super().__init__(**kwargs) - self.num_classes = num_classes - self.conv_dims = conv_dims - self.fc_dims = fc_dims - self.convs = [] - for conv_dim in conv_dims: - layer = keras.layers.Conv2D( - filters=conv_dim, - kernel_size=3, - strides=1, - padding="same", - activation="relu", - ) - self.convs.append(layer) - self.fcs = [] - for fc_dim in fc_dims: - layer = keras.layers.Dense(units=fc_dim, activation="relu") - self.fcs.append(layer) - self.box_pred = keras.layers.Dense(units=4) - self.cls_score = keras.layers.Dense( - units=num_classes + 1, activation="softmax" - ) - - def call(self, feature_map, training=None): - x = feature_map - for conv in self.convs: - x = conv(x) - for fc in self.fcs: - x = fc(x) - rcnn_boxes = self.box_pred(x) - rcnn_scores = self.cls_score(x) - return rcnn_boxes, rcnn_scores - - def get_config(self): - config = { - "num_classes": self.num_classes, - "conv_dims": self.conv_dims, - "fc_dims": self.fc_dims, - } - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) - - -# TODO(tanzheny): add more configurations -@keras.utils.register_keras_serializable(package="keras_cv") -class FasterRCNN(keras.Model): - """A Keras model implementing the FasterRCNN architecture. - - Implements the FasterRCNN architecture for object detection. The constructor - requires `num_classes`, `bounding_box_format` and a `backbone`. - - References: - - [FasterRCNN](https://arxiv.org/pdf/1506.01497.pdf) - - Usage: - ```python - retinanet = keras_cv.models.FasterRCNN( - num_classes=20, - bounding_box_format="xywh", - backbone=None, - ) - ``` - - Args: - num_classes: the number of classes in your dataset excluding the - background class. classes should be represented by integers in the - range [0, num_classes). - bounding_box_format: The format of bounding boxes of model output. Refer - [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) - for more details on supported bounding box formats. - backbone: Optional `keras.Model`. Must implement the - `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" - and layer names as values. If `None`, defaults to - `keras_cv.models.ResNet50Backbone()`. - anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. It is - used in the model to match ground truth boxes and labels with - anchors, or with region proposals. By default it uses the sizes and - ratios from the paper, that is optimized for image size between - [640, 800]. The users should pass their own anchor generator if the - input image size differs from paper. For now, only anchor generator - with per level dict output is supported, - label_encoder: (Optional) a keras.Layer that accepts an anchors Tensor, - a bounding box Tensor and a bounding box class Tensor to its - `call()` method, and returns RetinaNet training targets. It returns - box and class targets as well as sample weights. - rcnn_head: (Optional) a `keras.layers.Layer` that takes input feature - map and returns a box delta prediction (in reference to rois) and - multi-class prediction (all foreground classes + one background - class). By default it uses the rcnn head from paper, which is 2 FC - layer with 1024 dimension, 1 box regressor and 1 softmax classifier. - prediction_decoder: (Optional) a `keras.layers.Layer` that takes input - box prediction and softmaxed score prediction, and returns NMSed box - prediction, NMSed softmaxed score prediction, NMSed class - prediction, and NMSed valid detection. - """ # noqa: E501 - - def __init__( - self, - num_classes, - bounding_box_format, - backbone=None, - anchor_generator=None, - label_encoder=None, - rcnn_head=None, - prediction_decoder=None, - **kwargs, - ): - self.bounding_box_format = bounding_box_format - super().__init__(**kwargs) - scales = [2**x for x in [0]] - aspect_ratios = [0.5, 1.0, 2.0] - self.anchor_generator = anchor_generator or AnchorGenerator( - bounding_box_format="yxyx", - sizes={ - "P2": 32.0, - "P3": 64.0, - "P4": 128.0, - "P5": 256.0, - "P6": 512.0, - }, - scales=scales, - aspect_ratios=aspect_ratios, - strides={f"P{i}": 2**i for i in range(2, 7)}, - clip_boxes=True, - ) - self.rpn_head = RPNHead( - num_anchors_per_location=len(scales) * len(aspect_ratios) - ) - self.roi_generator = ROIGenerator( - bounding_box_format="yxyx", - nms_score_threshold_train=float("-inf"), - nms_score_threshold_test=float("-inf"), - ) - self.box_matcher = BoxMatcher( - thresholds=[0.0, 0.5], match_values=[-2, -1, 1] - ) - self.roi_sampler = _ROISampler( - bounding_box_format="yxyx", - roi_matcher=self.box_matcher, - background_class=num_classes, - num_sampled_rois=512, - ) - self.roi_pooler = _ROIAligner(bounding_box_format="yxyx") - self.rcnn_head = rcnn_head or RCNNHead(num_classes) - self.backbone = backbone or keras_cv.models.ResNet50Backbone() - extractor_levels = ["P2", "P3", "P4", "P5"] - extractor_layer_names = [ - self.backbone.pyramid_level_inputs[i] for i in extractor_levels - ] - self.feature_extractor = get_feature_extractor( - self.backbone, extractor_layer_names, extractor_levels - ) - self.feature_pyramid = FeaturePyramid() - self.rpn_labeler = label_encoder or _RpnLabelEncoder( - anchor_format="yxyx", - ground_truth_box_format="yxyx", - positive_threshold=0.7, - negative_threshold=0.3, - samples_per_image=256, - positive_fraction=0.5, - box_variance=BOX_VARIANCE, - ) - self._prediction_decoder = ( - prediction_decoder - or cv_layers.MultiClassNonMaxSuppression( - bounding_box_format=bounding_box_format, - from_logits=False, - max_detections_per_class=10, - max_detections=10, - ) - ) - - def _call_rpn(self, images, anchors, training=None): - image_shape = tf.shape(images[0]) - backbone_outputs = self.feature_extractor(images, training=training) - feature_map = self.feature_pyramid(backbone_outputs, training=training) - # [BS, num_anchors, 4], [BS, num_anchors, 1] - rpn_boxes, rpn_scores = self.rpn_head(feature_map, training=training) - # the decoded format is center_xywh, convert to yxyx - decoded_rpn_boxes = _decode_deltas_to_boxes( - anchors=anchors, - boxes_delta=rpn_boxes, - anchor_format="yxyx", - box_format="yxyx", - variance=BOX_VARIANCE, - ) - rois, _ = self.roi_generator( - decoded_rpn_boxes, rpn_scores, training=training - ) - rois = _clip_boxes(rois, "yxyx", image_shape) - rpn_boxes = tf.concat(tf.nest.flatten(rpn_boxes), axis=1) - rpn_scores = tf.concat(tf.nest.flatten(rpn_scores), axis=1) - return rois, feature_map, rpn_boxes, rpn_scores - - def _call_rcnn(self, rois, feature_map, training=None): - feature_map = self.roi_pooler(feature_map, rois) - # [BS, H*W*K, pool_shape*C] - feature_map = tf.reshape( - feature_map, tf.concat([tf.shape(rois)[:2], [-1]], axis=0) - ) - # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] - rcnn_box_pred, rcnn_cls_pred = self.rcnn_head( - feature_map, training=training - ) - return rcnn_box_pred, rcnn_cls_pred - - def call(self, images, training=None): - image_shape = tf.shape(images[0]) - anchors = self.anchor_generator(image_shape=image_shape) - rois, feature_map, _, _ = self._call_rpn( - images, anchors, training=training - ) - box_pred, cls_pred = self._call_rcnn( - rois, feature_map, training=training - ) - if not training: - # box_pred is on "center_yxhw" format, convert to target format. - box_pred = _decode_deltas_to_boxes( - anchors=rois, - boxes_delta=box_pred, - anchor_format="yxyx", - box_format=self.bounding_box_format, - variance=[0.1, 0.1, 0.2, 0.2], - ) - - return box_pred, cls_pred - - # TODO(tanzhenyu): Support compile with metrics. - def compile( - self, - box_loss=None, - classification_loss=None, - rpn_box_loss=None, - rpn_classification_loss=None, - weight_decay=0.0001, - loss=None, - **kwargs, - ): - # TODO(tanzhenyu): Add metrics support once COCOMap issue is addressed. - # https://github.com/keras-team/keras-cv/issues/915 - if "metrics" in kwargs.keys(): - raise ValueError( - "`FasterRCNN` does not currently support the use of " - "`metrics` due to performance and distribution concerns. " - "Please use the `PyCOCOCallback` to evaluate COCO metrics." - ) - if loss is not None: - raise ValueError( - "`FasterRCNN` does not accept a `loss` to `compile()`. " - "Instead, please pass `box_loss` and `classification_loss`. " - "`loss` will be ignored during training." - ) - box_loss = _validate_and_get_loss(box_loss, "box_loss") - classification_loss = _validate_and_get_loss( - classification_loss, "classification_loss" - ) - rpn_box_loss = _validate_and_get_loss(rpn_box_loss, "rpn_box_loss") - if rpn_classification_loss == "BinaryCrossentropy": - rpn_classification_loss = keras.losses.BinaryCrossentropy( - from_logits=True, reduction=keras.losses.Reduction.SUM - ) - rpn_classification_loss = _validate_and_get_loss( - rpn_classification_loss, "rpn_cls_loss" - ) - if not rpn_classification_loss.from_logits: - raise ValueError( - "`rpn_classification_loss` must come with `from_logits`=True" - ) - - self.rpn_box_loss = rpn_box_loss - self.rpn_cls_loss = rpn_classification_loss - self.box_loss = box_loss - self.cls_loss = classification_loss - self.weight_decay = weight_decay - losses = { - "box": self.box_loss, - "classification": self.cls_loss, - "rpn_box": self.rpn_box_loss, - "rpn_classification": self.rpn_cls_loss, - } - super().compile(loss=losses, **kwargs) - - def compute_loss(self, images, boxes, classes, training): - local_batch = images.get_shape().as_list()[0] - if tf.distribute.has_strategy(): - num_sync = tf.distribute.get_strategy().num_replicas_in_sync - else: - num_sync = 1 - global_batch = local_batch * num_sync - anchors = self.anchor_generator(image_shape=tuple(images[0].shape)) - ( - rpn_box_targets, - rpn_box_weights, - rpn_cls_targets, - rpn_cls_weights, - ) = self.rpn_labeler( - tf.concat(tf.nest.flatten(anchors), axis=0), boxes, classes - ) - rpn_box_weights /= ( - self.rpn_labeler.samples_per_image * global_batch * 0.25 - ) - rpn_cls_weights /= self.rpn_labeler.samples_per_image * global_batch - rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( - images, anchors, training=training - ) - rois = tf.stop_gradient(rois) - ( - rois, - box_targets, - box_weights, - cls_targets, - cls_weights, - ) = self.roi_sampler(rois, boxes, classes) - box_weights /= self.roi_sampler.num_sampled_rois * global_batch * 0.25 - cls_weights /= self.roi_sampler.num_sampled_rois * global_batch - box_pred, cls_pred = self._call_rcnn( - rois, feature_map, training=training - ) - y_true = { - "rpn_box": rpn_box_targets, - "rpn_classification": rpn_cls_targets, - "box": box_targets, - "classification": cls_targets, - } - y_pred = { - "rpn_box": rpn_box_pred, - "rpn_classification": rpn_cls_pred, - "box": box_pred, - "classification": cls_pred, - } - weights = { - "rpn_box": rpn_box_weights, - "rpn_classification": rpn_cls_weights, - "box": box_weights, - "classification": cls_weights, - } - return super().compute_loss( - x=images, y=y_true, y_pred=y_pred, sample_weight=weights - ) - - def train_step(self, data): - images, y = unpack_input(data) - - boxes = y["boxes"] - if len(y["classes"].shape) != 2: - raise ValueError( - "Expected 'classes' to be a tf.Tensor of rank 2. " - f"Got y['classes'].shape={y['classes'].shape}." - ) - # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere - classes = tf.expand_dims(y["classes"], axis=-1) - with tf.GradientTape() as tape: - total_loss = self.compute_loss( - images, boxes, classes, training=True - ) - reg_losses = [] - if self.weight_decay: - for var in self.trainable_variables: - if "bn" not in var.name: - reg_losses.append( - self.weight_decay * tf.nn.l2_loss(var) - ) - l2_loss = tf.math.add_n(reg_losses) - total_loss += l2_loss - self.optimizer.minimize(total_loss, self.trainable_variables, tape=tape) - return self.compute_metrics(images, {}, {}, sample_weight={}) - - def test_step(self, data): - images, y = unpack_input(data) - - boxes = y["boxes"] - if len(y["classes"].shape) != 2: - raise ValueError( - "Expected 'classes' to be a tf.Tensor of rank 2. " - f"Got y['classes'].shape={y['classes'].shape}." - ) - classes = tf.expand_dims(y["classes"], axis=-1) - self.compute_loss(images, boxes, classes, training=False) - return self.compute_metrics(images, {}, {}, sample_weight={}) - - def make_predict_function(self, force=False): - return predict_utils.make_predict_function(self, force=force) - - @property - def prediction_decoder(self): - return self._prediction_decoder - - @prediction_decoder.setter - def prediction_decoder(self, prediction_decoder): - self._prediction_decoder = prediction_decoder - self.make_predict_function(force=True) - - def decode_predictions(self, predictions, images): - # no-op if default decoder is used. - box_pred, scores_pred = predictions - box_pred = bounding_box.convert_format( - box_pred, - source=self.bounding_box_format, - target=self.prediction_decoder.bounding_box_format, - images=images, - ) - y_pred = self.prediction_decoder(box_pred, scores_pred[..., :-1]) - box_pred = bounding_box.convert_format( - y_pred["boxes"], - source=self.prediction_decoder.bounding_box_format, - target=self.bounding_box_format, - images=images, - ) - y_pred["boxes"] = box_pred - return y_pred - - def get_config(self): - return { - "num_classes": self.num_classes, - "bounding_box_format": self.bounding_box_format, - "backbone": self.backbone, - "anchor_generator": self.anchor_generator, - "label_encoder": self.rpn_labeler, - "prediction_decoder": self._prediction_decoder, - "feature_pyramid": self.feature_pyramid, - "rcnn_head": self.rcnn_head, - } - - -def _validate_and_get_loss(loss, loss_name): - if isinstance(loss, str): - loss = keras.losses.get(loss) - if loss is None or not isinstance(loss, keras.losses.Loss): - raise ValueError( - f"FasterRCNN only accepts `keras.losses.Loss` for {loss_name}, " - f"got {loss}" - ) - if loss.reduction != keras.losses.Reduction.SUM: - logging.info( - f"FasterRCNN only accepts `SUM` reduction, got {loss.reduction}, " - "automatically converted." - ) - loss.reduction = keras.losses.Reduction.SUM - return loss diff --git a/keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py deleted file mode 100644 index b8930af944..0000000000 --- a/keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2022 The KerasCV Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import tensorflow as tf -from absl.testing import parameterized -from tensorflow import keras -from tensorflow.keras import optimizers - -from keras_cv.models import ResNet18V2Backbone -from keras_cv.models.legacy.object_detection.faster_rcnn.faster_rcnn import ( - FasterRCNN, -) -from keras_cv.models.object_detection.__test_utils__ import ( - _create_bounding_box_dataset, -) -from keras_cv.tests.test_case import TestCase - - -class FasterRCNNTest(TestCase): - # TODO(ianstenbit): Make FasterRCNN support shapes that are not multiples - # of 128, perhaps by adding a flag to the anchor generator for whether to - # include anchors centered outside of the image. (RetinaNet does use those, - # while FasterRCNN doesn't). For more context on why this is the case, see - # https://github.com/keras-team/keras-cv/pull/1882 - @parameterized.parameters( - ((2, 640, 384, 3),), - ((2, 512, 512, 3),), - ((2, 128, 128, 3),), - ) - def test_faster_rcnn_infer(self, batch_shape): - model = FasterRCNN( - num_classes=80, - bounding_box_format="xyxy", - backbone=ResNet18V2Backbone(), - ) - images = tf.random.normal(batch_shape) - outputs = model(images, training=False) - # 1000 proposals in inference - self.assertAllEqual([2, 1000, 81], outputs[1].shape) - self.assertAllEqual([2, 1000, 4], outputs[0].shape) - - @parameterized.parameters( - ((2, 640, 384, 3),), - ((2, 512, 512, 3),), - ((2, 128, 128, 3),), - ) - def test_faster_rcnn_train(self, batch_shape): - model = FasterRCNN( - num_classes=80, - bounding_box_format="xyxy", - backbone=ResNet18V2Backbone(), - ) - images = tf.random.normal(batch_shape) - outputs = model(images, training=True) - self.assertAllEqual([2, 1000, 81], outputs[1].shape) - self.assertAllEqual([2, 1000, 4], outputs[0].shape) - - def test_invalid_compile(self): - model = FasterRCNN( - num_classes=80, - bounding_box_format="yxyx", - backbone=ResNet18V2Backbone(), - ) - with self.assertRaisesRegex(ValueError, "only accepts"): - model.compile(rpn_box_loss="binary_crossentropy") - with self.assertRaisesRegex(ValueError, "only accepts"): - model.compile( - rpn_classification_loss=keras.losses.BinaryCrossentropy( - from_logits=False - ) - ) - - @pytest.mark.large # Fit is slow, so mark these large. - def test_faster_rcnn_with_dictionary_input_format(self): - faster_rcnn = FasterRCNN( - num_classes=20, - bounding_box_format="xywh", - backbone=ResNet18V2Backbone(), - ) - - images, boxes = _create_bounding_box_dataset("xywh") - dataset = tf.data.Dataset.from_tensor_slices( - {"images": images, "bounding_boxes": boxes} - ).batch(5, drop_remainder=True) - - faster_rcnn.compile( - optimizer=optimizers.Adam(), - box_loss="Huber", - classification_loss="SparseCategoricalCrossentropy", - rpn_box_loss="Huber", - rpn_classification_loss="BinaryCrossentropy", - ) - - faster_rcnn.fit(dataset, epochs=1) - faster_rcnn.evaluate(dataset) diff --git a/keras_cv/models/object_detection/faster_rcnn/__init__.py b/keras_cv/models/object_detection/faster_rcnn/__init__.py index c66ec02fc2..02cac23b6e 100644 --- a/keras_cv/models/object_detection/faster_rcnn/__init__.py +++ b/keras_cv/models/object_detection/faster_rcnn/__init__.py @@ -11,15 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from keras_cv.models.object_detection.faster_rcnn.feature_pyramid import ( - FeaturePyramid, -) -from keras_cv.models.object_detection.faster_rcnn.rcnn_head import ( - RCNNHead, -) -from keras_cv.models.object_detection.faster_rcnn.rpn_head import ( # noqa: E501 - RPNHead, -) -from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import (# noqa: E501 +from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import ( # noqa: E501 FasterRCNN, -) \ No newline at end of file +) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index b847dab38b..7868684eb0 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy + import tree +from absl import logging import keras_cv from keras_cv import bounding_box @@ -21,21 +24,6 @@ from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.bounding_box.converters import _decode_deltas_to_boxes -# from keras_cv.models.backbones.backbone_presets import backbone_presets -# from keras_cv.models.backbones.backbone_presets import ( -# backbone_presets_with_weights, -# ) -from keras_cv.models.object_detection.__internal__ import unpack_input -from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid -from keras_cv.models.object_detection.faster_rcnn import RPNHead -from keras_cv.models.object_detection.faster_rcnn import RCNNHead -# from keras_cv.models.object_detection.retinanet import RetinaNetLabelEncoder -# from keras_cv.models.object_detection.retinanet.retinanet_presets import ( -# retinanet_presets, -# ) -from keras_cv.models.task import Task -# from keras_cv.utils.python_utils import classproperty -from keras_cv.utils.train import get_feature_extractor # All the imports from legacy from keras_cv.bounding_box.utils import _clip_boxes @@ -45,7 +33,19 @@ from keras_cv.layers.object_detection.roi_generator import ROIGenerator from keras_cv.layers.object_detection.roi_sampler import _ROISampler from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder -from keras_cv.models.object_detection import predict_utils +from keras_cv.models.backbones.backbone_presets import backbone_presets +from keras_cv.models.backbones.backbone_presets import ( + backbone_presets_with_weights, +) +from keras_cv.models.object_detection.__internal__ import unpack_input +from keras_cv.models.object_detection.faster_rcnn.feature_pyramid import ( + FeaturePyramid, +) +from keras_cv.models.object_detection.faster_rcnn.rcnn_head import RCNNHead +from keras_cv.models.object_detection.faster_rcnn.rpn_head import RPNHead +from keras_cv.models.task import Task +from keras_cv.utils.python_utils import classproperty +from keras_cv.utils.train import get_feature_extractor BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] @@ -92,7 +92,7 @@ class FasterRCNN(Task): box prediction and softmaxed score prediction, and returns NMSed box prediction, NMSed softmaxed score prediction, NMSed class prediction, and NMSed valid detection. - + Examples: ```python @@ -129,7 +129,7 @@ class FasterRCNN(Task): jit_compile=False, ) model.fit(images, labels) - ``` + ``` """ # noqa: E501 def __init__( @@ -200,11 +200,12 @@ def __init__( ) self._prediction_decoder = ( prediction_decoder - or cv_layers.MultiClassNonMaxSuppression( + or cv_layers.NonMaxSuppression( bounding_box_format=bounding_box_format, from_logits=False, - max_detections_per_class=10, - max_detections=10, + iou_threshold=0.5, + confidence_threshold=0.5, + max_detections=100, ) ) @@ -226,15 +227,15 @@ def _call_rpn(self, images, anchors, training=None): decoded_rpn_boxes, rpn_scores, training=training ) rois = _clip_boxes(rois, "yxyx", image_shape) - rpn_boxes = ops.concat(tree.flatten(rpn_boxes), axis=1) - rpn_scores = ops.concat(tree.flatten(rpn_scores), axis=1) + rpn_boxes = ops.concatenate(tree.flatten(rpn_boxes), axis=1) + rpn_scores = ops.concatenate(tree.flatten(rpn_scores), axis=1) return rois, feature_map, rpn_boxes, rpn_scores def _call_rcnn(self, rois, feature_map, training=None): feature_map = self.roi_pooler(feature_map, rois) # [BS, H*W*K, pool_shape*C] feature_map = ops.reshape( - feature_map, ops.concat([ops.shape(rois)[:2], [-1]], axis=0) + feature_map, ops.concatenate([ops.shape(rois)[:2], [-1]], axis=0) ) # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] rcnn_box_pred, rcnn_cls_pred = self.rcnn_head( @@ -320,11 +321,6 @@ def compile( def compute_loss(self, images, boxes, classes, training): local_batch = images.get_shape().as_list()[0] - if tf.distribute.has_strategy(): - num_sync = tf.distribute.get_strategy().num_replicas_in_sync - else: - num_sync = 1 - global_batch = local_batch * num_sync anchors = self.anchor_generator(image_shape=tuple(images[0].shape)) ( rpn_box_targets, @@ -332,16 +328,16 @@ def compute_loss(self, images, boxes, classes, training): rpn_cls_targets, rpn_cls_weights, ) = self.rpn_labeler( - ops.concat(tree.flatten(anchors), axis=0), boxes, classes + ops.concatenate(tree.flatten(anchors), axis=0), boxes, classes ) rpn_box_weights /= ( - self.rpn_labeler.samples_per_image * global_batch * 0.25 + self.rpn_labeler.samples_per_image * local_batch * 0.25 ) - rpn_cls_weights /= self.rpn_labeler.samples_per_image * global_batch + rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( images, anchors, training=training ) - rois = tf.stop_gradient(rois) + rois = ops.stop_gradient(rois) ( rois, box_targets, @@ -349,8 +345,8 @@ def compute_loss(self, images, boxes, classes, training): cls_targets, cls_weights, ) = self.roi_sampler(rois, boxes, classes) - box_weights /= self.roi_sampler.num_sampled_rois * global_batch * 0.25 - cls_weights /= self.roi_sampler.num_sampled_rois * global_batch + box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 + cls_weights /= self.roi_sampler.num_sampled_rois * local_batch box_pred, cls_pred = self._call_rcnn( rois, feature_map, training=training ) @@ -376,48 +372,24 @@ def compute_loss(self, images, boxes, classes, training): x=images, y=y_true, y_pred=y_pred, sample_weight=weights ) - def train_step(self, data): - images, y = unpack_input(data) - - boxes = y["boxes"] - if len(y["classes"].shape) != 2: - raise ValueError( - "Expected 'classes' to be a tf.Tensor of rank 2. " - f"Got y['classes'].shape={y['classes'].shape}." - ) - # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere - classes = tf.expand_dims(y["classes"], axis=-1) - with tf.GradientTape() as tape: - total_loss = self.compute_loss( - images, boxes, classes, training=True - ) - reg_losses = [] - if self.weight_decay: - for var in self.trainable_variables: - if "bn" not in var.name: - reg_losses.append( - self.weight_decay * tf.nn.l2_loss(var) - ) - l2_loss = tf.math.add_n(reg_losses) - total_loss += l2_loss - self.optimizer.minimize(total_loss, self.trainable_variables, tape=tape) - return self.compute_metrics(images, {}, {}, sample_weight={}) - - def test_step(self, data): - images, y = unpack_input(data) - - boxes = y["boxes"] - if len(y["classes"].shape) != 2: - raise ValueError( - "Expected 'classes' to be a tf.Tensor of rank 2. " - f"Got y['classes'].shape={y['classes'].shape}." - ) - classes = tf.expand_dims(y["classes"], axis=-1) - self.compute_loss(images, boxes, classes, training=False) - return self.compute_metrics(images, {}, {}, sample_weight={}) - - def make_predict_function(self, force=False): - return predict_utils.make_predict_function(self, force=force) + def train_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().train_step(*args, (x, y)) + + def test_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().test_step(*args, (x, y)) + + def predict_step(self, *args): + outputs = super().predict_step(*args) + if type(outputs) is tuple: + return self.decode_predictions(outputs[0], args[-1]), outputs[1] + else: + return self.decode_predictions(outputs, args[-1]) @property def prediction_decoder(self): @@ -459,14 +431,29 @@ def get_config(self): "rcnn_head": self.rcnn_head, } - # def presets(cls): - # return super().presets - - # def presets_with_weights(cls): - # return super().presets_with_weights - - # def backbone_presets(cls): - # return super().backbone_presets + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + # return copy.deepcopy({**backbone_presets, **fasterrcnn_presets}) + return copy.deepcopy({**backbone_presets}) + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return copy.deepcopy( + # {**backbone_presets_with_weights, **fasterrcnn_presets} + { + **backbone_presets_with_weights, + } + ) + + @classproperty + def backbone_presets(cls): + """Dictionary of preset names and configurations of compatible + backbones.""" + return copy.deepcopy(backbone_presets) + def _validate_and_get_loss(loss, loss_name): if isinstance(loss, str): diff --git a/keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py b/keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py index 04337b8cbc..18648a4ccc 100644 --- a/keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py +++ b/keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py @@ -72,4 +72,4 @@ def call(self, inputs, training=None): def get_config(self): config = {} base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) \ No newline at end of file + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_cv/models/object_detection/faster_rcnn/rcnn_head.py b/keras_cv/models/object_detection/faster_rcnn/rcnn_head.py index f9c4a28822..4caec64076 100644 --- a/keras_cv/models/object_detection/faster_rcnn/rcnn_head.py +++ b/keras_cv/models/object_detection/faster_rcnn/rcnn_head.py @@ -15,6 +15,7 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras + @keras_cv_export( "keras_cv.models.faster_rcnn.RCNNHead", package="keras_cv.models.faster_rcnn", @@ -68,4 +69,3 @@ def get_config(self): } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) - diff --git a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py index d5f8eebfff..fbd3bd8dc3 100644 --- a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py +++ b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py @@ -18,6 +18,7 @@ from keras_cv.backend import keras from keras_cv.backend import ops + @keras_cv_export( "keras_cv.models.faster_rcnn.RPNHead", package="keras_cv.models.faster_rcnn", @@ -98,4 +99,4 @@ def get_config(self): "num_anchors_per_location": self.num_anchors, } base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) \ No newline at end of file + return dict(list(base_config.items()) + list(config.items())) From 005f70d05d2fe68dbe8b4c99e973adc051f88322 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Mon, 28 Aug 2023 23:45:48 +0530 Subject: [PATCH 04/32] review nits --- keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 7868684eb0..023e9881dd 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -64,8 +64,7 @@ class FasterRCNN(Task): Args: backbone: `keras.Model`. Must implement the `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" - and layer names as values. If `None`, defaults to - `keras_cv.models.ResNet50Backbone()`. + and layer names as values. num_classes: the number of classes in your dataset excluding the background class. classes should be represented by integers in the range [0, num_classes). From da5a01e7ae2dd41e79d98bb61a0c8adc58bdca93 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Fri, 1 Sep 2023 12:10:20 +0530 Subject: [PATCH 05/32] chore: adding test --- .../faster_rcnn/faster_rcnn_test.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py index e69de29bb2..2ca334acb9 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py @@ -0,0 +1,104 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras_cv.backend import keras +from keras_cv.models import ResNet18V2Backbone +from keras_cv.models.object_detection.__test_utils__ import ( + _create_bounding_box_dataset, +) +from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import FasterRCNN +from keras_cv.tests.test_case import TestCase + + +class FasterRCNNTest(TestCase): + # TODO(ianstenbit): Make FasterRCNN support shapes that are not multiples + # of 128, perhaps by adding a flag to the anchor generator for whether to + # include anchors centered outside of the image. (RetinaNet does use those, + # while FasterRCNN doesn't). For more context on why this is the case, see + # https://github.com/keras-team/keras-cv/pull/1882 + @parameterized.parameters( + ((2, 640, 384, 3),), + ((2, 512, 512, 3),), + ((2, 128, 128, 3),), + ) + def test_faster_rcnn_infer(self, batch_shape): + model = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=ResNet18V2Backbone(), + ) + images = tf.random.normal(batch_shape) + outputs = model(images, training=False) + # 1000 proposals in inference + self.assertAllEqual([2, 1000, 81], outputs[1].shape) + self.assertAllEqual([2, 1000, 4], outputs[0].shape) + + @parameterized.parameters( + ((2, 640, 384, 3),), + ((2, 512, 512, 3),), + ((2, 128, 128, 3),), + ) + def test_faster_rcnn_train(self, batch_shape): + model = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=ResNet18V2Backbone(), + ) + images = tf.random.normal(batch_shape) + outputs = model(images, training=True) + self.assertAllEqual([2, 1000, 81], outputs[1].shape) + self.assertAllEqual([2, 1000, 4], outputs[0].shape) + + def test_invalid_compile(self): + model = FasterRCNN( + num_classes=80, + bounding_box_format="yxyx", + backbone=ResNet18V2Backbone(), + ) + with self.assertRaisesRegex(ValueError, "only accepts"): + model.compile(rpn_box_loss="binary_crossentropy") + with self.assertRaisesRegex(ValueError, "only accepts"): + model.compile( + rpn_classification_loss=keras.losses.BinaryCrossentropy( + from_logits=False + ) + ) + + @pytest.mark.large # Fit is slow, so mark these large. + def test_faster_rcnn_with_dictionary_input_format(self): + faster_rcnn = FasterRCNN( + num_classes=20, + bounding_box_format="xywh", + backbone=ResNet18V2Backbone(), + ) + + images, boxes = _create_bounding_box_dataset("xywh") + dataset = tf.data.Dataset.from_tensor_slices( + {"images": images, "bounding_boxes": boxes} + ).batch(5, drop_remainder=True) + + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + faster_rcnn.fit(dataset, epochs=1) + faster_rcnn.evaluate(dataset) From ac005b81d6d3af73cd9672baca59ec84fc0dc97c Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Fri, 8 Sep 2023 01:42:00 +0530 Subject: [PATCH 06/32] chore: reformat compute loss --- .../object_detection/faster_rcnn/faster_rcnn.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 023e9881dd..3c44343eb9 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -318,7 +318,17 @@ def compile( } super().compile(loss=losses, **kwargs) - def compute_loss(self, images, boxes, classes, training): + def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): + images = x + boxes = y["boxes"] + if len(y["classes"].shape) != 2: + raise ValueError( + "Expected 'classes' to be a tf.Tensor of rank 2. " + f"Got y['classes'].shape={y['classes'].shape}." + ) + # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere + classes = ops.expand_dims(y["classes"], axis=-1) + local_batch = images.get_shape().as_list()[0] anchors = self.anchor_generator(image_shape=tuple(images[0].shape)) ( @@ -334,7 +344,7 @@ def compute_loss(self, images, boxes, classes, training): ) rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( - images, anchors, training=training + images, anchors, training=kwargs["training"] ) rois = ops.stop_gradient(rois) ( @@ -347,7 +357,7 @@ def compute_loss(self, images, boxes, classes, training): box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 cls_weights /= self.roi_sampler.num_sampled_rois * local_batch box_pred, cls_pred = self._call_rcnn( - rois, feature_map, training=training + rois, feature_map, training=kwargs["training"] ) y_true = { "rpn_box": rpn_box_targets, From 613e29f07dee92f58ecdf0d2f351a52df9cbead7 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Fri, 15 Sep 2023 19:36:29 +0530 Subject: [PATCH 07/32] chore: faster rcnn call and predict work --- .../layers/object_detection/roi_generator.py | 4 +- .../object_detection/faster_rcnn/__init__.py | 6 +- .../faster_rcnn/faster_rcnn.py | 32 +++---- .../faster_rcnn/faster_rcnn_presets.py | 0 .../faster_rcnn/faster_rcnn_test.py | 84 ++++++++++++++++++- .../object_detection/faster_rcnn/rpn_head.py | 8 ++ 6 files changed, 113 insertions(+), 21 deletions(-) create mode 100644 keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py diff --git a/keras_cv/layers/object_detection/roi_generator.py b/keras_cv/layers/object_detection/roi_generator.py index db779f8d3f..a4ad53e4b1 100644 --- a/keras_cv/layers/object_detection/roi_generator.py +++ b/keras_cv/layers/object_detection/roi_generator.py @@ -23,7 +23,7 @@ from keras_cv import bounding_box from keras_cv.api_export import keras_cv_export from keras_cv.backend import assert_tf_keras - +from keras_cv.backend import ops @keras_cv_export("keras_cv.layers.ROIGenerator") class ROIGenerator(keras.layers.Layer): @@ -148,7 +148,7 @@ def per_level_gen(boxes, scores): # scores can also be [batch_size, num_boxes, 1] if len(scores_shape) == 3: scores = tf.squeeze(scores, axis=-1) - _, num_boxes = scores.get_shape().as_list() + num_boxes = ops.shape(boxes)[1] level_pre_nms_topk = min(num_boxes, pre_nms_topk) level_post_nms_topk = min(num_boxes, post_nms_topk) scores, sorted_indices = tf.nn.top_k( diff --git a/keras_cv/models/object_detection/faster_rcnn/__init__.py b/keras_cv/models/object_detection/faster_rcnn/__init__.py index 02cac23b6e..73c9c2ba56 100644 --- a/keras_cv/models/object_detection/faster_rcnn/__init__.py +++ b/keras_cv/models/object_detection/faster_rcnn/__init__.py @@ -11,6 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import ( # noqa: E501 - FasterRCNN, -) +from keras_cv.models.object_detection.faster_rcnn.feature_pyramid import FeaturePyramid +from keras_cv.models.object_detection.faster_rcnn.rcnn_head import RCNNHead +from keras_cv.models.object_detection.faster_rcnn.rpn_head import RPNHead diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 3c44343eb9..977c241826 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -24,8 +24,10 @@ from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.bounding_box.converters import _decode_deltas_to_boxes - -# All the imports from legacy +# from keras_cv.models.backbones.backbone_presets import backbone_presets +# from keras_cv.models.backbones.backbone_presets import ( +# backbone_presets_with_weights +# ) from keras_cv.bounding_box.utils import _clip_boxes from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator from keras_cv.layers.object_detection.box_matcher import BoxMatcher @@ -38,11 +40,10 @@ backbone_presets_with_weights, ) from keras_cv.models.object_detection.__internal__ import unpack_input -from keras_cv.models.object_detection.faster_rcnn.feature_pyramid import ( - FeaturePyramid, -) -from keras_cv.models.object_detection.faster_rcnn.rcnn_head import RCNNHead -from keras_cv.models.object_detection.faster_rcnn.rpn_head import RPNHead +from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid +from keras_cv.models.object_detection.faster_rcnn import RCNNHead +from keras_cv.models.object_detection.faster_rcnn import RPNHead +# from keras_cv.models.object_detection.faster_rcnn.faster_rcnn_presets import faster_rcnn_presets from keras_cv.models.task import Task from keras_cv.utils.python_utils import classproperty from keras_cv.utils.train import get_feature_extractor @@ -56,10 +57,10 @@ class FasterRCNN(Task): """A Keras model implementing the FasterRCNN architecture. Implements the FasterRCNN architecture for object detection. The constructor - requires `num_classes`, `bounding_box_format` and a `backbone`. + requires `backbone`, `num_classes`, and a `bounding_box_format`. References: - - [FasterRCNN](https://arxiv.org/pdf/1506.01497.pdf) + - [FasterRCNN](https://arxiv.org/abs/1506.01497) Args: backbone: `keras.Model`. Must implement the @@ -260,8 +261,11 @@ def call(self, images, training=None): box_format=self.bounding_box_format, variance=[0.1, 0.1, 0.2, 0.2], ) - - return box_pred, cls_pred + outputs = { + "boxes": box_pred, + "classes": cls_pred, + } + return outputs # TODO(tanzhenyu): Support compile with metrics. def compile( @@ -344,7 +348,7 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): ) rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( - images, anchors, training=kwargs["training"] + images, anchors, ) rois = ops.stop_gradient(rois) ( @@ -357,7 +361,7 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 cls_weights /= self.roi_sampler.num_sampled_rois * local_batch box_pred, cls_pred = self._call_rcnn( - rois, feature_map, training=kwargs["training"] + rois, feature_map, ) y_true = { "rpn_box": rpn_box_targets, @@ -411,7 +415,7 @@ def prediction_decoder(self, prediction_decoder): def decode_predictions(self, predictions, images): # no-op if default decoder is used. - box_pred, scores_pred = predictions + box_pred, scores_pred = predictions["boxes"], predictions["classes"] box_pred = bounding_box.convert_format( box_pred, source=self.bounding_box_format, diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py index 2ca334acb9..a6a57c713f 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + +import numpy as np import pytest import tensorflow as tf from absl.testing import parameterized +import keras_cv from keras_cv.backend import keras -from keras_cv.models import ResNet18V2Backbone +from keras_cv.backend import ops +# from keras_cv.models.backbones.test_backbone_presets import ( +# test_backbone_presets, +# ) from keras_cv.models.object_detection.__test_utils__ import ( _create_bounding_box_dataset, ) @@ -26,6 +33,79 @@ class FasterRCNNTest(TestCase): + def test_faster_rcnn_construction(self): + faster_rcnn = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone() + ) + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + @pytest.mark.large # Fit is slow, so mark these large. + def test_faster_rcnn_call(self): + faster_rcnn = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone(), + ) + images = np.random.uniform(size=(2, 512, 512, 3)) + _ = faster_rcnn(images) + _ = faster_rcnn.predict(images) + + def test_wrong_logits(self): + faster_rcnn = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone(), + ) + + with self.assertRaisesRegex( + ValueError, + "from_logits", + ): + faster_rcnn.compile( + optimizer=keras.optimizers.SGD(learning_rate=0.25), + box_loss=keras_cv.losses.SmoothL1Loss( + l1_cutoff=1.0, reduction="none" + ), + classification_loss=keras_cv.losses.FocalLoss( + from_logits=False, reduction="none" + ), + rpn_box_loss=keras_cv.losses.SmoothL1Loss( + l1_cutoff=1.0, reduction="none" + ), + rpn_classification_loss=keras_cv.losses.FocalLoss( + from_logits=False, reduction="none" + ), + ) + + def test_weights_contained_in_trainable_variables(self): + bounding_box_format = "xyxy" + faster_rcnn = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format=bounding_box_format, + backbone=keras_cv.models.ResNet18V2Backbone(), + ) + faster_rcnn.backbone.trainable = False + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + xs, ys = _create_bounding_box_dataset(bounding_box_format) + + # call once + _ = faster_rcnn(xs) + self.assertEqual(len(faster_rcnn.trainable_variables), 32) + # TODO(ianstenbit): Make FasterRCNN support shapes that are not multiples # of 128, perhaps by adding a flag to the anchor generator for whether to # include anchors centered outside of the image. (RetinaNet does use those, @@ -79,7 +159,7 @@ def test_invalid_compile(self): ) ) - @pytest.mark.large # Fit is slow, so mark these large. + # @pytest.mark.large # Fit is slow, so mark these large. def test_faster_rcnn_with_dictionary_input_format(self): faster_rcnn = FasterRCNN( num_classes=20, diff --git a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py index fbd3bd8dc3..1bc2394a0a 100644 --- a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py +++ b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py @@ -24,6 +24,14 @@ package="keras_cv.models.faster_rcnn", ) class RPNHead(keras.layers.Layer): + """ A Keras layer implementing the RPN architecture. + + Region Proposal Networks (RPN) was first suggested in [FasterRCNN](https://arxiv.org/abs/1506.01497). + This is an end to end trainable layer which proposes regions for a detector (RCNN). + + Args: + num_achors_per_location: The number of anchors per location. + """ def __init__( self, num_anchors_per_location=3, From 5bf2bc9a9ef7c2bb050123ada585236f13917b69 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Sat, 16 Sep 2023 14:07:14 +0530 Subject: [PATCH 08/32] chore: porting roi align to keras core --- keras_cv/layers/object_detection/roi_align.py | 160 +++++++++--------- 1 file changed, 81 insertions(+), 79 deletions(-) diff --git a/keras_cv/layers/object_detection/roi_align.py b/keras_cv/layers/object_detection/roi_align.py index 2c45060147..eb40aabcbc 100644 --- a/keras_cv/layers/object_detection/roi_align.py +++ b/keras_cv/layers/object_detection/roi_align.py @@ -22,6 +22,8 @@ from keras_cv import bounding_box from keras_cv.backend import assert_tf_keras +from keras_cv.backend import ops +from keras_cv.backend import keras def _feature_bilinear_interpolation( @@ -49,7 +51,7 @@ def _feature_bilinear_interpolation( A 5-D tensor representing feature crop of shape [batch_size, num_boxes, output_size, output_size, num_filters]. """ - features_shape = tf.shape(features) + features_shape = ops.shape(features) batch_size, num_boxes, output_size, num_filters = ( features_shape[0], features_shape[1], @@ -58,22 +60,22 @@ def _feature_bilinear_interpolation( ) output_size = output_size // 2 - kernel_y = tf.reshape(kernel_y, [batch_size, num_boxes, output_size * 2, 1]) - kernel_x = tf.reshape(kernel_x, [batch_size, num_boxes, 1, output_size * 2]) + kernel_y = ops.reshape(kernel_y, [batch_size, num_boxes, output_size * 2, 1]) + kernel_x = ops.reshape(kernel_x, [batch_size, num_boxes, 1, output_size * 2]) # Use implicit broadcast to generate the interpolation kernel. The # multiplier `4` is for avg pooling. interpolation_kernel = kernel_y * kernel_x * 4 # Interpolate the gathered features with computed interpolation kernels. - features *= tf.cast( - tf.expand_dims(interpolation_kernel, axis=-1), dtype=features.dtype + features *= ops.cast( + ops.expand_dims(interpolation_kernel, axis=-1), dtype=features.dtype ) - features = tf.reshape( + features = ops.reshape( features, [batch_size * num_boxes, output_size * 2, output_size * 2, num_filters], ) - features = tf.nn.avg_pool(features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID") - features = tf.reshape( + features = ops.nn.average_pool(features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID") + features = ops.reshape( features, [batch_size, num_boxes, output_size, output_size, num_filters] ) return features @@ -108,10 +110,10 @@ def _compute_grid_positions( box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2] box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2] """ - boxes_shape = tf.shape(boxes) + boxes_shape = ops.shape(boxes) batch_size, num_boxes = boxes_shape[0], boxes_shape[1] if batch_size is None: - batch_size = tf.shape(boxes)[0] + batch_size = ops.shape(boxes)[0] box_grid_x = [] box_grid_y = [] for i in range(output_size): @@ -121,29 +123,29 @@ def _compute_grid_positions( box_grid_y.append( boxes[:, :, 0] + (i + sample_offset) * boxes[:, :, 2] / output_size ) - box_grid_x = tf.stack(box_grid_x, axis=2) - box_grid_y = tf.stack(box_grid_y, axis=2) + box_grid_x = ops.stack(box_grid_x, axis=2) + box_grid_y = ops.stack(box_grid_y, axis=2) - box_grid_y0 = tf.floor(box_grid_y) - box_grid_x0 = tf.floor(box_grid_x) - box_grid_x0 = tf.maximum(tf.cast(0.0, dtype=box_grid_x0.dtype), box_grid_x0) - box_grid_y0 = tf.maximum(tf.cast(0.0, dtype=box_grid_y0.dtype), box_grid_y0) + box_grid_y0 = ops.floor(box_grid_y) + box_grid_x0 = ops.floor(box_grid_x) + box_grid_x0 = ops.maximum(ops.cast(0.0, dtype=box_grid_x0.dtype), box_grid_x0) + box_grid_y0 = ops.maximum(ops.cast(0.0, dtype=box_grid_y0.dtype), box_grid_y0) - box_grid_x0 = tf.minimum( - box_grid_x0, tf.expand_dims(boundaries[:, :, 1], -1) + box_grid_x0 = ops.minimum( + box_grid_x0, ops.expand_dims(boundaries[:, :, 1], -1) ) - box_grid_x1 = tf.minimum( - box_grid_x0 + 1, tf.expand_dims(boundaries[:, :, 1], -1) + box_grid_x1 = ops.minimum( + box_grid_x0 + 1, ops.expand_dims(boundaries[:, :, 1], -1) ) - box_grid_y0 = tf.minimum( - box_grid_y0, tf.expand_dims(boundaries[:, :, 0], -1) + box_grid_y0 = ops.minimum( + box_grid_y0, ops.expand_dims(boundaries[:, :, 0], -1) ) - box_grid_y1 = tf.minimum( - box_grid_y0 + 1, tf.expand_dims(boundaries[:, :, 0], -1) + box_grid_y1 = ops.minimum( + box_grid_y0 + 1, ops.expand_dims(boundaries[:, :, 0], -1) ) - box_gridx0x1 = tf.stack([box_grid_x0, box_grid_x1], axis=-1) - box_gridy0y1 = tf.stack([box_grid_y0, box_grid_y1], axis=-1) + box_gridx0x1 = ops.stack([box_grid_x0, box_grid_x1], axis=-1) + box_gridy0y1 = ops.stack([box_grid_y0, box_grid_y1], axis=-1) # The RoIAlign feature f can be computed by bilinear interpolation of four # neighboring feature points f0, f1, f2, and f3. @@ -155,11 +157,11 @@ def _compute_grid_positions( lx = box_grid_x - box_grid_x0 hy = 1.0 - ly hx = 1.0 - lx - kernel_y = tf.reshape( - tf.stack([hy, ly], axis=3), [batch_size, num_boxes, output_size, 2, 1] + kernel_y = ops.reshape( + ops.stack([hy, ly], axis=3), [batch_size, num_boxes, output_size, 2, 1] ) - kernel_x = tf.reshape( - tf.stack([hx, lx], axis=3), [batch_size, num_boxes, output_size, 2, 1] + kernel_x = ops.reshape( + ops.stack([hx, lx], axis=3), [batch_size, num_boxes, output_size, 2, 1] ) return kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 @@ -194,7 +196,7 @@ def multilevel_crop_and_resize( [batch_size, num_boxes, output_size, output_size, num_filters]. """ - with tf.name_scope("multilevel_crop_and_resize"): + with keras.backend.name_scope("multilevel_crop_and_resize"): levels_str = list(features.keys()) # Levels are represented by strings with a prefix "P" to represent # pyramid levels. The integer level can be obtained by looking at @@ -202,7 +204,7 @@ def multilevel_crop_and_resize( levels = [int(level_str[1:]) for level_str in levels_str] min_level = min(levels) max_level = max(levels) - features_shape = tf.shape(features[f"P{min_level}"]) + features_shape = ops.shape(features[f"P{min_level}"]) batch_size, max_feature_height, max_feature_width, num_filters = ( features_shape[0], features_shape[1], @@ -210,7 +212,7 @@ def multilevel_crop_and_resize( features_shape[3], ) - num_boxes = tf.shape(boxes)[1] + num_boxes = ops.shape(boxes)[1] # Stack feature pyramid into a features_all of shape # [batch_size, levels, height, width, num_filters]. @@ -218,15 +220,15 @@ def multilevel_crop_and_resize( feature_heights = [] feature_widths = [] for level in range(min_level, max_level + 1): - shape = features[f"P{level}"].get_shape().as_list() + shape = ops.shape(features[f"P{level}"]) feature_heights.append(shape[1]) feature_widths.append(shape[2]) # Concat tensor of [batch_size, height_l * width_l, num_filters] for # each level. features_all.append( - tf.reshape(features[f"P{level}"], [batch_size, -1, num_filters]) + ops.reshape(features[f"P{level}"], [batch_size, -1, num_filters]) ) - features_r2 = tf.reshape(tf.concat(features_all, 1), [-1, num_filters]) + features_r2 = ops.reshape(ops.concatenate(features_all, 1), [-1, num_filters]) # Calculate height_l * width_l for each level. level_dim_sizes = [ @@ -238,59 +240,59 @@ def multilevel_crop_and_resize( for i in range(len(feature_widths) - 1): level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i]) batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1] - level_dim_offsets = tf.constant(level_dim_offsets, tf.int32) - height_dim_sizes = tf.constant(feature_widths, tf.int32) + level_dim_offsets = keras.backend.constant(level_dim_offsets, "int32") + height_dim_sizes = keras.backend.constant(feature_widths, "int32") # Assigns boxes to the right level. box_width = boxes[:, :, 3] - boxes[:, :, 1] box_height = boxes[:, :, 2] - boxes[:, :, 0] - areas_sqrt = tf.sqrt( - tf.cast(box_height, tf.float32) * tf.cast(box_width, tf.float32) + areas_sqrt = ops.sqrt( + ops.cast(box_height, "float32") * ops.cast(box_width, "float32") ) # following the FPN paper to divide by 224. - levels = tf.cast( - tf.math.floordiv( - tf.math.log(tf.math.divide_no_nan(areas_sqrt, 224.0)), - tf.math.log(2.0), + levels = ops.cast( + ops.numpy.floor_divide( + ops.numpy.log(ops.numpy.divide(areas_sqrt, 224.0)), # tf.math.divide_no_nan + ops.numpy.log(2.0), ) + 4.0, - dtype=tf.int32, + dtype="int32", ) # Maps levels between [min_level, max_level]. - levels = tf.minimum(max_level, tf.maximum(levels, min_level)) + levels = ops.minimum(max_level, ops.maximum(levels, min_level)) # Projects box location and sizes to corresponding feature levels. - scale_to_level = tf.cast( - tf.pow(tf.constant(2.0), tf.cast(levels, tf.float32)), + scale_to_level = ops.cast( + ops.numpy.power(keras.backend.constant(2.0), ops.cast(levels, "float32")), dtype=boxes.dtype, ) - boxes /= tf.expand_dims(scale_to_level, axis=2) + boxes /= ops.expand_dims(scale_to_level, axis=2) box_width /= scale_to_level box_height /= scale_to_level - boxes = tf.concat( + boxes = ops.concatenate( [ boxes[:, :, 0:2], - tf.expand_dims(box_height, -1), - tf.expand_dims(box_width, -1), + ops.expand_dims(box_height, -1), + ops.expand_dims(box_width, -1), ], axis=-1, ) # Maps levels to [0, max_level-min_level]. levels -= min_level - level_strides = tf.pow([[2.0]], tf.cast(levels, tf.float32)) - boundary = tf.cast( - tf.concat( + level_strides = ops.numpy.power([[2.0]], ops.cast(levels, "float32")) + boundary = ops.cast( + ops.concatenate( [ - tf.expand_dims( - [[tf.cast(max_feature_height, tf.float32)]] + ops.expand_dims( + [[ops.cast(max_feature_height, "float32")]] / level_strides - 1, axis=-1, ), - tf.expand_dims( - [[tf.cast(max_feature_width, tf.float32)]] + ops.expand_dims( + [[ops.cast(max_feature_width, "float32")]] / level_strides - 1, axis=-1, @@ -309,42 +311,42 @@ def multilevel_crop_and_resize( box_gridx0x1, ) = _compute_grid_positions(boxes, boundary, output_size, sample_offset) - x_indices = tf.cast( - tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]), - dtype=tf.int32, + x_indices = ops.cast( + ops.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]), + dtype="int32", ) - y_indices = tf.cast( - tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]), - dtype=tf.int32, + y_indices = ops.cast( + ops.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]), + dtype="int32", ) - batch_size_offset = tf.tile( - tf.reshape( - tf.range(batch_size) * batch_dim_size, [batch_size, 1, 1, 1] + batch_size_offset = ops.tile( + ops.reshape( + ops.arange(batch_size) * batch_dim_size, [batch_size, 1, 1, 1] ), [1, num_boxes, output_size * 2, output_size * 2], ) # Get level offset for each box. Each box belongs to one level. - levels_offset = tf.tile( - tf.reshape( - tf.gather(level_dim_offsets, levels), + levels_offset = ops.tile( + ops.reshape( + keras.backend.gather(level_dim_offsets, levels), [batch_size, num_boxes, 1, 1], ), [1, 1, output_size * 2, output_size * 2], ) - y_indices_offset = tf.tile( - tf.reshape( + y_indices_offset = ops.tile( + ops.reshape( y_indices - * tf.expand_dims(tf.gather(height_dim_sizes, levels), -1), + * ops.expand_dims(keras.backend.gather(height_dim_sizes, levels), -1), [batch_size, num_boxes, output_size * 2, 1], ), [1, 1, 1, output_size * 2], ) - x_indices_offset = tf.tile( - tf.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]), + x_indices_offset = ops.tile( + ops.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]), [1, 1, output_size * 2, 1], ) - indices = tf.reshape( + indices = ops.reshape( batch_size_offset + levels_offset + y_indices_offset @@ -354,8 +356,8 @@ def multilevel_crop_and_resize( # TODO(tanzhenyu): replace tf.gather with tf.gather_nd and try to get # similar performance. - features_per_box = tf.reshape( - tf.gather(features_r2, indices), + features_per_box = ops.reshape( + keras.backend.gather(features_r2, indices), [ batch_size, num_boxes, From 7d6ef6f1b69b3c6360eca75bea571034a077c8b3 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Sat, 16 Sep 2023 15:04:34 +0530 Subject: [PATCH 09/32] chore: port roi sampler to keras core --- keras_cv/layers/object_detection/roi_align.py | 2 +- .../layers/object_detection/roi_sampler.py | 40 ++++++++++--------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/keras_cv/layers/object_detection/roi_align.py b/keras_cv/layers/object_detection/roi_align.py index eb40aabcbc..3b833cb10d 100644 --- a/keras_cv/layers/object_detection/roi_align.py +++ b/keras_cv/layers/object_detection/roi_align.py @@ -17,8 +17,8 @@ from typing import Optional from typing import Tuple +# TODO (ariG23498): remove tf and correct the type imports import tensorflow as tf -from tensorflow import keras from keras_cv import bounding_box from keras_cv.backend import assert_tf_keras diff --git a/keras_cv/layers/object_detection/roi_sampler.py b/keras_cv/layers/object_detection/roi_sampler.py index fe63e31ba9..b4680be607 100644 --- a/keras_cv/layers/object_detection/roi_sampler.py +++ b/keras_cv/layers/object_detection/roi_sampler.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO (ariG23498): remove tf and correct the type imports import tensorflow as tf -from tensorflow import keras + +import numpy as np # used for newaxis from keras_cv import bounding_box from keras_cv.backend import assert_tf_keras @@ -21,6 +23,8 @@ from keras_cv.layers.object_detection import box_matcher from keras_cv.layers.object_detection import sampling from keras_cv.utils import target_gather +from keras_cv.backend import ops +from keras_cv.backend import keras @keras.utils.register_keras_serializable(package="keras_cv") @@ -102,8 +106,8 @@ def call( """ if self.append_gt_boxes: # num_rois += num_gt - rois = tf.concat([rois, gt_boxes], axis=1) - num_rois = rois.get_shape().as_list()[1] + rois = ops.concatenate([rois, gt_boxes], axis=1) + num_rois = ops.shape(rois)[1] if num_rois is None: raise ValueError( f"`rois` must have static shape, got {rois.get_shape()}" @@ -126,27 +130,27 @@ def call( # [batch_size, num_rois] | [batch_size, num_rois] matched_gt_cols, matched_vals = self.roi_matcher(similarity_mat) # [batch_size, num_rois] - positive_matches = tf.math.equal(matched_vals, 1) - negative_matches = tf.math.equal(matched_vals, -1) + positive_matches = ops.equal(matched_vals, 1) + negative_matches = ops.equal(matched_vals, -1) self._positives.update_state( - tf.reduce_sum(tf.cast(positive_matches, tf.float32), axis=-1) + ops.sum(ops.cast(positive_matches, "float32"), axis=-1) ) self._negatives.update_state( - tf.reduce_sum(tf.cast(negative_matches, tf.float32), axis=-1) + ops.sum(ops.cast(negative_matches, "float32"), axis=-1) ) # [batch_size, num_rois, 1] - background_mask = tf.expand_dims( - tf.logical_not(positive_matches), axis=-1 + background_mask = ops.expand_dims( + ops.logical_not(positive_matches), axis=-1 ) # [batch_size, num_rois, 1] matched_gt_classes = target_gather._target_gather( gt_classes, matched_gt_cols ) # also set all background matches to `background_class` - matched_gt_classes = tf.where( + matched_gt_classes = ops.where( background_mask, - tf.cast( - self.background_class * tf.ones_like(matched_gt_classes), + ops.cast( + self.background_class * ops.ones_like(matched_gt_classes), gt_classes.dtype, ), matched_gt_classes, @@ -163,9 +167,9 @@ def call( variance=[0.1, 0.1, 0.2, 0.2], ) # also set all background matches to 0 coordinates - encoded_matched_gt_boxes = tf.where( + encoded_matched_gt_boxes = ops.where( background_mask, - tf.zeros_like(matched_gt_boxes), + ops.zeros_like(matched_gt_boxes), encoded_matched_gt_boxes, ) # [batch_size, num_rois] @@ -176,7 +180,7 @@ def call( self.positive_fraction, ) # [batch_size, num_sampled_rois] in the range of [0, num_rois) - sampled_indicators, sampled_indices = tf.math.top_k( + sampled_indicators, sampled_indices = ops.math.top_k( sampled_indicators, k=self.num_sampled_rois, sorted=True ) # [batch_size, num_sampled_rois, 4] @@ -192,12 +196,12 @@ def call( # [batch_size, num_sampled_rois, 1] # all negative samples will be ignored in regression sampled_box_weights = target_gather._target_gather( - tf.cast(positive_matches[..., tf.newaxis], gt_boxes.dtype), + ops.cast(positive_matches[..., np.newaxis], gt_boxes.dtype), sampled_indices, ) # [batch_size, num_sampled_rois, 1] - sampled_indicators = sampled_indicators[..., tf.newaxis] - sampled_class_weights = tf.cast(sampled_indicators, gt_classes.dtype) + sampled_indicators = sampled_indicators[..., np.newaxis] + sampled_class_weights = ops.cast(sampled_indicators, gt_classes.dtype) return ( sampled_rois, sampled_gt_boxes, From f1e3e1720aba65326f1c80cb4f115f57cc8e947f Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Sat, 16 Sep 2023 15:14:37 +0530 Subject: [PATCH 10/32] chore: port rpn label encoder to keras core --- .../object_detection/rpn_label_encoder.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/keras_cv/layers/object_detection/rpn_label_encoder.py b/keras_cv/layers/object_detection/rpn_label_encoder.py index 5cd9d88415..252c3d80c8 100644 --- a/keras_cv/layers/object_detection/rpn_label_encoder.py +++ b/keras_cv/layers/object_detection/rpn_label_encoder.py @@ -13,9 +13,10 @@ # limitations under the License. from typing import Mapping - +import tree +import numpy as np # Used for newaxis +# TODO (ariG23498): remove tf and correct the type imports import tensorflow as tf -from tensorflow import keras from keras_cv import bounding_box from keras_cv.backend import assert_tf_keras @@ -23,6 +24,8 @@ from keras_cv.layers.object_detection import box_matcher from keras_cv.layers.object_detection import sampling from keras_cv.utils import target_gather +from keras_cv.backend import ops +from keras_cv.backend import keras @keras.utils.register_keras_serializable(package="keras_cv") @@ -112,7 +115,7 @@ def call( anchors = anchors_dict if isinstance(anchors, dict): pack = True - anchors = tf.concat(tf.nest.flatten(anchors), axis=0) + anchors = ops.concatenate(tree.flatten(anchors), axis=0) anchors = bounding_box.convert_format( anchors, source=self.anchor_format, target="yxyx" ) @@ -126,14 +129,14 @@ def call( # [num_anchors] or [batch_size, num_anchors] matched_gt_indices, matched_vals = self.box_matcher(similarity_mat) # [num_anchors] or [batch_size, num_anchors] - positive_matches = tf.math.equal(matched_vals, 1) + positive_matches = ops.equal(matched_vals, 1) # currently SyncOnReadVariable does not support `assign_add` in # cross-replica. # self._positives.update_state( # tf.reduce_sum(tf.cast(positive_matches, tf.float32), axis=-1) # ) - negative_matches = tf.math.equal(matched_vals, -1) + negative_matches = ops.equal(matched_vals, -1) # [num_anchors, 4] or [batch_size, num_anchors, 4] matched_gt_boxes = target_gather._target_gather( gt_boxes, matched_gt_indices @@ -148,18 +151,18 @@ def call( variance=self.box_variance, ) # [num_anchors, 1] or [batch_size, num_anchors, 1] - box_sample_weights = tf.cast( - positive_matches[..., tf.newaxis], gt_boxes.dtype + box_sample_weights = ops.cast( + positive_matches[..., np.newaxis], gt_boxes.dtype ) # [num_anchors, 1] or [batch_size, num_anchors, 1] - positive_mask = tf.expand_dims(positive_matches, axis=-1) + positive_mask = ops.expand_dims(positive_matches, axis=-1) # set all negative and ignored matches to 0, and all positive matches to # 1 [num_anchors, 1] or [batch_size, num_anchors, 1] - positive_classes = tf.ones_like(positive_mask, dtype=gt_classes.dtype) - negative_classes = tf.zeros_like(positive_mask, dtype=gt_classes.dtype) + positive_classes = ops.ones_like(positive_mask, dtype=gt_classes.dtype) + negative_classes = ops.zeros_like(positive_mask, dtype=gt_classes.dtype) # [num_anchors, 1] or [batch_size, num_anchors, 1] - class_targets = tf.where( + class_targets = ops.where( positive_mask, positive_classes, negative_classes ) # [num_anchors] or [batch_size, num_anchors] @@ -170,8 +173,8 @@ def call( self.positive_fraction, ) # [num_anchors, 1] or [batch_size, num_anchors, 1] - class_sample_weights = tf.cast( - sampled_indicators[..., tf.newaxis], gt_classes.dtype + class_sample_weights = ops.cast( + sampled_indicators[..., np.newaxis], gt_classes.dtype ) if pack: encoded_box_targets = self.unpack_targets( From 6478cbf1704edec642f7cf84b6dfe049eb06bae3 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Sat, 16 Sep 2023 16:43:17 +0530 Subject: [PATCH 11/32] chore: adding tests and fix lint --- keras_cv/layers/object_detection/roi_align.py | 42 ++++-- .../layers/object_detection/roi_generator.py | 1 + .../layers/object_detection/roi_sampler.py | 8 +- .../object_detection/rpn_label_encoder.py | 10 +- .../object_detection/faster_rcnn/__init__.py | 4 +- .../faster_rcnn/faster_rcnn.py | 65 +++++++-- .../faster_rcnn/faster_rcnn_test.py | 125 ++++++++++++++++-- .../object_detection/faster_rcnn/rpn_head.py | 5 +- 8 files changed, 216 insertions(+), 44 deletions(-) diff --git a/keras_cv/layers/object_detection/roi_align.py b/keras_cv/layers/object_detection/roi_align.py index 3b833cb10d..6f0290dc2f 100644 --- a/keras_cv/layers/object_detection/roi_align.py +++ b/keras_cv/layers/object_detection/roi_align.py @@ -22,8 +22,8 @@ from keras_cv import bounding_box from keras_cv.backend import assert_tf_keras -from keras_cv.backend import ops from keras_cv.backend import keras +from keras_cv.backend import ops def _feature_bilinear_interpolation( @@ -60,8 +60,12 @@ def _feature_bilinear_interpolation( ) output_size = output_size // 2 - kernel_y = ops.reshape(kernel_y, [batch_size, num_boxes, output_size * 2, 1]) - kernel_x = ops.reshape(kernel_x, [batch_size, num_boxes, 1, output_size * 2]) + kernel_y = ops.reshape( + kernel_y, [batch_size, num_boxes, output_size * 2, 1] + ) + kernel_x = ops.reshape( + kernel_x, [batch_size, num_boxes, 1, output_size * 2] + ) # Use implicit broadcast to generate the interpolation kernel. The # multiplier `4` is for avg pooling. interpolation_kernel = kernel_y * kernel_x * 4 @@ -74,7 +78,9 @@ def _feature_bilinear_interpolation( features, [batch_size * num_boxes, output_size * 2, output_size * 2, num_filters], ) - features = ops.nn.average_pool(features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID") + features = ops.nn.average_pool( + features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID" + ) features = ops.reshape( features, [batch_size, num_boxes, output_size, output_size, num_filters] ) @@ -128,8 +134,12 @@ def _compute_grid_positions( box_grid_y0 = ops.floor(box_grid_y) box_grid_x0 = ops.floor(box_grid_x) - box_grid_x0 = ops.maximum(ops.cast(0.0, dtype=box_grid_x0.dtype), box_grid_x0) - box_grid_y0 = ops.maximum(ops.cast(0.0, dtype=box_grid_y0.dtype), box_grid_y0) + box_grid_x0 = ops.maximum( + ops.cast(0.0, dtype=box_grid_x0.dtype), box_grid_x0 + ) + box_grid_y0 = ops.maximum( + ops.cast(0.0, dtype=box_grid_y0.dtype), box_grid_y0 + ) box_grid_x0 = ops.minimum( box_grid_x0, ops.expand_dims(boundaries[:, :, 1], -1) @@ -226,9 +236,13 @@ def multilevel_crop_and_resize( # Concat tensor of [batch_size, height_l * width_l, num_filters] for # each level. features_all.append( - ops.reshape(features[f"P{level}"], [batch_size, -1, num_filters]) + ops.reshape( + features[f"P{level}"], [batch_size, -1, num_filters] + ) ) - features_r2 = ops.reshape(ops.concatenate(features_all, 1), [-1, num_filters]) + features_r2 = ops.reshape( + ops.concatenate(features_all, 1), [-1, num_filters] + ) # Calculate height_l * width_l for each level. level_dim_sizes = [ @@ -253,7 +267,9 @@ def multilevel_crop_and_resize( # following the FPN paper to divide by 224. levels = ops.cast( ops.numpy.floor_divide( - ops.numpy.log(ops.numpy.divide(areas_sqrt, 224.0)), # tf.math.divide_no_nan + ops.numpy.log( + ops.numpy.divide(areas_sqrt, 224.0) + ), # tf.math.divide_no_nan ops.numpy.log(2.0), ) + 4.0, @@ -264,7 +280,9 @@ def multilevel_crop_and_resize( # Projects box location and sizes to corresponding feature levels. scale_to_level = ops.cast( - ops.numpy.power(keras.backend.constant(2.0), ops.cast(levels, "float32")), + ops.numpy.power( + keras.backend.constant(2.0), ops.cast(levels, "float32") + ), dtype=boxes.dtype, ) boxes /= ops.expand_dims(scale_to_level, axis=2) @@ -337,7 +355,9 @@ def multilevel_crop_and_resize( y_indices_offset = ops.tile( ops.reshape( y_indices - * ops.expand_dims(keras.backend.gather(height_dim_sizes, levels), -1), + * ops.expand_dims( + keras.backend.gather(height_dim_sizes, levels), -1 + ), [batch_size, num_boxes, output_size * 2, 1], ), [1, 1, 1, output_size * 2], diff --git a/keras_cv/layers/object_detection/roi_generator.py b/keras_cv/layers/object_detection/roi_generator.py index a4ad53e4b1..42b5accb05 100644 --- a/keras_cv/layers/object_detection/roi_generator.py +++ b/keras_cv/layers/object_detection/roi_generator.py @@ -25,6 +25,7 @@ from keras_cv.backend import assert_tf_keras from keras_cv.backend import ops + @keras_cv_export("keras_cv.layers.ROIGenerator") class ROIGenerator(keras.layers.Layer): """ diff --git a/keras_cv/layers/object_detection/roi_sampler.py b/keras_cv/layers/object_detection/roi_sampler.py index b4680be607..86b8cfb16a 100644 --- a/keras_cv/layers/object_detection/roi_sampler.py +++ b/keras_cv/layers/object_detection/roi_sampler.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np # used for newaxis + # TODO (ariG23498): remove tf and correct the type imports import tensorflow as tf -import numpy as np # used for newaxis - from keras_cv import bounding_box from keras_cv.backend import assert_tf_keras +from keras_cv.backend import keras +from keras_cv.backend import ops from keras_cv.bounding_box import iou from keras_cv.layers.object_detection import box_matcher from keras_cv.layers.object_detection import sampling from keras_cv.utils import target_gather -from keras_cv.backend import ops -from keras_cv.backend import keras @keras.utils.register_keras_serializable(package="keras_cv") diff --git a/keras_cv/layers/object_detection/rpn_label_encoder.py b/keras_cv/layers/object_detection/rpn_label_encoder.py index 252c3d80c8..8ac27900a2 100644 --- a/keras_cv/layers/object_detection/rpn_label_encoder.py +++ b/keras_cv/layers/object_detection/rpn_label_encoder.py @@ -13,19 +13,21 @@ # limitations under the License. from typing import Mapping -import tree -import numpy as np # Used for newaxis + +import numpy as np # Used for newaxis + # TODO (ariG23498): remove tf and correct the type imports import tensorflow as tf +import tree from keras_cv import bounding_box from keras_cv.backend import assert_tf_keras +from keras_cv.backend import keras +from keras_cv.backend import ops from keras_cv.bounding_box import iou from keras_cv.layers.object_detection import box_matcher from keras_cv.layers.object_detection import sampling from keras_cv.utils import target_gather -from keras_cv.backend import ops -from keras_cv.backend import keras @keras.utils.register_keras_serializable(package="keras_cv") diff --git a/keras_cv/models/object_detection/faster_rcnn/__init__.py b/keras_cv/models/object_detection/faster_rcnn/__init__.py index 73c9c2ba56..d5f9e37b30 100644 --- a/keras_cv/models/object_detection/faster_rcnn/__init__.py +++ b/keras_cv/models/object_detection/faster_rcnn/__init__.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from keras_cv.models.object_detection.faster_rcnn.feature_pyramid import FeaturePyramid +from keras_cv.models.object_detection.faster_rcnn.feature_pyramid import ( + FeaturePyramid, +) from keras_cv.models.object_detection.faster_rcnn.rcnn_head import RCNNHead from keras_cv.models.object_detection.faster_rcnn.rpn_head import RPNHead diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 20fc631275..c6dfc4bfd3 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -19,10 +19,12 @@ from keras_cv import bounding_box from keras_cv import layers as cv_layers +from keras_cv import models from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.bounding_box.converters import _decode_deltas_to_boxes + # from keras_cv.models.backbones.backbone_presets import backbone_presets # from keras_cv.models.backbones.backbone_presets import ( # backbone_presets_with_weights @@ -42,6 +44,7 @@ from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid from keras_cv.models.object_detection.faster_rcnn import RCNNHead from keras_cv.models.object_detection.faster_rcnn import RPNHead + # from keras_cv.models.object_detection.faster_rcnn.faster_rcnn_presets import faster_rcnn_presets from keras_cv.models.task import Task from keras_cv.utils.python_utils import classproperty @@ -140,8 +143,10 @@ def __init__( label_encoder=None, rcnn_head=None, prediction_decoder=None, + feature_pyramid=None, **kwargs, ): + self.num_classes = num_classes self.bounding_box_format = bounding_box_format super().__init__(**kwargs) scales = [2**x for x in [0]] @@ -187,7 +192,7 @@ def __init__( self.feature_extractor = get_feature_extractor( self.backbone, extractor_layer_names, extractor_levels ) - self.feature_pyramid = FeaturePyramid() + self.feature_pyramid = feature_pyramid or FeaturePyramid() self.rpn_labeler = label_encoder or _RpnLabelEncoder( anchor_format="yxyx", ground_truth_box_format="yxyx", @@ -347,7 +352,8 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): ) rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( - images, anchors, + images, + anchors, ) rois = ops.stop_gradient(rois) ( @@ -360,7 +366,8 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 cls_weights /= self.roi_sampler.num_sampled_rois * local_batch box_pred, cls_pred = self._call_rcnn( - rois, feature_map, + rois, + feature_map, ) y_true = { "rpn_box": rpn_box_targets, @@ -435,14 +442,54 @@ def get_config(self): return { "num_classes": self.num_classes, "bounding_box_format": self.bounding_box_format, - "backbone": self.backbone, - "anchor_generator": self.anchor_generator, - "label_encoder": self.rpn_labeler, - "prediction_decoder": self._prediction_decoder, - "feature_pyramid": self.feature_pyramid, - "rcnn_head": self.rcnn_head, + "backbone": keras.saving.serialize_keras_object(self.backbone), + "anchor_generator": keras.saving.serialize_keras_object( + self.anchor_generator + ), + "label_encoder": keras.saving.serialize_keras_object( + self.rpn_labeler + ), + "prediction_decoder": keras.saving.serialize_keras_object( + self._prediction_decoder + ), + "feature_pyramid": keras.saving.serialize_keras_object( + self.feature_pyramid + ), + "rcnn_head": keras.saving.serialize_keras_object(self.rcnn_head), } + @classmethod + def from_config(cls, config): + if "rcnn_head" in config and isinstance(config["rcnn_head"], dict): + config["rcnn_head"] = keras.layers.deserialize(config["rcnn_head"]) + if "feature_pyramid" in config and isinstance( + config["feature_pyramid"], dict + ): + config["feature_pyramid"] = keras.layers.deserialize( + config["feature_pyramid"] + ) + if "prediction_decoder" in config and isinstance( + config["prediction_decoder"], dict + ): + config["prediction_decoder"] = keras.layers.deserialize( + config["prediction_decoder"] + ) + if "label_encoder" in config and isinstance( + config["label_encoder"], dict + ): + config["label_encoder"] = keras.layers.deserialize( + config["label_encoder"] + ) + if "anchor_generator" in config and isinstance( + config["anchor_generator"], dict + ): + config["anchor_generator"] = keras.layers.deserialize( + config["anchor_generator"] + ) + if "backbone" in config and isinstance(config["backbone"], dict): + config["backbone"] = keras.layers.deserialize(config["backbone"]) + return super().from_config(config) + @classproperty def presets(cls): """Dictionary of preset names and configurations.""" diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py index a6a57c713f..b871fb9431 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py @@ -22,6 +22,7 @@ import keras_cv from keras_cv.backend import keras from keras_cv.backend import ops + # from keras_cv.models.backbones.test_backbone_presets import ( # test_backbone_presets, # ) @@ -37,7 +38,7 @@ def test_faster_rcnn_construction(self): faster_rcnn = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone() + backbone=keras_cv.models.ResNet18V2Backbone(), ) faster_rcnn.compile( optimizer=keras.optimizers.Adam(), @@ -68,7 +69,7 @@ def test_wrong_logits(self): with self.assertRaisesRegex( ValueError, "from_logits", - ): + ): faster_rcnn.compile( optimizer=keras.optimizers.SGD(learning_rate=0.25), box_loss=keras_cv.losses.SmoothL1Loss( @@ -106,6 +107,104 @@ def test_weights_contained_in_trainable_variables(self): _ = faster_rcnn(xs) self.assertEqual(len(faster_rcnn.trainable_variables), 32) + @pytest.mark.large # Fit is slow, so mark these large. + def test_no_nans(self): + faster_rcnn = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone(), + ) + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + # only a -1 box + xs = np.ones((1, 512, 512, 3), "float32") + ys = { + "classes": np.array([[-1]], "float32"), + "boxes": np.array([[[0, 0, 0, 0]]], "float32"), + } + ds = tf.data.Dataset.from_tensor_slices((xs, ys)) + ds = ds.repeat(2) + ds = ds.batch(2, drop_remainder=True) + faster_rcnn.fit(ds, epochs=1) + + weights = faster_rcnn.get_weights() + for weight in weights: + self.assertFalse(ops.any(ops.isnan(weight))) + + @pytest.mark.large # Fit is slow, so mark these large. + def test_weights_change(self): + faster_rcnn = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone(), + ) + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + images, boxes = _create_bounding_box_dataset("xyxy") + ds = tf.data.Dataset.from_tensor_slices( + {"images": images, "bounding_boxes": boxes} + ).batch(5, drop_remainder=True) + + # call once + _ = faster_rcnn(ops.ones((1, 512, 512, 3))) + original_fpn_weights = faster_rcnn.feature_pyramid.get_weights() + original_rpn_head_weights = faster_rcnn.rpn_head.get_weights() + original_rcnn_head_weights = faster_rcnn.rcnn_head.get_weights() + + faster_rcnn.fit(ds, epochs=1) + fpn_after_fit = faster_rcnn.feature_pyramid.get_weights() + rpn_head_after_fit_weights = faster_rcnn.rpn_head.get_weights() + rcnn_head_after_fit_weights = faster_rcnn.rcnn_head.get_weights() + + for w1, w2 in zip( + original_rcnn_head_weights, + rcnn_head_after_fit_weights, + ): + self.assertNotAllClose(w1, w2) + + for w1, w2 in zip( + original_rpn_head_weights, rpn_head_after_fit_weights + ): + self.assertNotAllClose(w1, w2) + + for w1, w2 in zip(original_fpn_weights, fpn_after_fit): + self.assertNotAllClose(w1, w2) + + @pytest.mark.large # Saving is slow, so mark these large. + def test_saved_model(self): + model = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone(), + ) + input_batch = ops.ones(shape=(1, 512, 512, 3)) + model_output = model(input_batch) + save_path = os.path.join(self.get_temp_dir(), "faster_rcnn.keras") + model.save(save_path) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, keras_cv.models.FasterRCNN) + + # Check that output matches. + restored_output = restored_model(input_batch) + self.assertAllClose( + tf.nest.map_structure(ops.convert_to_numpy, model_output), + tf.nest.map_structure(ops.convert_to_numpy, restored_output), + ) + # TODO(ianstenbit): Make FasterRCNN support shapes that are not multiples # of 128, perhaps by adding a flag to the anchor generator for whether to # include anchors centered outside of the image. (RetinaNet does use those, @@ -120,13 +219,13 @@ def test_faster_rcnn_infer(self, batch_shape): model = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(), ) - images = tf.random.normal(batch_shape) + images = ops.random.normal(batch_shape) outputs = model(images, training=False) # 1000 proposals in inference - self.assertAllEqual([2, 1000, 81], outputs[1].shape) - self.assertAllEqual([2, 1000, 4], outputs[0].shape) + self.assertAllEqual([2, 1000, 81], outputs["classes"].shape) + self.assertAllEqual([2, 1000, 4], outputs["boxes"].shape) @parameterized.parameters( ((2, 640, 384, 3),), @@ -137,18 +236,18 @@ def test_faster_rcnn_train(self, batch_shape): model = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(), ) - images = tf.random.normal(batch_shape) + images = ops.random.normal(batch_shape) outputs = model(images, training=True) - self.assertAllEqual([2, 1000, 81], outputs[1].shape) - self.assertAllEqual([2, 1000, 4], outputs[0].shape) + self.assertAllEqual([2, 1000, 81], outputs["classes"].shape) + self.assertAllEqual([2, 1000, 4], outputs["boxes"].shape) def test_invalid_compile(self): model = FasterRCNN( num_classes=80, bounding_box_format="yxyx", - backbone=ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(), ) with self.assertRaisesRegex(ValueError, "only accepts"): model.compile(rpn_box_loss="binary_crossentropy") @@ -159,12 +258,12 @@ def test_invalid_compile(self): ) ) - # @pytest.mark.large # Fit is slow, so mark these large. + @pytest.mark.large # Fit is slow, so mark these large. def test_faster_rcnn_with_dictionary_input_format(self): faster_rcnn = FasterRCNN( num_classes=20, bounding_box_format="xywh", - backbone=ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(), ) images, boxes = _create_bounding_box_dataset("xywh") diff --git a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py index 1bc2394a0a..d3014abc88 100644 --- a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py +++ b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py @@ -24,14 +24,15 @@ package="keras_cv.models.faster_rcnn", ) class RPNHead(keras.layers.Layer): - """ A Keras layer implementing the RPN architecture. - + """A Keras layer implementing the RPN architecture. + Region Proposal Networks (RPN) was first suggested in [FasterRCNN](https://arxiv.org/abs/1506.01497). This is an end to end trainable layer which proposes regions for a detector (RCNN). Args: num_achors_per_location: The number of anchors per location. """ + def __init__( self, num_anchors_per_location=3, From 7741edcf9e94757c511a74d10bfebb5b5a424313 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Sat, 16 Sep 2023 16:49:28 +0530 Subject: [PATCH 12/32] fix: lint --- keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py | 4 +++- keras_cv/models/object_detection/faster_rcnn/rpn_head.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index c6dfc4bfd3..923f17a520 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -45,7 +45,9 @@ from keras_cv.models.object_detection.faster_rcnn import RCNNHead from keras_cv.models.object_detection.faster_rcnn import RPNHead -# from keras_cv.models.object_detection.faster_rcnn.faster_rcnn_presets import faster_rcnn_presets +# from keras_cv.models.object_detection.faster_rcnn.faster_rcnn_presets import ( +# faster_rcnn_presets +# ) from keras_cv.models.task import Task from keras_cv.utils.python_utils import classproperty from keras_cv.utils.train import get_feature_extractor diff --git a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py index d3014abc88..11d853125f 100644 --- a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py +++ b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py @@ -26,8 +26,10 @@ class RPNHead(keras.layers.Layer): """A Keras layer implementing the RPN architecture. - Region Proposal Networks (RPN) was first suggested in [FasterRCNN](https://arxiv.org/abs/1506.01497). - This is an end to end trainable layer which proposes regions for a detector (RCNN). + Region Proposal Networks (RPN) was first suggested in + [FasterRCNN](https://arxiv.org/abs/1506.01497). + This is an end to end trainable layer which proposes regions + for a detector (RCNN). Args: num_achors_per_location: The number of anchors per location. From 13a26e615c7d37f7bd068f57e484d37639bcc207 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Sat, 16 Sep 2023 16:53:43 +0530 Subject: [PATCH 13/32] chore: adding copyright to faster rcnn presets script --- .../faster_rcnn/faster_rcnn_presets.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py index e69de29bb2..1f8b847a76 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py @@ -0,0 +1,14 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FastrerRCNN Task presets.""" \ No newline at end of file From 3b42ecc14cc9962de36e5259d51f6a7bac328ff4 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Thu, 21 Sep 2023 10:31:00 +0530 Subject: [PATCH 14/32] chore: removing tf imports --- keras_cv/layers/object_detection/roi_align.py | 44 ++++++++----------- .../layers/object_detection/roi_sampler.py | 9 ++-- .../object_detection/rpn_label_encoder.py | 11 ++--- .../faster_rcnn/faster_rcnn_presets.py | 2 +- 4 files changed, 26 insertions(+), 40 deletions(-) diff --git a/keras_cv/layers/object_detection/roi_align.py b/keras_cv/layers/object_detection/roi_align.py index 6f0290dc2f..0a71cbe5bd 100644 --- a/keras_cv/layers/object_detection/roi_align.py +++ b/keras_cv/layers/object_detection/roi_align.py @@ -12,14 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict -from typing import Mapping -from typing import Optional -from typing import Tuple - -# TODO (ariG23498): remove tf and correct the type imports -import tensorflow as tf - from keras_cv import bounding_box from keras_cv.backend import assert_tf_keras from keras_cv.backend import keras @@ -27,7 +19,9 @@ def _feature_bilinear_interpolation( - features: tf.Tensor, kernel_y: tf.Tensor, kernel_x: tf.Tensor + features, + kernel_y, + kernel_x, ) -> tf.Tensor: """ Feature bilinear interpolation. @@ -88,11 +82,11 @@ def _feature_bilinear_interpolation( def _compute_grid_positions( - boxes: tf.Tensor, - boundaries: tf.Tensor, - output_size: int, - sample_offset: float, -) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: + boxes, + boundaries, + output_size, + sample_offset, +): """ Computes the grid position w.r.t. the corresponding feature map. @@ -106,8 +100,8 @@ def _compute_grid_positions( boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing the boundary (in (y, x)) of the corresponding feature map for each box. Any resampled grid points that go beyond the boundary will be clipped. - output_size: a scalar indicating the output crop size. - sample_offset: a float number in [0, 1] indicates the subpixel sample + output_size: a `int` scalar indicating the output crop size. + sample_offset: a `float` number in [0, 1] indicates the subpixel sample offset from grid point. Returns: @@ -177,11 +171,11 @@ def _compute_grid_positions( def multilevel_crop_and_resize( - features: Dict[str, tf.Tensor], - boxes: tf.Tensor, - output_size: int = 7, - sample_offset: float = 0.5, -) -> tf.Tensor: + features, + boxes, + output_size=7, + sample_offset=0.5, +): """ Crop and resize on multilevel feature pyramid. @@ -190,7 +184,7 @@ def multilevel_crop_and_resize( and resizing it using the corresponding feature map of that level. Args: - features: A dictionary with key as pyramid level and value as features. + features: A dictionary with key as pyramid level and value as features (tensors). The pyramid level keys need to be represented by strings like so: "P2", "P3", "P4", and so on. The features are in shape of [batch_size, height_l, width_l, @@ -427,9 +421,9 @@ def __init__( def call( self, - features: Mapping[str, tf.Tensor], - boxes: tf.Tensor, - training: Optional[bool] = None, + features, + boxes, + training=None, ): """ diff --git a/keras_cv/layers/object_detection/roi_sampler.py b/keras_cv/layers/object_detection/roi_sampler.py index 86b8cfb16a..0d2d74edca 100644 --- a/keras_cv/layers/object_detection/roi_sampler.py +++ b/keras_cv/layers/object_detection/roi_sampler.py @@ -14,9 +14,6 @@ import numpy as np # used for newaxis -# TODO (ariG23498): remove tf and correct the type imports -import tensorflow as tf - from keras_cv import bounding_box from keras_cv.backend import assert_tf_keras from keras_cv.backend import keras @@ -88,9 +85,9 @@ def __init__( def call( self, - rois: tf.Tensor, - gt_boxes: tf.Tensor, - gt_classes: tf.Tensor, + rois, + gt_boxes, + gt_classes, ): """ Args: diff --git a/keras_cv/layers/object_detection/rpn_label_encoder.py b/keras_cv/layers/object_detection/rpn_label_encoder.py index 8ac27900a2..69a1d2f630 100644 --- a/keras_cv/layers/object_detection/rpn_label_encoder.py +++ b/keras_cv/layers/object_detection/rpn_label_encoder.py @@ -12,12 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Mapping - import numpy as np # Used for newaxis - -# TODO (ariG23498): remove tf and correct the type imports -import tensorflow as tf import tree from keras_cv import bounding_box @@ -97,9 +92,9 @@ def __init__( def call( self, - anchors_dict: Mapping[str, tf.Tensor], - gt_boxes: tf.Tensor, - gt_classes: tf.Tensor, + anchors_dict, + gt_boxes, + gt_classes, ): """ Args: diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py index 1f8b847a76..1c056e7835 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""FastrerRCNN Task presets.""" \ No newline at end of file +"""FastrerRCNN Task presets.""" From be9178b5ce2cde88fbb36198cf3af48383f751cd Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Wed, 27 Sep 2023 12:30:46 +0530 Subject: [PATCH 15/32] fix imports --- keras_cv/layers/object_detection/roi_align.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/layers/object_detection/roi_align.py b/keras_cv/layers/object_detection/roi_align.py index 0a71cbe5bd..5424e70493 100644 --- a/keras_cv/layers/object_detection/roi_align.py +++ b/keras_cv/layers/object_detection/roi_align.py @@ -22,7 +22,7 @@ def _feature_bilinear_interpolation( features, kernel_y, kernel_x, -) -> tf.Tensor: +): """ Feature bilinear interpolation. From e59d2b4da29efc5da2f925676dc3aa501bc0d50a Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Mon, 6 Nov 2023 17:53:14 +0530 Subject: [PATCH 16/32] fix: style --- keras_cv/layers/object_detection/roi_align.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_cv/layers/object_detection/roi_align.py b/keras_cv/layers/object_detection/roi_align.py index 5424e70493..8d4f347878 100644 --- a/keras_cv/layers/object_detection/roi_align.py +++ b/keras_cv/layers/object_detection/roi_align.py @@ -184,7 +184,8 @@ def multilevel_crop_and_resize( and resizing it using the corresponding feature map of that level. Args: - features: A dictionary with key as pyramid level and value as features (tensors). + features: A dictionary with key as pyramid level and value as + features (tensors). The pyramid level keys need to be represented by strings like so: "P2", "P3", "P4", and so on. The features are in shape of [batch_size, height_l, width_l, From 001162c00ae84302839a88c34fb65cd51a07e7fa Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Tue, 7 Nov 2023 23:12:54 +0530 Subject: [PATCH 17/32] chore: making the model functional in init --- .../faster_rcnn/faster_rcnn.py | 107 +++++++++++------- 1 file changed, 64 insertions(+), 43 deletions(-) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 923f17a520..f857d99302 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -148,13 +148,10 @@ def __init__( feature_pyramid=None, **kwargs, ): - self.num_classes = num_classes - self.bounding_box_format = bounding_box_format - super().__init__(**kwargs) scales = [2**x for x in [0]] aspect_ratios = [0.5, 1.0, 2.0] - self.anchor_generator = anchor_generator or AnchorGenerator( - bounding_box_format="yxyx", + anchor_generator = anchor_generator or AnchorGenerator( + bounding_box_format=bounding_box_format, sizes={ "P2": 32.0, "P3": 64.0, @@ -167,14 +164,68 @@ def __init__( strides={f"P{i}": 2**i for i in range(2, 7)}, clip_boxes=True, ) - self.rpn_head = RPNHead( + rpn_head = RPNHead( num_anchors_per_location=len(scales) * len(aspect_ratios) ) - self.roi_generator = ROIGenerator( - bounding_box_format="yxyx", + roi_generator = ROIGenerator( + bounding_box_format=bounding_box_format, nms_score_threshold_train=float("-inf"), nms_score_threshold_test=float("-inf"), ) + roi_pooler = _ROIAligner(bounding_box_format=bounding_box_format) + rcnn_head = rcnn_head or RCNNHead(num_classes) + backbone = backbone or models.ResNet50Backbone() + extractor_levels = ["P2", "P3", "P4", "P5"] + extractor_layer_names = [ + backbone.pyramid_level_inputs[i] for i in extractor_levels + ] + feature_extractor = get_feature_extractor( + backbone, extractor_layer_names, extractor_levels + ) + feature_pyramid = feature_pyramid or FeaturePyramid() + + # Begin construction of forward pass + image_shape = feature_extractor.input_shape[1:] + images = keras.layers.Input(image_shape, name="images") + anchors = anchor_generator(image_shape=image_shape) + + # Calling the RPN block + backbone_outputs = feature_extractor(images) + feature_map = feature_pyramid(backbone_outputs) + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_boxes, rpn_scores = rpn_head(feature_map) + # the decoded format is center_xywh, convert to yxyx + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format=bounding_box_format, + box_format=bounding_box_format, + variance=BOX_VARIANCE, + ) + rois, _ = roi_generator(decoded_rpn_boxes, rpn_scores) + rois = _clip_boxes(rois, bounding_box_format, image_shape) + + # Calling the RCNN block + feature_map = roi_pooler(feature_map, rois) + # [BS, H*W*K, pool_shape*C] + feature_map = ops.reshape( + feature_map, ops.concatenate([ops.shape(rois)[:2], [-1]], axis=0) + ) + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + rcnn_box_pred, rcnn_cls_pred = self.rcnn_head( + feature_map, + ) + + inputs = {"images": images} + outputs = {"box": rcnn_box_pred, "classification": rcnn_cls_pred} + + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + self.num_classes = num_classes + self.bounding_box_format = bounding_box_format + self.anchor_generator = anchor_generator + self.rpn_head = rpn_head + self.roi_generator = roi_generator self.box_matcher = BoxMatcher( thresholds=[0.0, 0.5], match_values=[-2, -1, 1] ) @@ -184,17 +235,11 @@ def __init__( background_class=num_classes, num_sampled_rois=512, ) - self.roi_pooler = _ROIAligner(bounding_box_format="yxyx") - self.rcnn_head = rcnn_head or RCNNHead(num_classes) - self.backbone = backbone or models.ResNet50Backbone() - extractor_levels = ["P2", "P3", "P4", "P5"] - extractor_layer_names = [ - self.backbone.pyramid_level_inputs[i] for i in extractor_levels - ] - self.feature_extractor = get_feature_extractor( - self.backbone, extractor_layer_names, extractor_levels - ) - self.feature_pyramid = feature_pyramid or FeaturePyramid() + self.roi_pooler = roi_pooler + self.rcnn_head = rcnn_head + self.backbone = backbone + self.feature_extractor = feature_extractor + self.feature_pyramid = feature_pyramid self.rpn_labeler = label_encoder or _RpnLabelEncoder( anchor_format="yxyx", ground_truth_box_format="yxyx", @@ -249,30 +294,6 @@ def _call_rcnn(self, rois, feature_map, training=None): ) return rcnn_box_pred, rcnn_cls_pred - def call(self, images, training=None): - image_shape = ops.shape(images[0]) - anchors = self.anchor_generator(image_shape=image_shape) - rois, feature_map, _, _ = self._call_rpn( - images, anchors, training=training - ) - box_pred, cls_pred = self._call_rcnn( - rois, feature_map, training=training - ) - if not training: - # box_pred is on "center_yxhw" format, convert to target format. - box_pred = _decode_deltas_to_boxes( - anchors=rois, - boxes_delta=box_pred, - anchor_format="yxyx", - box_format=self.bounding_box_format, - variance=[0.1, 0.1, 0.2, 0.2], - ) - outputs = { - "boxes": box_pred, - "classes": cls_pred, - } - return outputs - # TODO(tanzhenyu): Support compile with metrics. def compile( self, From 9aab0e947bdae0b3cc4dd1b4d44c939245316523 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Mon, 18 Dec 2023 11:37:55 +0530 Subject: [PATCH 18/32] chore: adding static image shapes to backbone in tests --- .../faster_rcnn/faster_rcnn_test.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py index b871fb9431..5ac248315d 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py @@ -53,7 +53,7 @@ def test_faster_rcnn_call(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format="xywh", - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) images = np.random.uniform(size=(2, 512, 512, 3)) _ = faster_rcnn(images) @@ -63,7 +63,7 @@ def test_wrong_logits(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format="xywh", - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) with self.assertRaisesRegex( @@ -91,7 +91,7 @@ def test_weights_contained_in_trainable_variables(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format=bounding_box_format, - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) faster_rcnn.backbone.trainable = False faster_rcnn.compile( @@ -112,7 +112,7 @@ def test_no_nans(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) faster_rcnn.compile( optimizer=keras.optimizers.Adam(), @@ -142,7 +142,7 @@ def test_weights_change(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) faster_rcnn.compile( optimizer=keras.optimizers.Adam(), @@ -187,7 +187,7 @@ def test_saved_model(self): model = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) input_batch = ops.ones(shape=(1, 512, 512, 3)) model_output = model(input_batch) @@ -219,7 +219,7 @@ def test_faster_rcnn_infer(self, batch_shape): model = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) images = ops.random.normal(batch_shape) outputs = model(images, training=False) @@ -236,7 +236,7 @@ def test_faster_rcnn_train(self, batch_shape): model = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) images = ops.random.normal(batch_shape) outputs = model(images, training=True) @@ -247,7 +247,7 @@ def test_invalid_compile(self): model = FasterRCNN( num_classes=80, bounding_box_format="yxyx", - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) with self.assertRaisesRegex(ValueError, "only accepts"): model.compile(rpn_box_loss="binary_crossentropy") @@ -263,7 +263,7 @@ def test_faster_rcnn_with_dictionary_input_format(self): faster_rcnn = FasterRCNN( num_classes=20, bounding_box_format="xywh", - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) images, boxes = _create_bounding_box_dataset("xywh") From 49815d1e52d5c8a3e5d366c9850974b28fffc7cd Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Mon, 18 Dec 2023 11:54:46 +0530 Subject: [PATCH 19/32] fix: parameterised input shape in test --- keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py | 2 +- .../models/object_detection/faster_rcnn/faster_rcnn_test.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index f857d99302..8ebe37c3f8 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -212,7 +212,7 @@ def __init__( feature_map, ops.concatenate([ops.shape(rois)[:2], [-1]], axis=0) ) # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] - rcnn_box_pred, rcnn_cls_pred = self.rcnn_head( + rcnn_box_pred, rcnn_cls_pred = rcnn_head( feature_map, ) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py index 5ac248315d..a2630f446d 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py @@ -38,7 +38,7 @@ def test_faster_rcnn_construction(self): faster_rcnn = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), ) faster_rcnn.compile( optimizer=keras.optimizers.Adam(), @@ -219,7 +219,7 @@ def test_faster_rcnn_infer(self, batch_shape): model = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=batch_shape[1:]), ) images = ops.random.normal(batch_shape) outputs = model(images, training=False) @@ -236,7 +236,7 @@ def test_faster_rcnn_train(self, batch_shape): model = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone(input_shape=batch_shape[1:]), ) images = ops.random.normal(batch_shape) outputs = model(images, training=True) From 6061f01ed7bdc3cca97b262e80f47b40c3dd1be0 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Mon, 18 Dec 2023 15:19:41 +0530 Subject: [PATCH 20/32] fix: reshape --- keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 8ebe37c3f8..74c0edce6c 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -208,9 +208,7 @@ def __init__( # Calling the RCNN block feature_map = roi_pooler(feature_map, rois) # [BS, H*W*K, pool_shape*C] - feature_map = ops.reshape( - feature_map, ops.concatenate([ops.shape(rois)[:2], [-1]], axis=0) - ) + feature_map = keras.layers.Reshape(target_shape=(ops.shape(rois)[1], -1))(feature_map) # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] rcnn_box_pred, rcnn_cls_pred = rcnn_head( feature_map, From ef279a98ddb4bd3f4c065c699c283ff3bba6a39d Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Mon, 18 Dec 2023 15:47:01 +0530 Subject: [PATCH 21/32] fix: format and output dict --- .../faster_rcnn/faster_rcnn.py | 4 +- .../faster_rcnn/faster_rcnn_test.py | 52 +++++++++++++------ 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 74c0edce6c..2dd4398985 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -208,7 +208,9 @@ def __init__( # Calling the RCNN block feature_map = roi_pooler(feature_map, rois) # [BS, H*W*K, pool_shape*C] - feature_map = keras.layers.Reshape(target_shape=(ops.shape(rois)[1], -1))(feature_map) + feature_map = keras.layers.Reshape( + target_shape=(ops.shape(rois)[1], -1) + )(feature_map) # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] rcnn_box_pred, rcnn_cls_pred = rcnn_head( feature_map, diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py index a2630f446d..817a0c7641 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py @@ -38,7 +38,9 @@ def test_faster_rcnn_construction(self): faster_rcnn = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), ) faster_rcnn.compile( optimizer=keras.optimizers.Adam(), @@ -53,7 +55,9 @@ def test_faster_rcnn_call(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format="xywh", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), ) images = np.random.uniform(size=(2, 512, 512, 3)) _ = faster_rcnn(images) @@ -63,7 +67,9 @@ def test_wrong_logits(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format="xywh", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), ) with self.assertRaisesRegex( @@ -91,7 +97,9 @@ def test_weights_contained_in_trainable_variables(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format=bounding_box_format, - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), ) faster_rcnn.backbone.trainable = False faster_rcnn.compile( @@ -112,7 +120,9 @@ def test_no_nans(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), ) faster_rcnn.compile( optimizer=keras.optimizers.Adam(), @@ -142,7 +152,9 @@ def test_weights_change(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), ) faster_rcnn.compile( optimizer=keras.optimizers.Adam(), @@ -187,7 +199,9 @@ def test_saved_model(self): model = keras_cv.models.FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), ) input_batch = ops.ones(shape=(1, 512, 512, 3)) model_output = model(input_batch) @@ -219,13 +233,15 @@ def test_faster_rcnn_infer(self, batch_shape): model = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=batch_shape[1:]), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=batch_shape[1:] + ), ) images = ops.random.normal(batch_shape) outputs = model(images, training=False) # 1000 proposals in inference - self.assertAllEqual([2, 1000, 81], outputs["classes"].shape) - self.assertAllEqual([2, 1000, 4], outputs["boxes"].shape) + self.assertAllEqual([2, 1000, 81], outputs["classification"].shape) + self.assertAllEqual([2, 1000, 4], outputs["box"].shape) @parameterized.parameters( ((2, 640, 384, 3),), @@ -236,18 +252,22 @@ def test_faster_rcnn_train(self, batch_shape): model = FasterRCNN( num_classes=80, bounding_box_format="xyxy", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=batch_shape[1:]), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=batch_shape[1:] + ), ) images = ops.random.normal(batch_shape) outputs = model(images, training=True) - self.assertAllEqual([2, 1000, 81], outputs["classes"].shape) - self.assertAllEqual([2, 1000, 4], outputs["boxes"].shape) + self.assertAllEqual([2, 1000, 81], outputs["classification"].shape) + self.assertAllEqual([2, 1000, 4], outputs["box"].shape) def test_invalid_compile(self): model = FasterRCNN( num_classes=80, bounding_box_format="yxyx", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), ) with self.assertRaisesRegex(ValueError, "only accepts"): model.compile(rpn_box_loss="binary_crossentropy") @@ -263,7 +283,9 @@ def test_faster_rcnn_with_dictionary_input_format(self): faster_rcnn = FasterRCNN( num_classes=20, bounding_box_format="xywh", - backbone=keras_cv.models.ResNet18V2Backbone(input_shape=(512, 512, 3)), + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), ) images, boxes = _create_bounding_box_dataset("xywh") From 134f897651b1387db39212ae3d56ac90b95d0632 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Tue, 19 Dec 2023 13:12:44 +0530 Subject: [PATCH 22/32] chore: masking sample weights for box labels -1 --- keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 2dd4398985..4fec24a759 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -386,6 +386,12 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): cls_targets, cls_weights, ) = self.roi_sampler(rois, boxes, classes) + + # Mask weights for class -1 + positive_mask = ops.cast(ops.greater(classes, -1.0), dtype="float32") + box_weights *= positive_mask + cls_weights *= positive_mask + box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 cls_weights /= self.roi_sampler.num_sampled_rois * local_batch box_pred, cls_pred = self._call_rcnn( From e190e1b2309ee0bef519c9ff2f62c9c8e19ed0f0 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Tue, 19 Dec 2023 15:58:14 +0530 Subject: [PATCH 23/32] chore: fixing sample weights and decode predictions --- .../faster_rcnn/faster_rcnn.py | 34 +++++++++--- .../faster_rcnn/faster_rcnn_test.py | 52 +++++++++++++++++-- 2 files changed, 74 insertions(+), 12 deletions(-) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 4fec24a759..2456d4b9f3 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -387,17 +387,22 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): cls_weights, ) = self.roi_sampler(rois, boxes, classes) - # Mask weights for class -1 - positive_mask = ops.cast(ops.greater(classes, -1.0), dtype="float32") - box_weights *= positive_mask - cls_weights *= positive_mask + positive_mask = ops.cast( + ops.greater(cls_targets, -1.0), dtype="float32" + ) + normalizer = ops.sum(positive_mask) box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 cls_weights /= self.roi_sampler.num_sampled_rois * local_batch + + box_weights /= normalizer + cls_weights /= normalizer + box_pred, cls_pred = self._call_rcnn( rois, feature_map, ) + y_true = { "rpn_box": rpn_box_targets, "rpn_classification": rpn_cls_targets, @@ -410,14 +415,26 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): "box": box_pred, "classification": cls_pred, } - weights = { + sample_weights = { "rpn_box": rpn_box_weights, "rpn_classification": rpn_cls_weights, "box": box_weights, "classification": cls_weights, } + zero_weights = { + "rpn_box": ops.zeros_like(rpn_box_weights), + "rpn_classification": ops.zeros_like(rpn_cls_weights), + "box": ops.zeros_like(box_weights), + "classification": ops.zeros_like(cls_weights), + } + + sample_weights = ops.cond( + normalizer == 0.0, + lambda: zero_weights, + lambda: sample_weights, + ) return super().compute_loss( - x=images, y=y_true, y_pred=y_pred, sample_weight=weights + x=images, y=y_true, y_pred=y_pred, sample_weight=sample_weights ) def train_step(self, *args): @@ -450,7 +467,10 @@ def prediction_decoder(self, prediction_decoder): def decode_predictions(self, predictions, images): # no-op if default decoder is used. - box_pred, scores_pred = predictions["boxes"], predictions["classes"] + box_pred, scores_pred = ( + predictions["box"], + predictions["classification"], + ) box_pred = bounding_box.convert_format( box_pred, source=self.bounding_box_format, diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py index 817a0c7641..ee92e38d7f 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py @@ -22,10 +22,9 @@ import keras_cv from keras_cv.backend import keras from keras_cv.backend import ops - -# from keras_cv.models.backbones.test_backbone_presets import ( -# test_backbone_presets, -# ) +from keras_cv.models.backbones.test_backbone_presets import ( + test_backbone_presets, +) from keras_cv.models.object_detection.__test_utils__ import ( _create_bounding_box_dataset, ) @@ -50,7 +49,6 @@ def test_faster_rcnn_construction(self): rpn_classification_loss="BinaryCrossentropy", ) - @pytest.mark.large # Fit is slow, so mark these large. def test_faster_rcnn_call(self): faster_rcnn = keras_cv.models.FasterRCNN( num_classes=80, @@ -303,3 +301,47 @@ def test_faster_rcnn_with_dictionary_input_format(self): faster_rcnn.fit(dataset, epochs=1) faster_rcnn.evaluate(dataset) + + # @pytest.mark.large # Fit is slow, so mark these large. + def test_fit_with_no_valid_gt_bbox(self): + bounding_box_format = "xywh" + faster_rcnn = FasterRCNN( + num_classes=20, + bounding_box_format=bounding_box_format, + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + xs, ys = _create_bounding_box_dataset(bounding_box_format) + # Make all bounding_boxes invalid and filter out them + ys["classes"] = -np.ones_like(ys["classes"]) + + faster_rcnn.fit(x=xs, y=ys, epochs=1) + + +@pytest.mark.large +class FasterRCNNSmokeTest(TestCase): + @parameterized.named_parameters( + *[(preset, preset) for preset in test_backbone_presets] + ) + @pytest.mark.extra_large + def test_backbone_preset(self, preset): + model = keras_cv.models.FasterRCNN.from_preset( + preset, + num_classes=20, + bounding_box_format="xywh", + ) + xs, _ = _create_bounding_box_dataset(bounding_box_format="xywh") + output = model(xs) + + # 64 represents number of parameters in a box + # 5376 is the number of anchors for a 512x512 image + self.assertEqual(output["boxes"].shape, (xs.shape[0], 5376, 64)) From 821b7aaf41d8175461c180ae0933b9000beace48 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Tue, 2 Jan 2024 19:56:12 +0530 Subject: [PATCH 24/32] chore: porting roi gen to keras 3 ops non max supression padded api from tf is yet to be ported --- .../layers/object_detection/roi_generator.py | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/keras_cv/layers/object_detection/roi_generator.py b/keras_cv/layers/object_detection/roi_generator.py index 42b5accb05..6461500df5 100644 --- a/keras_cv/layers/object_detection/roi_generator.py +++ b/keras_cv/layers/object_detection/roi_generator.py @@ -17,13 +17,11 @@ from typing import Tuple from typing import Union -import tensorflow as tf -from tensorflow import keras - from keras_cv import bounding_box from keras_cv.api_export import keras_cv_export from keras_cv.backend import assert_tf_keras from keras_cv.backend import ops +from keras_cv.backend import keras @keras_cv_export("keras_cv.layers.ROIGenerator") @@ -113,10 +111,10 @@ def __init__( def call( self, - multi_level_boxes: Union[tf.Tensor, Mapping[int, tf.Tensor]], - multi_level_scores: Union[tf.Tensor, Mapping[int, tf.Tensor]], + multi_level_boxes, + multi_level_scores, training: Optional[bool] = None, - ) -> Tuple[tf.Tensor, tf.Tensor]: + ): """ Args: multi_level_boxes: float Tensor. A dictionary or single Tensor of @@ -148,14 +146,14 @@ def per_level_gen(boxes, scores): scores_shape = scores.get_shape().as_list() # scores can also be [batch_size, num_boxes, 1] if len(scores_shape) == 3: - scores = tf.squeeze(scores, axis=-1) + scores = ops.squeeze(scores, axis=-1) num_boxes = ops.shape(boxes)[1] level_pre_nms_topk = min(num_boxes, pre_nms_topk) level_post_nms_topk = min(num_boxes, post_nms_topk) - scores, sorted_indices = tf.nn.top_k( + scores, sorted_indices = ops.top_k( scores, k=level_pre_nms_topk, sorted=True ) - boxes = tf.gather(boxes, sorted_indices, batch_dims=1) + boxes = ops.take(boxes, sorted_indices, batch_dims=1) # convert from input format to yxyx for the TF NMS operation boxes = bounding_box.convert_format( boxes, @@ -163,6 +161,7 @@ def per_level_gen(boxes, scores): target="yxyx", ) # TODO(tanzhenyu): consider supporting soft / batched nms for accl + import tensorflow as tf selected_indices, num_valid = tf.image.non_max_suppression_padded( boxes, scores, @@ -179,16 +178,16 @@ def per_level_gen(boxes, scores): source="yxyx", target=self.bounding_box_format, ) - level_rois = tf.gather(boxes, selected_indices, batch_dims=1) - level_roi_scores = tf.gather(scores, selected_indices, batch_dims=1) - level_rois = level_rois * tf.cast( - tf.reshape(tf.range(level_post_nms_topk), [1, -1, 1]) - < tf.reshape(num_valid, [-1, 1, 1]), + level_rois = ops.take(boxes, selected_indices, batch_dims=1) + level_roi_scores = ops.take(scores, selected_indices, batch_dims=1) + level_rois = level_rois * ops.cast( + ops.reshape(ops.arange(level_post_nms_topk), [1, -1, 1]) + < ops.reshape(num_valid, [-1, 1, 1]), level_rois.dtype, ) - level_roi_scores = level_roi_scores * tf.cast( - tf.reshape(tf.range(level_post_nms_topk), [1, -1]) - < tf.reshape(num_valid, [-1, 1]), + level_roi_scores = level_roi_scores * ops.cast( + ops.reshape(ops.range(level_post_nms_topk), [1, -1]) + < ops.reshape(num_valid, [-1, 1]), level_roi_scores.dtype, ) return level_rois, level_roi_scores @@ -205,14 +204,14 @@ def per_level_gen(boxes, scores): rois.append(level_rois) roi_scores.append(level_roi_scores) - rois = tf.concat(rois, axis=1) - roi_scores = tf.concat(roi_scores, axis=1) + rois = ops.concatenate(rois, axis=1) + roi_scores = ops.concatenate(roi_scores, axis=1) _, num_valid_rois = roi_scores.get_shape().as_list() overall_top_k = min(num_valid_rois, post_nms_topk) - roi_scores, sorted_indices = tf.nn.top_k( + roi_scores, sorted_indices = ops.top_k( roi_scores, k=overall_top_k, sorted=True ) - rois = tf.gather(rois, sorted_indices, batch_dims=1) + rois = ops.take(rois, sorted_indices, batch_dims=1) return rois, roi_scores From 922725545a2e34857d85bbc0d38ec4502449616e Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Wed, 10 Jan 2024 13:19:17 +0530 Subject: [PATCH 25/32] chore: port roi gen to keras 3 --- .../layers/object_detection/roi_generator.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/keras_cv/layers/object_detection/roi_generator.py b/keras_cv/layers/object_detection/roi_generator.py index 6461500df5..477e41c5e2 100644 --- a/keras_cv/layers/object_detection/roi_generator.py +++ b/keras_cv/layers/object_detection/roi_generator.py @@ -20,8 +20,9 @@ from keras_cv import bounding_box from keras_cv.api_export import keras_cv_export from keras_cv.backend import assert_tf_keras -from keras_cv.backend import ops from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.layers import NonMaxSuppression @keras_cv_export("keras_cv.layers.ROIGenerator") @@ -161,17 +162,13 @@ def per_level_gen(boxes, scores): target="yxyx", ) # TODO(tanzhenyu): consider supporting soft / batched nms for accl - import tensorflow as tf - selected_indices, num_valid = tf.image.non_max_suppression_padded( - boxes, - scores, - max_output_size=level_post_nms_topk, + selected_indices, num_valid = NonMaxSuppression( + bounding_box_format=self.bounding_box_format, + from_logits=True, iou_threshold=nms_iou_threshold, - score_threshold=nms_score_threshold, - pad_to_max_output_size=True, - sorted_input=True, - canonicalized_coordinates=True, - ) + confidence_threshold=nms_score_threshold, + max_detections=level_post_nms_topk, + )(box_prediction=boxes, class_prediction=scores) # convert back to input format boxes = bounding_box.convert_format( boxes, From 345764f802e0e7aa7a1f58449f79d36fb39cb538 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Wed, 10 Jan 2024 16:44:42 +0530 Subject: [PATCH 26/32] chore: removing asserts for keras 3 --- keras_cv/layers/object_detection/roi_align.py | 1 - keras_cv/layers/object_detection/roi_generator.py | 12 ++++-------- .../models/object_detection/faster_rcnn/rpn_head.py | 2 +- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/keras_cv/layers/object_detection/roi_align.py b/keras_cv/layers/object_detection/roi_align.py index 8d4f347878..45cdc74c45 100644 --- a/keras_cv/layers/object_detection/roi_align.py +++ b/keras_cv/layers/object_detection/roi_align.py @@ -412,7 +412,6 @@ def __init__( sample_offset: A `float` in [0, 1] of the subpixel sample offset. **kwargs: Additional keyword arguments passed to Layer. """ - assert_tf_keras("keras_cv.layers._ROIAligner") self._config_dict = { "bounding_box_format": bounding_box_format, "crop_size": target_size, diff --git a/keras_cv/layers/object_detection/roi_generator.py b/keras_cv/layers/object_detection/roi_generator.py index 477e41c5e2..05e5a69644 100644 --- a/keras_cv/layers/object_detection/roi_generator.py +++ b/keras_cv/layers/object_detection/roi_generator.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Mapping from typing import Optional -from typing import Tuple -from typing import Union from keras_cv import bounding_box from keras_cv.api_export import keras_cv_export @@ -97,7 +94,6 @@ def __init__( post_nms_topk_test: int = 1000, **kwargs, ): - assert_tf_keras("keras_cv.layers.ROIGenerator") super().__init__(**kwargs) self.bounding_box_format = bounding_box_format self.pre_nms_topk_train = pre_nms_topk_train @@ -154,7 +150,7 @@ def per_level_gen(boxes, scores): scores, sorted_indices = ops.top_k( scores, k=level_pre_nms_topk, sorted=True ) - boxes = ops.take(boxes, sorted_indices, batch_dims=1) + boxes = ops.take(boxes, sorted_indices) # convert from input format to yxyx for the TF NMS operation boxes = bounding_box.convert_format( boxes, @@ -175,8 +171,8 @@ def per_level_gen(boxes, scores): source="yxyx", target=self.bounding_box_format, ) - level_rois = ops.take(boxes, selected_indices, batch_dims=1) - level_roi_scores = ops.take(scores, selected_indices, batch_dims=1) + level_rois = ops.take(boxes, selected_indices) + level_roi_scores = ops.take(scores, selected_indices) level_rois = level_rois * ops.cast( ops.reshape(ops.arange(level_post_nms_topk), [1, -1, 1]) < ops.reshape(num_valid, [-1, 1, 1]), @@ -208,7 +204,7 @@ def per_level_gen(boxes, scores): roi_scores, sorted_indices = ops.top_k( roi_scores, k=overall_top_k, sorted=True ) - rois = ops.take(rois, sorted_indices, batch_dims=1) + rois = ops.take(rois, sorted_indices) return rois, roi_scores diff --git a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py index 11d853125f..d0bf087b1d 100644 --- a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py +++ b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py @@ -46,7 +46,7 @@ def __init__( def build(self, input_shape): if isinstance(input_shape, (dict, list, tuple)): input_shape = tree.flatten(input_shape) - input_shape = input_shape[0] + input_shape = input_shape[0:4] filters = input_shape[-1] self.conv = keras.layers.Conv2D( filters=filters, From 9e7eea03e34ea5d03a07b823b60d751e7abe6aa7 Mon Sep 17 00:00:00 2001 From: Aritra Roy Gosthipaty Date: Wed, 28 Feb 2024 05:37:23 +0000 Subject: [PATCH 27/32] chore: adding faster rcnn to kokoro build script --- .kokoro/github/ubuntu/gpu/build.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 76ac0631b4..c368571a33 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -69,6 +69,7 @@ then keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ + keras_cv/models/object_detection/faster_rcnn \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion else @@ -83,6 +84,7 @@ else keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ + keras_cv/models/object_detection/faster_rcnn \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion fi \ No newline at end of file From af47e3fe1274d96eb9580a6042895379c2ec3453 Mon Sep 17 00:00:00 2001 From: Aritra Roy Gosthipaty Date: Wed, 28 Feb 2024 12:50:50 +0000 Subject: [PATCH 28/32] chore: changing a bunch of things and keeping it commited for reference --- demo.py | 47 ++ keras_cv/bounding_box/utils.py | 2 +- keras_cv/layers/object_detection/roi_align.py | 8 +- .../layers/object_detection/roi_generator.py | 2 +- .../faster_rcnn/faster_rcnn.py | 636 ++++-------------- .../faster_rcnn/faster_rcnn_port.py | 567 ++++++++++++++++ .../faster_rcnn/faster_rcnn_tf.py | 394 +++++++++++ .../object_detection/faster_rcnn/rpn_head.py | 2 +- 8 files changed, 1146 insertions(+), 512 deletions(-) create mode 100644 demo.py create mode 100644 keras_cv/models/object_detection/faster_rcnn/faster_rcnn_port.py create mode 100644 keras_cv/models/object_detection/faster_rcnn/faster_rcnn_tf.py diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000..5ae0890a8b --- /dev/null +++ b/demo.py @@ -0,0 +1,47 @@ +import keras +import numpy as np + +import keras_cv + +# Note: We absolutely need this while creating the Backbone +batch_size = 32 +image_shape = (512, 512, 3) + +images = np.ones((batch_size,) + image_shape) +labels = { + "boxes": np.array( + [ + [ + [0, 0, 100, 100], + [100, 100, 200, 200], + [300, 300, 100, 100], + ] + ], + dtype=np.float32, + ), + "classes": np.array([[1, 1, 1]], dtype=np.float32), +} +model = keras_cv.models.FasterRCNN( + batch_size=batch_size, + num_classes=20, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet50Backbone.from_preset( + "resnet50_imagenet", + input_shape=image_shape, + ), +) + +# # Evaluate model without box decoding and NMS +# model(images) + +# # Prediction with box decoding and NMS +# model.predict(images) + +# # Train model +# model.compile( +# classification_loss="focal", +# box_loss="smoothl1", +# optimizer=keras.optimizers.SGD(global_clipnorm=10.0), +# jit_compile=False, +# ) +# model.fit(images, labels) diff --git a/keras_cv/bounding_box/utils.py b/keras_cv/bounding_box/utils.py index fd85f2b893..dc0978a259 100644 --- a/keras_cv/bounding_box/utils.py +++ b/keras_cv/bounding_box/utils.py @@ -141,7 +141,7 @@ def _clip_boxes(boxes, box_format, image_shape): if isinstance(image_shape, list) or isinstance(image_shape, tuple): height, width, _ = image_shape - max_length = [height, width, height, width] + max_length = ops.stack([height, width, height, width], axis=-1) else: image_shape = ops.cast(image_shape, dtype=boxes.dtype) height = image_shape[0] diff --git a/keras_cv/layers/object_detection/roi_align.py b/keras_cv/layers/object_detection/roi_align.py index bf2b35ad9e..25b698e29b 100644 --- a/keras_cv/layers/object_detection/roi_align.py +++ b/keras_cv/layers/object_detection/roi_align.py @@ -243,9 +243,11 @@ def multilevel_crop_and_resize( for i in range(len(feature_widths) - 1): level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i]) batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1] + level_dim_offsets = ops.convert_to_tensor(level_dim_offsets) level_dim_offsets = ( ops.ones_like(level_dim_offsets, dtype="int32") * level_dim_offsets ) + feature_widths = ops.convert_to_tensor(feature_widths) height_dim_sizes = ( ops.ones_like(feature_widths, dtype="int32") * feature_widths ) @@ -271,7 +273,7 @@ def multilevel_crop_and_resize( # Projects box location and sizes to corresponding feature levels. scale_to_level = ops.cast( - ops.pow(2.0, ops.cast(levels, "float32")), + ops.power(2.0, ops.cast(levels, "float32")), dtype=boxes.dtype, ) boxes /= ops.expand_dims(scale_to_level, axis=2) @@ -288,7 +290,9 @@ def multilevel_crop_and_resize( # Maps levels to [0, max_level-min_level]. levels -= min_level - level_strides = ops.pow([[2.0]], ops.cast(levels, "float32")) + level_strides = ops.power([[2.0]], ops.cast(levels, "float32")) + print(f"{max_feature_height=}") + print(f"{level_strides=}") boundary = ops.cast( ops.concatenate( [ diff --git a/keras_cv/layers/object_detection/roi_generator.py b/keras_cv/layers/object_detection/roi_generator.py index da99dc080f..9cb4b75326 100644 --- a/keras_cv/layers/object_detection/roi_generator.py +++ b/keras_cv/layers/object_detection/roi_generator.py @@ -145,7 +145,7 @@ def per_level_gen(boxes, scores): # If so, remove the last dimension to make it 2D if len(scores_shape) == 3: scores = ops.squeeze(scores, axis=-1) - _, num_boxes = scores_shape + num_boxes = scores_shape[1] level_pre_nms_topk = min(num_boxes, pre_nms_topk) level_post_nms_topk = min(num_boxes, post_nms_topk) scores, sorted_indices = ops.top_k( diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 2456d4b9f3..5f95177e51 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -12,44 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - -import tree -from absl import logging - -from keras_cv import bounding_box -from keras_cv import layers as cv_layers from keras_cv import models from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.bounding_box.converters import _decode_deltas_to_boxes - -# from keras_cv.models.backbones.backbone_presets import backbone_presets -# from keras_cv.models.backbones.backbone_presets import ( -# backbone_presets_with_weights -# ) from keras_cv.bounding_box.utils import _clip_boxes from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator from keras_cv.layers.object_detection.box_matcher import BoxMatcher -from keras_cv.layers.object_detection.roi_align import _ROIAligner from keras_cv.layers.object_detection.roi_generator import ROIGenerator +from keras_cv.layers.object_detection.roi_pool import ROIPooler from keras_cv.layers.object_detection.roi_sampler import _ROISampler -from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder -from keras_cv.models.backbones.backbone_presets import backbone_presets -from keras_cv.models.backbones.backbone_presets import ( - backbone_presets_with_weights, -) from keras_cv.models.object_detection.__internal__ import unpack_input from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid -from keras_cv.models.object_detection.faster_rcnn import RCNNHead from keras_cv.models.object_detection.faster_rcnn import RPNHead - -# from keras_cv.models.object_detection.faster_rcnn.faster_rcnn_presets import ( -# faster_rcnn_presets -# ) from keras_cv.models.task import Task -from keras_cv.utils.python_utils import classproperty from keras_cv.utils.train import get_feature_extractor BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] @@ -58,98 +35,59 @@ # TODO(tanzheny): add more configurations @keras_cv_export("keras_cv.models.FasterRCNN") class FasterRCNN(Task): - """A Keras model implementing the FasterRCNN architecture. - - Implements the FasterRCNN architecture for object detection. The constructor - requires `backbone`, `num_classes`, and a `bounding_box_format`. - - References: - - [FasterRCNN](https://arxiv.org/abs/1506.01497) - - Args: - backbone: `keras.Model`. Must implement the - `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" - and layer names as values. - num_classes: the number of classes in your dataset excluding the - background class. classes should be represented by integers in the - range [0, num_classes). - bounding_box_format: The format of bounding boxes of model output. Refer - [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) - for more details on supported bounding box formats. - anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. It is - used in the model to match ground truth boxes and labels with - anchors, or with region proposals. By default it uses the sizes and - ratios from the paper, that is optimized for image size between - [640, 800]. The users should pass their own anchor generator if the - input image size differs from paper. For now, only anchor generator - with per level dict output is supported, - label_encoder: (Optional) a keras.Layer that accepts an anchors Tensor, - a bounding box Tensor and a bounding box class Tensor to its - `call()` method, and returns RetinaNet training targets. It returns - box and class targets as well as sample weights. - rcnn_head: (Optional) a `keras.layers.Layer` that takes input feature - map and returns a box delta prediction (in reference to rois) and - multi-class prediction (all foreground classes + one background - class). By default it uses the rcnn head from paper, which is 2 FC - layer with 1024 dimension, 1 box regressor and 1 softmax classifier. - prediction_decoder: (Optional) a `keras.layers.Layer` that takes input - box prediction and softmaxed score prediction, and returns NMSed box - prediction, NMSed softmaxed score prediction, NMSed class - prediction, and NMSed valid detection. - - Examples: - - ```python - images = np.ones((1, 512, 512, 3)) - labels = { - "boxes": [ - [ - [0, 0, 100, 100], - [100, 100, 200, 200], - [300, 300, 100, 100], - ] - ], - "classes": [[1, 1, 1]], - } - model = keras_cv.models.FasterRCNN( - num_classes=20, - bounding_box_format="xywh", - backbone=keras_cv.models.ResNet50Backbone.from_preset( - "resnet50_imagenet" - ) - ) - - # Evaluate model without box decoding and NMS - model(images) - - # Prediction with box decoding and NMS - model.predict(images) - - # Train model - model.compile( - classification_loss='focal', - box_loss='smoothl1', - optimizer=keras.optimizers.SGD(global_clipnorm=10.0), - jit_compile=False, - ) - model.fit(images, labels) - ``` - """ # noqa: E501 - def __init__( self, + batch_size, backbone, num_classes, bounding_box_format, anchor_generator=None, - label_encoder=None, - rcnn_head=None, - prediction_decoder=None, feature_pyramid=None, + *args, **kwargs, ): + + # Create the Input Layer + extractor_levels = ["P2", "P3", "P4", "P5"] + extractor_layer_names = [ + backbone.pyramid_level_inputs[i] for i in extractor_levels + ] + feature_extractor = get_feature_extractor( + backbone, extractor_layer_names, extractor_levels + ) + feature_pyramid = feature_pyramid or FeaturePyramid() + image_shape = feature_extractor.input_shape[1:] # excule the batch size + images = keras.layers.Input( + image_shape, batch_size=batch_size, name="images" + ) + print(f"{image_shape=}") + print(f"{images.shape=}") + + # Get the backbone outputs + backbone_outputs = feature_extractor(images) + feature_map = feature_pyramid(backbone_outputs) + print("backbone_outputs") + for key, value in backbone_outputs.items(): + print(f"\t{key}: {value.shape}") + print("feature_map") + for key, value in feature_map.items(): + print(f"\t{key}: {value.shape}") + + # Get the Region Proposal Boxes and Scores scales = [2**x for x in [0]] aspect_ratios = [0.5, 1.0, 2.0] + num_anchors_per_location = len(scales) * len(aspect_ratios) + rpn_head = RPNHead(num_anchors_per_location=num_anchors_per_location) + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_boxes, rpn_scores = rpn_head(feature_map) + print("rpn_boxes") + for key, value in rpn_boxes.items(): + print(f"\t{key}: {value.shape}") + print("rpn_scores") + for key, value in rpn_scores.items(): + print(f"\t{key}: {value.shape}") + + # Create the anchors anchor_generator = anchor_generator or AnchorGenerator( bounding_box_format=bounding_box_format, sizes={ @@ -164,37 +102,14 @@ def __init__( strides={f"P{i}": 2**i for i in range(2, 7)}, clip_boxes=True, ) - rpn_head = RPNHead( - num_anchors_per_location=len(scales) * len(aspect_ratios) - ) - roi_generator = ROIGenerator( - bounding_box_format=bounding_box_format, - nms_score_threshold_train=float("-inf"), - nms_score_threshold_test=float("-inf"), - ) - roi_pooler = _ROIAligner(bounding_box_format=bounding_box_format) - rcnn_head = rcnn_head or RCNNHead(num_classes) - backbone = backbone or models.ResNet50Backbone() - extractor_levels = ["P2", "P3", "P4", "P5"] - extractor_layer_names = [ - backbone.pyramid_level_inputs[i] for i in extractor_levels - ] - feature_extractor = get_feature_extractor( - backbone, extractor_layer_names, extractor_levels - ) - feature_pyramid = feature_pyramid or FeaturePyramid() - - # Begin construction of forward pass - image_shape = feature_extractor.input_shape[1:] - images = keras.layers.Input(image_shape, name="images") + # Note: `image_shape` should not be of NoneType + # Need to assert before this line anchors = anchor_generator(image_shape=image_shape) + print("anchors") + for key, value in anchors.items(): + print(f"\t{key}: {value.shape}") - # Calling the RPN block - backbone_outputs = feature_extractor(images) - feature_map = feature_pyramid(backbone_outputs) - # [BS, num_anchors, 4], [BS, num_anchors, 1] - rpn_boxes, rpn_scores = rpn_head(feature_map) - # the decoded format is center_xywh, convert to yxyx + # decode the deltas to boxes decoded_rpn_boxes = _decode_deltas_to_boxes( anchors=anchors, boxes_delta=rpn_boxes, @@ -202,379 +117,86 @@ def __init__( box_format=bounding_box_format, variance=BOX_VARIANCE, ) - rois, _ = roi_generator(decoded_rpn_boxes, rpn_scores) - rois = _clip_boxes(rois, bounding_box_format, image_shape) - - # Calling the RCNN block - feature_map = roi_pooler(feature_map, rois) - # [BS, H*W*K, pool_shape*C] - feature_map = keras.layers.Reshape( - target_shape=(ops.shape(rois)[1], -1) - )(feature_map) - # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] - rcnn_box_pred, rcnn_cls_pred = rcnn_head( - feature_map, - ) - - inputs = {"images": images} - outputs = {"box": rcnn_box_pred, "classification": rcnn_cls_pred} + print("decoded_rpn_boxes") + for key, value in decoded_rpn_boxes.items(): + print(f"\t{key}: {value.shape}") - super().__init__(inputs=inputs, outputs=outputs, **kwargs) - - self.num_classes = num_classes - self.bounding_box_format = bounding_box_format - self.anchor_generator = anchor_generator - self.rpn_head = rpn_head - self.roi_generator = roi_generator - self.box_matcher = BoxMatcher( - thresholds=[0.0, 0.5], match_values=[-2, -1, 1] - ) - self.roi_sampler = _ROISampler( - bounding_box_format="yxyx", - roi_matcher=self.box_matcher, - background_class=num_classes, - num_sampled_rois=512, - ) - self.roi_pooler = roi_pooler - self.rcnn_head = rcnn_head - self.backbone = backbone - self.feature_extractor = feature_extractor - self.feature_pyramid = feature_pyramid - self.rpn_labeler = label_encoder or _RpnLabelEncoder( - anchor_format="yxyx", - ground_truth_box_format="yxyx", - positive_threshold=0.7, - negative_threshold=0.3, - samples_per_image=256, - positive_fraction=0.5, - box_variance=BOX_VARIANCE, - ) - self._prediction_decoder = ( - prediction_decoder - or cv_layers.NonMaxSuppression( - bounding_box_format=bounding_box_format, - from_logits=False, - iou_threshold=0.5, - confidence_threshold=0.5, - max_detections=100, - ) - ) - - def _call_rpn(self, images, anchors, training=None): - image_shape = ops.shape(images[0]) - backbone_outputs = self.feature_extractor(images, training=training) - feature_map = self.feature_pyramid(backbone_outputs, training=training) - # [BS, num_anchors, 4], [BS, num_anchors, 1] - rpn_boxes, rpn_scores = self.rpn_head(feature_map, training=training) - # the decoded format is center_xywh, convert to yxyx - decoded_rpn_boxes = _decode_deltas_to_boxes( - anchors=anchors, - boxes_delta=rpn_boxes, - anchor_format="yxyx", - box_format="yxyx", - variance=BOX_VARIANCE, - ) - rois, _ = self.roi_generator( - decoded_rpn_boxes, rpn_scores, training=training - ) - rois = _clip_boxes(rois, "yxyx", image_shape) - rpn_boxes = ops.concatenate(tree.flatten(rpn_boxes), axis=1) - rpn_scores = ops.concatenate(tree.flatten(rpn_scores), axis=1) - return rois, feature_map, rpn_boxes, rpn_scores - - def _call_rcnn(self, rois, feature_map, training=None): - feature_map = self.roi_pooler(feature_map, rois) - # [BS, H*W*K, pool_shape*C] - feature_map = ops.reshape( - feature_map, ops.concatenate([ops.shape(rois)[:2], [-1]], axis=0) - ) - # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] - rcnn_box_pred, rcnn_cls_pred = self.rcnn_head( - feature_map, training=training - ) - return rcnn_box_pred, rcnn_cls_pred - - # TODO(tanzhenyu): Support compile with metrics. - def compile( - self, - box_loss=None, - classification_loss=None, - rpn_box_loss=None, - rpn_classification_loss=None, - weight_decay=0.0001, - loss=None, - **kwargs, - ): - # TODO(tanzhenyu): Add metrics support once COCOMap issue is addressed. - # https://github.com/keras-team/keras-cv/issues/915 - if "metrics" in kwargs.keys(): - raise ValueError( - "`FasterRCNN` does not currently support the use of " - "`metrics` due to performance and distribution concerns. " - "Please use the `PyCOCOCallback` to evaluate COCO metrics." - ) - if loss is not None: - raise ValueError( - "`FasterRCNN` does not accept a `loss` to `compile()`. " - "Instead, please pass `box_loss` and `classification_loss`. " - "`loss` will be ignored during training." - ) - box_loss = _validate_and_get_loss(box_loss, "box_loss") - classification_loss = _validate_and_get_loss( - classification_loss, "classification_loss" - ) - rpn_box_loss = _validate_and_get_loss(rpn_box_loss, "rpn_box_loss") - if rpn_classification_loss == "BinaryCrossentropy": - rpn_classification_loss = keras.losses.BinaryCrossentropy( - from_logits=True, reduction=keras.losses.Reduction.SUM - ) - rpn_classification_loss = _validate_and_get_loss( - rpn_classification_loss, "rpn_cls_loss" - ) - if not rpn_classification_loss.from_logits: - raise ValueError( - "`rpn_classification_loss` must come with `from_logits`=True" - ) - - self.rpn_box_loss = rpn_box_loss - self.rpn_cls_loss = rpn_classification_loss - self.box_loss = box_loss - self.cls_loss = classification_loss - self.weight_decay = weight_decay - losses = { - "box": self.box_loss, - "classification": self.cls_loss, - "rpn_box": self.rpn_box_loss, - "rpn_classification": self.rpn_cls_loss, - } - super().compile(loss=losses, **kwargs) - - def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): - images = x - boxes = y["boxes"] - if len(y["classes"].shape) != 2: - raise ValueError( - "Expected 'classes' to be a tf.Tensor of rank 2. " - f"Got y['classes'].shape={y['classes'].shape}." - ) - # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere - classes = ops.expand_dims(y["classes"], axis=-1) - - local_batch = images.get_shape().as_list()[0] - anchors = self.anchor_generator(image_shape=tuple(images[0].shape)) - ( - rpn_box_targets, - rpn_box_weights, - rpn_cls_targets, - rpn_cls_weights, - ) = self.rpn_labeler( - ops.concatenate(tree.flatten(anchors), axis=0), boxes, classes - ) - rpn_box_weights /= ( - self.rpn_labeler.samples_per_image * local_batch * 0.25 - ) - rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch - rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( - images, - anchors, - ) - rois = ops.stop_gradient(rois) - ( - rois, - box_targets, - box_weights, - cls_targets, - cls_weights, - ) = self.roi_sampler(rois, boxes, classes) - - positive_mask = ops.cast( - ops.greater(cls_targets, -1.0), dtype="float32" - ) - normalizer = ops.sum(positive_mask) - - box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 - cls_weights /= self.roi_sampler.num_sampled_rois * local_batch - - box_weights /= normalizer - cls_weights /= normalizer - - box_pred, cls_pred = self._call_rcnn( - rois, - feature_map, - ) - - y_true = { - "rpn_box": rpn_box_targets, - "rpn_classification": rpn_cls_targets, - "box": box_targets, - "classification": cls_targets, - } - y_pred = { - "rpn_box": rpn_box_pred, - "rpn_classification": rpn_cls_pred, - "box": box_pred, - "classification": cls_pred, - } - sample_weights = { - "rpn_box": rpn_box_weights, - "rpn_classification": rpn_cls_weights, - "box": box_weights, - "classification": cls_weights, - } - zero_weights = { - "rpn_box": ops.zeros_like(rpn_box_weights), - "rpn_classification": ops.zeros_like(rpn_cls_weights), - "box": ops.zeros_like(box_weights), - "classification": ops.zeros_like(cls_weights), - } - - sample_weights = ops.cond( - normalizer == 0.0, - lambda: zero_weights, - lambda: sample_weights, - ) - return super().compute_loss( - x=images, y=y_true, y_pred=y_pred, sample_weight=sample_weights - ) - - def train_step(self, *args): - data = args[-1] - args = args[:-1] - x, y = unpack_input(data) - return super().train_step(*args, (x, y)) - - def test_step(self, *args): - data = args[-1] - args = args[:-1] - x, y = unpack_input(data) - return super().test_step(*args, (x, y)) - - def predict_step(self, *args): - outputs = super().predict_step(*args) - if type(outputs) is tuple: - return self.decode_predictions(outputs[0], args[-1]), outputs[1] - else: - return self.decode_predictions(outputs, args[-1]) - - @property - def prediction_decoder(self): - return self._prediction_decoder - - @prediction_decoder.setter - def prediction_decoder(self, prediction_decoder): - self._prediction_decoder = prediction_decoder - self.make_predict_function(force=True) - - def decode_predictions(self, predictions, images): - # no-op if default decoder is used. - box_pred, scores_pred = ( - predictions["box"], - predictions["classification"], - ) - box_pred = bounding_box.convert_format( - box_pred, - source=self.bounding_box_format, - target=self.prediction_decoder.bounding_box_format, - images=images, - ) - y_pred = self.prediction_decoder(box_pred, scores_pred[..., :-1]) - box_pred = bounding_box.convert_format( - y_pred["boxes"], - source=self.prediction_decoder.bounding_box_format, - target=self.bounding_box_format, - images=images, - ) - y_pred["boxes"] = box_pred - return y_pred - - def get_config(self): - return { - "num_classes": self.num_classes, - "bounding_box_format": self.bounding_box_format, - "backbone": keras.saving.serialize_keras_object(self.backbone), - "anchor_generator": keras.saving.serialize_keras_object( - self.anchor_generator - ), - "label_encoder": keras.saving.serialize_keras_object( - self.rpn_labeler - ), - "prediction_decoder": keras.saving.serialize_keras_object( - self._prediction_decoder - ), - "feature_pyramid": keras.saving.serialize_keras_object( - self.feature_pyramid - ), - "rcnn_head": keras.saving.serialize_keras_object(self.rcnn_head), - } - - @classmethod - def from_config(cls, config): - if "rcnn_head" in config and isinstance(config["rcnn_head"], dict): - config["rcnn_head"] = keras.layers.deserialize(config["rcnn_head"]) - if "feature_pyramid" in config and isinstance( - config["feature_pyramid"], dict - ): - config["feature_pyramid"] = keras.layers.deserialize( - config["feature_pyramid"] - ) - if "prediction_decoder" in config and isinstance( - config["prediction_decoder"], dict - ): - config["prediction_decoder"] = keras.layers.deserialize( - config["prediction_decoder"] - ) - if "label_encoder" in config and isinstance( - config["label_encoder"], dict - ): - config["label_encoder"] = keras.layers.deserialize( - config["label_encoder"] - ) - if "anchor_generator" in config and isinstance( - config["anchor_generator"], dict - ): - config["anchor_generator"] = keras.layers.deserialize( - config["anchor_generator"] - ) - if "backbone" in config and isinstance(config["backbone"], dict): - config["backbone"] = keras.layers.deserialize(config["backbone"]) - return super().from_config(config) - - @classproperty - def presets(cls): - """Dictionary of preset names and configurations.""" - # return copy.deepcopy({**backbone_presets, **fasterrcnn_presets}) - return copy.deepcopy({**backbone_presets}) - - @classproperty - def presets_with_weights(cls): - """Dictionary of preset names and configurations that include - weights.""" - return copy.deepcopy( - # {**backbone_presets_with_weights, **fasterrcnn_presets} - { - **backbone_presets_with_weights, - } + # Generate the Region of Interests + roi_generator = ROIGenerator( + bounding_box_format=bounding_box_format, + nms_score_threshold_train=float("-inf"), + nms_score_threshold_test=float("-inf"), ) + rois, _ = roi_generator(decoded_rpn_boxes, rpn_scores) + rois = _clip_boxes(rois, bounding_box_format, image_shape) + print(f"{rois.shape=}") - @classproperty - def backbone_presets(cls): - """Dictionary of preset names and configurations of compatible - backbones.""" - return copy.deepcopy(backbone_presets) - - -def _validate_and_get_loss(loss, loss_name): - if isinstance(loss, str): - loss = keras.losses.get(loss) - if loss is None or not isinstance(loss, keras.losses.Loss): - raise ValueError( - f"FasterRCNN only accepts `keras.losses.Loss` for {loss_name}, " - f"got {loss}" - ) - if loss.reduction != keras.losses.Reduction.SUM: - logging.info( - f"FasterRCNN only accepts `SUM` reduction, got {loss.reduction}, " - "automatically converted." - ) - loss.reduction = keras.losses.Reduction.SUM - return loss + # Using the regions call the rcnn head + roi_pooler = ROIPooler( + bounding_box_format=bounding_box_format, + target_size=[7, 7], + image_shape=image_shape, + ) + feature_map_pooled = dict() + for key, value in feature_map.items(): + feature_map_pooled[key] = roi_pooler(value, rois) + + print("feature_map_pooled") + for key, value in feature_map_pooled.item(): + print(f"{key}: {value.shape}") + + # + # # Create the anchor generator + # scales = [2**x for x in [0]] + # aspect_ratios = [0.5, 1.0, 2.0] + # anchor_generator = anchor_generator or AnchorGenerator( + # bounding_box_format="yxyx", + # sizes={ + # "P2": 32.0, + # "P3": 64.0, + # "P4": 128.0, + # "P5": 256.0, + # "P6": 512.0, + # }, + # scales=scales, + # aspect_ratios=aspect_ratios, + # strides={f"P{i}": 2**i for i in range(2, 7)}, + # clip_boxes=True, + # ) + + # # Create the Region Proposal Network Head + # num_anchors_per_location = len(scales) * len(aspect_ratios) + # rpn_head = RPNHead(num_anchors_per_location=num_anchors_per_location) + + # # Create the Region of Interest Generator + # roi_generator = ROIGenerator( + # bounding_box_format="yxyx", + # nms_score_threshold_train=float("-inf"), + # nms_score_threshold_test=float("-inf"), + # ) + + # # Create the Box Matcher + # box_matcher = BoxMatcher( + # thresholds=[0.0, 0.5], match_values=[-2, -1, 1] + # ) + + # # Create the Region of Interest Sampler + + # images = None + # box_pred = None + # class_pred = None + # inputs = {"images": images} + # outputs = {"box": box_pred, "classification": class_pred} + # super().__init__(inputs=inputs, outputs=outputs, *args, **kwargs) + + # def train_step(self, *args): + # data = args[-1] + # args = args[:-1] + # x, y = unpack_input(data) + # return super().train_step(*args, (x, y)) + + # def test_step(self, *args): + # data = args[-1] + # args = args[:-1] + # x, y = unpack_input(data) + # return super().test_step(*args, (x, y)) diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_port.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_port.py new file mode 100644 index 0000000000..521902f348 --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_port.py @@ -0,0 +1,567 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import tree +from absl import logging + +from keras_cv import bounding_box +from keras_cv import layers as cv_layers +from keras_cv import models +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.bounding_box.converters import _decode_deltas_to_boxes +from keras_cv.bounding_box.utils import _clip_boxes +from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator +from keras_cv.layers.object_detection.box_matcher import BoxMatcher +from keras_cv.layers.object_detection.roi_align import _ROIAligner +from keras_cv.layers.object_detection.roi_generator import ROIGenerator +from keras_cv.layers.object_detection.roi_sampler import _ROISampler +from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder +from keras_cv.models.object_detection.__internal__ import unpack_input +from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid +from keras_cv.models.object_detection.faster_rcnn import RCNNHead +from keras_cv.models.object_detection.faster_rcnn import RPNHead +from keras_cv.models.task import Task +from keras_cv.utils.python_utils import classproperty +from keras_cv.utils.train import get_feature_extractor + +BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] + + +# TODO(tanzheny): add more configurations +@keras_cv_export("keras_cv.models.FasterRCNN") +class FasterRCNN(Task): + """A Keras model implementing the FasterRCNN architecture. + + Implements the FasterRCNN architecture for object detection. The constructor + requires `backbone`, `num_classes`, and a `bounding_box_format`. + + References: + - [FasterRCNN](https://arxiv.org/abs/1506.01497) + + Args: + backbone: `keras.Model`. Must implement the + `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" + and layer names as values. + num_classes: the number of classes in your dataset excluding the + background class. classes should be represented by integers in the + range [0, num_classes). + bounding_box_format: The format of bounding boxes of model output. Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. It is + used in the model to match ground truth boxes and labels with + anchors, or with region proposals. By default it uses the sizes and + ratios from the paper, that is optimized for image size between + [640, 800]. The users should pass their own anchor generator if the + input image size differs from paper. For now, only anchor generator + with per level dict output is supported, + label_encoder: (Optional) a keras.Layer that accepts an anchors Tensor, + a bounding box Tensor and a bounding box class Tensor to its + `call()` method, and returns RetinaNet training targets. It returns + box and class targets as well as sample weights. + rcnn_head: (Optional) a `keras.layers.Layer` that takes input feature + map and returns a box delta prediction (in reference to rois) and + multi-class prediction (all foreground classes + one background + class). By default it uses the rcnn head from paper, which is 2 FC + layer with 1024 dimension, 1 box regressor and 1 softmax classifier. + prediction_decoder: (Optional) a `keras.layers.Layer` that takes input + box prediction and softmaxed score prediction, and returns NMSed box + prediction, NMSed softmaxed score prediction, NMSed class + prediction, and NMSed valid detection. + + Examples: + + ```python + images = np.ones((1, 512, 512, 3)) + labels = { + "boxes": [ + [ + [0, 0, 100, 100], + [100, 100, 200, 200], + [300, 300, 100, 100], + ] + ], + "classes": [[1, 1, 1]], + } + model = keras_cv.models.FasterRCNN( + num_classes=20, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet50Backbone.from_preset( + "resnet50_imagenet" + ) + ) + + # Evaluate model without box decoding and NMS + model(images) + + # Prediction with box decoding and NMS + model.predict(images) + + # Train model + model.compile( + classification_loss='focal', + box_loss='smoothl1', + optimizer=keras.optimizers.SGD(global_clipnorm=10.0), + jit_compile=False, + ) + model.fit(images, labels) + ``` + """ # noqa: E501 + + def __init__( + self, + backbone, + num_classes, + bounding_box_format, + anchor_generator=None, + label_encoder=None, + rcnn_head=None, + prediction_decoder=None, + feature_pyramid=None, + **kwargs, + ): + scales = [2**x for x in [0]] + aspect_ratios = [0.5, 1.0, 2.0] + anchor_generator = anchor_generator or AnchorGenerator( + bounding_box_format=bounding_box_format, + sizes={ + "P2": 32.0, + "P3": 64.0, + "P4": 128.0, + "P5": 256.0, + "P6": 512.0, + }, + scales=scales, + aspect_ratios=aspect_ratios, + strides={f"P{i}": 2**i for i in range(2, 7)}, + clip_boxes=True, + ) + rpn_head = RPNHead( + num_anchors_per_location=len(scales) * len(aspect_ratios) + ) + roi_generator = ROIGenerator( + bounding_box_format=bounding_box_format, + nms_score_threshold_train=float("-inf"), + nms_score_threshold_test=float("-inf"), + ) + roi_pooler = _ROIAligner(bounding_box_format=bounding_box_format) + rcnn_head = rcnn_head or RCNNHead(num_classes) + backbone = backbone or models.ResNet50Backbone() + extractor_levels = ["P2", "P3", "P4", "P5"] + extractor_layer_names = [ + backbone.pyramid_level_inputs[i] for i in extractor_levels + ] + feature_extractor = get_feature_extractor( + backbone, extractor_layer_names, extractor_levels + ) + feature_pyramid = feature_pyramid or FeaturePyramid() + + # Begin construction of forward pass + image_shape = feature_extractor.input_shape[1:] + images = keras.layers.Input(image_shape, name="images") + anchors = anchor_generator(image_shape=image_shape) + + # Calling the RPN block + backbone_outputs = feature_extractor(images) + feature_map = feature_pyramid(backbone_outputs) + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_boxes, rpn_scores = rpn_head(feature_map) + # the decoded format is center_xywh, convert to yxyx + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format=bounding_box_format, + box_format=bounding_box_format, + variance=BOX_VARIANCE, + ) + rois, _ = roi_generator(decoded_rpn_boxes, rpn_scores) + rois = _clip_boxes(rois, bounding_box_format, image_shape) + + # Calling the RCNN block + feature_map = roi_pooler(feature_map, rois) + # [BS, H*W*K, pool_shape*C] + feature_map = keras.layers.Reshape( + target_shape=(ops.shape(rois)[1], -1) + )(feature_map) + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + rcnn_box_pred, rcnn_cls_pred = rcnn_head( + feature_map, + ) + + inputs = {"images": images} + outputs = {"box": rcnn_box_pred, "classification": rcnn_cls_pred} + + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + self.num_classes = num_classes + self.bounding_box_format = bounding_box_format + self.anchor_generator = anchor_generator + self.rpn_head = rpn_head + self.roi_generator = roi_generator + self.box_matcher = BoxMatcher( + thresholds=[0.0, 0.5], match_values=[-2, -1, 1] + ) + self.roi_sampler = _ROISampler( + bounding_box_format="yxyx", + roi_matcher=self.box_matcher, + background_class=num_classes, + num_sampled_rois=512, + ) + self.roi_pooler = roi_pooler + self.rcnn_head = rcnn_head + self.backbone = backbone + self.feature_extractor = feature_extractor + self.feature_pyramid = feature_pyramid + self.rpn_labeler = label_encoder or _RpnLabelEncoder( + anchor_format="yxyx", + ground_truth_box_format="yxyx", + positive_threshold=0.7, + negative_threshold=0.3, + samples_per_image=256, + positive_fraction=0.5, + box_variance=BOX_VARIANCE, + ) + self._prediction_decoder = ( + prediction_decoder + or cv_layers.NonMaxSuppression( + bounding_box_format=bounding_box_format, + from_logits=False, + iou_threshold=0.5, + confidence_threshold=0.5, + max_detections=100, + ) + ) + + def _call_rpn(self, images, anchors, training=None): + image_shape = ops.shape(images[0]) + backbone_outputs = self.feature_extractor(images, training=training) + feature_map = self.feature_pyramid(backbone_outputs, training=training) + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_boxes, rpn_scores = self.rpn_head(feature_map, training=training) + # the decoded format is center_xywh, convert to yxyx + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + rois, _ = self.roi_generator( + decoded_rpn_boxes, rpn_scores, training=training + ) + rois = _clip_boxes(rois, "yxyx", image_shape) + rpn_boxes = ops.concatenate(tree.flatten(rpn_boxes), axis=1) + rpn_scores = ops.concatenate(tree.flatten(rpn_scores), axis=1) + return rois, feature_map, rpn_boxes, rpn_scores + + def _call_rcnn(self, rois, feature_map, training=None): + feature_map = self.roi_pooler(feature_map, rois) + # [BS, H*W*K, pool_shape*C] + feature_map = ops.reshape( + feature_map, ops.concatenate([ops.shape(rois)[:2], [-1]], axis=0) + ) + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + rcnn_box_pred, rcnn_cls_pred = self.rcnn_head( + feature_map, training=training + ) + return rcnn_box_pred, rcnn_cls_pred + + # TODO(tanzhenyu): Support compile with metrics. + def compile( + self, + box_loss=None, + classification_loss=None, + rpn_box_loss=None, + rpn_classification_loss=None, + weight_decay=0.0001, + loss=None, + **kwargs, + ): + # TODO(tanzhenyu): Add metrics support once COCOMap issue is addressed. + # https://github.com/keras-team/keras-cv/issues/915 + if "metrics" in kwargs.keys(): + raise ValueError( + "`FasterRCNN` does not currently support the use of " + "`metrics` due to performance and distribution concerns. " + "Please use the `PyCOCOCallback` to evaluate COCO metrics." + ) + if loss is not None: + raise ValueError( + "`FasterRCNN` does not accept a `loss` to `compile()`. " + "Instead, please pass `box_loss` and `classification_loss`. " + "`loss` will be ignored during training." + ) + box_loss = _validate_and_get_loss(box_loss, "box_loss") + classification_loss = _validate_and_get_loss( + classification_loss, "classification_loss" + ) + rpn_box_loss = _validate_and_get_loss(rpn_box_loss, "rpn_box_loss") + if rpn_classification_loss == "BinaryCrossentropy": + rpn_classification_loss = keras.losses.BinaryCrossentropy( + from_logits=True, reduction=keras.losses.Reduction.SUM + ) + rpn_classification_loss = _validate_and_get_loss( + rpn_classification_loss, "rpn_cls_loss" + ) + if not rpn_classification_loss.from_logits: + raise ValueError( + "`rpn_classification_loss` must come with `from_logits`=True" + ) + + self.rpn_box_loss = rpn_box_loss + self.rpn_cls_loss = rpn_classification_loss + self.box_loss = box_loss + self.cls_loss = classification_loss + self.weight_decay = weight_decay + losses = { + "box": self.box_loss, + "classification": self.cls_loss, + "rpn_box": self.rpn_box_loss, + "rpn_classification": self.rpn_cls_loss, + } + super().compile(loss=losses, **kwargs) + + def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): + images = x + boxes = y["boxes"] + if len(y["classes"].shape) != 2: + raise ValueError( + "Expected 'classes' to be a tf.Tensor of rank 2. " + f"Got y['classes'].shape={y['classes'].shape}." + ) + # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere + classes = ops.expand_dims(y["classes"], axis=-1) + + local_batch = images.get_shape().as_list()[0] + anchors = self.anchor_generator(image_shape=tuple(images[0].shape)) + ( + rpn_box_targets, + rpn_box_weights, + rpn_cls_targets, + rpn_cls_weights, + ) = self.rpn_labeler( + ops.concatenate(tree.flatten(anchors), axis=0), boxes, classes + ) + rpn_box_weights /= ( + self.rpn_labeler.samples_per_image * local_batch * 0.25 + ) + rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch + rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( + images, + anchors, + ) + rois = ops.stop_gradient(rois) + ( + rois, + box_targets, + box_weights, + cls_targets, + cls_weights, + ) = self.roi_sampler(rois, boxes, classes) + + positive_mask = ops.cast( + ops.greater(cls_targets, -1.0), dtype="float32" + ) + normalizer = ops.sum(positive_mask) + + box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 + cls_weights /= self.roi_sampler.num_sampled_rois * local_batch + + box_weights /= normalizer + cls_weights /= normalizer + + box_pred, cls_pred = self._call_rcnn( + rois, + feature_map, + ) + + y_true = { + "rpn_box": rpn_box_targets, + "rpn_classification": rpn_cls_targets, + "box": box_targets, + "classification": cls_targets, + } + y_pred = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + } + sample_weights = { + "rpn_box": rpn_box_weights, + "rpn_classification": rpn_cls_weights, + "box": box_weights, + "classification": cls_weights, + } + zero_weights = { + "rpn_box": ops.zeros_like(rpn_box_weights), + "rpn_classification": ops.zeros_like(rpn_cls_weights), + "box": ops.zeros_like(box_weights), + "classification": ops.zeros_like(cls_weights), + } + + sample_weights = ops.cond( + normalizer == 0.0, + lambda: zero_weights, + lambda: sample_weights, + ) + return super().compute_loss( + x=images, y=y_true, y_pred=y_pred, sample_weight=sample_weights + ) + + def train_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().train_step(*args, (x, y)) + + def test_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().test_step(*args, (x, y)) + + def predict_step(self, *args): + outputs = super().predict_step(*args) + if type(outputs) is tuple: + return self.decode_predictions(outputs[0], args[-1]), outputs[1] + else: + return self.decode_predictions(outputs, args[-1]) + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + + def decode_predictions(self, predictions, images): + # no-op if default decoder is used. + box_pred, scores_pred = ( + predictions["box"], + predictions["classification"], + ) + box_pred = bounding_box.convert_format( + box_pred, + source=self.bounding_box_format, + target=self.prediction_decoder.bounding_box_format, + images=images, + ) + y_pred = self.prediction_decoder(box_pred, scores_pred[..., :-1]) + box_pred = bounding_box.convert_format( + y_pred["boxes"], + source=self.prediction_decoder.bounding_box_format, + target=self.bounding_box_format, + images=images, + ) + y_pred["boxes"] = box_pred + return y_pred + + def get_config(self): + return { + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "backbone": keras.saving.serialize_keras_object(self.backbone), + "anchor_generator": keras.saving.serialize_keras_object( + self.anchor_generator + ), + "label_encoder": keras.saving.serialize_keras_object( + self.rpn_labeler + ), + "prediction_decoder": keras.saving.serialize_keras_object( + self._prediction_decoder + ), + "feature_pyramid": keras.saving.serialize_keras_object( + self.feature_pyramid + ), + "rcnn_head": keras.saving.serialize_keras_object(self.rcnn_head), + } + + @classmethod + def from_config(cls, config): + if "rcnn_head" in config and isinstance(config["rcnn_head"], dict): + config["rcnn_head"] = keras.layers.deserialize(config["rcnn_head"]) + if "feature_pyramid" in config and isinstance( + config["feature_pyramid"], dict + ): + config["feature_pyramid"] = keras.layers.deserialize( + config["feature_pyramid"] + ) + if "prediction_decoder" in config and isinstance( + config["prediction_decoder"], dict + ): + config["prediction_decoder"] = keras.layers.deserialize( + config["prediction_decoder"] + ) + if "label_encoder" in config and isinstance( + config["label_encoder"], dict + ): + config["label_encoder"] = keras.layers.deserialize( + config["label_encoder"] + ) + if "anchor_generator" in config and isinstance( + config["anchor_generator"], dict + ): + config["anchor_generator"] = keras.layers.deserialize( + config["anchor_generator"] + ) + if "backbone" in config and isinstance(config["backbone"], dict): + config["backbone"] = keras.layers.deserialize(config["backbone"]) + return super().from_config(config) + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + # return copy.deepcopy({**backbone_presets, **fasterrcnn_presets}) + return copy.deepcopy({**backbone_presets}) + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return copy.deepcopy( + # {**backbone_presets_with_weights, **fasterrcnn_presets} + { + **backbone_presets_with_weights, + } + ) + + @classproperty + def backbone_presets(cls): + """Dictionary of preset names and configurations of compatible + backbones.""" + return copy.deepcopy(backbone_presets) + + +def _validate_and_get_loss(loss, loss_name): + if isinstance(loss, str): + loss = keras.losses.get(loss) + if loss is None or not isinstance(loss, keras.losses.Loss): + raise ValueError( + f"FasterRCNN only accepts `keras.losses.Loss` for {loss_name}, " + f"got {loss}" + ) + if loss.reduction != keras.losses.Reduction.SUM: + logging.info( + f"FasterRCNN only accepts `SUM` reduction, got {loss.reduction}, " + "automatically converted." + ) + loss.reduction = keras.losses.Reduction.SUM + return loss diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_tf.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_tf.py new file mode 100644 index 0000000000..fa898cb564 --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_tf.py @@ -0,0 +1,394 @@ +class FasterRCNN(keras.Model): + """A Keras model implementing the FasterRCNN architecture. + + Implements the FasterRCNN architecture for object detection. The constructor + requires `num_classes`, `bounding_box_format` and a `backbone`. + + References: + - [FasterRCNN](https://arxiv.org/pdf/1506.01497.pdf) + + Usage: + ```python + retinanet = keras_cv.models.FasterRCNN( + num_classes=20, + bounding_box_format="xywh", + backbone=None, + ) + ``` + + Args: + num_classes: the number of classes in your dataset excluding the + background class. classes should be represented by integers in the + range [0, num_classes). + bounding_box_format: The format of bounding boxes of model output. Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + backbone: Optional `keras.Model`. Must implement the + `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" + and layer names as values. If `None`, defaults to + `keras_cv.models.ResNet50Backbone()`. + anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. It is + used in the model to match ground truth boxes and labels with + anchors, or with region proposals. By default it uses the sizes and + ratios from the paper, that is optimized for image size between + [640, 800]. The users should pass their own anchor generator if the + input image size differs from paper. For now, only anchor generator + with per level dict output is supported, + label_encoder: (Optional) a keras.Layer that accepts an anchors Tensor, + a bounding box Tensor and a bounding box class Tensor to its + `call()` method, and returns RetinaNet training targets. It returns + box and class targets as well as sample weights. + rcnn_head: (Optional) a `keras.layers.Layer` that takes input feature + map and returns a box delta prediction (in reference to rois) and + multi-class prediction (all foreground classes + one background + class). By default it uses the rcnn head from paper, which is 2 FC + layer with 1024 dimension, 1 box regressor and 1 softmax classifier. + prediction_decoder: (Optional) a `keras.layers.Layer` that takes input + box prediction and softmaxed score prediction, and returns NMSed box + prediction, NMSed softmaxed score prediction, NMSed class + prediction, and NMSed valid detection. + """ # noqa: E501 + + def __init__( + self, + num_classes, + bounding_box_format, + backbone=None, + anchor_generator=None, + label_encoder=None, + rcnn_head=None, + prediction_decoder=None, + **kwargs, + ): + self.bounding_box_format = bounding_box_format + super().__init__(**kwargs) + scales = [2**x for x in [0]] + aspect_ratios = [0.5, 1.0, 2.0] + self.anchor_generator = anchor_generator or AnchorGenerator( + bounding_box_format="yxyx", + sizes={ + "P2": 32.0, + "P3": 64.0, + "P4": 128.0, + "P5": 256.0, + "P6": 512.0, + }, + scales=scales, + aspect_ratios=aspect_ratios, + strides={f"P{i}": 2**i for i in range(2, 7)}, + clip_boxes=True, + ) + self.rpn_head = RPNHead( + num_anchors_per_location=len(scales) * len(aspect_ratios) + ) + self.roi_generator = ROIGenerator( + bounding_box_format="yxyx", + nms_score_threshold_train=float("-inf"), + nms_score_threshold_test=float("-inf"), + ) + self.box_matcher = BoxMatcher( + thresholds=[0.0, 0.5], match_values=[-2, -1, 1] + ) + self.roi_sampler = _ROISampler( + bounding_box_format="yxyx", + roi_matcher=self.box_matcher, + background_class=num_classes, + num_sampled_rois=512, + ) + self.roi_pooler = _ROIAligner(bounding_box_format="yxyx") + self.rcnn_head = rcnn_head or RCNNHead(num_classes) + self.backbone = backbone or models.ResNet50Backbone() + extractor_levels = ["P2", "P3", "P4", "P5"] + extractor_layer_names = [ + self.backbone.pyramid_level_inputs[i] for i in extractor_levels + ] + self.feature_extractor = get_feature_extractor( + self.backbone, extractor_layer_names, extractor_levels + ) + self.feature_pyramid = FeaturePyramid() + self.rpn_labeler = label_encoder or _RpnLabelEncoder( + anchor_format="yxyx", + ground_truth_box_format="yxyx", + positive_threshold=0.7, + negative_threshold=0.3, + samples_per_image=256, + positive_fraction=0.5, + box_variance=BOX_VARIANCE, + ) + self._prediction_decoder = ( + prediction_decoder + or cv_layers.MultiClassNonMaxSuppression( + bounding_box_format=bounding_box_format, + from_logits=False, + max_detections_per_class=10, + max_detections=10, + ) + ) + + def _call_rpn(self, images, anchors, training=None): + image_shape = tf.shape(images[0]) + backbone_outputs = self.feature_extractor(images, training=training) + feature_map = self.feature_pyramid(backbone_outputs, training=training) + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_boxes, rpn_scores = self.rpn_head(feature_map, training=training) + # the decoded format is center_xywh, convert to yxyx + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + rois, _ = self.roi_generator( + decoded_rpn_boxes, rpn_scores, training=training + ) + rois = _clip_boxes(rois, "yxyx", image_shape) + rpn_boxes = tf.concat(tf.nest.flatten(rpn_boxes), axis=1) + rpn_scores = tf.concat(tf.nest.flatten(rpn_scores), axis=1) + return rois, feature_map, rpn_boxes, rpn_scores + + def _call_rcnn(self, rois, feature_map, training=None): + feature_map = self.roi_pooler(feature_map, rois) + # [BS, H*W*K, pool_shape*C] + feature_map = tf.reshape( + feature_map, tf.concat([tf.shape(rois)[:2], [-1]], axis=0) + ) + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + rcnn_box_pred, rcnn_cls_pred = self.rcnn_head( + feature_map, training=training + ) + return rcnn_box_pred, rcnn_cls_pred + + def call(self, images, training=None): + image_shape = tf.shape(images[0]) + anchors = self.anchor_generator(image_shape=image_shape) + rois, feature_map, _, _ = self._call_rpn( + images, anchors, training=training + ) + box_pred, cls_pred = self._call_rcnn( + rois, feature_map, training=training + ) + if not training: + # box_pred is on "center_yxhw" format, convert to target format. + box_pred = _decode_deltas_to_boxes( + anchors=rois, + boxes_delta=box_pred, + anchor_format="yxyx", + box_format=self.bounding_box_format, + variance=[0.1, 0.1, 0.2, 0.2], + ) + + return box_pred, cls_pred + + # TODO(tanzhenyu): Support compile with metrics. + def compile( + self, + box_loss=None, + classification_loss=None, + rpn_box_loss=None, + rpn_classification_loss=None, + weight_decay=0.0001, + loss=None, + **kwargs, + ): + # TODO(tanzhenyu): Add metrics support once COCOMap issue is addressed. + # https://github.com/keras-team/keras-cv/issues/915 + if "metrics" in kwargs.keys(): + raise ValueError( + "`FasterRCNN` does not currently support the use of " + "`metrics` due to performance and distribution concerns. " + "Please use the `PyCOCOCallback` to evaluate COCO metrics." + ) + if loss is not None: + raise ValueError( + "`FasterRCNN` does not accept a `loss` to `compile()`. " + "Instead, please pass `box_loss` and `classification_loss`. " + "`loss` will be ignored during training." + ) + box_loss = _validate_and_get_loss(box_loss, "box_loss") + classification_loss = _validate_and_get_loss( + classification_loss, "classification_loss" + ) + rpn_box_loss = _validate_and_get_loss(rpn_box_loss, "rpn_box_loss") + if rpn_classification_loss == "BinaryCrossentropy": + rpn_classification_loss = keras.losses.BinaryCrossentropy( + from_logits=True, reduction=keras.losses.Reduction.SUM + ) + rpn_classification_loss = _validate_and_get_loss( + rpn_classification_loss, "rpn_cls_loss" + ) + if not rpn_classification_loss.from_logits: + raise ValueError( + "`rpn_classification_loss` must come with `from_logits`=True" + ) + + self.rpn_box_loss = rpn_box_loss + self.rpn_cls_loss = rpn_classification_loss + self.box_loss = box_loss + self.cls_loss = classification_loss + self.weight_decay = weight_decay + losses = { + "box": self.box_loss, + "classification": self.cls_loss, + "rpn_box": self.rpn_box_loss, + "rpn_classification": self.rpn_cls_loss, + } + super().compile(loss=losses, **kwargs) + + def compute_loss(self, images, boxes, classes, training): + local_batch = images.get_shape().as_list()[0] + if tf.distribute.has_strategy(): + num_sync = tf.distribute.get_strategy().num_replicas_in_sync + else: + num_sync = 1 + global_batch = local_batch * num_sync + anchors = self.anchor_generator(image_shape=tuple(images[0].shape)) + ( + rpn_box_targets, + rpn_box_weights, + rpn_cls_targets, + rpn_cls_weights, + ) = self.rpn_labeler( + tf.concat(tf.nest.flatten(anchors), axis=0), boxes, classes + ) + rpn_box_weights /= ( + self.rpn_labeler.samples_per_image * global_batch * 0.25 + ) + rpn_cls_weights /= self.rpn_labeler.samples_per_image * global_batch + rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( + images, anchors, training=training + ) + rois = tf.stop_gradient(rois) + ( + rois, + box_targets, + box_weights, + cls_targets, + cls_weights, + ) = self.roi_sampler(rois, boxes, classes) + box_weights /= self.roi_sampler.num_sampled_rois * global_batch * 0.25 + cls_weights /= self.roi_sampler.num_sampled_rois * global_batch + box_pred, cls_pred = self._call_rcnn( + rois, feature_map, training=training + ) + y_true = { + "rpn_box": rpn_box_targets, + "rpn_classification": rpn_cls_targets, + "box": box_targets, + "classification": cls_targets, + } + y_pred = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + } + weights = { + "rpn_box": rpn_box_weights, + "rpn_classification": rpn_cls_weights, + "box": box_weights, + "classification": cls_weights, + } + return super().compute_loss( + x=images, y=y_true, y_pred=y_pred, sample_weight=weights + ) + + def train_step(self, data): + images, y = unpack_input(data) + + boxes = y["boxes"] + if len(y["classes"].shape) != 2: + raise ValueError( + "Expected 'classes' to be a tf.Tensor of rank 2. " + f"Got y['classes'].shape={y['classes'].shape}." + ) + # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere + classes = tf.expand_dims(y["classes"], axis=-1) + with tf.GradientTape() as tape: + total_loss = self.compute_loss( + images, boxes, classes, training=True + ) + reg_losses = [] + if self.weight_decay: + for var in self.trainable_variables: + if "bn" not in var.name: + reg_losses.append( + self.weight_decay * tf.nn.l2_loss(var) + ) + l2_loss = tf.math.add_n(reg_losses) + total_loss += l2_loss + self.optimizer.minimize(total_loss, self.trainable_variables, tape=tape) + return self.compute_metrics(images, {}, {}, sample_weight={}) + + def test_step(self, data): + images, y = unpack_input(data) + + boxes = y["boxes"] + if len(y["classes"].shape) != 2: + raise ValueError( + "Expected 'classes' to be a tf.Tensor of rank 2. " + f"Got y['classes'].shape={y['classes'].shape}." + ) + classes = tf.expand_dims(y["classes"], axis=-1) + self.compute_loss(images, boxes, classes, training=False) + return self.compute_metrics(images, {}, {}, sample_weight={}) + + def make_predict_function(self, force=False): + return predict_utils.make_predict_function(self, force=force) + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + + def decode_predictions(self, predictions, images): + # no-op if default decoder is used. + box_pred, scores_pred = predictions + box_pred = bounding_box.convert_format( + box_pred, + source=self.bounding_box_format, + target=self.prediction_decoder.bounding_box_format, + images=images, + ) + y_pred = self.prediction_decoder(box_pred, scores_pred[..., :-1]) + box_pred = bounding_box.convert_format( + y_pred["boxes"], + source=self.prediction_decoder.bounding_box_format, + target=self.bounding_box_format, + images=images, + ) + y_pred["boxes"] = box_pred + return y_pred + + def get_config(self): + return { + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "backbone": self.backbone, + "anchor_generator": self.anchor_generator, + "label_encoder": self.rpn_labeler, + "prediction_decoder": self._prediction_decoder, + "feature_pyramid": self.feature_pyramid, + "rcnn_head": self.rcnn_head, + } + + +def _validate_and_get_loss(loss, loss_name): + if isinstance(loss, str): + loss = keras.losses.get(loss) + if loss is None or not isinstance(loss, keras.losses.Loss): + raise ValueError( + f"FasterRCNN only accepts `keras.losses.Loss` for {loss_name}, " + f"got {loss}" + ) + if loss.reduction != keras.losses.Reduction.SUM: + logging.info( + f"FasterRCNN only accepts `SUM` reduction, got {loss.reduction}, " + "automatically converted." + ) + loss.reduction = keras.losses.Reduction.SUM + return loss diff --git a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py index d0bf087b1d..2e8f581dfb 100644 --- a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py +++ b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py @@ -73,7 +73,7 @@ def build(self, input_shape): def call(self, feature_map, training=None): def call_single_level(f_map): - batch_size = f_map.get_shape().as_list()[0] or ops.shape(f_map)[0] + batch_size = ops.shape(f_map)[0] # [BS, H, W, C] t = self.conv(f_map) # [BS, H, W, K] From 2f5c0a2e4e570a3d55b3f916f2e68adf0982c645 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Wed, 13 Mar 2024 17:02:07 +0530 Subject: [PATCH 29/32] chore: update roi align --- demo.py | 4 +- keras_cv/layers/object_detection/roi_align.py | 187 ++---- .../faster_rcnn/faster_rcnn.py | 26 +- .../faster_rcnn/faster_rcnn_port.py | 567 ------------------ .../faster_rcnn/faster_rcnn_tf.py | 394 ------------ 5 files changed, 71 insertions(+), 1107 deletions(-) delete mode 100644 keras_cv/models/object_detection/faster_rcnn/faster_rcnn_port.py delete mode 100644 keras_cv/models/object_detection/faster_rcnn/faster_rcnn_tf.py diff --git a/demo.py b/demo.py index 5ae0890a8b..1cbfca779e 100644 --- a/demo.py +++ b/demo.py @@ -1,8 +1,10 @@ -import keras import numpy as np import keras_cv +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + # Note: We absolutely need this while creating the Backbone batch_size = 32 image_shape = (512, 512, 3) diff --git a/keras_cv/layers/object_detection/roi_align.py b/keras_cv/layers/object_detection/roi_align.py index 25b698e29b..5b8ed109d3 100644 --- a/keras_cv/layers/object_detection/roi_align.py +++ b/keras_cv/layers/object_detection/roi_align.py @@ -25,31 +25,26 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x): The RoIAlign feature f can be computed by bilinear interpolation of four neighboring feature points f0, f1, f2, and f3. + f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T - [f10, f11]] + [f10, f11]] f(y, x) = (hy*hx)f00 + (hy*lx)f01 + (ly*hx)f10 + (lx*ly)f11 f(y, x) = w00*f00 + w01*f01 + w10*f10 + w11*f11 kernel_y = [hy, ly] kernel_x = [hx, lx] Args: - features: The features are in shape of [batch_size, num_boxes, - output_size * 2, output_size * 2, num_filters]. - kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1]. - kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1]. + features: The features are in shape of [batch_size, num_boxes, output_size * + 2, output_size * 2, num_filters]. + kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1]. + kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1]. Returns: - A 5-D tensor representing feature crop of shape - [batch_size, num_boxes, output_size, output_size, num_filters]. - """ - features_shape = ops.shape(features) - batch_size, num_boxes, output_size, num_filters = ( - features_shape[0], - features_shape[1], - features_shape[2], - features_shape[4], - ) + A 5-D tensor representing feature crop of shape + [batch_size, num_boxes, output_size, output_size, num_filters]. + """ + (batch_size, num_boxes, output_size, _, num_filters) = ops.shape(features) output_size = output_size // 2 kernel_y = ops.reshape( kernel_y, [batch_size, num_boxes, output_size * 2, 1] @@ -69,48 +64,38 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x): features, [batch_size * num_boxes, output_size * 2, output_size * 2, num_filters], ) - features = ops.nn.average_pool( - features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID" - ) + features = ops.average_pool(features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID") features = ops.reshape( features, [batch_size, num_boxes, output_size, output_size, num_filters] ) return features -def _compute_grid_positions( - boxes, - boundaries, - output_size, - sample_offset, -): - """ - Computes the grid position w.r.t. the corresponding feature map. +def _compute_grid_positions(boxes, boundaries, output_size, sample_offset): + """Compute the grid position w.r.t. + + the corresponding feature map. Args: - boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the + boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the information of each box w.r.t. the corresponding feature map. boxes[:, :, 0:2] are the grid position in (y, x) (float) of the top-left corner of each box. boxes[:, :, 2:4] are the box sizes in (h, w) (float) - in terms of the number of pixels of the corresponding feature map - size. - boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing + in terms of the number of pixels of the corresponding feature map size. + boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing the boundary (in (y, x)) of the corresponding feature map for each box. - Any resampled grid points that go beyond the boundary will be clipped. - output_size: a `int` scalar indicating the output crop size. - sample_offset: a `float` number in [0, 1] indicates the subpixel sample - offset from grid point. + Any resampled grid points that go beyond the bounary will be clipped. + output_size: a scalar indicating the output crop size. + sample_offset: a float number in [0, 1] indicates the subpixel sample offset + from grid point. Returns: - kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1]. - kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1]. - box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2] - box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2] + kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1]. + kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1]. + box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2] + box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2] """ - boxes_shape = ops.shape(boxes) - batch_size, num_boxes = boxes_shape[0], boxes_shape[1] - if batch_size is None: - batch_size = ops.shape(boxes)[0] + batch_size, num_boxes, _ = ops.shape(boxes) box_grid_x = [] box_grid_y = [] for i in range(output_size): @@ -125,12 +110,8 @@ def _compute_grid_positions( box_grid_y0 = ops.floor(box_grid_y) box_grid_x0 = ops.floor(box_grid_x) - box_grid_x0 = ops.maximum( - ops.cast(0.0, dtype=box_grid_x0.dtype), box_grid_x0 - ) - box_grid_y0 = ops.maximum( - ops.cast(0.0, dtype=box_grid_y0.dtype), box_grid_y0 - ) + box_grid_x0 = ops.maximum(0.0, box_grid_x0) + box_grid_y0 = ops.maximum(0.0, box_grid_y0) box_grid_x0 = ops.minimum( box_grid_x0, ops.expand_dims(boundaries[:, :, 1], -1) @@ -168,52 +149,33 @@ def _compute_grid_positions( def multilevel_crop_and_resize( - features, - boxes, - output_size: int = 7, - sample_offset: float = 0.5, + features, boxes, output_size=7, sample_offset=0.5 ): - """ - Crop and resize on multilevel feature pyramid. + """Crop and resize on multilevel feature pyramid. Generate the (output_size, output_size) set of pixels for each input box by first locating the box into the correct feature level, and then cropping - and resizing it using the corresponding feature map of that level. + and resizing it using the correspoding feature map of that level. Args: - features: A dictionary with key as pyramid level and value as - features (tensors). - The pyramid level keys need to be represented by strings like so: - "P2", "P3", "P4", and so on. - The features are in shape of [batch_size, height_l, width_l, - num_filters]. - boxes: A 3-D Tensor of shape [batch_size, num_boxes, 4]. Each row - represents a box with [y1, x1, y2, x2] in un-normalized coordinates. - output_size: A scalar to indicate the output crop size. - sample_offset: a float number in [0, 1] indicates the subpixel sample - offset from grid point. + features: A dictionary with key as pyramid level and value as features. The + features are in shape of [batch_size, height_l, width_l, num_filters]. + boxes: A 3-D Tensor of shape [batch_size, num_boxes, 4]. Each row represents + a box with [y1, x1, y2, x2] in un-normalized coordinates. + output_size: A scalar to indicate the output crop size. Returns: - A 5-D tensor representing feature crop of shape - [batch_size, num_boxes, output_size, output_size, num_filters]. + A 5-D tensor representing feature crop of shape + [batch_size, num_boxes, output_size, output_size, num_filters]. """ - - levels_str = list(features.keys()) - # Levels are represented by strings with a prefix "P" to represent - # pyramid levels. The integer level can be obtained by looking at - # the value that follows the "P". - levels = [int(level_str[1:]) for level_str in levels_str] + levels = list(features.keys()) + levels = [int(level[1:]) for level in levels] min_level = min(levels) max_level = max(levels) - features_shape = ops.shape(features[f"P{min_level}"]) - batch_size, max_feature_height, max_feature_width, num_filters = ( - features_shape[0], - features_shape[1], - features_shape[2], - features_shape[3], + batch_size, max_feature_height, max_feature_width, num_filters = ops.shape( + features[f"P{min_level}"] ) - - num_boxes = ops.shape(boxes)[1] + _, num_boxes, _ = ops.shape(boxes) # Stack feature pyramid into a features_all of shape # [batch_size, levels, height, width, num_filters]. @@ -224,14 +186,14 @@ def multilevel_crop_and_resize( shape = ops.shape(features[f"P{level}"]) feature_heights.append(shape[1]) feature_widths.append(shape[2]) - # Concat tensor of [batch_size, height_l * width_l, num_filters] for - # each level. + # Concat tensor of [batch_size, height_l * width_l, num_filters] for each + # levels. features_all.append( ops.reshape(features[f"P{level}"], [batch_size, -1, num_filters]) ) - features_r2 = ops.reshape( - ops.concatenate(features_all, 1), [-1, num_filters] - ) + features_r2 = ops.reshape( + ops.concatenate(features_all, 1), [-1, num_filters] + ) # Calculate height_l * width_l for each level. level_dim_sizes = [ @@ -243,28 +205,15 @@ def multilevel_crop_and_resize( for i in range(len(feature_widths) - 1): level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i]) batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1] - level_dim_offsets = ops.convert_to_tensor(level_dim_offsets) - level_dim_offsets = ( - ops.ones_like(level_dim_offsets, dtype="int32") * level_dim_offsets - ) - feature_widths = ops.convert_to_tensor(feature_widths) - height_dim_sizes = ( - ops.ones_like(feature_widths, dtype="int32") * feature_widths - ) + level_dim_offsets = ops.array(level_dim_offsets, dtype="int32") + height_dim_sizes = ops.array(feature_widths, dtype="int32") # Assigns boxes to the right level. box_width = boxes[:, :, 3] - boxes[:, :, 1] box_height = boxes[:, :, 2] - boxes[:, :, 0] - areas_sqrt = ops.sqrt( - ops.cast(box_height, "float32") * ops.cast(box_width, "float32") - ) - - # following the FPN paper to divide by 224. + areas_sqrt = ops.sqrt(box_height * box_width) levels = ops.cast( - ops.floor_divide( - ops.log(ops.divide(areas_sqrt, 224.0)), - ops.log(2.0), - ) + ops.floor_divide(ops.log(ops.divide(areas_sqrt, 224.0)), ops.log(2.0)) + 4.0, dtype="int32", ) @@ -273,7 +222,7 @@ def multilevel_crop_and_resize( # Projects box location and sizes to corresponding feature levels. scale_to_level = ops.cast( - ops.power(2.0, ops.cast(levels, "float32")), + ops.power(ops.array(2.0), ops.cast(levels, "float32")), dtype=boxes.dtype, ) boxes /= ops.expand_dims(scale_to_level, axis=2) @@ -291,8 +240,6 @@ def multilevel_crop_and_resize( # Maps levels to [0, max_level-min_level]. levels -= min_level level_strides = ops.power([[2.0]], ops.cast(levels, "float32")) - print(f"{max_feature_height=}") - print(f"{level_strides=}") boundary = ops.cast( ops.concatenate( [ @@ -313,12 +260,9 @@ def multilevel_crop_and_resize( ) # Compute grid positions. - ( - kernel_y, - kernel_x, - box_gridy0y1, - box_gridx0x1, - ) = _compute_grid_positions(boxes, boundary, output_size, sample_offset) + kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = _compute_grid_positions( + boxes, boundary, output_size, sample_offset=sample_offset + ) x_indices = ops.cast( ops.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]), @@ -338,8 +282,7 @@ def multilevel_crop_and_resize( # Get level offset for each box. Each box belongs to one level. levels_offset = ops.tile( ops.reshape( - ops.take(level_dim_offsets, levels), - [batch_size, num_boxes, 1, 1], + ops.take(level_dim_offsets, levels), [batch_size, num_boxes, 1, 1] ), [1, 1, output_size * 2, output_size * 2], ) @@ -359,17 +302,11 @@ def multilevel_crop_and_resize( [-1], ) - # TODO(tanzhenyu): replace tf.gather with tf.gather_nd and try to get - # similar performance. + # TODO(tanzhenyu): replace tf.gather with tf.gather_nd and try to get similar + # performance. features_per_box = ops.reshape( - ops.take(features_r2, indices), - [ - batch_size, - num_boxes, - output_size * 2, - output_size * 2, - num_filters, - ], + ops.take(features_r2, indices, axis=0), + [batch_size, num_boxes, output_size * 2, output_size * 2, num_filters], ) # Bilinear interpolation. diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 5f95177e51..77bc993230 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -12,18 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_cv import models from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras -from keras_cv.backend import ops from keras_cv.bounding_box.converters import _decode_deltas_to_boxes from keras_cv.bounding_box.utils import _clip_boxes from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator -from keras_cv.layers.object_detection.box_matcher import BoxMatcher from keras_cv.layers.object_detection.roi_generator import ROIGenerator -from keras_cv.layers.object_detection.roi_pool import ROIPooler -from keras_cv.layers.object_detection.roi_sampler import _ROISampler -from keras_cv.models.object_detection.__internal__ import unpack_input +from keras_cv.layers.object_detection.roi_align import _ROIAligner from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid from keras_cv.models.object_detection.faster_rcnn import RPNHead from keras_cv.models.task import Task @@ -56,7 +51,7 @@ def __init__( backbone, extractor_layer_names, extractor_levels ) feature_pyramid = feature_pyramid or FeaturePyramid() - image_shape = feature_extractor.input_shape[1:] # excule the batch size + image_shape = feature_extractor.input_shape[1:] # exclude the batch size images = keras.layers.Input( image_shape, batch_size=batch_size, name="images" ) @@ -132,19 +127,10 @@ def __init__( print(f"{rois.shape=}") # Using the regions call the rcnn head - roi_pooler = ROIPooler( - bounding_box_format=bounding_box_format, - target_size=[7, 7], - image_shape=image_shape, - ) - feature_map_pooled = dict() - for key, value in feature_map.items(): - feature_map_pooled[key] = roi_pooler(value, rois) - - print("feature_map_pooled") - for key, value in feature_map_pooled.item(): - print(f"{key}: {value.shape}") - + roi_pooler = _ROIAligner(bounding_box_format="yxyx") + feature_map = roi_pooler(features=feature_map, boxes=rois) + print(f"{feature_map.shape=}") + # # # Create the anchor generator # scales = [2**x for x in [0]] diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_port.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_port.py deleted file mode 100644 index 521902f348..0000000000 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_port.py +++ /dev/null @@ -1,567 +0,0 @@ -# Copyright 2023 The KerasCV Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy - -import tree -from absl import logging - -from keras_cv import bounding_box -from keras_cv import layers as cv_layers -from keras_cv import models -from keras_cv.api_export import keras_cv_export -from keras_cv.backend import keras -from keras_cv.backend import ops -from keras_cv.bounding_box.converters import _decode_deltas_to_boxes -from keras_cv.bounding_box.utils import _clip_boxes -from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator -from keras_cv.layers.object_detection.box_matcher import BoxMatcher -from keras_cv.layers.object_detection.roi_align import _ROIAligner -from keras_cv.layers.object_detection.roi_generator import ROIGenerator -from keras_cv.layers.object_detection.roi_sampler import _ROISampler -from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder -from keras_cv.models.object_detection.__internal__ import unpack_input -from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid -from keras_cv.models.object_detection.faster_rcnn import RCNNHead -from keras_cv.models.object_detection.faster_rcnn import RPNHead -from keras_cv.models.task import Task -from keras_cv.utils.python_utils import classproperty -from keras_cv.utils.train import get_feature_extractor - -BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] - - -# TODO(tanzheny): add more configurations -@keras_cv_export("keras_cv.models.FasterRCNN") -class FasterRCNN(Task): - """A Keras model implementing the FasterRCNN architecture. - - Implements the FasterRCNN architecture for object detection. The constructor - requires `backbone`, `num_classes`, and a `bounding_box_format`. - - References: - - [FasterRCNN](https://arxiv.org/abs/1506.01497) - - Args: - backbone: `keras.Model`. Must implement the - `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" - and layer names as values. - num_classes: the number of classes in your dataset excluding the - background class. classes should be represented by integers in the - range [0, num_classes). - bounding_box_format: The format of bounding boxes of model output. Refer - [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) - for more details on supported bounding box formats. - anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. It is - used in the model to match ground truth boxes and labels with - anchors, or with region proposals. By default it uses the sizes and - ratios from the paper, that is optimized for image size between - [640, 800]. The users should pass their own anchor generator if the - input image size differs from paper. For now, only anchor generator - with per level dict output is supported, - label_encoder: (Optional) a keras.Layer that accepts an anchors Tensor, - a bounding box Tensor and a bounding box class Tensor to its - `call()` method, and returns RetinaNet training targets. It returns - box and class targets as well as sample weights. - rcnn_head: (Optional) a `keras.layers.Layer` that takes input feature - map and returns a box delta prediction (in reference to rois) and - multi-class prediction (all foreground classes + one background - class). By default it uses the rcnn head from paper, which is 2 FC - layer with 1024 dimension, 1 box regressor and 1 softmax classifier. - prediction_decoder: (Optional) a `keras.layers.Layer` that takes input - box prediction and softmaxed score prediction, and returns NMSed box - prediction, NMSed softmaxed score prediction, NMSed class - prediction, and NMSed valid detection. - - Examples: - - ```python - images = np.ones((1, 512, 512, 3)) - labels = { - "boxes": [ - [ - [0, 0, 100, 100], - [100, 100, 200, 200], - [300, 300, 100, 100], - ] - ], - "classes": [[1, 1, 1]], - } - model = keras_cv.models.FasterRCNN( - num_classes=20, - bounding_box_format="xywh", - backbone=keras_cv.models.ResNet50Backbone.from_preset( - "resnet50_imagenet" - ) - ) - - # Evaluate model without box decoding and NMS - model(images) - - # Prediction with box decoding and NMS - model.predict(images) - - # Train model - model.compile( - classification_loss='focal', - box_loss='smoothl1', - optimizer=keras.optimizers.SGD(global_clipnorm=10.0), - jit_compile=False, - ) - model.fit(images, labels) - ``` - """ # noqa: E501 - - def __init__( - self, - backbone, - num_classes, - bounding_box_format, - anchor_generator=None, - label_encoder=None, - rcnn_head=None, - prediction_decoder=None, - feature_pyramid=None, - **kwargs, - ): - scales = [2**x for x in [0]] - aspect_ratios = [0.5, 1.0, 2.0] - anchor_generator = anchor_generator or AnchorGenerator( - bounding_box_format=bounding_box_format, - sizes={ - "P2": 32.0, - "P3": 64.0, - "P4": 128.0, - "P5": 256.0, - "P6": 512.0, - }, - scales=scales, - aspect_ratios=aspect_ratios, - strides={f"P{i}": 2**i for i in range(2, 7)}, - clip_boxes=True, - ) - rpn_head = RPNHead( - num_anchors_per_location=len(scales) * len(aspect_ratios) - ) - roi_generator = ROIGenerator( - bounding_box_format=bounding_box_format, - nms_score_threshold_train=float("-inf"), - nms_score_threshold_test=float("-inf"), - ) - roi_pooler = _ROIAligner(bounding_box_format=bounding_box_format) - rcnn_head = rcnn_head or RCNNHead(num_classes) - backbone = backbone or models.ResNet50Backbone() - extractor_levels = ["P2", "P3", "P4", "P5"] - extractor_layer_names = [ - backbone.pyramid_level_inputs[i] for i in extractor_levels - ] - feature_extractor = get_feature_extractor( - backbone, extractor_layer_names, extractor_levels - ) - feature_pyramid = feature_pyramid or FeaturePyramid() - - # Begin construction of forward pass - image_shape = feature_extractor.input_shape[1:] - images = keras.layers.Input(image_shape, name="images") - anchors = anchor_generator(image_shape=image_shape) - - # Calling the RPN block - backbone_outputs = feature_extractor(images) - feature_map = feature_pyramid(backbone_outputs) - # [BS, num_anchors, 4], [BS, num_anchors, 1] - rpn_boxes, rpn_scores = rpn_head(feature_map) - # the decoded format is center_xywh, convert to yxyx - decoded_rpn_boxes = _decode_deltas_to_boxes( - anchors=anchors, - boxes_delta=rpn_boxes, - anchor_format=bounding_box_format, - box_format=bounding_box_format, - variance=BOX_VARIANCE, - ) - rois, _ = roi_generator(decoded_rpn_boxes, rpn_scores) - rois = _clip_boxes(rois, bounding_box_format, image_shape) - - # Calling the RCNN block - feature_map = roi_pooler(feature_map, rois) - # [BS, H*W*K, pool_shape*C] - feature_map = keras.layers.Reshape( - target_shape=(ops.shape(rois)[1], -1) - )(feature_map) - # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] - rcnn_box_pred, rcnn_cls_pred = rcnn_head( - feature_map, - ) - - inputs = {"images": images} - outputs = {"box": rcnn_box_pred, "classification": rcnn_cls_pred} - - super().__init__(inputs=inputs, outputs=outputs, **kwargs) - - self.num_classes = num_classes - self.bounding_box_format = bounding_box_format - self.anchor_generator = anchor_generator - self.rpn_head = rpn_head - self.roi_generator = roi_generator - self.box_matcher = BoxMatcher( - thresholds=[0.0, 0.5], match_values=[-2, -1, 1] - ) - self.roi_sampler = _ROISampler( - bounding_box_format="yxyx", - roi_matcher=self.box_matcher, - background_class=num_classes, - num_sampled_rois=512, - ) - self.roi_pooler = roi_pooler - self.rcnn_head = rcnn_head - self.backbone = backbone - self.feature_extractor = feature_extractor - self.feature_pyramid = feature_pyramid - self.rpn_labeler = label_encoder or _RpnLabelEncoder( - anchor_format="yxyx", - ground_truth_box_format="yxyx", - positive_threshold=0.7, - negative_threshold=0.3, - samples_per_image=256, - positive_fraction=0.5, - box_variance=BOX_VARIANCE, - ) - self._prediction_decoder = ( - prediction_decoder - or cv_layers.NonMaxSuppression( - bounding_box_format=bounding_box_format, - from_logits=False, - iou_threshold=0.5, - confidence_threshold=0.5, - max_detections=100, - ) - ) - - def _call_rpn(self, images, anchors, training=None): - image_shape = ops.shape(images[0]) - backbone_outputs = self.feature_extractor(images, training=training) - feature_map = self.feature_pyramid(backbone_outputs, training=training) - # [BS, num_anchors, 4], [BS, num_anchors, 1] - rpn_boxes, rpn_scores = self.rpn_head(feature_map, training=training) - # the decoded format is center_xywh, convert to yxyx - decoded_rpn_boxes = _decode_deltas_to_boxes( - anchors=anchors, - boxes_delta=rpn_boxes, - anchor_format="yxyx", - box_format="yxyx", - variance=BOX_VARIANCE, - ) - rois, _ = self.roi_generator( - decoded_rpn_boxes, rpn_scores, training=training - ) - rois = _clip_boxes(rois, "yxyx", image_shape) - rpn_boxes = ops.concatenate(tree.flatten(rpn_boxes), axis=1) - rpn_scores = ops.concatenate(tree.flatten(rpn_scores), axis=1) - return rois, feature_map, rpn_boxes, rpn_scores - - def _call_rcnn(self, rois, feature_map, training=None): - feature_map = self.roi_pooler(feature_map, rois) - # [BS, H*W*K, pool_shape*C] - feature_map = ops.reshape( - feature_map, ops.concatenate([ops.shape(rois)[:2], [-1]], axis=0) - ) - # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] - rcnn_box_pred, rcnn_cls_pred = self.rcnn_head( - feature_map, training=training - ) - return rcnn_box_pred, rcnn_cls_pred - - # TODO(tanzhenyu): Support compile with metrics. - def compile( - self, - box_loss=None, - classification_loss=None, - rpn_box_loss=None, - rpn_classification_loss=None, - weight_decay=0.0001, - loss=None, - **kwargs, - ): - # TODO(tanzhenyu): Add metrics support once COCOMap issue is addressed. - # https://github.com/keras-team/keras-cv/issues/915 - if "metrics" in kwargs.keys(): - raise ValueError( - "`FasterRCNN` does not currently support the use of " - "`metrics` due to performance and distribution concerns. " - "Please use the `PyCOCOCallback` to evaluate COCO metrics." - ) - if loss is not None: - raise ValueError( - "`FasterRCNN` does not accept a `loss` to `compile()`. " - "Instead, please pass `box_loss` and `classification_loss`. " - "`loss` will be ignored during training." - ) - box_loss = _validate_and_get_loss(box_loss, "box_loss") - classification_loss = _validate_and_get_loss( - classification_loss, "classification_loss" - ) - rpn_box_loss = _validate_and_get_loss(rpn_box_loss, "rpn_box_loss") - if rpn_classification_loss == "BinaryCrossentropy": - rpn_classification_loss = keras.losses.BinaryCrossentropy( - from_logits=True, reduction=keras.losses.Reduction.SUM - ) - rpn_classification_loss = _validate_and_get_loss( - rpn_classification_loss, "rpn_cls_loss" - ) - if not rpn_classification_loss.from_logits: - raise ValueError( - "`rpn_classification_loss` must come with `from_logits`=True" - ) - - self.rpn_box_loss = rpn_box_loss - self.rpn_cls_loss = rpn_classification_loss - self.box_loss = box_loss - self.cls_loss = classification_loss - self.weight_decay = weight_decay - losses = { - "box": self.box_loss, - "classification": self.cls_loss, - "rpn_box": self.rpn_box_loss, - "rpn_classification": self.rpn_cls_loss, - } - super().compile(loss=losses, **kwargs) - - def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): - images = x - boxes = y["boxes"] - if len(y["classes"].shape) != 2: - raise ValueError( - "Expected 'classes' to be a tf.Tensor of rank 2. " - f"Got y['classes'].shape={y['classes'].shape}." - ) - # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere - classes = ops.expand_dims(y["classes"], axis=-1) - - local_batch = images.get_shape().as_list()[0] - anchors = self.anchor_generator(image_shape=tuple(images[0].shape)) - ( - rpn_box_targets, - rpn_box_weights, - rpn_cls_targets, - rpn_cls_weights, - ) = self.rpn_labeler( - ops.concatenate(tree.flatten(anchors), axis=0), boxes, classes - ) - rpn_box_weights /= ( - self.rpn_labeler.samples_per_image * local_batch * 0.25 - ) - rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch - rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( - images, - anchors, - ) - rois = ops.stop_gradient(rois) - ( - rois, - box_targets, - box_weights, - cls_targets, - cls_weights, - ) = self.roi_sampler(rois, boxes, classes) - - positive_mask = ops.cast( - ops.greater(cls_targets, -1.0), dtype="float32" - ) - normalizer = ops.sum(positive_mask) - - box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 - cls_weights /= self.roi_sampler.num_sampled_rois * local_batch - - box_weights /= normalizer - cls_weights /= normalizer - - box_pred, cls_pred = self._call_rcnn( - rois, - feature_map, - ) - - y_true = { - "rpn_box": rpn_box_targets, - "rpn_classification": rpn_cls_targets, - "box": box_targets, - "classification": cls_targets, - } - y_pred = { - "rpn_box": rpn_box_pred, - "rpn_classification": rpn_cls_pred, - "box": box_pred, - "classification": cls_pred, - } - sample_weights = { - "rpn_box": rpn_box_weights, - "rpn_classification": rpn_cls_weights, - "box": box_weights, - "classification": cls_weights, - } - zero_weights = { - "rpn_box": ops.zeros_like(rpn_box_weights), - "rpn_classification": ops.zeros_like(rpn_cls_weights), - "box": ops.zeros_like(box_weights), - "classification": ops.zeros_like(cls_weights), - } - - sample_weights = ops.cond( - normalizer == 0.0, - lambda: zero_weights, - lambda: sample_weights, - ) - return super().compute_loss( - x=images, y=y_true, y_pred=y_pred, sample_weight=sample_weights - ) - - def train_step(self, *args): - data = args[-1] - args = args[:-1] - x, y = unpack_input(data) - return super().train_step(*args, (x, y)) - - def test_step(self, *args): - data = args[-1] - args = args[:-1] - x, y = unpack_input(data) - return super().test_step(*args, (x, y)) - - def predict_step(self, *args): - outputs = super().predict_step(*args) - if type(outputs) is tuple: - return self.decode_predictions(outputs[0], args[-1]), outputs[1] - else: - return self.decode_predictions(outputs, args[-1]) - - @property - def prediction_decoder(self): - return self._prediction_decoder - - @prediction_decoder.setter - def prediction_decoder(self, prediction_decoder): - self._prediction_decoder = prediction_decoder - self.make_predict_function(force=True) - - def decode_predictions(self, predictions, images): - # no-op if default decoder is used. - box_pred, scores_pred = ( - predictions["box"], - predictions["classification"], - ) - box_pred = bounding_box.convert_format( - box_pred, - source=self.bounding_box_format, - target=self.prediction_decoder.bounding_box_format, - images=images, - ) - y_pred = self.prediction_decoder(box_pred, scores_pred[..., :-1]) - box_pred = bounding_box.convert_format( - y_pred["boxes"], - source=self.prediction_decoder.bounding_box_format, - target=self.bounding_box_format, - images=images, - ) - y_pred["boxes"] = box_pred - return y_pred - - def get_config(self): - return { - "num_classes": self.num_classes, - "bounding_box_format": self.bounding_box_format, - "backbone": keras.saving.serialize_keras_object(self.backbone), - "anchor_generator": keras.saving.serialize_keras_object( - self.anchor_generator - ), - "label_encoder": keras.saving.serialize_keras_object( - self.rpn_labeler - ), - "prediction_decoder": keras.saving.serialize_keras_object( - self._prediction_decoder - ), - "feature_pyramid": keras.saving.serialize_keras_object( - self.feature_pyramid - ), - "rcnn_head": keras.saving.serialize_keras_object(self.rcnn_head), - } - - @classmethod - def from_config(cls, config): - if "rcnn_head" in config and isinstance(config["rcnn_head"], dict): - config["rcnn_head"] = keras.layers.deserialize(config["rcnn_head"]) - if "feature_pyramid" in config and isinstance( - config["feature_pyramid"], dict - ): - config["feature_pyramid"] = keras.layers.deserialize( - config["feature_pyramid"] - ) - if "prediction_decoder" in config and isinstance( - config["prediction_decoder"], dict - ): - config["prediction_decoder"] = keras.layers.deserialize( - config["prediction_decoder"] - ) - if "label_encoder" in config and isinstance( - config["label_encoder"], dict - ): - config["label_encoder"] = keras.layers.deserialize( - config["label_encoder"] - ) - if "anchor_generator" in config and isinstance( - config["anchor_generator"], dict - ): - config["anchor_generator"] = keras.layers.deserialize( - config["anchor_generator"] - ) - if "backbone" in config and isinstance(config["backbone"], dict): - config["backbone"] = keras.layers.deserialize(config["backbone"]) - return super().from_config(config) - - @classproperty - def presets(cls): - """Dictionary of preset names and configurations.""" - # return copy.deepcopy({**backbone_presets, **fasterrcnn_presets}) - return copy.deepcopy({**backbone_presets}) - - @classproperty - def presets_with_weights(cls): - """Dictionary of preset names and configurations that include - weights.""" - return copy.deepcopy( - # {**backbone_presets_with_weights, **fasterrcnn_presets} - { - **backbone_presets_with_weights, - } - ) - - @classproperty - def backbone_presets(cls): - """Dictionary of preset names and configurations of compatible - backbones.""" - return copy.deepcopy(backbone_presets) - - -def _validate_and_get_loss(loss, loss_name): - if isinstance(loss, str): - loss = keras.losses.get(loss) - if loss is None or not isinstance(loss, keras.losses.Loss): - raise ValueError( - f"FasterRCNN only accepts `keras.losses.Loss` for {loss_name}, " - f"got {loss}" - ) - if loss.reduction != keras.losses.Reduction.SUM: - logging.info( - f"FasterRCNN only accepts `SUM` reduction, got {loss.reduction}, " - "automatically converted." - ) - loss.reduction = keras.losses.Reduction.SUM - return loss diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_tf.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_tf.py deleted file mode 100644 index fa898cb564..0000000000 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_tf.py +++ /dev/null @@ -1,394 +0,0 @@ -class FasterRCNN(keras.Model): - """A Keras model implementing the FasterRCNN architecture. - - Implements the FasterRCNN architecture for object detection. The constructor - requires `num_classes`, `bounding_box_format` and a `backbone`. - - References: - - [FasterRCNN](https://arxiv.org/pdf/1506.01497.pdf) - - Usage: - ```python - retinanet = keras_cv.models.FasterRCNN( - num_classes=20, - bounding_box_format="xywh", - backbone=None, - ) - ``` - - Args: - num_classes: the number of classes in your dataset excluding the - background class. classes should be represented by integers in the - range [0, num_classes). - bounding_box_format: The format of bounding boxes of model output. Refer - [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) - for more details on supported bounding box formats. - backbone: Optional `keras.Model`. Must implement the - `pyramid_level_inputs` property with keys "P2", "P3", "P4", and "P5" - and layer names as values. If `None`, defaults to - `keras_cv.models.ResNet50Backbone()`. - anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. It is - used in the model to match ground truth boxes and labels with - anchors, or with region proposals. By default it uses the sizes and - ratios from the paper, that is optimized for image size between - [640, 800]. The users should pass their own anchor generator if the - input image size differs from paper. For now, only anchor generator - with per level dict output is supported, - label_encoder: (Optional) a keras.Layer that accepts an anchors Tensor, - a bounding box Tensor and a bounding box class Tensor to its - `call()` method, and returns RetinaNet training targets. It returns - box and class targets as well as sample weights. - rcnn_head: (Optional) a `keras.layers.Layer` that takes input feature - map and returns a box delta prediction (in reference to rois) and - multi-class prediction (all foreground classes + one background - class). By default it uses the rcnn head from paper, which is 2 FC - layer with 1024 dimension, 1 box regressor and 1 softmax classifier. - prediction_decoder: (Optional) a `keras.layers.Layer` that takes input - box prediction and softmaxed score prediction, and returns NMSed box - prediction, NMSed softmaxed score prediction, NMSed class - prediction, and NMSed valid detection. - """ # noqa: E501 - - def __init__( - self, - num_classes, - bounding_box_format, - backbone=None, - anchor_generator=None, - label_encoder=None, - rcnn_head=None, - prediction_decoder=None, - **kwargs, - ): - self.bounding_box_format = bounding_box_format - super().__init__(**kwargs) - scales = [2**x for x in [0]] - aspect_ratios = [0.5, 1.0, 2.0] - self.anchor_generator = anchor_generator or AnchorGenerator( - bounding_box_format="yxyx", - sizes={ - "P2": 32.0, - "P3": 64.0, - "P4": 128.0, - "P5": 256.0, - "P6": 512.0, - }, - scales=scales, - aspect_ratios=aspect_ratios, - strides={f"P{i}": 2**i for i in range(2, 7)}, - clip_boxes=True, - ) - self.rpn_head = RPNHead( - num_anchors_per_location=len(scales) * len(aspect_ratios) - ) - self.roi_generator = ROIGenerator( - bounding_box_format="yxyx", - nms_score_threshold_train=float("-inf"), - nms_score_threshold_test=float("-inf"), - ) - self.box_matcher = BoxMatcher( - thresholds=[0.0, 0.5], match_values=[-2, -1, 1] - ) - self.roi_sampler = _ROISampler( - bounding_box_format="yxyx", - roi_matcher=self.box_matcher, - background_class=num_classes, - num_sampled_rois=512, - ) - self.roi_pooler = _ROIAligner(bounding_box_format="yxyx") - self.rcnn_head = rcnn_head or RCNNHead(num_classes) - self.backbone = backbone or models.ResNet50Backbone() - extractor_levels = ["P2", "P3", "P4", "P5"] - extractor_layer_names = [ - self.backbone.pyramid_level_inputs[i] for i in extractor_levels - ] - self.feature_extractor = get_feature_extractor( - self.backbone, extractor_layer_names, extractor_levels - ) - self.feature_pyramid = FeaturePyramid() - self.rpn_labeler = label_encoder or _RpnLabelEncoder( - anchor_format="yxyx", - ground_truth_box_format="yxyx", - positive_threshold=0.7, - negative_threshold=0.3, - samples_per_image=256, - positive_fraction=0.5, - box_variance=BOX_VARIANCE, - ) - self._prediction_decoder = ( - prediction_decoder - or cv_layers.MultiClassNonMaxSuppression( - bounding_box_format=bounding_box_format, - from_logits=False, - max_detections_per_class=10, - max_detections=10, - ) - ) - - def _call_rpn(self, images, anchors, training=None): - image_shape = tf.shape(images[0]) - backbone_outputs = self.feature_extractor(images, training=training) - feature_map = self.feature_pyramid(backbone_outputs, training=training) - # [BS, num_anchors, 4], [BS, num_anchors, 1] - rpn_boxes, rpn_scores = self.rpn_head(feature_map, training=training) - # the decoded format is center_xywh, convert to yxyx - decoded_rpn_boxes = _decode_deltas_to_boxes( - anchors=anchors, - boxes_delta=rpn_boxes, - anchor_format="yxyx", - box_format="yxyx", - variance=BOX_VARIANCE, - ) - rois, _ = self.roi_generator( - decoded_rpn_boxes, rpn_scores, training=training - ) - rois = _clip_boxes(rois, "yxyx", image_shape) - rpn_boxes = tf.concat(tf.nest.flatten(rpn_boxes), axis=1) - rpn_scores = tf.concat(tf.nest.flatten(rpn_scores), axis=1) - return rois, feature_map, rpn_boxes, rpn_scores - - def _call_rcnn(self, rois, feature_map, training=None): - feature_map = self.roi_pooler(feature_map, rois) - # [BS, H*W*K, pool_shape*C] - feature_map = tf.reshape( - feature_map, tf.concat([tf.shape(rois)[:2], [-1]], axis=0) - ) - # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] - rcnn_box_pred, rcnn_cls_pred = self.rcnn_head( - feature_map, training=training - ) - return rcnn_box_pred, rcnn_cls_pred - - def call(self, images, training=None): - image_shape = tf.shape(images[0]) - anchors = self.anchor_generator(image_shape=image_shape) - rois, feature_map, _, _ = self._call_rpn( - images, anchors, training=training - ) - box_pred, cls_pred = self._call_rcnn( - rois, feature_map, training=training - ) - if not training: - # box_pred is on "center_yxhw" format, convert to target format. - box_pred = _decode_deltas_to_boxes( - anchors=rois, - boxes_delta=box_pred, - anchor_format="yxyx", - box_format=self.bounding_box_format, - variance=[0.1, 0.1, 0.2, 0.2], - ) - - return box_pred, cls_pred - - # TODO(tanzhenyu): Support compile with metrics. - def compile( - self, - box_loss=None, - classification_loss=None, - rpn_box_loss=None, - rpn_classification_loss=None, - weight_decay=0.0001, - loss=None, - **kwargs, - ): - # TODO(tanzhenyu): Add metrics support once COCOMap issue is addressed. - # https://github.com/keras-team/keras-cv/issues/915 - if "metrics" in kwargs.keys(): - raise ValueError( - "`FasterRCNN` does not currently support the use of " - "`metrics` due to performance and distribution concerns. " - "Please use the `PyCOCOCallback` to evaluate COCO metrics." - ) - if loss is not None: - raise ValueError( - "`FasterRCNN` does not accept a `loss` to `compile()`. " - "Instead, please pass `box_loss` and `classification_loss`. " - "`loss` will be ignored during training." - ) - box_loss = _validate_and_get_loss(box_loss, "box_loss") - classification_loss = _validate_and_get_loss( - classification_loss, "classification_loss" - ) - rpn_box_loss = _validate_and_get_loss(rpn_box_loss, "rpn_box_loss") - if rpn_classification_loss == "BinaryCrossentropy": - rpn_classification_loss = keras.losses.BinaryCrossentropy( - from_logits=True, reduction=keras.losses.Reduction.SUM - ) - rpn_classification_loss = _validate_and_get_loss( - rpn_classification_loss, "rpn_cls_loss" - ) - if not rpn_classification_loss.from_logits: - raise ValueError( - "`rpn_classification_loss` must come with `from_logits`=True" - ) - - self.rpn_box_loss = rpn_box_loss - self.rpn_cls_loss = rpn_classification_loss - self.box_loss = box_loss - self.cls_loss = classification_loss - self.weight_decay = weight_decay - losses = { - "box": self.box_loss, - "classification": self.cls_loss, - "rpn_box": self.rpn_box_loss, - "rpn_classification": self.rpn_cls_loss, - } - super().compile(loss=losses, **kwargs) - - def compute_loss(self, images, boxes, classes, training): - local_batch = images.get_shape().as_list()[0] - if tf.distribute.has_strategy(): - num_sync = tf.distribute.get_strategy().num_replicas_in_sync - else: - num_sync = 1 - global_batch = local_batch * num_sync - anchors = self.anchor_generator(image_shape=tuple(images[0].shape)) - ( - rpn_box_targets, - rpn_box_weights, - rpn_cls_targets, - rpn_cls_weights, - ) = self.rpn_labeler( - tf.concat(tf.nest.flatten(anchors), axis=0), boxes, classes - ) - rpn_box_weights /= ( - self.rpn_labeler.samples_per_image * global_batch * 0.25 - ) - rpn_cls_weights /= self.rpn_labeler.samples_per_image * global_batch - rois, feature_map, rpn_box_pred, rpn_cls_pred = self._call_rpn( - images, anchors, training=training - ) - rois = tf.stop_gradient(rois) - ( - rois, - box_targets, - box_weights, - cls_targets, - cls_weights, - ) = self.roi_sampler(rois, boxes, classes) - box_weights /= self.roi_sampler.num_sampled_rois * global_batch * 0.25 - cls_weights /= self.roi_sampler.num_sampled_rois * global_batch - box_pred, cls_pred = self._call_rcnn( - rois, feature_map, training=training - ) - y_true = { - "rpn_box": rpn_box_targets, - "rpn_classification": rpn_cls_targets, - "box": box_targets, - "classification": cls_targets, - } - y_pred = { - "rpn_box": rpn_box_pred, - "rpn_classification": rpn_cls_pred, - "box": box_pred, - "classification": cls_pred, - } - weights = { - "rpn_box": rpn_box_weights, - "rpn_classification": rpn_cls_weights, - "box": box_weights, - "classification": cls_weights, - } - return super().compute_loss( - x=images, y=y_true, y_pred=y_pred, sample_weight=weights - ) - - def train_step(self, data): - images, y = unpack_input(data) - - boxes = y["boxes"] - if len(y["classes"].shape) != 2: - raise ValueError( - "Expected 'classes' to be a tf.Tensor of rank 2. " - f"Got y['classes'].shape={y['classes'].shape}." - ) - # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere - classes = tf.expand_dims(y["classes"], axis=-1) - with tf.GradientTape() as tape: - total_loss = self.compute_loss( - images, boxes, classes, training=True - ) - reg_losses = [] - if self.weight_decay: - for var in self.trainable_variables: - if "bn" not in var.name: - reg_losses.append( - self.weight_decay * tf.nn.l2_loss(var) - ) - l2_loss = tf.math.add_n(reg_losses) - total_loss += l2_loss - self.optimizer.minimize(total_loss, self.trainable_variables, tape=tape) - return self.compute_metrics(images, {}, {}, sample_weight={}) - - def test_step(self, data): - images, y = unpack_input(data) - - boxes = y["boxes"] - if len(y["classes"].shape) != 2: - raise ValueError( - "Expected 'classes' to be a tf.Tensor of rank 2. " - f"Got y['classes'].shape={y['classes'].shape}." - ) - classes = tf.expand_dims(y["classes"], axis=-1) - self.compute_loss(images, boxes, classes, training=False) - return self.compute_metrics(images, {}, {}, sample_weight={}) - - def make_predict_function(self, force=False): - return predict_utils.make_predict_function(self, force=force) - - @property - def prediction_decoder(self): - return self._prediction_decoder - - @prediction_decoder.setter - def prediction_decoder(self, prediction_decoder): - self._prediction_decoder = prediction_decoder - self.make_predict_function(force=True) - - def decode_predictions(self, predictions, images): - # no-op if default decoder is used. - box_pred, scores_pred = predictions - box_pred = bounding_box.convert_format( - box_pred, - source=self.bounding_box_format, - target=self.prediction_decoder.bounding_box_format, - images=images, - ) - y_pred = self.prediction_decoder(box_pred, scores_pred[..., :-1]) - box_pred = bounding_box.convert_format( - y_pred["boxes"], - source=self.prediction_decoder.bounding_box_format, - target=self.bounding_box_format, - images=images, - ) - y_pred["boxes"] = box_pred - return y_pred - - def get_config(self): - return { - "num_classes": self.num_classes, - "bounding_box_format": self.bounding_box_format, - "backbone": self.backbone, - "anchor_generator": self.anchor_generator, - "label_encoder": self.rpn_labeler, - "prediction_decoder": self._prediction_decoder, - "feature_pyramid": self.feature_pyramid, - "rcnn_head": self.rcnn_head, - } - - -def _validate_and_get_loss(loss, loss_name): - if isinstance(loss, str): - loss = keras.losses.get(loss) - if loss is None or not isinstance(loss, keras.losses.Loss): - raise ValueError( - f"FasterRCNN only accepts `keras.losses.Loss` for {loss_name}, " - f"got {loss}" - ) - if loss.reduction != keras.losses.Reduction.SUM: - logging.info( - f"FasterRCNN only accepts `SUM` reduction, got {loss.reduction}, " - "automatically converted." - ) - loss.reduction = keras.losses.Reduction.SUM - return loss From 9c85dfc80dfb6f1018fd179f791559afc195d61d Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Thu, 14 Mar 2024 21:26:35 +0530 Subject: [PATCH 30/32] chore: adding init and compute loss --- demo.py | 51 +-- .../faster_rcnn/faster_rcnn.py | 425 ++++++++++++++---- 2 files changed, 355 insertions(+), 121 deletions(-) diff --git a/demo.py b/demo.py index 1cbfca779e..330eef019e 100644 --- a/demo.py +++ b/demo.py @@ -1,17 +1,16 @@ -import numpy as np - -import keras_cv - import os os.environ['CUDA_VISIBLE_DEVICES'] = '-1' -# Note: We absolutely need this while creating the Backbone -batch_size = 32 +import keras +import keras_cv +from keras_cv.models import FasterRCNN + +batch_size = 1 image_shape = (512, 512, 3) -images = np.ones((batch_size,) + image_shape) +images = keras.ops.ones((batch_size,) + image_shape) labels = { - "boxes": np.array( + "boxes": keras.ops.array( [ [ [0, 0, 100, 100], @@ -19,13 +18,13 @@ [300, 300, 100, 100], ] ], - dtype=np.float32, + dtype="float32", ), - "classes": np.array([[1, 1, 1]], dtype=np.float32), + "classes": keras.ops.array([[1, 1, 1]], dtype="float32"), } -model = keras_cv.models.FasterRCNN( +model = FasterRCNN( batch_size=batch_size, - num_classes=20, + num_classes=2, bounding_box_format="xywh", backbone=keras_cv.models.ResNet50Backbone.from_preset( "resnet50_imagenet", @@ -33,17 +32,19 @@ ), ) -# # Evaluate model without box decoding and NMS -# model(images) - -# # Prediction with box decoding and NMS -# model.predict(images) +# Call the model +outputs = model(images) +print("outputs") +for key, value in outputs.items(): + print(f"{key}: {value.shape}") -# # Train model -# model.compile( -# classification_loss="focal", -# box_loss="smoothl1", -# optimizer=keras.optimizers.SGD(global_clipnorm=10.0), -# jit_compile=False, -# ) -# model.fit(images, labels) +model.compile( + optimizer=keras.optimizers.Adam(), + box_loss=keras.losses.Huber(), + classification_loss=keras.losses.CategoricalCrossentropy(), + rpn_box_loss=keras.losses.Huber(), + rpn_classification_loss=keras.losses.BinaryCrossentropy(from_logits=True), +) +# Train the model +loss = model.compute_loss(x=images, y=labels, y_pred=None, sample_weight=None) +print(loss) \ No newline at end of file diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index 77bc993230..ba79261402 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -14,18 +14,29 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras +from keras_cv.models.object_detection.__internal__ import unpack_input from keras_cv.bounding_box.converters import _decode_deltas_to_boxes from keras_cv.bounding_box.utils import _clip_boxes from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator from keras_cv.layers.object_detection.roi_generator import ROIGenerator +from keras_cv.layers.object_detection.box_matcher import BoxMatcher from keras_cv.layers.object_detection.roi_align import _ROIAligner +from keras_cv.layers.object_detection.roi_sampler import _ROISampler +from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid from keras_cv.models.object_detection.faster_rcnn import RPNHead +from keras_cv.models.object_detection.faster_rcnn import RCNNHead from keras_cv.models.task import Task from keras_cv.utils.train import get_feature_extractor - +import tree BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] +class StopGradientLayer(keras.layers.Layer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def call(self, x): + return keras.ops.stop_gradient(x) # TODO(tanzheny): add more configurations @keras_cv_export("keras_cv.models.FasterRCNN") @@ -38,11 +49,13 @@ def __init__( bounding_box_format, anchor_generator=None, feature_pyramid=None, + rcnn_head=None, + label_encoder=None, *args, **kwargs, ): - # Create the Input Layer + # 1. Create the Input Layer extractor_levels = ["P2", "P3", "P4", "P5"] extractor_layer_names = [ backbone.pyramid_level_inputs[i] for i in extractor_levels @@ -50,39 +63,15 @@ def __init__( feature_extractor = get_feature_extractor( backbone, extractor_layer_names, extractor_levels ) - feature_pyramid = feature_pyramid or FeaturePyramid() + feature_pyramid = feature_pyramid or FeaturePyramid(name="feature_pyramid") image_shape = feature_extractor.input_shape[1:] # exclude the batch size images = keras.layers.Input( - image_shape, batch_size=batch_size, name="images" + image_shape, batch_size=batch_size, name="images", ) - print(f"{image_shape=}") - print(f"{images.shape=}") - # Get the backbone outputs - backbone_outputs = feature_extractor(images) - feature_map = feature_pyramid(backbone_outputs) - print("backbone_outputs") - for key, value in backbone_outputs.items(): - print(f"\t{key}: {value.shape}") - print("feature_map") - for key, value in feature_map.items(): - print(f"\t{key}: {value.shape}") - - # Get the Region Proposal Boxes and Scores + # 2. Create the anchors scales = [2**x for x in [0]] aspect_ratios = [0.5, 1.0, 2.0] - num_anchors_per_location = len(scales) * len(aspect_ratios) - rpn_head = RPNHead(num_anchors_per_location=num_anchors_per_location) - # [BS, num_anchors, 4], [BS, num_anchors, 1] - rpn_boxes, rpn_scores = rpn_head(feature_map) - print("rpn_boxes") - for key, value in rpn_boxes.items(): - print(f"\t{key}: {value.shape}") - print("rpn_scores") - for key, value in rpn_scores.items(): - print(f"\t{key}: {value.shape}") - - # Create the anchors anchor_generator = anchor_generator or AnchorGenerator( bounding_box_format=bounding_box_format, sizes={ @@ -96,15 +85,27 @@ def __init__( aspect_ratios=aspect_ratios, strides={f"P{i}": 2**i for i in range(2, 7)}, clip_boxes=True, + name="anchor_generator", ) # Note: `image_shape` should not be of NoneType # Need to assert before this line anchors = anchor_generator(image_shape=image_shape) - print("anchors") - for key, value in anchors.items(): - print(f"\t{key}: {value.shape}") - # decode the deltas to boxes + ####################################################################### + # Call RPN + ####################################################################### + + # 3. Get the backbone outputs + backbone_outputs = feature_extractor(images) + feature_map = feature_pyramid(backbone_outputs) + + # 4. Get the Region Proposal Boxes and Scores + num_anchors_per_location = len(scales) * len(aspect_ratios) + rpn_head = RPNHead(num_anchors_per_location=num_anchors_per_location, name="rpn_head") + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_boxes, rpn_scores = rpn_head(feature_map) + + # 5. Decode the deltas to boxes decoded_rpn_boxes = _decode_deltas_to_boxes( anchors=anchors, boxes_delta=rpn_boxes, @@ -112,77 +113,309 @@ def __init__( box_format=bounding_box_format, variance=BOX_VARIANCE, ) - print("decoded_rpn_boxes") - for key, value in decoded_rpn_boxes.items(): - print(f"\t{key}: {value.shape}") - # Generate the Region of Interests + # 6. Generate the Region of Interests roi_generator = ROIGenerator( bounding_box_format=bounding_box_format, nms_score_threshold_train=float("-inf"), nms_score_threshold_test=float("-inf"), + name="roi_generator", ) rois, _ = roi_generator(decoded_rpn_boxes, rpn_scores) rois = _clip_boxes(rois, bounding_box_format, image_shape) - print(f"{rois.shape=}") + rpn_box_pred = keras.ops.concatenate(tree.flatten(rpn_boxes), axis=1) + rpn_cls_pred = keras.ops.concatenate(tree.flatten(rpn_scores), axis=1) + + ####################################################################### + # Call RCNN + ####################################################################### - # Using the regions call the rcnn head - roi_pooler = _ROIAligner(bounding_box_format="yxyx") + # 7. Pool the region of interests + roi_pooler = _ROIAligner(bounding_box_format="yxyx", name="roi_pooler") feature_map = roi_pooler(features=feature_map, boxes=rois) - print(f"{feature_map.shape=}") + + # 8. Reshape the feature map [BS, H*W*K] + feature_map = keras.ops.reshape( + feature_map, newshape=keras.ops.shape(rois)[:2] + (-1,), + ) + + # 9. Pass the feature map to RCNN head + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + rcnn_head = rcnn_head or RCNNHead(num_classes=num_classes, name="rcnn_head") + box_pred, cls_pred = rcnn_head(feature_map=feature_map) + + # 10. Create the model using Functional API + inputs = {"images": images} + box_pred = keras.layers.Concatenate(axis=1, name="box")([box_pred]) + cls_pred = keras.layers.Concatenate(axis=1, name="classification")([cls_pred]) + rpn_box_pred = keras.layers.Concatenate(axis=1, name="rpn_box")([rpn_box_pred]) + rpn_cls_pred = keras.layers.Concatenate(axis=1, name="rpn_classification")([rpn_cls_pred]) + outputs = { + "box": box_pred, + "classification": cls_pred, + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + } + + super().__init__(inputs=inputs, outputs=outputs, *args, **kwargs) + + # Define the model parameters + self.bounding_box_format = bounding_box_format + self.anchor_generator = anchor_generator + self.rpn_labeler = label_encoder or _RpnLabelEncoder( + anchor_format="yxyx", + ground_truth_box_format="yxyx", + positive_threshold=0.7, + negative_threshold=0.3, + samples_per_image=256, + positive_fraction=0.5, + box_variance=BOX_VARIANCE, + ) + self.feature_extractor = feature_extractor + self.feature_pyramid = feature_pyramid + self.roi_generator = roi_generator + self.rpn_head = rpn_head + self.box_matcher = BoxMatcher( + thresholds=[0.0, 0.5], match_values=[-2, -1, 1] + ) + self.roi_sampler = _ROISampler( + bounding_box_format="yxyx", + roi_matcher=self.box_matcher, + background_class=num_classes, + num_sampled_rois=512, + ) + self.roi_pooler = roi_pooler + self.rcnn_head = rcnn_head + + def compile( + self, + box_loss=None, + classification_loss=None, + rpn_box_loss=None, + rpn_classification_loss=None, + weight_decay=0.0001, + loss=None, + metrics=None, + **kwargs, + ): + if loss is not None: + raise ValueError( + "`FasterRCNN` does not accept a `loss` to `compile()`. " + "Instead, please pass `box_loss` and `classification_loss`. " + "`loss` will be ignored during training." + ) + box_loss = _parse_box_loss(box_loss) + classification_loss = _parse_classification_loss(classification_loss) + + # if hasattr(box_loss, "bounding_box_format"): + # if box_loss.bounding_box_format != self.bounding_box_format: + # raise ValueError( + # "Wrong `bounding_box_format` passed to `box_loss` in " + # "`RetinaNet.compile()`. Got " + # "`box_loss.bounding_box_format=" + # f"{box_loss.bounding_box_format}`, want " + # "`box_loss.bounding_box_format=" + # f"{self.bounding_box_format}`" + # ) + # if hasattr(classification_loss, "from_logits"): + # if not classification_loss.from_logits: + # raise ValueError( + # "FasterRCNN.compile() expects `from_logits` to be True for " + # "`classification_loss`. Got " + # "`classification_loss.from_logits=" + # f"{classification_loss.from_logits}`" + # ) + + rpn_box_loss = _parse_box_loss(rpn_box_loss) + rpn_classification_loss = _parse_classification_loss(rpn_classification_loss) + + # if hasattr(rpn_box_loss, "bounding_box_format"): + # if rpn_box_loss.bounding_box_format != self.bounding_box_format: + # raise ValueError( + # "Wrong `bounding_box_format` passed to `box_loss` in " + # "`RetinaNet.compile()`. Got " + # "`box_loss.bounding_box_format=" + # f"{box_loss.bounding_box_format}`, want " + # "`box_loss.bounding_box_format=" + # f"{self.bounding_box_format}`" + # ) + # if hasattr(rpn_classification_loss, "from_logits"): + # if not rpn_classification_loss.from_logits: + # raise ValueError( + # "FasterRCNN.compile() expects `from_logits` to be True for " + # "`classification_loss`. Got " + # "`classification_loss.from_logits=" + # f"{classification_loss.from_logits}`" + # ) + + self.rpn_box_loss = rpn_box_loss + self.rpn_cls_loss = rpn_classification_loss + self.box_loss = box_loss + self.cls_loss = classification_loss + self.weight_decay = weight_decay + losses = { + "box": self.box_loss, + "classification": self.cls_loss, + "rpn_box": self.rpn_box_loss, + "rpn_classification": self.rpn_cls_loss, + } + self._has_user_metrics = metrics is not None and len(metrics) != 0 + self._user_metrics = metrics + super().compile(loss=losses, **kwargs) + + def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): + # 1. Unpack the inputs + images = x + gt_boxes = y["boxes"] + if keras.ops.ndim(y["classes"]) != 2: + raise ValueError( + "Expected 'classes' to be a Tensor of rank 2. " + f"Got y['classes'].shape={keras.ops.shape(y['classes'])}." + ) + # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere + # gt_classes = keras.ops.expand_dims(y["classes"], axis=-1) + gt_classes = y["classes"] + + # Generate anchors + # image shape must not contain the batch size + local_batch = keras.ops.shape(images)[0] + image_shape = keras.ops.shape(images)[1:] + anchors = self.anchor_generator(image_shape=image_shape) - # - # # Create the anchor generator - # scales = [2**x for x in [0]] - # aspect_ratios = [0.5, 1.0, 2.0] - # anchor_generator = anchor_generator or AnchorGenerator( - # bounding_box_format="yxyx", - # sizes={ - # "P2": 32.0, - # "P3": 64.0, - # "P4": 128.0, - # "P5": 256.0, - # "P6": 512.0, - # }, - # scales=scales, - # aspect_ratios=aspect_ratios, - # strides={f"P{i}": 2**i for i in range(2, 7)}, - # clip_boxes=True, - # ) - - # # Create the Region Proposal Network Head - # num_anchors_per_location = len(scales) * len(aspect_ratios) - # rpn_head = RPNHead(num_anchors_per_location=num_anchors_per_location) - - # # Create the Region of Interest Generator - # roi_generator = ROIGenerator( - # bounding_box_format="yxyx", - # nms_score_threshold_train=float("-inf"), - # nms_score_threshold_test=float("-inf"), - # ) - - # # Create the Box Matcher - # box_matcher = BoxMatcher( - # thresholds=[0.0, 0.5], match_values=[-2, -1, 1] - # ) - - # # Create the Region of Interest Sampler - - # images = None - # box_pred = None - # class_pred = None - # inputs = {"images": images} - # outputs = {"box": box_pred, "classification": class_pred} - # super().__init__(inputs=inputs, outputs=outputs, *args, **kwargs) - - # def train_step(self, *args): - # data = args[-1] - # args = args[:-1] - # x, y = unpack_input(data) - # return super().train_step(*args, (x, y)) - - # def test_step(self, *args): - # data = args[-1] - # args = args[:-1] - # x, y = unpack_input(data) - # return super().test_step(*args, (x, y)) + # 2. Label with the anchors -- exclusive to compute_loss + ( + rpn_box_targets, + rpn_box_weights, + rpn_cls_targets, + rpn_cls_weights, + ) = self.rpn_labeler( + anchors_dict=keras.ops.concatenate( + tree.flatten(anchors), + axis=0, + ), + gt_boxes=gt_boxes, + gt_classes=gt_classes, + ) + + # 3. Computing the weights + rpn_box_weights /= self.rpn_labeler.samples_per_image * local_batch * 0.25 + rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch + + ####################################################################### + # Call RPN + ####################################################################### + + backbone_outputs = self.feature_extractor(images) + feature_map = self.feature_pyramid(backbone_outputs) + + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_boxes, rpn_scores = self.rpn_head(feature_map) + + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format=self.bounding_box_format, + box_format=self.bounding_box_format, + variance=BOX_VARIANCE, + ) + + rois, _ = self.roi_generator(decoded_rpn_boxes, rpn_scores) + rois = _clip_boxes(rois, self.bounding_box_format, image_shape) + rpn_box_pred = keras.ops.concatenate(tree.flatten(rpn_boxes), axis=1) + rpn_cls_pred = keras.ops.concatenate(tree.flatten(rpn_scores), axis=1) + + # 4. Stop gradient from flowing into the ROI -- exclusive to compute_loss + rois = keras.ops.stop_gradient(rois) + + # 5. Sample the ROIS -- exclusive to compute_loss -- exclusive to compute loss + ( + rois, + box_targets, + box_weights, + cls_targets, + cls_weights, + ) = self.roi_sampler(rois, gt_boxes, gt_classes) + + # 6. Box and class weights -- exclusive to compute loss + box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 + cls_weights /= self.roi_sampler.num_sampled_rois * local_batch + + ####################################################################### + # Call RCNN + ####################################################################### + + feature_map = self.roi_pooler(features=feature_map, boxes=rois) + + # [BS, H*W*K] + feature_map = keras.ops.reshape( + feature_map, newshape=keras.ops.shape(rois)[:2] + (-1,), + ) + + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + box_pred, cls_pred = self.rcnn_head(feature_map=feature_map) + + y_true = { + "rpn_box": rpn_box_targets, + "rpn_classification": rpn_cls_targets, + "box": box_targets, + "classification": cls_targets, + } + y_pred = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + } + weights = { + "rpn_box": rpn_box_weights, + "rpn_classification": rpn_cls_weights, + "box": box_weights, + "classification": cls_weights, + } + + return super().compute_loss( + x=x, y=y_true, y_pred=y_pred, sample_weight=weights, **kwargs + ) + + def train_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().train_step(*args, (x, y)) + + def test_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().test_step(*args, (x, y)) + + +def _parse_box_loss(loss): + if not isinstance(loss, str): + # support arbitrary callables + return loss + + # case insensitive comparison + if loss.lower() == "smoothl1": + return keras.losses.SmoothL1Loss(l1_cutoff=1.0, reduction="sum") + if loss.lower() == "huber": + return keras.losses.Huber(reduction="sum") + + raise ValueError( + "Expected `box_loss` to be either a Keras Loss, " + f"callable, or the string 'SmoothL1', 'Huber'. Got loss={loss}." + ) + +def _parse_classification_loss(loss): + if not isinstance(loss, str): + # support arbitrary callables + return loss + + # case insensitive comparison + if loss.lower() == "focal": + return keras.losses.FocalLoss(from_logits=True, reduction="sum") + + raise ValueError( + "Expected `classification_loss` to be either a Keras Loss, " + f"callable, or the string 'Focal'. Got loss={loss}." + ) From e26a8efb644f29e3cfa09d171e90d1ad99e5a745 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Thu, 14 Mar 2024 21:28:04 +0530 Subject: [PATCH 31/32] chore: format --- demo.py | 2 + .../faster_rcnn/faster_rcnn.py | 119 ++++++++---------- 2 files changed, 52 insertions(+), 69 deletions(-) diff --git a/demo.py b/demo.py index 330eef019e..7cd8a2f5b6 100644 --- a/demo.py +++ b/demo.py @@ -1,7 +1,9 @@ import os + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' import keras + import keras_cv from keras_cv.models import FasterRCNN diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py index ba79261402..5079c06c70 100644 --- a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -12,31 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tree + from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras -from keras_cv.models.object_detection.__internal__ import unpack_input from keras_cv.bounding_box.converters import _decode_deltas_to_boxes from keras_cv.bounding_box.utils import _clip_boxes from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator -from keras_cv.layers.object_detection.roi_generator import ROIGenerator from keras_cv.layers.object_detection.box_matcher import BoxMatcher from keras_cv.layers.object_detection.roi_align import _ROIAligner +from keras_cv.layers.object_detection.roi_generator import ROIGenerator from keras_cv.layers.object_detection.roi_sampler import _ROISampler from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder +from keras_cv.models.object_detection.__internal__ import unpack_input from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid -from keras_cv.models.object_detection.faster_rcnn import RPNHead from keras_cv.models.object_detection.faster_rcnn import RCNNHead +from keras_cv.models.object_detection.faster_rcnn import RPNHead from keras_cv.models.task import Task from keras_cv.utils.train import get_feature_extractor -import tree + BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] -class StopGradientLayer(keras.layers.Layer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def call(self, x): - return keras.ops.stop_gradient(x) # TODO(tanzheny): add more configurations @keras_cv_export("keras_cv.models.FasterRCNN") @@ -63,10 +59,16 @@ def __init__( feature_extractor = get_feature_extractor( backbone, extractor_layer_names, extractor_levels ) - feature_pyramid = feature_pyramid or FeaturePyramid(name="feature_pyramid") - image_shape = feature_extractor.input_shape[1:] # exclude the batch size + feature_pyramid = feature_pyramid or FeaturePyramid( + name="feature_pyramid" + ) + image_shape = feature_extractor.input_shape[ + 1: + ] # exclude the batch size images = keras.layers.Input( - image_shape, batch_size=batch_size, name="images", + image_shape, + batch_size=batch_size, + name="images", ) # 2. Create the anchors @@ -101,7 +103,9 @@ def __init__( # 4. Get the Region Proposal Boxes and Scores num_anchors_per_location = len(scales) * len(aspect_ratios) - rpn_head = RPNHead(num_anchors_per_location=num_anchors_per_location, name="rpn_head") + rpn_head = RPNHead( + num_anchors_per_location=num_anchors_per_location, name="rpn_head" + ) # [BS, num_anchors, 4], [BS, num_anchors, 1] rpn_boxes, rpn_scores = rpn_head(feature_map) @@ -136,20 +140,29 @@ def __init__( # 8. Reshape the feature map [BS, H*W*K] feature_map = keras.ops.reshape( - feature_map, newshape=keras.ops.shape(rois)[:2] + (-1,), + feature_map, + newshape=keras.ops.shape(rois)[:2] + (-1,), ) - + # 9. Pass the feature map to RCNN head # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] - rcnn_head = rcnn_head or RCNNHead(num_classes=num_classes, name="rcnn_head") + rcnn_head = rcnn_head or RCNNHead( + num_classes=num_classes, name="rcnn_head" + ) box_pred, cls_pred = rcnn_head(feature_map=feature_map) # 10. Create the model using Functional API inputs = {"images": images} box_pred = keras.layers.Concatenate(axis=1, name="box")([box_pred]) - cls_pred = keras.layers.Concatenate(axis=1, name="classification")([cls_pred]) - rpn_box_pred = keras.layers.Concatenate(axis=1, name="rpn_box")([rpn_box_pred]) - rpn_cls_pred = keras.layers.Concatenate(axis=1, name="rpn_classification")([rpn_cls_pred]) + cls_pred = keras.layers.Concatenate(axis=1, name="classification")( + [cls_pred] + ) + rpn_box_pred = keras.layers.Concatenate(axis=1, name="rpn_box")( + [rpn_box_pred] + ) + rpn_cls_pred = keras.layers.Concatenate( + axis=1, name="rpn_classification" + )([rpn_cls_pred]) outputs = { "box": box_pred, "classification": cls_pred, @@ -207,46 +220,10 @@ def compile( box_loss = _parse_box_loss(box_loss) classification_loss = _parse_classification_loss(classification_loss) - # if hasattr(box_loss, "bounding_box_format"): - # if box_loss.bounding_box_format != self.bounding_box_format: - # raise ValueError( - # "Wrong `bounding_box_format` passed to `box_loss` in " - # "`RetinaNet.compile()`. Got " - # "`box_loss.bounding_box_format=" - # f"{box_loss.bounding_box_format}`, want " - # "`box_loss.bounding_box_format=" - # f"{self.bounding_box_format}`" - # ) - # if hasattr(classification_loss, "from_logits"): - # if not classification_loss.from_logits: - # raise ValueError( - # "FasterRCNN.compile() expects `from_logits` to be True for " - # "`classification_loss`. Got " - # "`classification_loss.from_logits=" - # f"{classification_loss.from_logits}`" - # ) - rpn_box_loss = _parse_box_loss(rpn_box_loss) - rpn_classification_loss = _parse_classification_loss(rpn_classification_loss) - - # if hasattr(rpn_box_loss, "bounding_box_format"): - # if rpn_box_loss.bounding_box_format != self.bounding_box_format: - # raise ValueError( - # "Wrong `bounding_box_format` passed to `box_loss` in " - # "`RetinaNet.compile()`. Got " - # "`box_loss.bounding_box_format=" - # f"{box_loss.bounding_box_format}`, want " - # "`box_loss.bounding_box_format=" - # f"{self.bounding_box_format}`" - # ) - # if hasattr(rpn_classification_loss, "from_logits"): - # if not rpn_classification_loss.from_logits: - # raise ValueError( - # "FasterRCNN.compile() expects `from_logits` to be True for " - # "`classification_loss`. Got " - # "`classification_loss.from_logits=" - # f"{classification_loss.from_logits}`" - # ) + rpn_classification_loss = _parse_classification_loss( + rpn_classification_loss + ) self.rpn_box_loss = rpn_box_loss self.rpn_cls_loss = rpn_classification_loss @@ -281,7 +258,7 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): local_batch = keras.ops.shape(images)[0] image_shape = keras.ops.shape(images)[1:] anchors = self.anchor_generator(image_shape=image_shape) - + # 2. Label with the anchors -- exclusive to compute_loss ( rpn_box_targets, @@ -296,9 +273,11 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): gt_boxes=gt_boxes, gt_classes=gt_classes, ) - + # 3. Computing the weights - rpn_box_weights /= self.rpn_labeler.samples_per_image * local_batch * 0.25 + rpn_box_weights /= ( + self.rpn_labeler.samples_per_image * local_batch * 0.25 + ) rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch ####################################################################### @@ -348,17 +327,18 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): # [BS, H*W*K] feature_map = keras.ops.reshape( - feature_map, newshape=keras.ops.shape(rois)[:2] + (-1,), + feature_map, + newshape=keras.ops.shape(rois)[:2] + (-1,), ) - + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] box_pred, cls_pred = self.rcnn_head(feature_map=feature_map) y_true = { - "rpn_box": rpn_box_targets, - "rpn_classification": rpn_cls_targets, - "box": box_targets, - "classification": cls_targets, + "rpn_box": rpn_box_targets, + "rpn_classification": rpn_cls_targets, + "box": box_targets, + "classification": cls_targets, } y_pred = { "rpn_box": rpn_box_pred, @@ -372,7 +352,7 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): "box": box_weights, "classification": cls_weights, } - + return super().compute_loss( x=x, y=y_true, y_pred=y_pred, sample_weight=weights, **kwargs ) @@ -406,6 +386,7 @@ def _parse_box_loss(loss): f"callable, or the string 'SmoothL1', 'Huber'. Got loss={loss}." ) + def _parse_classification_loss(loss): if not isinstance(loss, str): # support arbitrary callables From 5a1f5a7437207fe97e563407149795eba4c8bf4e Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Thu, 14 Mar 2024 21:54:02 +0530 Subject: [PATCH 32/32] chore: demo.py --- demo.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/demo.py b/demo.py index 7cd8a2f5b6..4f156cc214 100644 --- a/demo.py +++ b/demo.py @@ -24,6 +24,8 @@ ), "classes": keras.ops.array([[1, 1, 1]], dtype="float32"), } + +# Initialize the model model = FasterRCNN( batch_size=batch_size, num_classes=2, @@ -40,6 +42,7 @@ for key, value in outputs.items(): print(f"{key}: {value.shape}") +# Compile the model model.compile( optimizer=keras.optimizers.Adam(), box_loss=keras.losses.Huber(), @@ -47,6 +50,27 @@ rpn_box_loss=keras.losses.Huber(), rpn_classification_loss=keras.losses.BinaryCrossentropy(from_logits=True), ) -# Train the model + +# Compute Loss from the model loss = model.compute_loss(x=images, y=labels, y_pred=None, sample_weight=None) -print(loss) \ No newline at end of file +print(loss) + +# Train step +xs = keras.ops.ones((1, 512, 512, 3), "float32") +ys = { + "classes": keras.ops.array([[1, 1, 1]], dtype="float32"), + "boxes": keras.ops.array( + [ + [ + [0, 0, 100, 100], + [100, 100, 200, 200], + [300, 300, 100, 100], + ] + ], + dtype="float32", + ), +} +import tensorflow as tf +ds = tf.data.Dataset.from_tensor_slices((xs, ys)) +ds = ds.batch(1, drop_remainder=True) +model.fit(ds, epochs=1) \ No newline at end of file