From 8eeb729d9f677e3a99a1d97c24e34d83da243c46 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson Date: Mon, 4 Nov 2024 18:10:29 -0500 Subject: [PATCH 01/11] Adding simple 'locsAndConfs' TCN input vector --- configs/data/ptg.yaml | 18 +-- configs/experiment/m2/feat_locsConfs.yaml | 122 +++++++++++++++++++ tcn_hpl/data/tcn_dataset.py | 26 ++-- tcn_hpl/data/vectorize/_data.py | 3 + tcn_hpl/data/vectorize/locs_and_confs.py | 141 ++++++++++++++++++++++ 5 files changed, 292 insertions(+), 18 deletions(-) create mode 100644 configs/experiment/m2/feat_locsConfs.yaml create mode 100644 tcn_hpl/data/vectorize/locs_and_confs.py diff --git a/configs/data/ptg.yaml b/configs/data/ptg.yaml index 0d0cd6974..d3565ab4d 100644 --- a/configs/data/ptg.yaml +++ b/configs/data/ptg.yaml @@ -3,14 +3,16 @@ _target_: tcn_hpl.data.ptg_datamodule.PTGDataModule train_dataset: _target_: tcn_hpl.data.tcn_dataset.TCNDataset window_size: 15 - vectorizer: - _target_: tcn_hpl.data.vectorize.classic.Classic - feat_version: 6 - top_k: 1 - num_classes: 7 - background_idx: 0 - hand_left_idx: 5 - hand_right_idx: 6 + # No vectorizer should be specified here, as there should be no "default". + # Example of a vectorizer: + # vectorizer: + # _target_: tcn_hpl.data.vectorize.classic.Classic + # feat_version: 6 + # top_k: 1 + # num_classes: 7 + # background_idx: 0 + # hand_left_idx: 5 + # hand_right_idx: 6 transform: _target_: torchvision.transforms.Compose transforms: [] diff --git a/configs/experiment/m2/feat_locsConfs.yaml b/configs/experiment/m2/feat_locsConfs.yaml new file mode 100644 index 000000000..3af586543 --- /dev/null +++ b/configs/experiment/m2/feat_locsConfs.yaml @@ -0,0 +1,122 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example +task: "m2" +# feature_version: 6 +topic: "medical" + +defaults: + - override /data: ptg + - override /model: ptg + - override /callbacks: default + - override /trainer: gpu + - override /paths: default + #- override /logger: aim + - override /logger: csv + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +# Change this name to something descriptive and unique for this experiment. +# This will differentiate the run logs and output to be separate from other +# experiments that may have been run under the configured +# Setting this value influences: +# - the name of the directory under `${paths.root_dir}/logs/` in which training +# run files are stored. +# Default is "train" set in the "configs/train.yaml" file. +#task_name: +task_name: cameron_locs_and_confs + +# simply provide checkpoint path to resume training +#ckpt_path: null + +tags: ["m2", "ms_tcn", "debug"] + +seed: 12345 + +trainer: + min_epochs: 50 + max_epochs: 500 + log_every_n_steps: 1 + +model: + compile: false + net: + # Length of feature vector for a single frame. + # Currently derived from feature version and other hyperparameters. + dim: 102 + num_classes: 9 + +data: + coco_train_activities: "${paths.coco_file_root}/TRAIN-activity_truth.coco.json" + coco_train_objects: "${paths.coco_file_root}/TRAIN-object_detections.coco.json" + coco_train_poses: "${paths.coco_file_root}/TRAIN-pose_estimates.coco.json" + + coco_validation_activities: "${paths.coco_file_root}/VALIDATION-activity_truth.coco.json" + coco_validation_objects: "${paths.coco_file_root}/VALIDATION-object_detections.coco.json" + coco_validation_poses: "${paths.coco_file_root}/VALIDATION-pose_estimates.coco.json" + + coco_test_activities: "${paths.coco_file_root}/TEST-activity_truth.coco.json" + coco_test_objects: "${paths.coco_file_root}/TEST-object_detections.coco.json" + coco_test_poses: "${paths.coco_file_root}/TEST-pose_estimates.coco.json" + + batch_size: 16384 + num_workers: 16 + target_framerate: 15 # BBN Hololens2 Framerate + epoch_length: 200000 + + train_dataset: + window_size: 25 + vectorizer: + _target_: tcn_hpl.data.vectorize.locs_and_confs.LocsAndConfs + top_k: 1 + num_classes: 7 + use_joint_confs: True + use_pixel_norm: True + use_hand_obj_offsets: False + background_idx: 0 + transform: + transforms: [] # no transforms +# - _target_: tcn_hpl.data.components.augmentations.MoveCenterPts +# hand_dist_delta: 0.05 +# obj_dist_delta: 0.05 +# joint_dist_delta: 0.025 +# im_w: 1280 +# im_h: 720 +# num_obj_classes: 42 +# feat_version: 2 +# top_k_objects: 1 +# - _target_: tcn_hpl.data.components.augmentations.NormalizePixelPts +# im_w: 1280 +# im_h: 720 +# num_obj_classes: 42 +# feat_version: 2 +# top_k_objects: 1 + val_dataset: + transform: + transforms: [] # no transforms +# - _target_: tcn_hpl.data.components.augmentations.NormalizePixelPts +# im_w: 1280 +# im_h: 720 +# num_obj_classes: 42 +# feat_version: 2 +# top_k_objects: 1 + # Test dataset usually configured the same as val, unless there is some + # different set of transforms that should be used during test/prediction. + +paths: + # root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL/" + # root_dir: "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/training_root" + root_dir: "/home/local/KHQ/cameron.johnson/code/TCN_HPL/tcn_hpl/train-TCN-M2_bbn_hololens/training_root" + + # Convenience variable to where your train/val/test split COCO file datasets + # are stored. + # coco_file_root: "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens" + coco_file_root: "/home/local/KHQ/cameron.johnson/code/TCN_HPL/train-TCN-M2_bbn_hololens" + +#exp_name: "tcn_training_revive" +#logger: +# aim: +# experiment: ${task_name} +# capture_terminal_logs: true diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index 059a416c3..bc6878c7e 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -306,6 +306,12 @@ def load_data_offline( ) else: frame_dets = empty_dets + + # Frame height and width should be available. + img_info = activity_coco.index.imgs[img_id] + assert "height" in img_info + assert "width" in img_info + frame_size = (img_info["width"], img_info["height"]) # Only consider annotations that actually have keypoints. # There may be no poses on this frame. @@ -336,7 +342,8 @@ def load_data_offline( ) else: frame_poses = empty_pose - vid_frame_data.append(FrameData(frame_dets, frame_poses)) + # import ipdb; ipdb.set_trace() + vid_frame_data.append(FrameData(frame_dets, frame_poses, frame_size)) # Compose a list of indices into frame_data that this video's # worth of content resides. @@ -639,15 +646,14 @@ def test_dataset_for_input( pose_coco = kwcoco.CocoDataset(pose_coco) # TODO: Some method of configuring which vectorizer to use. - from tcn_hpl.data.vectorize.classic import Classic - vectorizer = Classic( - feat_version=6, - top_k=1, - # M2/R18 object detection class indices - num_classes=7, - background_idx=0, - hand_left_idx=5, - hand_right_idx=6, + from tcn_hpl.data.vectorize.locs_and_confs import LocsAndConfs + vectorizer = LocsAndConfs( + top_k = 1, + num_classes = 7, + use_joint_confs = True, + use_pixel_norm = True, + use_hand_obj_offsets = False, + background_idx = 0 ) dataset = TCNDataset(window_size=window_size, vectorizer=vectorizer) diff --git a/tcn_hpl/data/vectorize/_data.py b/tcn_hpl/data/vectorize/_data.py index 95714cc0d..c17289924 100644 --- a/tcn_hpl/data/vectorize/_data.py +++ b/tcn_hpl/data/vectorize/_data.py @@ -90,6 +90,9 @@ class FrameData: # This may be None, which implies that an object pose estimation was not # run for this frame. poses: tg.Optional[FramePoses] + # FrameSize: Length-2 tuple expected: (width, height). + # This is the video frame's width and height in pixels. + size: tuple def __bool__(self): """ diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py new file mode 100644 index 000000000..0920dea9e --- /dev/null +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -0,0 +1,141 @@ +import functools +import typing as tg + +import numpy as np +from numpy import typing as npt + +from tcn_hpl.data.vectorize._interface import Vectorize, FrameData +from tcn_hpl.data.vectorize_classic import ( + obj_det2d_set_to_feature, + zero_joint_offset, + HAND_STR_LEFT, + HAND_STR_RIGHT, +) + + +class LocsAndConfs(Vectorize): + """ + Previous manual approach to vectorization. + + Arguments: + feat_version: Version number of the feature to produce. + top_k: The number of top per-class examples to use in vector + construction. + """ + + def __init__( + self, + top_k: int = 1, + num_classes: int = 7, + use_joint_confs: bool = True, + use_pixel_norm: bool = True, + use_hand_obj_offsets: bool = False, + background_idx: int = 0 + ): + super().__init__() + + self._top_k = top_k + self._num_classes = num_classes + self._use_joint_confs = use_joint_confs + self._use_pixel_norm = use_pixel_norm + self._use_hand_obj_offsets = use_hand_obj_offsets + self._background_idx = background_idx + + # Get the top "k" object indexes for each object + @staticmethod + def get_top_k_indexes_of_one_obj_type(f_dets, k, label_ind): + """ + Find all instances of a label index in object detections. + Then sort them and return the top K. + Inputs: + - object_dets: + """ + labels = f_dets.labels + scores = f_dets.scores + # Get all labels of an obj type + filtered_idxs = [i for i, e in enumerate(labels) if e == label_ind] + if not filtered_idxs: + return None + filtered_scores = [scores[i] for i in filtered_idxs] + # Sort labels by score values. + sorted_inds = [i[1] for i in sorted(zip(filtered_scores, filtered_idxs))] + return sorted_inds[:k] + @staticmethod + def append_vector(frame_feat, i, number): + frame_feat[i] = number + return frame_feat, i + 1 + + + def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: + + ######################### + # Feature vector + ######################### + # Length: pose confs * 22, pose X's * 22, pose Y's * 22, + # obj confs * num_objects(7 for M2), + # obj X * num_objects(7 for M2), + # obj Y * num_objects(7 for M2) + # obj W * num_objects(7 for M2) + # obj H * num_objects(7 for M2) + # casualty conf * 1 + vector_len = 102 + frame_feat = np.zeros(vector_len) + vector_ind = 0 + if self._use_pixel_norm: + W = data.size[0] + H = data.size[1] + else: + W = 1 + H = 1 + f_dets = data.object_detections + + # Loop through all classes: populate obj conf, obj X, obj Y. + # Assumption: class labels are [0, 1, 2,... num_classes-1]. + for obj_ind in range(0,self._num_classes): + top_k_idxs = self.get_top_k_indexes_of_one_obj_type(f_dets, self._top_k, obj_ind) + if top_k_idxs: # This is None if there were no detections to sort for this class + for idx in top_k_idxs: + # Conf + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.scores[idx]) + # X + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][0] / W) + # Y + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][1] / H) + # W + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][2] / W) + # H + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][3] / H) + else: + for _ in range(0,5): + # 5 Zeros + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) + + f_poses = data.poses + if f_poses: + # Find most confident body detection + confident_pose_idx = np.argmax(f_poses.scores) + num_joints = f_poses.joint_positions.shape[1] + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.scores[confident_pose_idx]) + + for joint_ind in range(0, num_joints): + # Conf + if self._use_joint_confs: + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.joint_scores[confident_pose_idx][joint_ind]) + # X + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.joint_positions[confident_pose_idx][joint_ind][0] / W) + # Y + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.joint_positions[confident_pose_idx][joint_ind][1] / H) + else: + num_joints = f_poses.joint_positions.shape[1] + if self._use_joint_confs: + rows_per_joint = 3 + else: + rows_per_joint = 2 + for _ in range(num_joints * rows_per_joint + 1): + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) + + assert vector_ind == vector_len + + frame_feat = frame_feat.ravel().astype(feat_dtype) + + return frame_feat From 706d6f0c3bf9773b5d7b55cffd6d3738d6452a57 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson Date: Tue, 5 Nov 2024 11:43:36 -0500 Subject: [PATCH 02/11] remove custom arg --- configs/experiment/m2/feat_locsConfs.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/experiment/m2/feat_locsConfs.yaml b/configs/experiment/m2/feat_locsConfs.yaml index 3af586543..fe5c95ae2 100644 --- a/configs/experiment/m2/feat_locsConfs.yaml +++ b/configs/experiment/m2/feat_locsConfs.yaml @@ -26,7 +26,6 @@ defaults: # run files are stored. # Default is "train" set in the "configs/train.yaml" file. #task_name: -task_name: cameron_locs_and_confs # simply provide checkpoint path to resume training #ckpt_path: null From 80c5bc85e271b75f6a7177241ba3fae107bead1b Mon Sep 17 00:00:00 2001 From: cameron-a-johnson <43187095+cameron-a-johnson@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:16:35 -0500 Subject: [PATCH 03/11] Update tcn_hpl/data/vectorize/_data.py Co-authored-by: Paul Tunison <735270+Purg@users.noreply.github.com> --- tcn_hpl/data/vectorize/_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tcn_hpl/data/vectorize/_data.py b/tcn_hpl/data/vectorize/_data.py index c17289924..8e94256ad 100644 --- a/tcn_hpl/data/vectorize/_data.py +++ b/tcn_hpl/data/vectorize/_data.py @@ -92,7 +92,7 @@ class FrameData: poses: tg.Optional[FramePoses] # FrameSize: Length-2 tuple expected: (width, height). # This is the video frame's width and height in pixels. - size: tuple + size: tg.Tuple[int, int] def __bool__(self): """ From 3b7f2cb537185e955f1da06af5022f3dc0b2bdcf Mon Sep 17 00:00:00 2001 From: cameron-a-johnson Date: Tue, 5 Nov 2024 16:18:46 -0500 Subject: [PATCH 04/11] remove unnecessary import --- tcn_hpl/data/vectorize/locs_and_confs.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index 0920dea9e..f5508706d 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -5,13 +5,6 @@ from numpy import typing as npt from tcn_hpl.data.vectorize._interface import Vectorize, FrameData -from tcn_hpl.data.vectorize_classic import ( - obj_det2d_set_to_feature, - zero_joint_offset, - HAND_STR_LEFT, - HAND_STR_RIGHT, -) - class LocsAndConfs(Vectorize): """ @@ -136,6 +129,6 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: assert vector_ind == vector_len - frame_feat = frame_feat.ravel().astype(feat_dtype) + frame_feat = frame_feat.ravel().astype(np.float32) return frame_feat From 65675a6a8c6022b0d4263c641594a910fea6b5b2 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson Date: Tue, 5 Nov 2024 16:27:27 -0500 Subject: [PATCH 05/11] Adding arguments documentation for LocsAndConfs --- tcn_hpl/data/vectorize/locs_and_confs.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index f5508706d..46914e1b5 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -11,9 +11,20 @@ class LocsAndConfs(Vectorize): Previous manual approach to vectorization. Arguments: - feat_version: Version number of the feature to produce. top_k: The number of top per-class examples to use in vector construction. + num_classes: the number of classes in the object detector. + use_joint_confs: use the confidence of each pose joint. + (changes the length of the input vector, which needs to + be manually updated if this flag changes.) + use_pixel_norm: Normalize pixel coordinates by dividing by + frame height and width, respectively. Normalized values + are between 0 and 1. Does not change input vector length. + use_joint_obj_offsets: add abs(X and Y offsets) for between joints and + each object. + (changes the length of the input vector, which needs to + be manually updated if this flag changes.) + """ def __init__( @@ -22,7 +33,7 @@ def __init__( num_classes: int = 7, use_joint_confs: bool = True, use_pixel_norm: bool = True, - use_hand_obj_offsets: bool = False, + use_joint_obj_offsets: bool = False, background_idx: int = 0 ): super().__init__() @@ -31,7 +42,7 @@ def __init__( self._num_classes = num_classes self._use_joint_confs = use_joint_confs self._use_pixel_norm = use_pixel_norm - self._use_hand_obj_offsets = use_hand_obj_offsets + self._use_joint_obj_offsets = use_joint_obj_offsets self._background_idx = background_idx # Get the top "k" object indexes for each object From 4d80561b9722c1e30fe5010388f7427a93aaa338 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson <43187095+cameron-a-johnson@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:28:24 -0500 Subject: [PATCH 06/11] Update tcn_hpl/data/vectorize/locs_and_confs.py Co-authored-by: Paul Tunison <735270+Purg@users.noreply.github.com> --- tcn_hpl/data/vectorize/locs_and_confs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index 46914e1b5..c138d0a96 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -64,6 +64,7 @@ def get_top_k_indexes_of_one_obj_type(f_dets, k, label_ind): # Sort labels by score values. sorted_inds = [i[1] for i in sorted(zip(filtered_scores, filtered_idxs))] return sorted_inds[:k] + @staticmethod def append_vector(frame_feat, i, number): frame_feat[i] = number From 9554fd0411f3d74bb8b9533c9882867ad0ac8a49 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson <43187095+cameron-a-johnson@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:28:38 -0500 Subject: [PATCH 07/11] Update tcn_hpl/data/vectorize/locs_and_confs.py Co-authored-by: Paul Tunison <735270+Purg@users.noreply.github.com> --- tcn_hpl/data/vectorize/locs_and_confs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index c138d0a96..f4283a0e1 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -70,7 +70,6 @@ def append_vector(frame_feat, i, number): frame_feat[i] = number return frame_feat, i + 1 - def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: ######################### From 31d1a159ef947ea9c1dcf924ebff9d384a3f4134 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson <43187095+cameron-a-johnson@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:00:31 -0500 Subject: [PATCH 08/11] Update tcn_hpl/data/vectorize/locs_and_confs.py Co-authored-by: Paul Tunison <735270+Purg@users.noreply.github.com> --- tcn_hpl/data/vectorize/locs_and_confs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index f4283a0e1..5417a9a42 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -110,7 +110,7 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: # H frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][3] / H) else: - for _ in range(0,5): + for _ in range(0, self._top_k * 5): # 5 Zeros frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) From add6dad284cdf1d26d814754af7a3aee8ac05db3 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson Date: Tue, 5 Nov 2024 17:19:38 -0500 Subject: [PATCH 09/11] set vector type at construction. Remove ravel. --- tcn_hpl/data/vectorize/locs_and_confs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index 46914e1b5..3426c0389 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -83,7 +83,7 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: # obj H * num_objects(7 for M2) # casualty conf * 1 vector_len = 102 - frame_feat = np.zeros(vector_len) + frame_feat = np.zeros(vector_len, dtype=np.float32) vector_ind = 0 if self._use_pixel_norm: W = data.size[0] @@ -140,6 +140,4 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: assert vector_ind == vector_len - frame_feat = frame_feat.ravel().astype(np.float32) - return frame_feat From e137021b50034246a5677724581edd23a2bed4be Mon Sep 17 00:00:00 2001 From: cameron-a-johnson Date: Wed, 6 Nov 2024 15:41:51 -0500 Subject: [PATCH 10/11] compute the vector length instead of hard-coding it. --- tcn_hpl/data/vectorize/locs_and_confs.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index 472c2e7b1..559c7be9e 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -6,6 +6,8 @@ from tcn_hpl.data.vectorize._interface import Vectorize, FrameData +NUM_POSE_JOINTS = 22 + class LocsAndConfs(Vectorize): """ Previous manual approach to vectorization. @@ -69,9 +71,8 @@ def get_top_k_indexes_of_one_obj_type(f_dets, k, label_ind): def append_vector(frame_feat, i, number): frame_feat[i] = number return frame_feat, i + 1 - - def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: - + + def determine_vector_length(self, data: FrameData) -> int: ######################### # Feature vector ######################### @@ -82,7 +83,20 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: # obj W * num_objects(7 for M2) # obj H * num_objects(7 for M2) # casualty conf * 1 - vector_len = 102 + vector_length = 0 + # Joint confidences + if self._use_joint_confs: + vector_length += NUM_POSE_JOINTS + # X and Y for each joint + vector_length += 2 * NUM_POSE_JOINTS + # [Conf, X, Y, W, H] for k instances of each object class. + vector_length = 5 * self._top_k * self._num_classes + return vector_length + + + def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: + + vector_len = self.determine_vector_length(data) frame_feat = np.zeros(vector_len, dtype=np.float32) vector_ind = 0 if self._use_pixel_norm: From 202f9aa3fa608b50ca61843bff9b19ec55544223 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson Date: Wed, 6 Nov 2024 15:45:35 -0500 Subject: [PATCH 11/11] add todo's for PR things I'm not getting to yet. --- tcn_hpl/data/vectorize/locs_and_confs.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index 559c7be9e..99bd0d64a 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -98,6 +98,9 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: vector_len = self.determine_vector_length(data) frame_feat = np.zeros(vector_len, dtype=np.float32) + # TODO: instead of carrying around this vector_ind, we should + # directly compute the offset of each feature we add to the TCN + # input vector. This would be much easier to debug. vector_ind = 0 if self._use_pixel_norm: W = data.size[0] @@ -109,6 +112,7 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: # Loop through all classes: populate obj conf, obj X, obj Y. # Assumption: class labels are [0, 1, 2,... num_classes-1]. + # TODO: this will break if top_k is ever > 1. Fix that. for obj_ind in range(0,self._num_classes): top_k_idxs = self.get_top_k_indexes_of_one_obj_type(f_dets, self._top_k, obj_ind) if top_k_idxs: # This is None if there were no detections to sort for this class