Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating FasterRCNN to use Task API #2012

Draft
wants to merge 50 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
9bc256f
chore: initial commit
ariG23498 Aug 4, 2023
a8ad7c4
review comments
ariG23498 Aug 17, 2023
d523a32
Merge branch 'master' into aritra/port-rcnn
ariG23498 Aug 17, 2023
ed3337c
chore: train test step modification
ariG23498 Aug 18, 2023
301bb1d
Merge branch 'master' into aritra/port-rcnn
ariG23498 Aug 28, 2023
005f70d
review nits
ariG23498 Aug 28, 2023
da5a01e
chore: adding test
ariG23498 Sep 1, 2023
ea88f2c
Merge branch 'master' into aritra/port-rcnn
ariG23498 Sep 1, 2023
5c7048f
Merge branch 'master' into aritra/port-rcnn
ariG23498 Sep 7, 2023
ac005b8
chore: reformat compute loss
ariG23498 Sep 7, 2023
613e29f
chore: faster rcnn call and predict work
ariG23498 Sep 15, 2023
dcb648a
resolved conflicts
ariG23498 Sep 16, 2023
5bf2bc9
chore: porting roi align to keras core
ariG23498 Sep 16, 2023
7d6ef6f
chore: port roi sampler to keras core
ariG23498 Sep 16, 2023
f1e3e17
chore: port rpn label encoder to keras core
ariG23498 Sep 16, 2023
6478cbf
chore: adding tests and fix lint
ariG23498 Sep 16, 2023
7741edc
fix: lint
ariG23498 Sep 16, 2023
13a26e6
chore: adding copyright to faster rcnn presets script
ariG23498 Sep 16, 2023
0bc4cfa
Merge branch 'master' into aritra/port-rcnn
ariG23498 Sep 19, 2023
3b42ecc
chore: removing tf imports
ariG23498 Sep 21, 2023
be9178b
fix imports
ariG23498 Sep 27, 2023
c3b0cfa
Merge branch 'master' into aritra/port-rcnn
ariG23498 Nov 2, 2023
54fd49c
Merge branch 'master' into aritra/port-rcnn
ariG23498 Nov 6, 2023
e59d2b4
fix: style
ariG23498 Nov 6, 2023
001162c
chore: making the model functional in init
ariG23498 Nov 7, 2023
4889192
Merge branch 'master' into aritra/port-rcnn
ariG23498 Nov 7, 2023
4da5ff1
Merge branch 'master' into aritra/port-rcnn
ariG23498 Nov 22, 2023
6a51562
Merge branch 'master' into aritra/port-rcnn
ariG23498 Dec 4, 2023
36da548
Merge branch 'master' into aritra/port-rcnn
ariG23498 Dec 6, 2023
711c031
Merge branch 'master' into aritra/port-rcnn
ariG23498 Dec 18, 2023
9aab0e9
chore: adding static image shapes to backbone in tests
ariG23498 Dec 18, 2023
49815d1
fix: parameterised input shape in test
ariG23498 Dec 18, 2023
6061f01
fix: reshape
ariG23498 Dec 18, 2023
ef279a9
fix: format and output dict
ariG23498 Dec 18, 2023
134f897
chore: masking sample weights for box labels -1
ariG23498 Dec 19, 2023
e190e1b
chore: fixing sample weights and decode predictions
ariG23498 Dec 19, 2023
70f205c
Merge branch 'master' into aritra/port-rcnn
ariG23498 Jan 2, 2024
821b7aa
chore: porting roi gen to keras 3 ops
ariG23498 Jan 2, 2024
324f7fc
Merge branch 'master' into aritra/port-rcnn
ariG23498 Jan 10, 2024
9227255
chore: port roi gen to keras 3
ariG23498 Jan 10, 2024
345764f
chore: removing asserts for keras 3
ariG23498 Jan 10, 2024
3a714e7
Merge branch 'master' into aritra/port-rcnn
ariG23498 Feb 28, 2024
9e7eea0
chore: adding faster rcnn to kokoro build script
ariG23498 Feb 28, 2024
af47e3f
chore: changing a bunch of things and keeping it commited for reference
ariG23498 Feb 28, 2024
fd20746
Merge branch 'master' into aritra/port-rcnn
ariG23498 Mar 13, 2024
2f5c0a2
chore: update roi align
ariG23498 Mar 13, 2024
9c85dfc
chore: adding init and compute loss
ariG23498 Mar 14, 2024
e26a8ef
chore: format
ariG23498 Mar 14, 2024
5a1f5a7
chore: demo.py
ariG23498 Mar 14, 2024
7d873f6
Merge branch 'master' into aritra/port-rcnn
ariG23498 Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 122 additions & 105 deletions keras_cv/layers/object_detection/roi_align.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion keras_cv/layers/object_detection/roi_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import assert_tf_keras
from keras_cv.backend import ops

divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved

