From 388b18d5c654fdb31b97caf338f0959855a792da Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Mon, 28 Oct 2024 15:31:51 -0400 Subject: [PATCH 1/5] Fix issue with computing weights vector for limited class representation --- tcn_hpl/data/tcn_dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index 939bfd98c..be634ccd4 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -371,7 +371,11 @@ def load_data_offline( # windows, which is the truth value for the window as a whole. window_final_class_ids = self._window_truth[:, -1] cls_ids, cls_counts = np.unique(window_final_class_ids, return_counts=True) - cls_weights = 1.0 / cls_counts + # Some classes may not be represented in the truth, so initialize the + # weights vector separately, and then assign weight values based on + # which class IDs were actually represented. + cls_weights = np.zeros(len(activity_coco.cats)) + cls_weights[cls_ids] = 1.0 / cls_counts self._window_weights = cls_weights[window_final_class_ids] # Check if there happens to be a cache file of pre-computed window @@ -422,9 +426,6 @@ def load_data_offline( # Pre-vectorize data for iteration efficiency during training. window_vectors: List[npt.NDArray[float]] = [] itable = (self._vectorize_window(d) for d in window_data) - # pool = ThreadPoolExecutor() - # pool = ProcessPoolExecutor(max_workers=4) - # itable = pool.map(self._vectorize_window, window_data) for one_vector in tqdm( itable, desc="Windows vectorized", From b9fce69b6692685151916649e14620d6ec4f9429 Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Mon, 28 Oct 2024 15:56:26 -0400 Subject: [PATCH 2/5] Add Vectorize interface and classic mode implementation Validated that output from this transition produces the same results as was being produced by the previous method of invoking the obj_det2d_set_to_feature function, but now in a way where we can parameterize the vectorization functor without pointing to config files. --- tcn_hpl/data/tcn_dataset.py | 2 +- tcn_hpl/data/vectorize/__init__.py | 6 + tcn_hpl/data/vectorize/_data.py | 77 ++++++++++ tcn_hpl/data/vectorize/_interface.py | 28 ++++ tcn_hpl/data/vectorize/classic.py | 136 ++++++++++++++++++ tcn_hpl/data/vectorize_classic.py | 7 +- .../{vectorize.py => vectorize_window.py} | 97 +------------ 7 files changed, 259 insertions(+), 94 deletions(-) create mode 100644 tcn_hpl/data/vectorize/__init__.py create mode 100644 tcn_hpl/data/vectorize/_data.py create mode 100644 tcn_hpl/data/vectorize/_interface.py create mode 100644 tcn_hpl/data/vectorize/classic.py rename tcn_hpl/data/{vectorize.py => vectorize_window.py} (58%) diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index be634ccd4..8c84a8e46 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -18,7 +18,7 @@ from torch.utils.data import Dataset from tqdm import tqdm -from tcn_hpl.data.vectorize import ( +from tcn_hpl.data.vectorize_window import ( FrameData, FrameObjectDetections, FramePoses, diff --git a/tcn_hpl/data/vectorize/__init__.py b/tcn_hpl/data/vectorize/__init__.py new file mode 100644 index 000000000..20f008d18 --- /dev/null +++ b/tcn_hpl/data/vectorize/__init__.py @@ -0,0 +1,6 @@ +from ._data import ( + FrameObjectDetections, + FramePoses, + FrameData, +) +from ._interface import Vectorize diff --git a/tcn_hpl/data/vectorize/_data.py b/tcn_hpl/data/vectorize/_data.py new file mode 100644 index 000000000..f8ec55dea --- /dev/null +++ b/tcn_hpl/data/vectorize/_data.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass +import typing as tg + +import numpy.typing as npt + + +__all__ = [ + "FrameObjectDetections", + "FramePoses", + "FrameData", +] + + +@dataclass +class FrameObjectDetections: + """ + Representation of object detection predictions for a single image. + All sequences should be the same length. + """ + + # Detection 2D bounding boxes in xywh format for the left, top, width and + # height measurements respectively. Shape: (num_detections, 4) + boxes: npt.NDArray[float] + # Object category ID of the most confident class. Shape: (num_detections,) + labels: npt.NDArray[int] + # Vectorized detection confidence value of the most confidence class. + # Shape: (num_detections,) + scores: npt.NDArray[float] + + def __post_init__(self): + assert self.boxes.ndim == 2 + assert self.labels.ndim == 1 + assert self.scores.ndim == 1 + assert self.boxes.shape[0] == self.labels.shape[0] + assert self.boxes.shape[0] == self.scores.shape[0] + assert self.boxes.shape[1] == 4 + + def __bool__(self): + return bool(self.boxes.size) + + +@dataclass +class FramePoses: + """ + Represents pose estimations for a single image. + + We currently assume that all poses will be composed of the same number of + keypoints. + """ + + # Array of scores for each pose. Ostensibly the bbox score. Shape: (num_poses,) + scores: npt.NDArray[float] + # Pose join 2D positions in ascending joint ID order. If the joint is not + # present, 0s are used. Shape: (num_poses, num_joints, 2) + joint_positions: npt.NDArray[float] + # Poise joint scores. Shape: (num_poses, num_joints) + joint_scores: npt.NDArray[float] + + def __post_init__(self): + assert self.scores.ndim == 1 + assert self.joint_positions.ndim == 3 + assert self.joint_scores.ndim == 2 + assert self.scores.shape[0] == self.joint_positions.shape[0] + assert self.scores.shape[0] == self.joint_scores.shape[0] + assert self.joint_positions.shape[1] == self.joint_scores.shape[1] + + def __bool__(self): + return bool(self.scores.size) + + +@dataclass +class FrameData: + object_detections: tg.Optional[FrameObjectDetections] + poses: tg.Optional[FramePoses] + + def __bool__(self): + return bool(self.object_detections) or bool(self.poses) diff --git a/tcn_hpl/data/vectorize/_interface.py b/tcn_hpl/data/vectorize/_interface.py new file mode 100644 index 000000000..f106161e3 --- /dev/null +++ b/tcn_hpl/data/vectorize/_interface.py @@ -0,0 +1,28 @@ +import abc +import numpy as np +import numpy.typing as npt + +from ._data import FrameData + + +__all__ = [ + "Vectorize", +] + + +class Vectorize(metaclass=abc.ABCMeta): + """ + Interface for a functor that will vectorize input data into an embedding + space for use in TCN training and inference. + """ + + @abc.abstractmethod + def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: + """ + Perform vectorization of the input data into an embedding space. + + :param data: Input data to generate an embedding from. + """ + + def __call__(self, data: FrameData) -> npt.NDArray[np.float32]: + return self.vectorize(data) diff --git a/tcn_hpl/data/vectorize/classic.py b/tcn_hpl/data/vectorize/classic.py new file mode 100644 index 000000000..aebeb4cf1 --- /dev/null +++ b/tcn_hpl/data/vectorize/classic.py @@ -0,0 +1,136 @@ +import functools +import typing as tg + +import numpy as np +from numpy import typing as npt + +from tcn_hpl.data.vectorize._interface import Vectorize, FrameData +from tcn_hpl.data.vectorize_classic import ( + obj_det2d_set_to_feature, + zero_joint_offset, + HAND_STR_LEFT, + HAND_STR_RIGHT, +) + + +class Classic(Vectorize): + """ + Previous manual approach to vectorization. + + Arguments: + feat_version: Version number of the feature to produce. + top_k: The number of top per-class examples to use in vector + construction. + """ + + def __init__( + self, + feat_version: int = 6, + top_k: int = 1, + num_classes: int = 7, + background_idx: int = 0, + hand_left_idx: int = 5, + hand_right_idx: int = 6, + ): + self._feat_version = feat_version + self._top_k = top_k + # The classic means of vectorization required some inputs that involved + # string forms of object class labels. + # * The first vector of string labels is the sequence of predicted + # classes but as the string semantic labels. This is only ever used + # to index into a second mapping structure to get the numerical index + # of that class. + # * The second structure is a mapping of string class labels to some + # zero-based index. The indices represented in this mapping must + # start at 0 and consecutively increase. There is expected to be some + # background class index in raw object predictions that will need to + # be appropriately excluded. This mapping is known to be checked for + # special left and right hand names (HAND_STR_LEFT & HAND_STR_RIGHT). + # + self._num_classes = num_classes + self._bg_idx = background_idx + self._h_l_idx = hand_left_idx + self._h_r_idx = hand_right_idx + # Construct a vector of "labels" that can be mapped to via object + # detection preds, specifically injecting the special names for hands + # at the specified indices. + det_class_labels = list(map(str, range(num_classes))) + det_class_labels[background_idx] = None + det_class_labels[hand_left_idx] = HAND_STR_LEFT + det_class_labels[hand_right_idx] = HAND_STR_RIGHT + self._det_class_labels = det_class_labels = tuple(det_class_labels) + self._det_class2idx_map = _class_labels_to_map(det_class_labels) + + def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: + det_class_labels = self._det_class_labels + obj_label_to_ind = self._det_class2idx_map + + f_dets = data.object_detections + if f_dets is not None: + # extract object detection xywh as 4 component vectors. + det_xs = f_dets.boxes.T[0] + det_ys = f_dets.boxes.T[1] + det_ws = f_dets.boxes.T[2] + det_hs = f_dets.boxes.T[3] + # Other vectors + det_lbls = [det_class_labels[lbl] for lbl in f_dets.labels] + det_scores = f_dets.scores + else: + det_lbls = [] + det_xs = [] + det_ys = [] + det_ws = [] + det_hs = [] + det_scores = [] + + # There may be zero or multiple poses predicted on a frame. + # If multiple poses, select the most confident "patient" pose. + # If there was no pose on this frame, provide a list of 0's equal in + # length to the number of joints. + f_poses = data.poses + if f_poses is not None and f_poses.scores.size: + best_pose_idx = np.argmax(f_poses.scores) + pose_kps = [ + {"xy": joint_pt} + for joint_pt in f_poses.joint_positions[best_pose_idx] + ] + else: + # special value for the classic method to indicate no pose joints. + pose_kps = zero_joint_offset + + frame_feat = obj_det2d_set_to_feature( + label_vec=det_lbls, + xs=det_xs, + ys=det_ys, + ws=det_ws, + hs=det_hs, + label_confidences=det_scores, + pose_keypoints=pose_kps, + obj_label_to_ind=obj_label_to_ind, + version=self._feat_version, + top_k_objects=self._top_k, + ).ravel().astype(np.float32) + + return frame_feat + + +@functools.lru_cache() +def _class_labels_to_map( + class_labels: tg.Sequence[tg.Optional[str]] +) -> tg.Dict[str, int]: + """ + Transform a sequence of class label strings into a mapping from label name + to index. + + The output mapping will map labels to 0-based indices based on the order of + the labels provided. + """ + lbl_to_idx = {lbl: i for i, lbl in enumerate(class_labels) if lbl is not None} + # Determine min index to subtract. + min_cat = min(lbl_to_idx.values()) + for k in lbl_to_idx: + lbl_to_idx[k] -= min_cat + assert ( + set(lbl_to_idx.values()) == set(range(len(lbl_to_idx))) + ), "Resulting category indices must start at 0 and be contiguous." + return lbl_to_idx diff --git a/tcn_hpl/data/vectorize_classic.py b/tcn_hpl/data/vectorize_classic.py index eb845ba48..2b7de4d13 100644 --- a/tcn_hpl/data/vectorize_classic.py +++ b/tcn_hpl/data/vectorize_classic.py @@ -27,6 +27,9 @@ random_colors = list(mcolors.CSS4_COLORS.keys()) random.shuffle(random_colors) +HAND_STR_LEFT = "hand (left)" +HAND_STR_RIGHT = "hand (right)" + def tlbr_to_xywh( top: npt.ArrayLike, @@ -814,12 +817,12 @@ def dist_to_center(center1, center2): ######################### # Find the right hand (right_hand_idx, right_hand_bbox, right_hand_conf, right_hand_center) = find_hand( - "hand (right)" + HAND_STR_RIGHT ) # Find the left hand (left_hand_idx, left_hand_bbox, left_hand_conf, left_hand_center) = find_hand( - "hand (left)" + HAND_STR_LEFT ) right_left_hand_kwboxes = det_class_kwboxes[0, [right_hand_idx, left_hand_idx]] diff --git a/tcn_hpl/data/vectorize.py b/tcn_hpl/data/vectorize_window.py similarity index 58% rename from tcn_hpl/data/vectorize.py rename to tcn_hpl/data/vectorize_window.py index c947524be..e70a84682 100644 --- a/tcn_hpl/data/vectorize.py +++ b/tcn_hpl/data/vectorize_window.py @@ -2,108 +2,23 @@ Logic and utilities to perform the vectorization of input data into an embedding space used for TCN training and prediction. """ -import functools -from dataclasses import dataclass import typing as tg import numpy as np import numpy.typing as npt -from jedi.plugins.stdlib import functools_partial +from tcn_hpl.data.vectorize import ( + FrameObjectDetections, + FramePoses, + FrameData, +) +from tcn_hpl.data.vectorize.classic import _class_labels_to_map # noqa from tcn_hpl.data.vectorize_classic import ( obj_det2d_set_to_feature, zero_joint_offset, ) -@dataclass -class FrameObjectDetections: - """ - Representation of object detection predictions for a single image. - All sequences should be the same length. - """ - - # Detection 2D bounding boxes in xywh format for the left, top, width and - # height measurements respectively. Shape: (num_detections, 4) - boxes: npt.NDArray[float] - # Object category ID of the most confident class. Shape: (num_detections,) - labels: npt.NDArray[int] - # Vectorized detection confidence value of the most confidence class. - # Shape: (num_detections,) - scores: npt.NDArray[float] - - def __post_init__(self): - assert self.boxes.ndim == 2 - assert self.labels.ndim == 1 - assert self.scores.ndim == 1 - assert self.boxes.shape[0] == self.labels.shape[0] - assert self.boxes.shape[0] == self.scores.shape[0] - assert self.boxes.shape[1] == 4 - - def __bool__(self): - return bool(self.boxes.size) - - -@dataclass -class FramePoses: - """ - Represents pose estimations for a single image. - - We currently assume that all poses will be composed of the same number of - keypoints. - """ - - # Array of scores for each pose. Ostensibly the bbox score. Shape: (num_poses,) - scores: npt.NDArray[float] - # Pose join 2D positions in ascending joint ID order. If the joint is not - # present, 0s are used. Shape: (num_poses, num_joints, 2) - joint_positions: npt.NDArray[float] - # Poise joint scores. Shape: (num_poses, num_joints) - joint_scores: npt.NDArray[float] - - def __post_init__(self): - assert self.scores.ndim == 1 - assert self.joint_positions.ndim == 3 - assert self.joint_scores.ndim == 2 - assert self.scores.shape[0] == self.joint_positions.shape[0] - assert self.scores.shape[0] == self.joint_scores.shape[0] - assert self.joint_positions.shape[1] == self.joint_scores.shape[1] - - def __bool__(self): - return bool(self.scores.size) - - -@dataclass -class FrameData: - object_detections: tg.Optional[FrameObjectDetections] - poses: tg.Optional[FramePoses] - - def __bool__(self): - return bool(self.object_detections) or bool(self.poses) - - -@functools.lru_cache() -def _class_labels_to_map( - class_labels: tg.Sequence[tg.Optional[str]] -) -> tg.Dict[str, int]: - """ - Transform a sequence of class label strings into a mapping from label name - to index. - - The output mapping will map labels to 0-based indices based on the order of - the labels provided. - """ - lbl_to_idx = {lbl: i for i, lbl in enumerate(class_labels) if lbl is not None} - # Determine min index to subtract. - min_cat = min(lbl_to_idx.values()) - for k in lbl_to_idx: - lbl_to_idx[k] -= min_cat - assert ( - set(lbl_to_idx.values()) == set(range(len(lbl_to_idx))) - ), "Resulting category indices must start at 0 and be contiguous." - return lbl_to_idx - - def vectorize_window( frame_data: tg.Sequence[FrameData], det_class_labels: tg.Sequence[tg.Optional[str]], From b843cb1429c9e17df40e9d350566b803546d1155 Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Mon, 28 Oct 2024 18:50:12 -0400 Subject: [PATCH 3/5] Update TCNDataset to use input Vectorize functor --- tcn_hpl/data/tcn_dataset.py | 120 +++++++++++--------------- tcn_hpl/data/vectorize/_interface.py | 27 ++++++ tcn_hpl/data/vectorize/classic.py | 39 +++++---- tcn_hpl/data/vectorize_window.py | 121 --------------------------- 4 files changed, 99 insertions(+), 208 deletions(-) delete mode 100644 tcn_hpl/data/vectorize_window.py diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index 8c84a8e46..ee19f04b8 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -1,5 +1,6 @@ import logging from hashlib import sha256 +import json from pathlib import Path import time from typing import Callable @@ -18,11 +19,11 @@ from torch.utils.data import Dataset from tqdm import tqdm -from tcn_hpl.data.vectorize_window import ( - FrameData, +from tcn_hpl.data.vectorize import ( FrameObjectDetections, FramePoses, - vectorize_window, + FrameData, + Vectorize, ) @@ -52,8 +53,9 @@ class TCNDataset(Dataset): window_size: The size of the sliding window used to collect inputs from either a real-time or offline source. - feature_version: - Integer version ID of the feature vector to generate. + vectorizer: + Vectorization functor to convert frame data into an embedding + space. transform: Optional feature vector transformation/augmentation function. """ @@ -61,11 +63,11 @@ class TCNDataset(Dataset): def __init__( self, window_size: int, - feature_version: int, + vectorizer: Vectorize, transform: Optional[Callable] = None, ): self.window_size = window_size - self.feature_version = feature_version + self.vectorizer = vectorizer self.transform = transform # For offline mode, pre-cut videos into clips according to window @@ -92,25 +94,23 @@ def __init__( # Optionally defined set of pre-computed window vectors. self._window_vectors: Optional[npt.NDArray[float]] = None - # Sequence of object detection category semantic names in their - # relative category ID order. For use with classic vectorization logic. - # This attribute should be set by the `load_*` methods. - # This will need to have enough indices such that any object detection - # prediction category ID ("label") will map to something. It is - # generally assumed that the first index (0) must be for the background - # class despite such a class not being represented predictions. - self._det_label_vec: Sequence[Optional[str]] = [] - # Constant 1's mask value to re-use during get-item. self._ones_mask: npt.NDArray[int] = np.ones(window_size, dtype=int) @property def window_weights(self) -> npt.NDArray[float]: + """ + Get per-index weights to use with a weighted sampler. + + :return: Array of per-index weight floats. + """ if self._window_weights is None: raise RuntimeError("No class weights calculated for this dataset.") return self._window_weights - def _vectorize_window(self, window_data: Sequence[FrameData]): + def _vectorize_window( + self, window_data: Sequence[FrameData] + ) -> npt.NDArray[np.float32]: """ Vectorize a single window of data. :param window_data: Window of data to vectorize. Must be window-size @@ -118,14 +118,8 @@ def _vectorize_window(self, window_data: Sequence[FrameData]): :return: Transformed vector. """ assert len(window_data) == self.window_size - tcn_vector = vectorize_window( - frame_data=window_data, - # The following arguments may be specific to the "classic" version - # feature construction. - det_class_labels=self._det_label_vec, - feat_version=self.feature_version, - ) - return tcn_vector + v = self.vectorizer + return np.asarray([v(d) for d in window_data]) def load_data_offline( self, @@ -241,20 +235,6 @@ def load_data_offline( np.ndarray(shape=(0, num_pose_keypoints)), ) - # Save object detection category names in relative ID order. - # This is ultimately being saved for the classic version of - # vectorization which merely wants - # The +1 is here because we know the background category is not being - # included in object detection output that has historically been fed - # into here. - det_label_vec = [None] * (max(dets_coco.cats) + 1) - for c in dets_coco.cats.values(): - det_label_vec[c["id"]] = c["name"] - assert ( - det_label_vec[0] is None - ), "Not expecting input dataset categories to include the background class." - self._det_label_vec = tuple(det_label_vec) - # # Collect per-frame data first per-video, then slice into windows. # @@ -396,23 +376,19 @@ def load_data_offline( and dets_coco_fpath.is_file() and pose_coco_fpath.is_file() ): - # Make this into a function? + csum = sha256() with open(activity_coco_fpath, "rb") as f: - act_sha256 = sha256(f.read()).hexdigest() + csum.update(f.read()) with open(dets_coco_fpath, "rb") as f: - det_sha256 = sha256(f.read()).hexdigest() + csum.update(f.read()) with open(pose_coco_fpath, "rb") as f: - pos_sha256 = sha256(f.read()).hexdigest() + csum.update(f.read()) + csum.update(f"{target_framerate:0.{framerate_round_decimals}f}".encode()) + csum.update(f"{self.window_size:d}".encode()) + csum.update(json.dumps(self.vectorizer.hparams()).encode()) # Include vectorization variables in the name of the file. # Note the "z" in the name, expecting to use savez_compressed. - cache_filename = "{}_{}_{}_{:.2f}_{:d}_{:d}.npz".format( - act_sha256, - det_sha256, - pos_sha256, - target_framerate, - self.window_size, - self.feature_version, - ) + cache_filename = "{}.npz".format(csum.hexdigest()) cache_filepath = Path(cache_dir) / cache_filename has_vector_cache = cache_filepath.is_file() @@ -424,7 +400,7 @@ def load_data_offline( logger.info("Loading window vectors from cache... Done") else: # Pre-vectorize data for iteration efficiency during training. - window_vectors: List[npt.NDArray[float]] = [] + window_vectors: List[npt.NDArray[np.float32]] = [] itable = (self._vectorize_window(d) for d in window_data) for one_vector in tqdm( itable, @@ -448,7 +424,6 @@ def load_data_offline( def load_data_online( self, window_data: Sequence[FrameData], - det_class_label_vec: Sequence[Optional[str]], ) -> None: """ Receive data from a streaming runtime to yield from __getitem__. @@ -458,12 +433,6 @@ def load_data_online( Args: window_data: Per-frame data to compose the solitary window. - det_class_label_vec: - Sequence of string labels mapping predicted object detection - class label integers into strings. This is generally all - categories that the detector may predict, in index order. - The background class should be left out as a string in - sequence, but instead be represented by a None value. """ # Just load one windows worth of stuff so only __getitem__(0) makes # sense. @@ -473,8 +442,6 @@ class label integers into strings. This is generally all f"({len(window_data)} != {self.window_size})." ) - self._det_label_vec = tuple(det_class_label_vec) - # Assign a single window of frame data. self._window_data = [list(window_data)] # The following are undefined for online mode, so we're just filling in @@ -486,7 +453,7 @@ class label integers into strings. This is generally all def __getitem__( self, index: int ) -> Tuple[ - npt.NDArray[float], + npt.NDArray[np.float32], npt.NDArray[int], npt.NDArray[int], npt.NDArray[int], @@ -549,19 +516,32 @@ def __len__(self): # Example usage: activity_coco = kwcoco.CocoDataset( - "/home/local/KHQ/paul.tunison/data/darpa-ptg/tcn_training_example/TEST-activity_truth.coco.json" - # "/home/local/KHQ/paul.tunison/data/darpa-ptg/tcn_training_example/activity_truth.coco.json" + # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/activity_truth.coco.json" + # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TEST-activity_truth.coco.json" + "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TRAIN-activity_truth-vid_1.coco.json" ) dets_coco = kwcoco.CocoDataset( - "/home/local/KHQ/paul.tunison/data/darpa-ptg/tcn_training_example/TEST-object_detections.coco.json" - # "/home/local/KHQ/paul.tunison/data/darpa-ptg/tcn_training_example/all_object_detections.coco.json" + # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/all_object_detections.coco.json" + # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TEST-object_detections.coco.json" + "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TRAIN-object_detections-vid_1.coco.json" ) pose_coco = kwcoco.CocoDataset( - "/home/local/KHQ/paul.tunison/data/darpa-ptg/tcn_training_example/TEST-pose_estimates.coco.json" - # "/home/local/KHQ/paul.tunison/data/darpa-ptg/tcn_training_example/all_poses.coco.json" + # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/all_poses.coco.json" + # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TEST-pose_estimates.coco.json" + "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TRAIN-pose_estimates-vid_1.coco.json" ) - dataset = TCNDataset(window_size=25, feature_version=6) + from tcn_hpl.data.vectorize.classic import Classic + + vectorizer = Classic( + feat_version=6, + top_k=1, + num_classes=7, # M2 object detection classes + background_idx=0, + hand_left_idx=5, + hand_right_idx=6, + ) + dataset = TCNDataset(window_size=25, vectorizer=vectorizer) dataset.load_data_offline( activity_coco, dets_coco, @@ -582,7 +562,7 @@ def __len__(self): count = 0 s = time.time() - for index, batch in tqdm( + for idx, batch in tqdm( enumerate(data_loader), desc="Iterating batches of features", unit="batches", diff --git a/tcn_hpl/data/vectorize/_interface.py b/tcn_hpl/data/vectorize/_interface.py index f106161e3..42fc10d1e 100644 --- a/tcn_hpl/data/vectorize/_interface.py +++ b/tcn_hpl/data/vectorize/_interface.py @@ -1,6 +1,10 @@ import abc +import inspect +from typing import Any, Mapping + import numpy as np import numpy.typing as npt +from pytorch_lightning.utilities.parsing import collect_init_args from ._data import FrameData @@ -16,6 +20,29 @@ class Vectorize(metaclass=abc.ABCMeta): space for use in TCN training and inference. """ + def __init__(self): + # Collect parameters to checksum until we are out of the __init__ + # stack. + init_args = {} + # merge init args from the bottom up: higher stacks override. + for local_args in collect_init_args(inspect.currentframe().f_back, []): + init_args.update(local_args) + # Instead of keeping around hyperparameter values forever, just + # checksum now. This should be fine because, even if we retain them, + # runtime + self.__init_args = init_args + + def hparams(self) -> Mapping[str, Any]: + """ + Return a deterministic checksum of hyperparameters this instance was + constructed with. + + This may need to be overwritten if hyperparameters bed + + :returns: Hexadecimal digest of the SHA256 checksum of hyperparameters. + """ + return self.__init_args + @abc.abstractmethod def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: """ diff --git a/tcn_hpl/data/vectorize/classic.py b/tcn_hpl/data/vectorize/classic.py index aebeb4cf1..8dcfbad5f 100644 --- a/tcn_hpl/data/vectorize/classic.py +++ b/tcn_hpl/data/vectorize/classic.py @@ -32,6 +32,8 @@ def __init__( hand_left_idx: int = 5, hand_right_idx: int = 6, ): + super().__init__() + self._feat_version = feat_version self._top_k = top_k # The classic means of vectorization required some inputs that involved @@ -91,32 +93,35 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: if f_poses is not None and f_poses.scores.size: best_pose_idx = np.argmax(f_poses.scores) pose_kps = [ - {"xy": joint_pt} - for joint_pt in f_poses.joint_positions[best_pose_idx] + {"xy": joint_pt} for joint_pt in f_poses.joint_positions[best_pose_idx] ] else: # special value for the classic method to indicate no pose joints. pose_kps = zero_joint_offset - frame_feat = obj_det2d_set_to_feature( - label_vec=det_lbls, - xs=det_xs, - ys=det_ys, - ws=det_ws, - hs=det_hs, - label_confidences=det_scores, - pose_keypoints=pose_kps, - obj_label_to_ind=obj_label_to_ind, - version=self._feat_version, - top_k_objects=self._top_k, - ).ravel().astype(np.float32) + frame_feat = ( + obj_det2d_set_to_feature( + label_vec=det_lbls, + xs=det_xs, + ys=det_ys, + ws=det_ws, + hs=det_hs, + label_confidences=det_scores, + pose_keypoints=pose_kps, + obj_label_to_ind=obj_label_to_ind, + version=self._feat_version, + top_k_objects=self._top_k, + ) + .ravel() + .astype(np.float32) + ) return frame_feat @functools.lru_cache() def _class_labels_to_map( - class_labels: tg.Sequence[tg.Optional[str]] + class_labels: tg.Sequence[tg.Optional[str]], ) -> tg.Dict[str, int]: """ Transform a sequence of class label strings into a mapping from label name @@ -130,7 +135,7 @@ def _class_labels_to_map( min_cat = min(lbl_to_idx.values()) for k in lbl_to_idx: lbl_to_idx[k] -= min_cat - assert ( - set(lbl_to_idx.values()) == set(range(len(lbl_to_idx))) + assert set(lbl_to_idx.values()) == set( + range(len(lbl_to_idx)) ), "Resulting category indices must start at 0 and be contiguous." return lbl_to_idx diff --git a/tcn_hpl/data/vectorize_window.py b/tcn_hpl/data/vectorize_window.py deleted file mode 100644 index e70a84682..000000000 --- a/tcn_hpl/data/vectorize_window.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -Logic and utilities to perform the vectorization of input data into an -embedding space used for TCN training and prediction. -""" -import typing as tg - -import numpy as np -import numpy.typing as npt - -from tcn_hpl.data.vectorize import ( - FrameObjectDetections, - FramePoses, - FrameData, -) -from tcn_hpl.data.vectorize.classic import _class_labels_to_map # noqa -from tcn_hpl.data.vectorize_classic import ( - obj_det2d_set_to_feature, - zero_joint_offset, -) - - -def vectorize_window( - frame_data: tg.Sequence[FrameData], - det_class_labels: tg.Sequence[tg.Optional[str]], - feat_version: int = 6, - top_k_objects: int = 1, -) -> npt.NDArray[float]: - """ - Construct an embedding vector for some window of data to be used for - training and predicting against the TCN model. - - The length of the input data list is interpreted as the window size. - - Args: - frame_data: - Component data to use for constructing the embedding. - det_class_labels: - Sequence of string labels corresponding to object detection - classes, in index order. Some indices may be None which means there - is no class associated with that index. Any possible - `FrameObjectDetections.labels` value must map to a label in this - input. Other unused intermediate indices may be set to None. - feat_version: - Integer version number of the feature to compute. - (historical concept) - top_k_objects: - Use this many most confident instances of every object type in the - feature vector. - - Returns: - Embedding vector matrix for the input window of data. - Shape: (len(frame_data), embedding_dim) - """ - # Discover feature dimension on first successful call to the per-frame - # feature generation function. - feat_dim: tg.Optional[int] = None - feat_dtype = np.float32 - - # Inverse mapping to the input index-to-label sequence for the classic - # vectorizer. - obj_label_to_ind = _class_labels_to_map(det_class_labels) - - f_vecs: tg.List[tg.Optional[npt.NDArray[float]]] = [None] * len(frame_data) - for i, f_data in enumerate(frame_data): - f_dets = f_data.object_detections - if f_dets is None: - # Cannot proceed with classic vector computation without object - # detections. - continue - - # extract object detection xywh as 4 component vectors. - det_xs = f_dets.boxes.T[0] - det_ys = f_dets.boxes.T[1] - det_ws = f_dets.boxes.T[2] - det_hs = f_dets.boxes.T[3] - - # There may be zero or multiple poses predicted on a frame. - # If multiple poses, select the most confident "patient" pose. - # If there was no pose on this frame, provide a list of 0's equal in - # length to the number of joints. - f_poses = f_data.poses - if f_poses is not None and f_poses.scores.size: - best_pose_idx = np.argmax(f_poses.scores) - pose_kps = [ - {"xy": joint_pt} - for joint_pt in f_poses.joint_positions[best_pose_idx] - ] - else: - # special value for the classic method to indicate no pose joints. - pose_kps = zero_joint_offset - - frame_feat = obj_det2d_set_to_feature( - label_vec=[det_class_labels[lbl] for lbl in f_dets.labels], - xs=det_xs, - ys=det_ys, - ws=det_ws, - hs=det_hs, - label_confidences=f_dets.scores, - pose_keypoints=pose_kps, - obj_label_to_ind=obj_label_to_ind, - version=feat_version, - top_k_objects=top_k_objects, - ).ravel().astype(feat_dtype) - feat_dim = frame_feat.size - f_vecs[i] = frame_feat - - # If a caller is getting this, we could start to throw a more specific - # error, and the caller could safely catch it to consider this window as - # whatever the "background" class is. - assert ( - feat_dim is not None - ), "No features computed for any frame this window?" - - # If a feature fails to be generated for a frame: - # * insert zero-vector matching dimensionality. - empty_vec = np.zeros(shape=(feat_dim,), dtype=feat_dtype) - for i in range(len(f_vecs)): - if f_vecs[i] is None: - f_vecs[i] = empty_vec - - return np.asarray(f_vecs) From e3ade12dd8a06a183312064352c74982725a50eb Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Mon, 28 Oct 2024 19:06:53 -0400 Subject: [PATCH 4/5] Update training experiment configuration to configure vectorizer --- configs/data/ptg.yaml | 11 +++++++++-- configs/experiment/m2/feat_v6.yaml | 23 ++++++++++++++++++----- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/configs/data/ptg.yaml b/configs/data/ptg.yaml index df01c42ec..0d0cd6974 100644 --- a/configs/data/ptg.yaml +++ b/configs/data/ptg.yaml @@ -3,7 +3,14 @@ _target_: tcn_hpl.data.ptg_datamodule.PTGDataModule train_dataset: _target_: tcn_hpl.data.tcn_dataset.TCNDataset window_size: 15 - feature_version: 1 + vectorizer: + _target_: tcn_hpl.data.vectorize.classic.Classic + feat_version: 6 + top_k: 1 + num_classes: 7 + background_idx: 0 + hand_left_idx: 5 + hand_right_idx: 6 transform: _target_: torchvision.transforms.Compose transforms: [] @@ -11,7 +18,7 @@ train_dataset: val_dataset: _target_: tcn_hpl.data.tcn_dataset.TCNDataset window_size: ${data.train_dataset.window_size} - feature_version: ${data.train_dataset.feature_version} + vectorizer: ${data.train_dataset.vectorizer} transform: _target_: torchvision.transforms.Compose transforms: [] diff --git a/configs/experiment/m2/feat_v6.yaml b/configs/experiment/m2/feat_v6.yaml index 5fa154233..d6f1e2f4b 100644 --- a/configs/experiment/m2/feat_v6.yaml +++ b/configs/experiment/m2/feat_v6.yaml @@ -18,12 +18,18 @@ defaults: # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters -# Set this value influences: +# Change this name to something descriptive and unique for this experiment. +# This will differentiate the run logs and output to be separate from other +# experiments that may have been run under the configured +# Setting this value influences: # - the name of the directory under `${paths.root_dir}/logs/` in which training # run files are stored. # Default is "train" set in the "configs/train.yaml" file. #task_name: +# simply provide checkpoint path to resume training +#ckpt_path: null + tags: ["m2", "ms_tcn", "debug"] seed: 12345 @@ -61,7 +67,14 @@ data: train_dataset: window_size: 25 - feature_version: 6 + vectorizer: + _target_: tcn_hpl.data.vectorize.classic.Classic + feat_version: 6 + top_k: 1 + num_classes: 7 + background_idx: 0 + hand_left_idx: 5 + hand_right_idx: 6 transform: transforms: [] # no transforms # - _target_: tcn_hpl.data.components.augmentations.MoveCenterPts @@ -93,14 +106,14 @@ data: paths: # root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL/" - root_dir: "/home/local/KHQ/paul.tunison/data/darpa-ptg/tcn_training_example/training_root" + root_dir: "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/training_root" # Convenience variable to where your train/val/test split COCO file datasets # are stored. - coco_file_root: "/home/local/KHQ/paul.tunison/data/darpa-ptg/tcn_training_example" + coco_file_root: "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens" #exp_name: "tcn_training_revive" #logger: # aim: -# experiment: ${exp_name} +# experiment: ${task_name} # capture_terminal_logs: true From 548560d7ed5664e329f6583c239ec4bbd85c3975 Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Mon, 28 Oct 2024 19:58:51 -0400 Subject: [PATCH 5/5] Make use of Dataloader parallelization in pre-vectorization Torch folks are smarter than me. Use their work. --- tcn_hpl/data/tcn_dataset.py | 82 +++++++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 27 deletions(-) diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index ee19f04b8..26cef5c4b 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -1,4 +1,5 @@ import logging +import os from hashlib import sha256 import json from pathlib import Path @@ -16,7 +17,7 @@ import numpy as np import numpy.typing as npt import torch -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader from tqdm import tqdm from tcn_hpl.data.vectorize import ( @@ -49,7 +50,7 @@ class TCNDataset(Dataset): cache these vectors if a cache directory is provided to the `load_data_offline` method. - Arguments: + Args: window_size: The size of the sliding window used to collect inputs from either a real-time or offline source. @@ -102,7 +103,8 @@ def window_weights(self) -> npt.NDArray[float]: """ Get per-index weights to use with a weighted sampler. - :return: Array of per-index weight floats. + Returns: + Array of per-index weight floats. """ if self._window_weights is None: raise RuntimeError("No class weights calculated for this dataset.") @@ -113,9 +115,13 @@ def _vectorize_window( ) -> npt.NDArray[np.float32]: """ Vectorize a single window of data. - :param window_data: Window of data to vectorize. Must be window-size - in length. - :return: Transformed vector. + + Args: + window_data: Window of data to vectorize. Must be window-size + in length. + + Returns: + Transformed vector. """ assert len(window_data) == self.window_size v = self.vectorizer @@ -129,6 +135,7 @@ def load_data_offline( target_framerate: float, # probably 15 framerate_round_decimals: int = 1, pre_vectorize: bool = True, + pre_vectorize_cores: int = os.cpu_count(), cache_dir: Optional[Union[str, Path]] = None, ) -> None: """ @@ -137,7 +144,10 @@ def load_data_offline( We will pre-compute window vectors to save time during training. We will attempt to cache these vectors if a cache directory is provided. - Arguments: + Vector caching also requires that the input COCO datasets have an + associated filepath that exists. + + Args: activity_coco: COCO dataset of per-frame activity classification ground truth. This dataset also serves as the authority for data processing @@ -158,8 +168,11 @@ def load_data_offline( pre_vectorize: If we should pre-compute window vectors, possibly caching the results, as part of this load. + pre_vectorize_cores: + Number of cores to utilize when pre-computing window vectors. cache_dir: - Optional directory for cache file storage and retrieval. + Optional directory for cache file storage and retrieval. If + this is not specified, no caching will occur. """ # The data coverage for all the input datasets must be congruent. logger.info("Checking dataset video/image congruency") @@ -400,18 +413,35 @@ def load_data_offline( logger.info("Loading window vectors from cache... Done") else: # Pre-vectorize data for iteration efficiency during training. + # * Creating a mini Dataset/Dataloader situation to efficiently + # generate vectors. + vectorize_window = self._vectorize_window window_vectors: List[npt.NDArray[np.float32]] = [] - itable = (self._vectorize_window(d) for d in window_data) - for one_vector in tqdm( - itable, + + class VecDset(Dataset): + def __getitem__(self, item): + return vectorize_window(window_data[item]) + + def __len__(self): + return len(window_data) + + # Using larger batch sizes than 1 did not show any particular + # increase in throughput. This may require increasing + # `ulimit -n`, though. + dloader = DataLoader( + VecDset(), + batch_size=1, + num_workers=pre_vectorize_cores, + ) + + for batch in tqdm( + dloader, desc="Windows vectorized", - total=len(window_data), unit="windows", ): - # Pre-allocate matrix on first compute which will give us - # the vector shape. - window_vectors.append(one_vector) + window_vectors.extend(batch.numpy()) self._window_vectors = window_vectors + if cache_filepath is not None: logger.info("Saving window vectors to cache...") cache_filepath.parent.mkdir(parents=True, exist_ok=True) @@ -467,7 +497,6 @@ def __getitem__( Returns: Series of 5 numpy arrays: - * Embedding Vector, shape: (window_size, n_dims) * per-frame truth, shape: (window_size,) * per-frame applicability mask, shape: (window_size,) @@ -489,6 +518,9 @@ def __getitem__( # to random aspects that augmentation can be configured to have during # training. if self.transform is not None: + # TODO: Augment using a helper on the vectorizer? I'm imaging that + # augmentations might be specific to which vectorizer is + # used. tcn_vector = self.transform(tcn_vector) return ( @@ -517,18 +549,18 @@ def __len__(self): # Example usage: activity_coco = kwcoco.CocoDataset( # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/activity_truth.coco.json" - # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TEST-activity_truth.coco.json" - "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TRAIN-activity_truth-vid_1.coco.json" + "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TEST-activity_truth.coco.json" + # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TRAIN-activity_truth-vid_1.coco.json" ) dets_coco = kwcoco.CocoDataset( # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/all_object_detections.coco.json" - # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TEST-object_detections.coco.json" - "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TRAIN-object_detections-vid_1.coco.json" + "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TEST-object_detections.coco.json" + # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TRAIN-object_detections-vid_1.coco.json" ) pose_coco = kwcoco.CocoDataset( # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/all_poses.coco.json" - # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TEST-pose_estimates.coco.json" - "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TRAIN-pose_estimates-vid_1.coco.json" + "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TEST-pose_estimates.coco.json" + # "/home/local/KHQ/paul.tunison/data/darpa-ptg/train-TCN-M2_bbn_hololens/TRAIN-pose_estimates-vid_1.coco.json" ) from tcn_hpl.data.vectorize.classic import Classic @@ -543,11 +575,7 @@ def __len__(self): ) dataset = TCNDataset(window_size=25, vectorizer=vectorizer) dataset.load_data_offline( - activity_coco, - dets_coco, - pose_coco, - target_framerate=15, - cache_dir="/home/local/KHQ/paul.tunison/dev/darpa-ptg/angel_system/python-tpl/TCN_HPL/test_cache", + activity_coco, dets_coco, pose_coco, target_framerate=15, cache_dir=None ) print(f"dataset: {len(dataset)}")