@keras_cv_export("keras_cv.layers.ROIGenerator")
Expand Down Expand Up @@ -148,7 +149,7 @@ def per_level_gen(boxes, scores):
# scores can also be [batch_size, num_boxes, 1]
if len(scores_shape) == 3:
scores = tf.squeeze(scores, axis=-1)
_, num_boxes = scores.get_shape().as_list()
num_boxes = ops.shape(boxes)[1]
level_pre_nms_topk = min(num_boxes, pre_nms_topk)
level_post_nms_topk = min(num_boxes, post_nms_topk)
scores, sorted_indices = tf.nn.top_k(
Expand Down
45 changes: 23 additions & 22 deletions keras_cv/layers/object_detection/roi_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow import keras
import numpy as np # used for newaxis

from keras_cv import bounding_box
from keras_cv.backend import assert_tf_keras
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.bounding_box import iou
from keras_cv.layers.object_detection import box_matcher
from keras_cv.layers.object_detection import sampling
Expand Down Expand Up @@ -84,9 +85,9 @@ def __init__(

def call(
self,
rois: tf.Tensor,
gt_boxes: tf.Tensor,
gt_classes: tf.Tensor,
rois,
gt_boxes,
gt_classes,
):
"""
Args:
Expand All @@ -102,8 +103,8 @@ def call(
"""
if self.append_gt_boxes:
# num_rois += num_gt
rois = tf.concat([rois, gt_boxes], axis=1)
num_rois = rois.get_shape().as_list()[1]
rois = ops.concatenate([rois, gt_boxes], axis=1)
num_rois = ops.shape(rois)[1]
if num_rois is None:
raise ValueError(
f"`rois` must have static shape, got {rois.get_shape()}"
Expand All @@ -126,27 +127,27 @@ def call(
# [batch_size, num_rois] | [batch_size, num_rois]
matched_gt_cols, matched_vals = self.roi_matcher(similarity_mat)
# [batch_size, num_rois]
positive_matches = tf.math.equal(matched_vals, 1)
negative_matches = tf.math.equal(matched_vals, -1)
positive_matches = ops.equal(matched_vals, 1)
negative_matches = ops.equal(matched_vals, -1)
self._positives.update_state(
tf.reduce_sum(tf.cast(positive_matches, tf.float32), axis=-1)
ops.sum(ops.cast(positive_matches, "float32"), axis=-1)
)
self._negatives.update_state(
tf.reduce_sum(tf.cast(negative_matches, tf.float32), axis=-1)
ops.sum(ops.cast(negative_matches, "float32"), axis=-1)
)
# [batch_size, num_rois, 1]
background_mask = tf.expand_dims(
tf.logical_not(positive_matches), axis=-1
background_mask = ops.expand_dims(
ops.logical_not(positive_matches), axis=-1
)
# [batch_size, num_rois, 1]
matched_gt_classes = target_gather._target_gather(
gt_classes, matched_gt_cols
)
# also set all background matches to `background_class`
matched_gt_classes = tf.where(
matched_gt_classes = ops.where(
background_mask,
tf.cast(
self.background_class * tf.ones_like(matched_gt_classes),
ops.cast(
self.background_class * ops.ones_like(matched_gt_classes),
gt_classes.dtype,
),
matched_gt_classes,
Expand All @@ -163,9 +164,9 @@ def call(
variance=[0.1, 0.1, 0.2, 0.2],
)
# also set all background matches to 0 coordinates
encoded_matched_gt_boxes = tf.where(
encoded_matched_gt_boxes = ops.where(
background_mask,
tf.zeros_like(matched_gt_boxes),
ops.zeros_like(matched_gt_boxes),
encoded_matched_gt_boxes,
)
# [batch_size, num_rois]
Expand All @@ -176,7 +177,7 @@ def call(
self.positive_fraction,
)
# [batch_size, num_sampled_rois] in the range of [0, num_rois)
sampled_indicators, sampled_indices = tf.math.top_k(
sampled_indicators, sampled_indices = ops.math.top_k(
sampled_indicators, k=self.num_sampled_rois, sorted=True
)
# [batch_size, num_sampled_rois, 4]
Expand All @@ -192,12 +193,12 @@ def call(
# [batch_size, num_sampled_rois, 1]
# all negative samples will be ignored in regression
sampled_box_weights = target_gather._target_gather(
tf.cast(positive_matches[..., tf.newaxis], gt_boxes.dtype),
ops.cast(positive_matches[..., np.newaxis], gt_boxes.dtype),
sampled_indices,
)
# [batch_size, num_sampled_rois, 1]
sampled_indicators = sampled_indicators[..., tf.newaxis]
sampled_class_weights = tf.cast(sampled_indicators, gt_classes.dtype)
sampled_indicators = sampled_indicators[..., np.newaxis]
sampled_class_weights = ops.cast(sampled_indicators, gt_classes.dtype)
return (
sampled_rois,
sampled_gt_boxes,
Expand Down
36 changes: 18 additions & 18 deletions keras_cv/layers/object_detection/rpn_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Mapping

import tensorflow as tf
from tensorflow import keras
import numpy as np # Used for newaxis
import tree

from keras_cv import bounding_box
from keras_cv.backend import assert_tf_keras
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.bounding_box import iou
from keras_cv.layers.object_detection import box_matcher
from keras_cv.layers.object_detection import sampling
Expand Down Expand Up @@ -92,9 +92,9 @@ def __init__(

def call(
self,
anchors_dict: Mapping[str, tf.Tensor],
gt_boxes: tf.Tensor,
gt_classes: tf.Tensor,
anchors_dict,
gt_boxes,
gt_classes,
):
"""
Args:
Expand All @@ -112,7 +112,7 @@ def call(
anchors = anchors_dict
if isinstance(anchors, dict):
pack = True
anchors = tf.concat(tf.nest.flatten(anchors), axis=0)
anchors = ops.concatenate(tree.flatten(anchors), axis=0)
anchors = bounding_box.convert_format(
anchors, source=self.anchor_format, target="yxyx"
)
Expand All @@ -126,14 +126,14 @@ def call(
# [num_anchors] or [batch_size, num_anchors]
matched_gt_indices, matched_vals = self.box_matcher(similarity_mat)
# [num_anchors] or [batch_size, num_anchors]
positive_matches = tf.math.equal(matched_vals, 1)
positive_matches = ops.equal(matched_vals, 1)
# currently SyncOnReadVariable does not support `assign_add` in
# cross-replica.
# self._positives.update_state(
# tf.reduce_sum(tf.cast(positive_matches, tf.float32), axis=-1)
# )

negative_matches = tf.math.equal(matched_vals, -1)
negative_matches = ops.equal(matched_vals, -1)
# [num_anchors, 4] or [batch_size, num_anchors, 4]
matched_gt_boxes = target_gather._target_gather(
gt_boxes, matched_gt_indices
Expand All @@ -148,18 +148,18 @@ def call(
variance=self.box_variance,
)
# [num_anchors, 1] or [batch_size, num_anchors, 1]
box_sample_weights = tf.cast(
positive_matches[..., tf.newaxis], gt_boxes.dtype
box_sample_weights = ops.cast(
positive_matches[..., np.newaxis], gt_boxes.dtype
)

# [num_anchors, 1] or [batch_size, num_anchors, 1]
positive_mask = tf.expand_dims(positive_matches, axis=-1)
positive_mask = ops.expand_dims(positive_matches, axis=-1)
# set all negative and ignored matches to 0, and all positive matches to
# 1 [num_anchors, 1] or [batch_size, num_anchors, 1]
positive_classes = tf.ones_like(positive_mask, dtype=gt_classes.dtype)
negative_classes = tf.zeros_like(positive_mask, dtype=gt_classes.dtype)
positive_classes = ops.ones_like(positive_mask, dtype=gt_classes.dtype)
negative_classes = ops.zeros_like(positive_mask, dtype=gt_classes.dtype)
# [num_anchors, 1] or [batch_size, num_anchors, 1]
class_targets = tf.where(
class_targets = ops.where(
positive_mask, positive_classes, negative_classes
)
# [num_anchors] or [batch_size, num_anchors]
Expand All @@ -170,8 +170,8 @@ def call(
self.positive_fraction,
)
# [num_anchors, 1] or [batch_size, num_anchors, 1]
class_sample_weights = tf.cast(
sampled_indicators[..., tf.newaxis], gt_classes.dtype
class_sample_weights = ops.cast(
sampled_indicators[..., np.newaxis], gt_classes.dtype
)
if pack:
encoded_box_targets = self.unpack_targets(
Expand Down
1 change: 1 addition & 0 deletions keras_cv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetLBackbone
from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone
from keras_cv.models.classification.image_classifier import ImageClassifier
from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import FasterRCNN
from keras_cv.models.object_detection.retinanet.retinanet import RetinaNet
from keras_cv.models.object_detection.yolo_v8.yolo_v8_backbone import (
YOLOV8Backbone,
Expand Down
3 changes: 0 additions & 3 deletions keras_cv/models/legacy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
from keras_cv.models.legacy.mlp_mixer import MLPMixerB16
from keras_cv.models.legacy.mlp_mixer import MLPMixerB32
from keras_cv.models.legacy.mlp_mixer import MLPMixerL16
from keras_cv.models.legacy.object_detection.faster_rcnn.faster_rcnn import (
FasterRCNN,
)
from keras_cv.models.legacy.regnet import RegNetX002
from keras_cv.models.legacy.regnet import RegNetX004
from keras_cv.models.legacy.regnet import RegNetX006
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The KerasCV Authors
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_cv.models.object_detection.faster_rcnn.feature_pyramid import (
FeaturePyramid,
)
from keras_cv.models.object_detection.faster_rcnn.rcnn_head import RCNNHead
from keras_cv.models.object_detection.faster_rcnn.rpn_head import RPNHead
Loading
Loading