From eb9851c402940063bae336f56fec32c85b004322 Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Fri, 1 Nov 2024 17:49:26 -0400 Subject: [PATCH 01/11] Move frame data out of vectorize interface land --- tcn_hpl/data/{vectorize/_data.py => frame_data.py} | 0 tcn_hpl/data/tcn_dataset.py | 4 ++-- tcn_hpl/data/vectorize/__init__.py | 5 ----- tcn_hpl/data/vectorize/_interface.py | 2 +- 4 files changed, 3 insertions(+), 8 deletions(-) rename tcn_hpl/data/{vectorize/_data.py => frame_data.py} (100%) diff --git a/tcn_hpl/data/vectorize/_data.py b/tcn_hpl/data/frame_data.py similarity index 100% rename from tcn_hpl/data/vectorize/_data.py rename to tcn_hpl/data/frame_data.py diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index bc6878c7e..fddb96a7f 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -21,12 +21,12 @@ from torch.utils.data import Dataset, DataLoader from tqdm import tqdm -from tcn_hpl.data.vectorize import ( +from tcn_hpl.data.frame_data import ( FrameObjectDetections, FramePoses, FrameData, - Vectorize, ) +from tcn_hpl.data.vectorize import Vectorize logger = logging.getLogger(__name__) diff --git a/tcn_hpl/data/vectorize/__init__.py b/tcn_hpl/data/vectorize/__init__.py index 20f008d18..aeb0f10e7 100644 --- a/tcn_hpl/data/vectorize/__init__.py +++ b/tcn_hpl/data/vectorize/__init__.py @@ -1,6 +1 @@ -from ._data import ( - FrameObjectDetections, - FramePoses, - FrameData, -) from ._interface import Vectorize diff --git a/tcn_hpl/data/vectorize/_interface.py b/tcn_hpl/data/vectorize/_interface.py index 42fc10d1e..64fec8ccb 100644 --- a/tcn_hpl/data/vectorize/_interface.py +++ b/tcn_hpl/data/vectorize/_interface.py @@ -6,7 +6,7 @@ import numpy.typing as npt from pytorch_lightning.utilities.parsing import collect_init_args -from ._data import FrameData +from tcn_hpl.data.frame_data import FrameData __all__ = [ From 522151afa37d89e56c5521fec37fd266563d7d5c Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Mon, 4 Nov 2024 18:22:11 -0500 Subject: [PATCH 02/11] Initial FrameData dropout augmentation transform --- tcn_hpl/data/frame_data_aug/__init__.py | 0 .../frame_data_aug/window_frame_dropout.py | 177 ++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 tcn_hpl/data/frame_data_aug/__init__.py create mode 100644 tcn_hpl/data/frame_data_aug/window_frame_dropout.py diff --git a/tcn_hpl/data/frame_data_aug/__init__.py b/tcn_hpl/data/frame_data_aug/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tcn_hpl/data/frame_data_aug/window_frame_dropout.py b/tcn_hpl/data/frame_data_aug/window_frame_dropout.py new file mode 100644 index 000000000..20d1f2bc5 --- /dev/null +++ b/tcn_hpl/data/frame_data_aug/window_frame_dropout.py @@ -0,0 +1,177 @@ +from typing import Sequence, List, Optional + +import numpy as np +import torch + +from tcn_hpl.data.frame_data import FrameData + + +class DropoutFrameDataTransform(torch.nn.Module): + """ + Augmentation of a FrameData window that will drop out object detections or + pose estimations for some frames as if they were never computed for those + frames. + + This aims to simulate how a live system cannot keep up with predicting + these estimations on all input data in a streaming system. + + Args: + frame_rate: + The frame rate in Hz of the input video or sequence from which a + window of data is associated with. + dets_throughput_mean: + Rate in Hz at which object detection predictions should be + represented in the window. + pose_throughput_mean: + Rate in Hz at which pose estimation predictions should be + represented in the window. + dets_latency: + Optional separate latency in seconds for object detection + predictions. If not provided, we will interpret latency to be the + inverse of throughput. It many be useful to provide a specific + latency value if processing conditions are specialized beyond the + naive consideration that windows will end in the latest observable + image frame. + pose_latency: + Optional separate latency in seconds for pose estimation + predictions. If not provided, we will interpret latency to be the + inverse of throughput. It many be useful to provide a specific + latency value if processing conditions are specialized beyond the + naive consideration that windows will end in the latest observable + image frame. + """ + + def __init__( + self, + frame_rate: float, + dets_throughput_mean: float, + pose_throughput_mean: float, + dets_latency: Optional[float] = None, + pose_latency: Optional[float] = None, + ): + super().__init__() + self.frame_rate = frame_rate + self.dets_throughput_mean = dets_throughput_mean + self.pose_throughput_mean = pose_throughput_mean + # If no separate latency, then just assume inverse of throughput. + self.dets_latency = dets_latency if dets_latency is not None else 1. / dets_throughput_mean + self.pose_latency = pose_latency if pose_latency is not None else 1. / pose_throughput_mean + + def forward(self, window: Sequence[FrameData]) -> List[FrameData]: + # Starting from some latency back from the end of the window, start + # dropping out detections and poses as if they were not produced for + # that frame. Do this separately for poses and detections as their + # agents can operate at different rates. + + n_frames = len(window) + one_frame_time = 1.0 / self.frame_rate + + # Vector of frame time offsets starting from the oldest frame. + # Time progresses from the first frame (0 seconds) to the + # last frame in the window (increasing by one_frame_time for each frame). + frame_times = np.arange(n_frames) * one_frame_time + max_frame_time = frame_times[-1] + + # Define processing intervals (how often a frame is processed) + # TODO: Vectorize this, adding random variation by utilizing + # `torch.normal(mean, std)`. + dets_interval = 1.0 / self.dets_throughput_mean + pose_interval = 1.0 / self.pose_throughput_mean + + # Initialize end time trackers for processing detections and poses. + # Simulate that agents may already be part-way through processing a + # frame before the beginning of this window. + dets_processing_end = np.full(n_frames + 1, torch.rand(1) * dets_interval) + pose_processing_end = np.full(n_frames + 1, torch.rand(1) * pose_interval) + + # Boolean arrays to keep track of whether a frame can be processed + dets_mask = np.zeros(n_frames, dtype=bool) + pose_mask = np.zeros(n_frames, dtype=bool) + + # Simulate realistic processing behavior + for idx in range(n_frames): + # Processing can occur on this frame if the processing for the + # previous frame finishes before the frame after this arrives + # (represented by the `+ one_frame_time`), since the "current" + # frame for the agent would still be this frame. + + # Object detection processing + if frame_times[idx] + one_frame_time > dets_processing_end[idx]: + # Agent finishes processing before the next frame would come it + # so it processes this frame. + dets_mask[idx] = True + dets_processing_end[idx + 1:] = ( + dets_processing_end[idx] + dets_interval + if dets_processing_end[idx] >= frame_times[idx] + else frame_times[idx] + dets_interval + ) + + # Pose processing + if frame_times[idx] + one_frame_time > pose_processing_end[idx]: + # Agent finishes processing before the next frame would come it + # so it processes this frame. + pose_mask[idx] = True + pose_processing_end[idx + 1:] = ( + pose_processing_end[idx] + pose_interval + if pose_processing_end[idx] >= frame_times[idx] + else frame_times[idx] + pose_interval + ) + + # Mask out the ends based on configured latencies. This ensures we do + # not mark a window frame as "processed" when its completion would + # occur *after* the final frame is received. + dets_mask &= (frame_times + self.dets_latency) <= max_frame_time + pose_mask &= (frame_times + self.pose_latency) <= max_frame_time + + # Create the modified sequence + modified_sequence = [ + FrameData( + object_detections=( + window[idx].object_detections if dets_mask[idx] else None + ), + poses=(window[idx].poses if pose_mask[idx] else None), + ) + for idx in range(n_frames) + ] + + return modified_sequence + + +def test(): + from tcn_hpl.data.frame_data import FrameObjectDetections, FramePoses + + frame1 = FrameData( + object_detections=FrameObjectDetections( + boxes=np.array([[10, 20, 30, 40], [50, 60, 70, 80]]), + labels=np.array([1, 2]), + scores=np.array([0.9, 0.75]), + ), + poses=FramePoses( + scores=np.array([0.8]), + joint_positions=np.array([[[10, 20], [30, 40], [50, 60]]]), + joint_scores=np.array([[0.9, 0.85, 0.8]]), + ), + ) + sequence = [frame1] * 25 + transform = DropoutFrameDataTransform( + frame_rate=1, + dets_throughput_mean=15, + pose_throughput_mean=0.66, + ) + # transform = DropoutFrameDataTransform( + # frame_rate=15, + # dets_throughput=14.7771, + # dets_latency=0, + # pose_throughput=10, + # pose_latency=1/10, # (1/10)-(1/14.7771), + # ) + modified_sequence = transform(sequence) + + for idx, frame in enumerate(modified_sequence): + print( + f"Frame {idx}: Object Detections: {frame.object_detections is not None}, Poses: {frame.poses is not None}" + ) + + +if __name__ == "__main__": + test() From fb98c120fc332403e5c25def12bd2a81668eecad Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Mon, 4 Nov 2024 20:10:45 -0500 Subject: [PATCH 03/11] Add throughput std-dev parameter for variation --- .../frame_data_aug/window_frame_dropout.py | 105 +++++++++++++----- 1 file changed, 80 insertions(+), 25 deletions(-) diff --git a/tcn_hpl/data/frame_data_aug/window_frame_dropout.py b/tcn_hpl/data/frame_data_aug/window_frame_dropout.py index 20d1f2bc5..0c45ad1d0 100644 --- a/tcn_hpl/data/frame_data_aug/window_frame_dropout.py +++ b/tcn_hpl/data/frame_data_aug/window_frame_dropout.py @@ -39,6 +39,10 @@ class DropoutFrameDataTransform(torch.nn.Module): latency value if processing conditions are specialized beyond the naive consideration that windows will end in the latest observable image frame. + dets_throughput_std: + Standard deviation of the throughput rate for object detections. + pose_throughput_std: + Standard deviation of the throughput rate for pose estimations. """ def __init__( @@ -48,20 +52,36 @@ def __init__( pose_throughput_mean: float, dets_latency: Optional[float] = None, pose_latency: Optional[float] = None, + dets_throughput_std: float = 0.0, + pose_throughput_std: float = 0.0, ): super().__init__() self.frame_rate = frame_rate self.dets_throughput_mean = dets_throughput_mean self.pose_throughput_mean = pose_throughput_mean + self.dets_throughput_std = dets_throughput_std + self.pose_throughput_std = pose_throughput_std # If no separate latency, then just assume inverse of throughput. - self.dets_latency = dets_latency if dets_latency is not None else 1. / dets_throughput_mean - self.pose_latency = pose_latency if pose_latency is not None else 1. / pose_throughput_mean + self.dets_latency = ( + dets_latency if dets_latency is not None else 1.0 / dets_throughput_mean + ) + self.pose_latency = ( + pose_latency if pose_latency is not None else 1.0 / pose_throughput_mean + ) def forward(self, window: Sequence[FrameData]) -> List[FrameData]: # Starting from some latency back from the end of the window, start # dropping out detections and poses as if they were not produced for # that frame. Do this separately for poses and detections as their # agents can operate at different rates. + # + # NOTE: This method makes use of numpy for most vector operations + # because it's just faster than torch during testing, however torch's + # random operators are utilized to align with training system's setting + # seeds to torch and not numpy. + # Local machine testing: + # * numpy operations: ~80 μs + # * torch equivalent: ~1100 µs n_frames = len(window) one_frame_time = 1.0 / self.frame_rate @@ -73,16 +93,34 @@ def forward(self, window: Sequence[FrameData]) -> List[FrameData]: max_frame_time = frame_times[-1] # Define processing intervals (how often a frame is processed) - # TODO: Vectorize this, adding random variation by utilizing - # `torch.normal(mean, std)`. - dets_interval = 1.0 / self.dets_throughput_mean - pose_interval = 1.0 / self.pose_throughput_mean + # This cursed formatting is because of `black`. + dets_interval = ( + 1.0 + / torch.normal( + mean=self.dets_throughput_mean, + std=self.dets_throughput_std, + size=(n_frames,), + ).numpy() + ) + pose_interval = ( + 1.0 + / torch.normal( + mean=self.pose_throughput_mean, + std=self.pose_throughput_std, + size=(n_frames,), + ).numpy() + ) # Initialize end time trackers for processing detections and poses. # Simulate that agents may already be part-way through processing a - # frame before the beginning of this window. - dets_processing_end = np.full(n_frames + 1, torch.rand(1) * dets_interval) - pose_processing_end = np.full(n_frames + 1, torch.rand(1) * pose_interval) + # frame before the beginning of this window, utilizing the first value + # from respective interval vectors. + dets_processing_end = np.full( + n_frames + 1, torch.rand(1).item() * dets_interval[0] + ) + pose_processing_end = np.full( + n_frames + 1, torch.rand(1).item() * pose_interval[0] + ) # Boolean arrays to keep track of whether a frame can be processed dets_mask = np.zeros(n_frames, dtype=bool) @@ -94,16 +132,23 @@ def forward(self, window: Sequence[FrameData]) -> List[FrameData]: # previous frame finishes before the frame after this arrives # (represented by the `+ one_frame_time`), since the "current" # frame for the agent would still be this frame. + # + # Assignment back into *_processing_end vectors assigns to the + # remainder of indices because we want the end time to carry into + # future frames in case the processing time for an agent is larger + # than 1 frame's worth of time. Otherwise, the next frame "resets" + # and an agent will skip at most one frame even though it should be + # skipping more. # Object detection processing if frame_times[idx] + one_frame_time > dets_processing_end[idx]: # Agent finishes processing before the next frame would come it # so it processes this frame. dets_mask[idx] = True - dets_processing_end[idx + 1:] = ( - dets_processing_end[idx] + dets_interval + dets_processing_end[idx + 1 :] = ( + dets_processing_end[idx] + dets_interval[idx] if dets_processing_end[idx] >= frame_times[idx] - else frame_times[idx] + dets_interval + else frame_times[idx] + dets_interval[idx] ) # Pose processing @@ -111,10 +156,10 @@ def forward(self, window: Sequence[FrameData]) -> List[FrameData]: # Agent finishes processing before the next frame would come it # so it processes this frame. pose_mask[idx] = True - pose_processing_end[idx + 1:] = ( - pose_processing_end[idx] + pose_interval + pose_processing_end[idx + 1 :] = ( + pose_processing_end[idx] + pose_interval[idx] if pose_processing_end[idx] >= frame_times[idx] - else frame_times[idx] + pose_interval + else frame_times[idx] + pose_interval[idx] ) # Mask out the ends based on configured latencies. This ensures we do @@ -138,6 +183,7 @@ def forward(self, window: Sequence[FrameData]) -> List[FrameData]: def test(): + import numpy as np from tcn_hpl.data.frame_data import FrameObjectDetections, FramePoses frame1 = FrameData( @@ -153,18 +199,22 @@ def test(): ), ) sequence = [frame1] * 25 - transform = DropoutFrameDataTransform( - frame_rate=1, - dets_throughput_mean=15, - pose_throughput_mean=0.66, - ) # transform = DropoutFrameDataTransform( - # frame_rate=15, - # dets_throughput=14.7771, - # dets_latency=0, - # pose_throughput=10, - # pose_latency=1/10, # (1/10)-(1/14.7771), + # frame_rate=1, + # dets_throughput_mean=15, + # dets_throughput_std=0.1, + # pose_throughput_mean=0.66, + # pose_throughput_std=0.02, # ) + transform = DropoutFrameDataTransform( + frame_rate=15, + dets_throughput_mean=14.5, + dets_throughput_std=0.2, + dets_latency=0, + pose_throughput_mean=10, + pose_throughput_std=0.2, + pose_latency=(1 / 10) - (1 / 14.5), + ) modified_sequence = transform(sequence) for idx, frame in enumerate(modified_sequence): @@ -172,6 +222,11 @@ def test(): f"Frame {idx}: Object Detections: {frame.object_detections is not None}, Poses: {frame.poses is not None}" ) + from IPython import get_ipython + + ipython = get_ipython() + ipython.run_line_magic("timeit", "transform(sequence)") + if __name__ == "__main__": test() From 1c7a52d21c2f9498786e4ed1d438989ae12cbb8b Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Tue, 5 Nov 2024 10:41:16 -0500 Subject: [PATCH 04/11] Log validation f1/recall/precision --- tcn_hpl/models/ptg_module.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index 1de1f8006..d52f49f85 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -90,6 +90,9 @@ def __init__( self.val_f1 = F1Score( num_classes=num_classes, average="none", task="multiclass" ) + self.val_f1_avg = F1Score( + num_classes=num_classes, average="weighted", task="multiclass" + ) self.test_f1 = F1Score( num_classes=num_classes, average="none", task="multiclass" ) @@ -97,6 +100,9 @@ def __init__( self.val_recall = Recall( num_classes=num_classes, average="none", task="multiclass" ) + self.val_recall_avg = Recall( + num_classes=num_classes, average="weighted", task="multiclass" + ) self.test_recall = Recall( num_classes=num_classes, average="none", task="multiclass" ) @@ -104,6 +110,9 @@ def __init__( self.val_precision = Precision( num_classes=num_classes, average="none", task="multiclass" ) + self.val_precision_avg = Precision( + num_classes=num_classes, average="weighted", task="multiclass" + ) self.test_precision = Precision( num_classes=num_classes, average="none", task="multiclass" ) @@ -347,24 +356,22 @@ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) all_targets = torch.cat([o['targets'] for o in outputs]) acc = self.val_acc.compute() - current_best_val_acc = self.val_acc_best.value # log `val_acc_best` as a value through `.compute()` return, instead of # as a metric object otherwise metric would be reset by lightning after # each epoch. best_val_acc = self.val_acc_best(acc) # update best so far val acc + self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) - if best_val_acc > current_best_val_acc: - val_f1_score = self.val_f1(all_preds, all_targets) - val_recall_score = self.val_recall(all_preds, all_targets) - val_precision_score = self.val_precision(all_preds, all_targets) - - # print(f"preds: {all_preds}") - # print(f"all_targets: {all_targets}") - print(f"validation f1 score: {val_f1_score}") - print(f"validation recall score: {val_recall_score}") - print(f"validation precision score: {val_precision_score}") + val_f1_score = self.val_f1(all_preds, all_targets) + val_recall_score = self.val_recall(all_preds, all_targets) + val_precision_score = self.val_precision(all_preds, all_targets) + self.val_f1_avg(all_preds, all_targets) + self.val_recall_avg(all_preds, all_targets) + self.val_precision_avg(all_preds, all_targets) + self.log("val/f1", self.val_f1_avg, prog_bar=True, on_epoch=True) + self.log("val/recall", self.val_recall_avg, prog_bar=True, on_epoch=True) + self.log("val/prevision", self.val_precision_avg, prog_bar=True, on_epoch=True) - self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) def test_step( self, From d216f85bc6e39366bce81c6180159bd56146859f Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Tue, 5 Nov 2024 11:30:58 -0500 Subject: [PATCH 05/11] Use transform in dataset, remove vector caching --- .../frame_data_aug/window_frame_dropout.py | 6 +- tcn_hpl/data/ptg_datamodule.py | 10 +- tcn_hpl/data/tcn_dataset.py | 233 +++++------------- tcn_hpl/data/vectorize/locs_and_confs.py | 17 +- 4 files changed, 72 insertions(+), 194 deletions(-) diff --git a/tcn_hpl/data/frame_data_aug/window_frame_dropout.py b/tcn_hpl/data/frame_data_aug/window_frame_dropout.py index 0c45ad1d0..5d2870c5a 100644 --- a/tcn_hpl/data/frame_data_aug/window_frame_dropout.py +++ b/tcn_hpl/data/frame_data_aug/window_frame_dropout.py @@ -209,11 +209,11 @@ def test(): transform = DropoutFrameDataTransform( frame_rate=15, dets_throughput_mean=14.5, - dets_throughput_std=0.2, - dets_latency=0, pose_throughput_mean=10, + dets_latency=0, + pose_latency=1/10, # (1 / 10) - (1 / 14.5), + dets_throughput_std=0.2, pose_throughput_std=0.2, - pose_latency=(1 / 10) - (1 / 14.5), ) modified_sequence = transform(sequence) diff --git a/tcn_hpl/data/ptg_datamodule.py b/tcn_hpl/data/ptg_datamodule.py index ad6ab4936..b632ac3e7 100644 --- a/tcn_hpl/data/ptg_datamodule.py +++ b/tcn_hpl/data/ptg_datamodule.py @@ -5,7 +5,7 @@ import kwcoco from pytorch_lightning import LightningDataModule import torch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from torchvision.transforms import transforms from typing import Any, Optional @@ -104,7 +104,6 @@ def __init__( coco_test_activities: str, coco_test_objects: str, coco_test_poses: str, - vector_cache_dir: str, batch_size: int, num_workers: int, target_framerate: float, @@ -166,24 +165,18 @@ def setup(self, stage: Optional[str] = None) -> None: kwcoco.CocoDataset(self.hparams.coco_train_objects), kwcoco.CocoDataset(self.hparams.coco_train_poses), self.hparams.target_framerate, - pre_vectorize=True, - cache_dir=self.hparams.vector_cache_dir, ) self.data_val.load_data_offline( kwcoco.CocoDataset(self.hparams.coco_validation_activities), kwcoco.CocoDataset(self.hparams.coco_validation_objects), kwcoco.CocoDataset(self.hparams.coco_validation_poses), self.hparams.target_framerate, - pre_vectorize=True, - cache_dir=self.hparams.vector_cache_dir, ) self.data_test.load_data_offline( kwcoco.CocoDataset(self.hparams.coco_test_activities), kwcoco.CocoDataset(self.hparams.coco_test_objects), kwcoco.CocoDataset(self.hparams.coco_test_poses), self.hparams.target_framerate, - pre_vectorize=True, - cache_dir=self.hparams.vector_cache_dir, ) def train_dataloader(self) -> DataLoader[Any]: @@ -210,7 +203,6 @@ def val_dataloader(self) -> DataLoader[Any]: :return: The validation dataloader. """ - return DataLoader( dataset=self.data_val, batch_size=self.hparams.batch_size, diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index fddb96a7f..701d2193b 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -1,8 +1,6 @@ import click import logging import os -from hashlib import sha256 -import json from pathlib import Path import time from typing import Callable @@ -12,12 +10,10 @@ from typing import Sequence from typing import Set from typing import Tuple -from typing import Union import kwcoco import numpy as np import numpy.typing as npt -import torch.multiprocessing from torch.utils.data import Dataset, DataLoader from tqdm import tqdm @@ -55,22 +51,26 @@ class TCNDataset(Dataset): window_size: The size of the sliding window used to collect inputs from either a real-time or offline source. - vectorizer: + transform_frame_data: + Optional augmentation function that operates on a window of + FrameData before being input to vectorization. Such an augmentation + function should not modify the input FrameData. + vectorize: Vectorization functor to convert frame data into an embedding space. - transform: - Optional feature vector transformation/augmentation function. """ def __init__( self, window_size: int, - vectorizer: Vectorize, - transform: Optional[Callable] = None, + vectorize: Vectorize, + transform_frame_data: Optional[ + Callable[[Sequence[FrameData]], Sequence[FrameData]] + ] = None, ): self.window_size = window_size - self.vectorizer = vectorizer - self.transform = transform + self.vectorize = vectorize + self.transform_frame_data = transform_frame_data # For offline mode, pre-cut videos into clips according to window # size for easy batching. @@ -98,11 +98,6 @@ def __init__( # weighted random sampling during training. This should only be # available when there is truth available, i.e. during offline mode. self._window_weights: Optional[npt.NDArray[float]] = None - # Optionally defined set of pre-computed vectors for each frame. - # Congruent index association with self._frame_data, so - # self._window_data_idx values may be used here. - # Shape: (n_frames, feat_dim) # see self._frame_data - self._frame_vectors: Optional[npt.NDArray[np.float32]] = None # Constant 1's mask value to re-use during get-item. self._ones_mask: npt.NDArray[int] = np.ones(window_size, dtype=int) @@ -126,9 +121,6 @@ def load_data_offline( pose_coco: kwcoco.CocoDataset, target_framerate: float, # probably 15 framerate_round_decimals: int = 1, - pre_vectorize: bool = True, - pre_vectorize_cores: int = os.cpu_count(), - cache_dir: Optional[Union[str, Path]] = None, ) -> None: """ Load data from filesystem resources for use during training. @@ -139,6 +131,11 @@ def load_data_offline( Vector caching also requires that the input COCO datasets have an associated filepath that exists. + Assumptions: + * This assumes that input pose predictions only contain a single + class that has pose keypoints associated with it. The current + PoseData structure has no slot for + Args: activity_coco: COCO dataset of per-frame activity classification ground truth. @@ -157,14 +154,6 @@ def load_data_offline( framerate_round_decimals: Number of floating-point decimals to round to when considering frame-rates. - pre_vectorize: - If we should pre-compute window vectors, possibly caching the - results, as part of this load. - pre_vectorize_cores: - Number of cores to utilize when pre-computing window vectors. - cache_dir: - Optional directory for cache file storage and retrieval. If - this is not specified, no caching will occur. """ # The data coverage for all the input datasets must be congruent. logger.info("Checking dataset video/image congruency") @@ -269,7 +258,7 @@ def load_data_offline( # cache frequently called module functions np_asarray = np.asarray - for vid_id in tqdm(activity_coco.videos()): + for vid_id in tqdm(activity_coco.videos(), unit="video"): vid_id: int vid_images = activity_coco.images(video_id=vid_id) vid_img_ids: List[int] = list(vid_images) @@ -306,7 +295,7 @@ def load_data_offline( ) else: frame_dets = empty_dets - + # Frame height and width should be available. img_info = activity_coco.index.imgs[img_id] assert "height" in img_info @@ -396,99 +385,6 @@ def load_data_offline( cls_weights[cls_ids] = 1.0 / cls_counts self._window_weights = cls_weights[window_final_class_ids] - # Check if there happens to be a cache file of pre-computed window - # vectors available to load. - # - # Caching is even possible if: - # * given a directory home for cache files - # * input COCO dataset filepaths are real and can be checksum'ed. - has_vector_cache = False - cache_filepath = None - activity_coco_fpath = Path(activity_coco.fpath) - dets_coco_fpath = Path(dets_coco.fpath) - pose_coco_fpath = Path(pose_coco.fpath) - if ( - pre_vectorize - and cache_dir is not None - and activity_coco_fpath.is_file() - and dets_coco_fpath.is_file() - and pose_coco_fpath.is_file() - ): - csum = sha256() - with open(activity_coco_fpath, "rb") as f: - csum.update(f.read()) - with open(dets_coco_fpath, "rb") as f: - csum.update(f.read()) - with open(pose_coco_fpath, "rb") as f: - csum.update(f.read()) - csum.update(f"{target_framerate:0.{framerate_round_decimals}f}".encode()) - csum.update(f"{self.window_size:d}".encode()) - csum.update(self.vectorizer.__class__.__module__.encode()) - csum.update(self.vectorizer.__class__.__name__.encode()) - csum.update(json.dumps(self.vectorizer.hparams()).encode()) - # Include vectorization variables in the name of the file. - # Note the "z" in the name, expecting to use savez_compressed. - cache_filename = "{}.npz".format(csum.hexdigest()) - cache_filepath = Path(cache_dir) / cache_filename - has_vector_cache = cache_filepath.is_file() - - if pre_vectorize: - if has_vector_cache: - logger.info("Loading frame vectors from cache...") - with np.load(cache_filepath) as data: - self._frame_vectors = data["frame_vectors"] - logger.info("Loading frame vectors from cache... Done") - else: - # Pre-vectorize data for iteration efficiency during training. - # * Creating a mini Dataset/Dataloader situation to efficiently - # generate vectors. - - # Set the sharing strategy to filesystem for the duration of - # this operation, and then restoring the existing strategy - # after we're done. - current_sharing_strategy = torch.multiprocessing.get_sharing_strategy() - - try: - # This iteration seems to go twice as fast when utilizing - # the file-system strategy. - torch.multiprocessing.set_sharing_strategy("file_system") - - vec_dset = _VectorizationDataset(self.vectorizer, frame_data) - - # Using larger batch sizes than 1 did not show any particular - # increase in throughput. This may require increasing - # `ulimit -n`, though. - dloader = DataLoader( - vec_dset, - batch_size=1, - num_workers=pre_vectorize_cores, - # Required, especially for large dataset sizes, so the - # dataloader multiprocessing iteration does not exhaust - # shared memory. - pin_memory=True, - ) - - frame_vectors: List[npt.NDArray[np.float32]] = [] - for batch in tqdm( - dloader, - desc="Frame data vectorized", - unit="frames", - ): - frame_vectors.extend(batch.numpy()) - finally: - torch.multiprocessing.set_sharing_strategy(current_sharing_strategy) - - self._frame_vectors = np.asarray(frame_vectors) - - if cache_filepath is not None: - logger.info("Saving window vectors to cache...") - cache_filepath.parent.mkdir(parents=True, exist_ok=True) - np.savez_compressed( - cache_filepath, - frame_vectors=frame_vectors, - ) - logger.info("Saving window vectors to cache... Done") - def load_data_online( self, window_data: Sequence[FrameData], @@ -550,26 +446,15 @@ def __getitem__( window_vid = self._window_vid[index] window_frames = self._window_frames[index] - frame_vectors = self._frame_vectors - if frame_vectors is not None: - window_mat = frame_vectors[window_data_idx] - else: - vectorizer = self.vectorizer - window_mat = np.asarray( - [vectorizer(frame_data[idx]) for idx in window_data_idx] - ) + window_frame_data = [frame_data[idx] for idx in window_data_idx] + + if self.transform_frame_data is not None: + window_frame_data = self.transform_frame_data(window_frame_data) - # Augmentation has to happen on the fly and cannot be pre-computed due - # to random aspects that augmentation can be configured to have during - # training. - if self.transform is not None: - # TODO: Augment using a helper on the vectorizer? I'm imaging that - # augmentations might be specific to which vectorizer is - # used. - window_mat = self.transform(window_mat) + window_vectors = np.asarray([self.vectorize(d) for d in window_frame_data]) return ( - window_mat, + window_vectors, window_truth, # Under the current operation of this dataset, the mask should always # consist of 1's. This may be removed in the future. @@ -588,23 +473,6 @@ def __len__(self): return len(self._window_data_idx) if self._window_data_idx is not None else 0 -class _VectorizationDataset(Dataset): - """ - Helper dataset for iterating over individual frames of data and producing - embedding vectors. - """ - - def __init__(self, vectorize: Vectorize, frame_data: Sequence[FrameData]): - self.vectorize = vectorize - self.frame_data = frame_data - - def __len__(self): - return len(self.frame_data) - - def __getitem__(self, item): - return self.vectorize(self.frame_data[item]) - - @click.command() @click.help_option("-h", "--help") @click.argument("activity_coco", type=click.Path(path_type=Path)) @@ -622,19 +490,12 @@ def __getitem__(self, item): default=15, show_default=True, ) -@click.option( - "--pre-vectorize", - is_flag=True, - help="Run pre-vectorization or not.", - show_default=True, -) def test_dataset_for_input( activity_coco: Path, detections_coco: Path, pose_coco: Path, window_size: int, target_framerate: float, - pre_vectorize: bool, ): """ Test the TCN Dataset iteration over some test data. @@ -647,22 +508,42 @@ def test_dataset_for_input( # TODO: Some method of configuring which vectorizer to use. from tcn_hpl.data.vectorize.locs_and_confs import LocsAndConfs - vectorizer = LocsAndConfs( + + vectorize = LocsAndConfs( top_k = 1, num_classes = 7, use_joint_confs = True, use_pixel_norm = True, - use_hand_obj_offsets = False, - background_idx = 0 + use_joint_obj_offsets = False, + background_idx = 0, ) - dataset = TCNDataset(window_size=window_size, vectorizer=vectorizer) + # TODO: Some method of configuring which augmentations to use. + from tcn_hpl.data.frame_data_aug.window_frame_dropout import DropoutFrameDataTransform + import torchvision.transforms + + transform_frame_data = torchvision.transforms.Compose([ + DropoutFrameDataTransform( + frame_rate=15, + dets_throughput_mean=14.5, + pose_throughput_mean=10, + dets_latency=0, + pose_latency=1/10, # (1 / 10) - (1 / 14.5), + dets_throughput_std=0.2, + pose_throughput_std=0.2, + ) + ]) + + dataset = TCNDataset( + window_size=window_size, + vectorize=vectorize, + transform_frame_data=transform_frame_data, + ) dataset.load_data_offline( activity_coco, dets_coco, pose_coco, target_framerate=target_framerate, - pre_vectorize=pre_vectorize, ) logger.info(f"Number of windows: {len(dataset)}") @@ -673,12 +554,13 @@ def test_dataset_for_input( # Test that we can iterate over the dataset using a DataLoader with # shuffling. - batch_size = 512 # 16 + batch_size = 1 # 16384 # 512 # 16 data_loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), + # Pin is required for large quantities of batches here. pin_memory=True, ) count = 0 @@ -688,12 +570,13 @@ def test_dataset_for_input( desc="Iterating batches of features", unit="batches", ): - count += 1 + count += len(batch[0]) duration = time.time() - s logger.info(f"Iterated over the full TCN Dataset in {duration:.2f} s.") + logger.info(f"Windows per-second: {count / duration}") # Test creating online mode with subset of data from above. - dset_online = TCNDataset(window_size=window_size, vectorizer=vectorizer) + dset_online = TCNDataset(window_size=window_size, vectorize=vectorize) dset_online.load_data_online(dataset._frame_data[:window_size]) # noqa assert len(dset_online) == 1, "Online dataset should be size 1" _ = dset_online[0] @@ -704,9 +587,11 @@ def test_dataset_for_input( except IndexError: failed_index_error = False assert not failed_index_error, "Should have had an index error at [1]" - assert ( - (dataset[0][0] == dset_online[0][0]).all() # noqa - ), "Online should have produced same window matrix as offline version." + assert ( # noqa + dataset[0][0] == dset_online[0][0] + ).all(), ( + "Online should have produced same window matrix as offline version." + ) if __name__ == "__main__": diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index 99bd0d64a..7f911203d 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -6,8 +6,10 @@ from tcn_hpl.data.vectorize._interface import Vectorize, FrameData + NUM_POSE_JOINTS = 22 + class LocsAndConfs(Vectorize): """ Previous manual approach to vectorization. @@ -20,13 +22,12 @@ class LocsAndConfs(Vectorize): (changes the length of the input vector, which needs to be manually updated if this flag changes.) use_pixel_norm: Normalize pixel coordinates by dividing by - frame height and width, respectively. Normalized values + frame height and width, respectively. Normalized values are between 0 and 1. Does not change input vector length. use_joint_obj_offsets: add abs(X and Y offsets) for between joints and each object. (changes the length of the input vector, which needs to be manually updated if this flag changes.) - """ def __init__( @@ -46,7 +47,7 @@ def __init__( self._use_pixel_norm = use_pixel_norm self._use_joint_obj_offsets = use_joint_obj_offsets self._background_idx = background_idx - + # Get the top "k" object indexes for each object @staticmethod def get_top_k_indexes_of_one_obj_type(f_dets, k, label_ind): @@ -71,14 +72,14 @@ def get_top_k_indexes_of_one_obj_type(f_dets, k, label_ind): def append_vector(frame_feat, i, number): frame_feat[i] = number return frame_feat, i + 1 - + def determine_vector_length(self, data: FrameData) -> int: ######################### # Feature vector ######################### # Length: pose confs * 22, pose X's * 22, pose Y's * 22, - # obj confs * num_objects(7 for M2), - # obj X * num_objects(7 for M2), + # obj confs * num_objects(7 for M2), + # obj X * num_objects(7 for M2), # obj Y * num_objects(7 for M2) # obj W * num_objects(7 for M2) # obj H * num_objects(7 for M2) @@ -131,7 +132,7 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: for _ in range(0, self._top_k * 5): # 5 Zeros frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) - + f_poses = data.poses if f_poses: # Find most confident body detection @@ -155,7 +156,7 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: rows_per_joint = 2 for _ in range(num_joints * rows_per_joint + 1): frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) - + assert vector_ind == vector_len return frame_feat From a1da7bfcf9e3197d8f6d775fb81df3c9937323bc Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Tue, 5 Nov 2024 12:02:19 -0500 Subject: [PATCH 06/11] Update configuration files appropriately --- configs/data/ptg.yaml | 14 ++++----- configs/experiment/r18/feat_v6.yaml | 48 ++++++++++++++++------------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/configs/data/ptg.yaml b/configs/data/ptg.yaml index d3565ab4d..8c728b944 100644 --- a/configs/data/ptg.yaml +++ b/configs/data/ptg.yaml @@ -3,9 +3,10 @@ _target_: tcn_hpl.data.ptg_datamodule.PTGDataModule train_dataset: _target_: tcn_hpl.data.tcn_dataset.TCNDataset window_size: 15 - # No vectorizer should be specified here, as there should be no "default". - # Example of a vectorizer: - # vectorizer: + # A vectorizer is required to complete construction of a TCN Dataset. + # We are not providing a default here given how hydra merged hyperparameters. + # For example: + #vectorize: # _target_: tcn_hpl.data.vectorize.classic.Classic # feat_version: 6 # top_k: 1 @@ -13,15 +14,15 @@ train_dataset: # background_idx: 0 # hand_left_idx: 5 # hand_right_idx: 6 - transform: + transform_frame_data: _target_: torchvision.transforms.Compose transforms: [] val_dataset: _target_: tcn_hpl.data.tcn_dataset.TCNDataset window_size: ${data.train_dataset.window_size} - vectorizer: ${data.train_dataset.vectorizer} - transform: + vectorize: ${data.train_dataset.vectorize} + transform_frame_data: _target_: torchvision.transforms.Compose transforms: [] @@ -36,7 +37,6 @@ coco_validation_poses: "" coco_test_activities: "" coco_test_objects: "" coco_test_poses: "" -vector_cache_dir: "${paths.coco_file_root}/dataset_vector_cache" batch_size: 128 num_workers: 0 target_framerate: 15 diff --git a/configs/experiment/r18/feat_v6.yaml b/configs/experiment/r18/feat_v6.yaml index 3d736b73a..185a21f52 100644 --- a/configs/experiment/r18/feat_v6.yaml +++ b/configs/experiment/r18/feat_v6.yaml @@ -58,14 +58,19 @@ data: coco_test_objects: "${paths.coco_file_root}/TEST-object_detections.coco.json" coco_test_poses: "${paths.coco_file_root}/TEST-pose_estimations.coco.json" - batch_size: 16384 + # Lower batch size than previously now that we are augmenting and cannot have + # window vectorization cached. This value provided for a good balance of + # maximizing CPU load with GPU load averages (16 cores, ~14 load avg., ~80% + # GPU utilization, ~10.35 GB VRAM). + batch_size: 56 num_workers: 16 target_framerate: 15 # BBN Hololens2 Framerate + # This is a little more than the number of windows in the training dataset. epoch_length: 80000 train_dataset: window_size: 25 - vectorizer: + vectorize: _target_: tcn_hpl.data.vectorize.classic.Classic feat_version: 6 top_k: 1 @@ -73,26 +78,25 @@ data: background_idx: 0 hand_left_idx: 5 hand_right_idx: 6 - transform: - transforms: [] # no transforms -# - _target_: tcn_hpl.data.components.augmentations.MoveCenterPts -# hand_dist_delta: 0.05 -# obj_dist_delta: 0.05 -# joint_dist_delta: 0.025 -# im_w: 1280 -# im_h: 720 -# num_obj_classes: 42 -# feat_version: 2 -# top_k_objects: 1 -# - _target_: tcn_hpl.data.components.augmentations.NormalizePixelPts -# im_w: 1280 -# im_h: 720 -# num_obj_classes: 42 -# feat_version: 2 -# top_k_objects: 1 + # Augmentations on windows of frame data before performing vectorization. + transform_frame_data: + transforms: + - _target_: tcn_hpl.data.frame_data_aug.window_frame_dropout.DropoutFrameDataTransform + # These parameters are a fudge for now to experiment. Window presence + # looks qualitatively right with what we're seeing live. + frame_rate: ${data.target_framerate} + dets_throughput_mean: 14.5 + pose_throughput_mean: 10 + dets_latency: 0 + pose_latency: 0.1 + dets_throughput_std: 0.2 + pose_throughput_std: 0.2 val_dataset: - transform: - transforms: [] # no transforms + # Augmentations on windows of frame data before performing vectorization. + # Sharing transform with training dataset as it is only the drop-out aug to + # simulate stream processing dropout the same. + transform_frame_data: ${data.train_dataset.transform_frame_data} +# transforms: [] # no transforms # - _target_: tcn_hpl.data.components.augmentations.NormalizePixelPts # im_w: 1280 # im_h: 720 @@ -104,7 +108,7 @@ data: paths: # root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL/" - root_dir: "/data/paul.tunison/data/darpa-ptg/train-TCN-R18_bbn_hololens-yolo_v7-mmpose" + root_dir: "/data/paul.tunison/data/darpa-ptg/train-TCN-R18_bbn_hololens-yolo_v7-mmpose-window_dropout" # Convenience variable to where your train/val/test split COCO file datasets # are stored. From 846375df2d6da95150aedeb9be2fed67ea5bd869 Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Wed, 6 Nov 2024 09:52:38 -0500 Subject: [PATCH 07/11] Update metrics measured in PTG Module --- tcn_hpl/callbacks/plot_metrics.py | 21 ++++-- tcn_hpl/models/ptg_module.py | 104 ++++++++++-------------------- 2 files changed, 49 insertions(+), 76 deletions(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 49265d96a..062c17f23 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -81,6 +81,9 @@ class PlotMetrics(Callback): """ Various on-stage-end plotting functionalities. + This will currently only work with training a PTGLitModule due to metric + access. + Args: output_dir: Directory into which to output plots. @@ -160,6 +163,8 @@ def on_train_epoch_end( all_source_frames = torch.cat(self._train_all_source_frames) # shape: #frames current_epoch = pl_module.current_epoch + curr_acc = pl_module.train_acc.compute() + curr_f1 = pl_module.train_f1.compute() class_ids = np.arange(all_probs.shape[-1]) num_classes = len(class_ids) @@ -194,7 +199,7 @@ def on_train_epoch_end( # labels, title and ticks ax.set_xlabel("Predicted labels") ax.set_ylabel("True labels") - ax.set_title(f"CM Training Epoch {current_epoch}") + ax.set_title(f"CM Training Epoch {current_epoch}, Accuracy: {curr_acc:.4f}, F1: {curr_f1:.4f}") ax.xaxis.set_ticklabels(class_ids, rotation=25) ax.yaxis.set_ticklabels(class_ids, rotation=0) @@ -202,7 +207,7 @@ def on_train_epoch_end( pl_module.logger.experiment.track(Image(fig), name=f"CM Training Epoch") fig.savefig( - self.output_dir / f"confusion_mat_train_epoch{current_epoch:04d}.jpg", + self.output_dir / f"confusion_mat_train_epoch{current_epoch:04d}_acc_{curr_acc:.4f}_f1_{curr_f1:.4f}.jpg", pad_inches=5, ) @@ -249,6 +254,7 @@ def on_validation_epoch_end( current_epoch = pl_module.current_epoch curr_acc = pl_module.val_acc.compute() best_acc = pl_module.val_acc_best.compute() + curr_f1 = pl_module.val_f1.compute() class_ids = np.arange(all_probs.shape[-1]) num_classes = len(class_ids) @@ -283,7 +289,7 @@ def on_validation_epoch_end( # labels, title and ticks ax.set_xlabel("Predicted labels") ax.set_ylabel("True labels") - ax.set_title(f"CM Validation Epoch {current_epoch}, Accuracy: {curr_acc:.4f}") + ax.set_title(f"CM Validation Epoch {current_epoch}, Accuracy: {curr_acc:.4f}, F1: {curr_f1:.4f}") ax.xaxis.set_ticklabels(class_ids, rotation=25) ax.yaxis.set_ticklabels(class_ids, rotation=0) @@ -293,7 +299,7 @@ def on_validation_epoch_end( if curr_acc >= best_acc: fig.savefig( self.output_dir - / f"confusion_mat_val_epoch{current_epoch:04d}_acc_{curr_acc:.4f}.jpg", + / f"confusion_mat_val_epoch{current_epoch:04d}_acc_{curr_acc:.4f}_f1_{curr_f1:.4f}.jpg", pad_inches=5, ) @@ -339,6 +345,8 @@ def on_test_epoch_end( all_source_frames = torch.cat(self._val_all_source_frames) # shape: #frames current_epoch = pl_module.current_epoch + test_acc = pl_module.test_acc.compute() + test_f1 = pl_module.test_f1.compute() class_ids = np.arange(all_probs.shape[-1]) num_classes = len(class_ids) @@ -377,13 +385,12 @@ def on_test_epoch_end( # labels, title and ticks ax.set_xlabel("Predicted labels") ax.set_ylabel("True labels") - ax.set_title(f"CM Test Epoch {current_epoch}") + ax.set_title(f"CM Test Epoch {current_epoch}, Accuracy: {test_acc:.4f}, F1: {test_f1:.4f}") ax.xaxis.set_ticklabels(class_ids, rotation=25) ax.yaxis.set_ticklabels(class_ids, rotation=0) - test_acc = pl_module.test_acc.compute() fig.savefig( - self.output_dir / f"confusion_mat_test_acc_{test_acc:0.2f}.jpg", + self.output_dir / f"confusion_mat_test_acc_{test_acc:0.2f}_f1_{test_f1:.4f}.jpg", pad_inches=5, ) diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index d52f49f85..969b43fee 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -87,34 +87,34 @@ def __init__( task="multiclass", average="weighted", num_classes=num_classes ) - self.val_f1 = F1Score( - num_classes=num_classes, average="none", task="multiclass" + self.train_f1 = F1Score( + num_classes=num_classes, average="weighted", task="multiclass" ) - self.val_f1_avg = F1Score( + self.val_f1 = F1Score( num_classes=num_classes, average="weighted", task="multiclass" ) self.test_f1 = F1Score( - num_classes=num_classes, average="none", task="multiclass" + num_classes=num_classes, average="weighted", task="multiclass" ) - self.val_recall = Recall( - num_classes=num_classes, average="none", task="multiclass" + self.train_recall = Recall( + num_classes=num_classes, average="weighted", task="multiclass" ) - self.val_recall_avg = Recall( + self.val_recall = Recall( num_classes=num_classes, average="weighted", task="multiclass" ) self.test_recall = Recall( - num_classes=num_classes, average="none", task="multiclass" + num_classes=num_classes, average="weighted", task="multiclass" ) - self.val_precision = Precision( - num_classes=num_classes, average="none", task="multiclass" + self.train_precision = Precision( + num_classes=num_classes, average="weighted", task="multiclass" ) - self.val_precision_avg = Precision( + self.val_precision = Precision( num_classes=num_classes, average="weighted", task="multiclass" ) self.test_precision = Precision( - num_classes=num_classes, average="none", task="multiclass" + num_classes=num_classes, average="weighted", task="multiclass" ) # for averaging loss across batches @@ -126,18 +126,6 @@ def __init__( self.val_acc_best = MaxMetric() self.train_acc_best = MaxMetric() - self.validation_step_outputs_prob = [] - self.validation_step_outputs_pred = [] - self.validation_step_outputs_target = [] - self.validation_step_outputs_source_vid = [] - self.validation_step_outputs_source_frame = [] - - self.training_step_outputs_target = [] - self.training_step_outputs_source_vid = [] - self.training_step_outputs_source_frame = [] - self.training_step_outputs_pred = [] - self.training_step_outputs_prob = [] - def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor: """Perform a forward pass through the model `self.net`. @@ -289,6 +277,17 @@ def training_step( "source_frame": source_frame[:, -1], } + def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + all_preds = torch.cat([o["preds"] for o in outputs]) + all_targets = torch.cat([o['targets'] for o in outputs]) + + self.train_f1(all_preds, all_targets) + self.train_recall(all_preds, all_targets) + self.train_precision(all_preds, all_targets) + self.log("train/f1", self.train_f1, prog_bar=True, on_epoch=True) + self.log("train/recall", self.train_recall, prog_bar=True, on_epoch=True) + self.log("train/precision", self.train_precision, prog_bar=True, on_epoch=True) + def validation_step( self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], @@ -322,35 +321,6 @@ def validation_step( "source_frame": source_frame[:, -1], } - # Why is this stage specifically only using these special index - # conditions? - # # print(f"preds: {preds.shape}, targets: {targets.shape}") - # # print(f"mask: {mask.shape}, {mask[0,:]}") - # ys = targets[:, -1] - # # print(f"y: {ys.shape}") - # # print(f"y: {ys}") - # windowed_preds, windowed_ys = [], [] - # window_size = 15 - # center = 7 - # inds = [] - # for i in range(preds.shape[0] - window_size + 1): - # y = ys[i : i + window_size].tolist() - # # print(f"y: {y}") - # # print(f"len of set: {len(list(set(y)))}") - # if len(list(set(y))) == 1: - # inds.append(i + center - 1) - # windowed_preds.append(preds[i + center - 1]) - # windowed_ys.append(ys[i + center - 1]) - # - # windowed_preds = torch.tensor(windowed_preds).to(targets) - # windowed_ys = torch.tensor(windowed_ys).to(targets) - # - # self.validation_step_outputs_target.append(targets[inds, -1]) - # self.validation_step_outputs_source_vid.append(source_vid[inds, -1]) - # self.validation_step_outputs_source_frame.append(source_frame[inds, -1]) - # self.validation_step_outputs_pred.append(preds[inds]) - # self.validation_step_outputs_prob.append(probs[inds]) - def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: all_preds = torch.cat([o['preds'] for o in outputs]) all_targets = torch.cat([o['targets'] for o in outputs]) @@ -362,16 +332,12 @@ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) best_val_acc = self.val_acc_best(acc) # update best so far val acc self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) - val_f1_score = self.val_f1(all_preds, all_targets) - val_recall_score = self.val_recall(all_preds, all_targets) - val_precision_score = self.val_precision(all_preds, all_targets) - self.val_f1_avg(all_preds, all_targets) - self.val_recall_avg(all_preds, all_targets) - self.val_precision_avg(all_preds, all_targets) - self.log("val/f1", self.val_f1_avg, prog_bar=True, on_epoch=True) - self.log("val/recall", self.val_recall_avg, prog_bar=True, on_epoch=True) - self.log("val/prevision", self.val_precision_avg, prog_bar=True, on_epoch=True) - + self.val_f1(all_preds, all_targets) + self.val_recall(all_preds, all_targets) + self.val_precision(all_preds, all_targets) + self.log("val/f1", self.val_f1, prog_bar=True, on_epoch=True) + self.log("val/recall", self.val_recall, prog_bar=True, on_epoch=True) + self.log("val/precision", self.val_precision, prog_bar=True, on_epoch=True) def test_step( self, @@ -391,16 +357,16 @@ def test_step( # update and log metrics self.test_loss(loss) self.test_acc(preds, targets[:, -1]) + self.test_f1(preds, targets[:, -1]) + self.test_recall(preds, targets[:, -1]) + self.test_precision(preds, targets[:, -1]) self.log( "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True ) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) - - self.validation_step_outputs_target.append(targets[:, -1]) - self.validation_step_outputs_source_vid.append(source_vid[:, -1]) - self.validation_step_outputs_source_frame.append(source_frame[:, -1]) - self.validation_step_outputs_pred.append(preds) - self.validation_step_outputs_prob.append(probs) + self.log("test/f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/recall", self.test_recall, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/precision", self.test_precision, on_step=False, on_epoch=True, prog_bar=True) # Only retain the truth and source vid/frame IDs for the final window # frame as this is the ultimately relevant result. From 399dc8b531ae167f663402333ea7bd553338cc0c Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Wed, 6 Nov 2024 13:23:10 -0500 Subject: [PATCH 08/11] Use best F1 instead of accuracy --- tcn_hpl/callbacks/plot_metrics.py | 6 ++--- tcn_hpl/models/ptg_module.py | 41 ++++++++++++++++++------------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 062c17f23..894b129c3 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -253,8 +253,8 @@ def on_validation_epoch_end( current_epoch = pl_module.current_epoch curr_acc = pl_module.val_acc.compute() - best_acc = pl_module.val_acc_best.compute() curr_f1 = pl_module.val_f1.compute() + best_f1 = pl_module.val_f1_best.compute() class_ids = np.arange(all_probs.shape[-1]) num_classes = len(class_ids) @@ -296,7 +296,7 @@ def on_validation_epoch_end( if Image is not None: pl_module.logger.experiment.track(Image(fig), name=f"CM Validation Epoch") - if curr_acc >= best_acc: + if curr_f1 >= best_f1: fig.savefig( self.output_dir / f"confusion_mat_val_epoch{current_epoch:04d}_acc_{curr_acc:.4f}_f1_{curr_f1:.4f}.jpg", @@ -380,7 +380,7 @@ def on_test_epoch_end( fig, ax = plt.subplots(figsize=(num_classes, num_classes)) - sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", vmin=0, vmax=1) + sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) # labels, title and ticks ax.set_xlabel("Predicted labels") diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index 969b43fee..5de396dd0 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -123,8 +123,7 @@ def __init__( self.test_loss = MeanMetric() # for tracking best so far validation accuracy - self.val_acc_best = MaxMetric() - self.train_acc_best = MaxMetric() + self.val_f1_best = MaxMetric() def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor: """Perform a forward pass through the model `self.net`. @@ -141,13 +140,16 @@ def on_train_start(self) -> None: # so it's worth to make sure validation metrics don't store results from these checks self.val_loss.reset() self.val_acc.reset() - self.val_acc_best.reset() + self.val_f1.reset() + self.val_recall.reset() + self.val_precision.reset() + self.val_f1_best.reset() def compute_loss(self, p, y, mask): """Compute the total loss for a batch :param p: The prediction - :param batch_target: The target labels + :param y: The target labels :param mask: Marks valid input data :return: The loss @@ -325,13 +327,6 @@ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) all_preds = torch.cat([o['preds'] for o in outputs]) all_targets = torch.cat([o['targets'] for o in outputs]) - acc = self.val_acc.compute() - # log `val_acc_best` as a value through `.compute()` return, instead of - # as a metric object otherwise metric would be reset by lightning after - # each epoch. - best_val_acc = self.val_acc_best(acc) # update best so far val acc - self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) - self.val_f1(all_preds, all_targets) self.val_recall(all_preds, all_targets) self.val_precision(all_preds, all_targets) @@ -339,6 +334,12 @@ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) self.log("val/recall", self.val_recall, prog_bar=True, on_epoch=True) self.log("val/precision", self.val_precision, prog_bar=True, on_epoch=True) + # log `val_f1_best` as a value through `.compute()` return, instead of + # as a metric object otherwise metric would be reset by lightning after + # each epoch. + self.val_f1_best(self.val_f1.compute()) + self.log("val/f1_best", self.val_f1_best.compute(), prog_bar=True, on_epoch=True) + def test_step( self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], @@ -357,16 +358,10 @@ def test_step( # update and log metrics self.test_loss(loss) self.test_acc(preds, targets[:, -1]) - self.test_f1(preds, targets[:, -1]) - self.test_recall(preds, targets[:, -1]) - self.test_precision(preds, targets[:, -1]) self.log( "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True ) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/recall", self.test_recall, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/precision", self.test_precision, on_step=False, on_epoch=True, prog_bar=True) # Only retain the truth and source vid/frame IDs for the final window # frame as this is the ultimately relevant result. @@ -379,6 +374,18 @@ def test_step( "source_frame": source_frame[:, -1], } + def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: + all_preds = torch.cat([o['preds'] for o in outputs]) + all_targets = torch.cat([o['targets'] for o in outputs]) + + # update and log metrics + self.test_f1(all_preds, all_targets) + self.test_recall(all_preds, all_targets) + self.test_precision(all_preds, all_targets) + self.log("test/f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/recall", self.test_recall, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/precision", self.test_precision, on_step=False, on_epoch=True, prog_bar=True) + def setup(self, stage: Optional[str] = None) -> None: """Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict. From 9ed88381b37ec1c0691dd5e24dae5b1f2d13222e Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Wed, 6 Nov 2024 16:06:57 -0500 Subject: [PATCH 09/11] Fix GT/Pred plotting --- tcn_hpl/callbacks/plot_metrics.py | 140 +++++++++++++++++------------- 1 file changed, 78 insertions(+), 62 deletions(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 894b129c3..7fdd99309 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -1,13 +1,10 @@ from collections import defaultdict from pathlib import Path -import typing as ty from typing import Any, Dict, Optional, Tuple -import kwcoco import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl -from debugpy.common.timestamp import current from pytorch_lightning.callbacks import Callback import seaborn as sns from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -20,19 +17,50 @@ Image = None +def create_video_frame_gt_preds( + all_targets: torch.Tensor, + all_preds: torch.Tensor, + all_source_vids: torch.Tensor, + all_source_frames: torch.Tensor, +) -> Dict[int, Dict[int, Tuple[int, int]]]: + """ + Create a two-layer mapping from video ID to frame ID to pair of (gt, pred) + class IDs. + + :param all_targets: Tensor of all target window class IDs. + :param all_preds: Tensor of all predicted window class IDs. + :param all_source_vids: Tensor of video IDs for the window. + :param all_source_frames: Tensor of video frame number for the final + frame of windows. + + :return: New mapping. + """ + per_video_frame_gt_preds = defaultdict(dict) + for (gt, pred, source_vid, source_frame) in zip( + all_targets, all_preds, all_source_vids, all_source_frames + ): + per_video_frame_gt_preds[source_vid.item()][source_frame.item()] = (gt.item(), pred.item()) + return per_video_frame_gt_preds + + def plot_gt_vs_preds( output_dir: Path, - per_video_frame_gt_preds: ty.Dict[ty.Any, ty.Dict[int, ty.Tuple[int, int]]], + per_video_frame_gt_preds: Dict[int, Dict[int, Tuple[int, int]]], + epoch: int, split="train", - max_items=30, + max_items=np.inf, ) -> None: """ Plot activity classification truth and predictions through the course of a video's frames as lines. + Successive calls to this function will overwrite any images in the given + output directory for the given split. + :param output_dir: Base directory into which to save plots. :param per_video_frame_gt_preds: Mapping of video-to-frames-to-tuple, where the tuple is the (gt, pred) pair of class IDs for the respective frame. + :param epoch: The current epoch number for the plot title. :param split: Which train/val/test split the input is for. This will influence the names of files generated. :param max_items: Only consider the first N videos in the given @@ -63,7 +91,7 @@ def plot_gt_vs_preds( label="Pred", ax=ax, ).set( - title=f"{split} Step Prediction Per Frame", + title=f"{split} Step Prediction Per Frame (Epoch {epoch})", xlabel="Index", ylabel="Step", ) @@ -74,7 +102,7 @@ def plot_gt_vs_preds( root_dir.mkdir(parents=True, exist_ok=True) fig.savefig(root_dir / f"{split}_vid{video:03d}.jpg", pad_inches=5) - plt.close() + plt.close(fig) class PlotMetrics(Callback): @@ -119,23 +147,6 @@ def __init__( # care for plotting outputs for. self._has_begun_training = False - # self.topic = topic - # - # # Get Action Names - # mapping_file = f"{self.hparams.data_dir}/{mapping_file_name}" - # actions_dict = dict() - # with open(mapping_file, "r") as file_ptr: - # actions = file_ptr.readlines() - # actions = [a.strip() for a in actions] # drop leading/trailing whitespace - # for a in actions: - # parts = a.split() # split on any number of whitespace - # actions_dict[parts[1]] = int(parts[0]) - # - # self.class_ids = list(actions_dict.values()) - # self.classes = list(actions_dict.keys()) - # - # self.action_id_to_str = dict(zip(self.class_ids, self.classes)) - def on_train_batch_end( self, trainer: "pl.Trainer", @@ -166,25 +177,26 @@ def on_train_epoch_end( curr_acc = pl_module.train_acc.compute() curr_f1 = pl_module.train_f1.compute() - class_ids = np.arange(all_probs.shape[-1]) - num_classes = len(class_ids) - # # Plot per-video class predictions vs. GT across progressive frames in # that video. # - # Build up mapping of truth to preds for each video - per_video_frame_gt_preds = defaultdict(dict) - for (gt, pred, prob, source_vid, source_frame) in zip( - all_targets, all_preds, all_probs, all_source_vids, all_source_frames - ): - per_video_frame_gt_preds[source_vid][source_frame] = (int(gt), int(pred)) - - plot_gt_vs_preds(self.output_dir, per_video_frame_gt_preds, split="train") + plot_gt_vs_preds( + self.output_dir, + create_video_frame_gt_preds( + all_targets, + all_preds, + all_source_vids, + all_source_frames + ), + epoch=current_epoch, + split="train", + ) # # Create confusion matrix # + class_ids = np.arange(all_probs.shape[-1]) cm = confusion_matrix( all_targets.cpu().numpy(), all_preds.cpu().numpy(), @@ -192,6 +204,7 @@ def on_train_epoch_end( normalize="true", ) + num_classes = len(class_ids) fig, ax = plt.subplots(figsize=(num_classes, num_classes)) sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) @@ -256,25 +269,10 @@ def on_validation_epoch_end( curr_f1 = pl_module.val_f1.compute() best_f1 = pl_module.val_f1_best.compute() - class_ids = np.arange(all_probs.shape[-1]) - num_classes = len(class_ids) - - # - # Plot per-video class predictions vs. GT across progressive frames in - # that video. - # - # Build up mapping of truth to preds for each video - per_video_frame_gt_preds = defaultdict(dict) - for (gt, pred, prob, source_vid, source_frame) in zip( - all_targets, all_preds, all_probs, all_source_vids, all_source_frames - ): - per_video_frame_gt_preds[source_vid][source_frame] = (int(gt), int(pred)) - - plot_gt_vs_preds(self.output_dir, per_video_frame_gt_preds, split="validation") - # # Create confusion matrix # + class_ids = np.arange(all_probs.shape[-1]) cm = confusion_matrix( all_targets.numpy(), all_preds.numpy(), @@ -282,6 +280,7 @@ def on_validation_epoch_end( normalize="true", ) + num_classes = len(class_ids) fig, ax = plt.subplots(figsize=(num_classes, num_classes)) sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) @@ -302,6 +301,21 @@ def on_validation_epoch_end( / f"confusion_mat_val_epoch{current_epoch:04d}_acc_{curr_acc:.4f}_f1_{curr_f1:.4f}.jpg", pad_inches=5, ) + # + # Plot per-video class predictions vs. GT across progressive frames in + # that video. + # + plot_gt_vs_preds( + self.output_dir, + create_video_frame_gt_preds( + all_targets, + all_preds, + all_source_vids, + all_source_frames + ), + epoch=current_epoch, + split="validation", + ) plt.close(fig) @@ -348,21 +362,21 @@ def on_test_epoch_end( test_acc = pl_module.test_acc.compute() test_f1 = pl_module.test_f1.compute() - class_ids = np.arange(all_probs.shape[-1]) - num_classes = len(class_ids) - # # Plot per-video class predictions vs. GT across progressive frames in # that video. # - # Build up mapping of truth to preds for each video - per_video_frame_gt_preds = defaultdict(dict) - for (gt, pred, prob, source_vid, source_frame) in zip( - all_targets, all_preds, all_probs, all_source_vids, all_source_frames - ): - per_video_frame_gt_preds[source_vid][source_frame] = (int(gt), int(pred)) - - plot_gt_vs_preds(self.output_dir, per_video_frame_gt_preds, split="test") + plot_gt_vs_preds( + self.output_dir, + create_video_frame_gt_preds( + all_targets, + all_preds, + all_source_vids, + all_source_frames + ), + epoch=current_epoch, + split="test", + ) # Built a COCO dataset of test results to output. # TODO: Configure activity test COCO file as input. @@ -371,6 +385,7 @@ def on_test_epoch_end( # # Create confusion matrix # + class_ids = np.arange(all_probs.shape[-1]) cm = confusion_matrix( all_targets.cpu().numpy(), all_preds.cpu().numpy(), @@ -378,6 +393,7 @@ def on_test_epoch_end( normalize="true", ) + num_classes = len(class_ids) fig, ax = plt.subplots(figsize=(num_classes, num_classes)) sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) From 925296f21a8d56154981e09b887bfc5b0260eb4c Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Wed, 6 Nov 2024 16:54:38 -0500 Subject: [PATCH 10/11] Fixes for rebase with new vectorizer --- .../frame_data_aug/window_frame_dropout.py | 2 + tcn_hpl/data/vectorize/locs_and_confs.py | 65 ++++++++++--------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/tcn_hpl/data/frame_data_aug/window_frame_dropout.py b/tcn_hpl/data/frame_data_aug/window_frame_dropout.py index 5d2870c5a..f3e40d096 100644 --- a/tcn_hpl/data/frame_data_aug/window_frame_dropout.py +++ b/tcn_hpl/data/frame_data_aug/window_frame_dropout.py @@ -175,6 +175,7 @@ def forward(self, window: Sequence[FrameData]) -> List[FrameData]: window[idx].object_detections if dets_mask[idx] else None ), poses=(window[idx].poses if pose_mask[idx] else None), + size=window[idx].size, # forward existing value ) for idx in range(n_frames) ] @@ -197,6 +198,7 @@ def test(): joint_positions=np.array([[[10, 20], [30, 40], [50, 60]]]), joint_scores=np.array([[0.9, 0.85, 0.8]]), ), + size=(500, 500), ) sequence = [frame1] * 25 # transform = DropoutFrameDataTransform( diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index 7f911203d..5a7308ba8 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -1,6 +1,3 @@ -import functools -import typing as tg - import numpy as np from numpy import typing as npt @@ -73,7 +70,7 @@ def append_vector(frame_feat, i, number): frame_feat[i] = number return frame_feat, i + 1 - def determine_vector_length(self, data: FrameData) -> int: + def determine_vector_length(self) -> int: ######################### # Feature vector ######################### @@ -85,19 +82,21 @@ def determine_vector_length(self, data: FrameData) -> int: # obj H * num_objects(7 for M2) # casualty conf * 1 vector_length = 0 + # [Conf, X, Y, W, H] for k instances of each object class. + vector_length += 5 * self._top_k * self._num_classes + # Pose confidence score + vector_length += 1 # Joint confidences if self._use_joint_confs: vector_length += NUM_POSE_JOINTS # X and Y for each joint vector_length += 2 * NUM_POSE_JOINTS - # [Conf, X, Y, W, H] for k instances of each object class. - vector_length = 5 * self._top_k * self._num_classes return vector_length def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: - vector_len = self.determine_vector_length(data) + vector_len = self.determine_vector_length() frame_feat = np.zeros(vector_len, dtype=np.float32) # TODO: instead of carrying around this vector_ind, we should # directly compute the offset of each feature we add to the TCN @@ -109,29 +108,34 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: else: W = 1 H = 1 - f_dets = data.object_detections - # Loop through all classes: populate obj conf, obj X, obj Y. - # Assumption: class labels are [0, 1, 2,... num_classes-1]. - # TODO: this will break if top_k is ever > 1. Fix that. - for obj_ind in range(0,self._num_classes): - top_k_idxs = self.get_top_k_indexes_of_one_obj_type(f_dets, self._top_k, obj_ind) - if top_k_idxs: # This is None if there were no detections to sort for this class - for idx in top_k_idxs: - # Conf - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.scores[idx]) - # X - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][0] / W) - # Y - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][1] / H) - # W - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][2] / W) - # H - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][3] / H) - else: - for _ in range(0, self._top_k * 5): - # 5 Zeros - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) + f_dets = data.object_detections + if f_dets: + # Loop through all classes: populate obj conf, obj X, obj Y. + # Assumption: class labels are [0, 1, 2,... num_classes-1]. + # TODO: this will break if top_k is ever > 1. Fix that. + for obj_ind in range(0,self._num_classes): + top_k_idxs = self.get_top_k_indexes_of_one_obj_type(f_dets, self._top_k, obj_ind) + if top_k_idxs: # This is None if there were no detections to sort for this class + for idx in top_k_idxs: + # Conf + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.scores[idx]) + # X + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][0] / W) + # Y + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][1] / H) + # W + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][2] / W) + # H + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][3] / H) + else: + for _ in range(0, self._top_k * 5): + # 5 Zeros + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) + else: + # No detections, fill in appropriate amount of zeros. + for _ in range(self._num_classes * self._top_k * 5): + frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) f_poses = data.poses if f_poses: @@ -149,12 +153,11 @@ def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: # Y frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.joint_positions[confident_pose_idx][joint_ind][1] / H) else: - num_joints = f_poses.joint_positions.shape[1] if self._use_joint_confs: rows_per_joint = 3 else: rows_per_joint = 2 - for _ in range(num_joints * rows_per_joint + 1): + for _ in range(NUM_POSE_JOINTS * rows_per_joint + 1): frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) assert vector_ind == vector_len From efee4906d09738715a756ef50dd334dc3b39eb1e Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Wed, 6 Nov 2024 20:57:55 -0500 Subject: [PATCH 11/11] Optimization and edge case fixes to LocsConfs vectorizer --- tcn_hpl/data/tcn_dataset.py | 22 ++-- tcn_hpl/data/vectorize/locs_and_confs.py | 135 ++++++++++++----------- 2 files changed, 86 insertions(+), 71 deletions(-) diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index 701d2193b..e89599ac7 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -513,7 +513,7 @@ def test_dataset_for_input( top_k = 1, num_classes = 7, use_joint_confs = True, - use_pixel_norm = True, + use_pixel_norm = False, use_joint_obj_offsets = False, background_idx = 0, ) @@ -554,7 +554,7 @@ def test_dataset_for_input( # Test that we can iterate over the dataset using a DataLoader with # shuffling. - batch_size = 1 # 16384 # 512 # 16 + batch_size = 32 # 512 data_loader = DataLoader( dataset, batch_size=batch_size, @@ -576,7 +576,11 @@ def test_dataset_for_input( logger.info(f"Windows per-second: {count / duration}") # Test creating online mode with subset of data from above. - dset_online = TCNDataset(window_size=window_size, vectorize=vectorize) + dset_online = TCNDataset( + window_size=window_size, + vectorize=vectorize, + transform_frame_data=transform_frame_data, + ) dset_online.load_data_online(dataset._frame_data[:window_size]) # noqa assert len(dset_online) == 1, "Online dataset should be size 1" _ = dset_online[0] @@ -587,11 +591,13 @@ def test_dataset_for_input( except IndexError: failed_index_error = False assert not failed_index_error, "Should have had an index error at [1]" - assert ( # noqa - dataset[0][0] == dset_online[0][0] - ).all(), ( - "Online should have produced same window matrix as offline version." - ) + # With augmentation, this can no longer be expected because of random + # variation per access. + # assert ( # noqa + # dataset[0][0] == dset_online[0][0] + # ).all(), ( + # "Online should have produced same window matrix as offline version." + # ) if __name__ == "__main__": diff --git a/tcn_hpl/data/vectorize/locs_and_confs.py b/tcn_hpl/data/vectorize/locs_and_confs.py index 5a7308ba8..00ec274ba 100644 --- a/tcn_hpl/data/vectorize/locs_and_confs.py +++ b/tcn_hpl/data/vectorize/locs_and_confs.py @@ -1,6 +1,9 @@ +from typing import List + import numpy as np from numpy import typing as npt +from tcn_hpl.data.frame_data import FrameObjectDetections from tcn_hpl.data.vectorize._interface import Vectorize, FrameData @@ -34,7 +37,7 @@ def __init__( use_joint_confs: bool = True, use_pixel_norm: bool = True, use_joint_obj_offsets: bool = False, - background_idx: int = 0 + background_idx: int = 0, ): super().__init__() @@ -45,25 +48,26 @@ def __init__( self._use_joint_obj_offsets = use_joint_obj_offsets self._background_idx = background_idx - # Get the top "k" object indexes for each object @staticmethod - def get_top_k_indexes_of_one_obj_type(f_dets, k, label_ind): + def get_top_k_indexes_of_one_obj_type( + f_dets: FrameObjectDetections, + k: int, + label_ind: int, + ) -> List[int]: """ Find all instances of a label index in object detections. Then sort them and return the top K. Inputs: - object_dets: """ - labels = f_dets.labels scores = f_dets.scores # Get all labels of an obj type - filtered_idxs = [i for i, e in enumerate(labels) if e == label_ind] - if not filtered_idxs: - return None + filtered_idxs = [i for i, e in enumerate(f_dets.labels) if e == label_ind] + # Sort filtered indices return by highest score filtered_scores = [scores[i] for i in filtered_idxs] - # Sort labels by score values. - sorted_inds = [i[1] for i in sorted(zip(filtered_scores, filtered_idxs))] - return sorted_inds[:k] + return [ + i[1] for i in sorted(zip(filtered_scores, filtered_idxs), reverse=True)[:k] + ] @staticmethod def append_vector(frame_feat, i, number): @@ -93,73 +97,78 @@ def determine_vector_length(self) -> int: vector_length += 2 * NUM_POSE_JOINTS return vector_length - def vectorize(self, data: FrameData) -> npt.NDArray[np.float32]: + # I tried utilizing range assignment into frame_feat, but this was + # empirically not as fast as this method in the context of being run + # within a torch DataLoader. + # E.g. instead of + # for i, det_idx in enumerate(top_det_idxs): + # topk_offset = obj_offset + (i * 5) + # frame_feat[topk_offset + 0] = f_dets.scores[det_idx] + # frame_feat[topk_offset + 1] = f_dets.boxes[det_idx][0] / w + # frame_feat[topk_offset + 2] = f_dets.boxes[det_idx][1] / h + # frame_feat[topk_offset + 3] = f_dets.boxes[det_idx][2] / w + # frame_feat[topk_offset + 4] = f_dets.boxes[det_idx][3] / h + # doing: + # obj_end_idx = obj_offset + (len(top_det_idxs) * 5) + # frame_feat[obj_offset + 0:obj_end_idx:5] = f_dets.scores[top_det_idxs] + # frame_feat[obj_offset + 1:obj_end_idx:5] = f_dets.boxes[top_det_idxs, 0] / w + # frame_feat[obj_offset + 2:obj_end_idx:5] = f_dets.boxes[top_det_idxs, 1] / h + # frame_feat[obj_offset + 3:obj_end_idx:5] = f_dets.boxes[top_det_idxs, 2] / w + # frame_feat[obj_offset + 4:obj_end_idx:5] = f_dets.boxes[top_det_idxs, 3] / h + # Was *slower* in the context of batched computation. vector_len = self.determine_vector_length() frame_feat = np.zeros(vector_len, dtype=np.float32) - # TODO: instead of carrying around this vector_ind, we should - # directly compute the offset of each feature we add to the TCN - # input vector. This would be much easier to debug. - vector_ind = 0 + if self._use_pixel_norm: - W = data.size[0] - H = data.size[1] + w = data.size[0] + h = data.size[1] else: - W = 1 - H = 1 + w = 1 + h = 1 + + obj_num_classes = self._num_classes + obj_top_k = self._top_k + + # Indices into the feature vector where components start + objs_start_offset = 0 + pose_start_offset = obj_num_classes * obj_top_k * 5 f_dets = data.object_detections if f_dets: - # Loop through all classes: populate obj conf, obj X, obj Y. - # Assumption: class labels are [0, 1, 2,... num_classes-1]. - # TODO: this will break if top_k is ever > 1. Fix that. - for obj_ind in range(0,self._num_classes): - top_k_idxs = self.get_top_k_indexes_of_one_obj_type(f_dets, self._top_k, obj_ind) - if top_k_idxs: # This is None if there were no detections to sort for this class - for idx in top_k_idxs: - # Conf - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.scores[idx]) - # X - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][0] / W) - # Y - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][1] / H) - # W - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][2] / W) - # H - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_dets.boxes[idx][3] / H) - else: - for _ in range(0, self._top_k * 5): - # 5 Zeros - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) - else: - # No detections, fill in appropriate amount of zeros. - for _ in range(self._num_classes * self._top_k * 5): - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) + for obj_ind in range(obj_num_classes): + obj_offset = objs_start_offset + (obj_ind * obj_top_k * 5) + top_det_idxs = self.get_top_k_indexes_of_one_obj_type( + f_dets, obj_top_k, obj_ind + ) + for i, det_idx in enumerate(top_det_idxs): + topk_offset = obj_offset + (i * 5) + frame_feat[topk_offset + 0] = f_dets.scores[det_idx] + frame_feat[topk_offset + 1] = f_dets.boxes[det_idx][0] / w + frame_feat[topk_offset + 2] = f_dets.boxes[det_idx][1] / h + frame_feat[topk_offset + 3] = f_dets.boxes[det_idx][2] / w + frame_feat[topk_offset + 4] = f_dets.boxes[det_idx][3] / h + # If there are less than top_k indices returned, the vector was + # already initialized to zero so nothing else to do. f_poses = data.poses if f_poses: # Find most confident body detection confident_pose_idx = np.argmax(f_poses.scores) - num_joints = f_poses.joint_positions.shape[1] - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.scores[confident_pose_idx]) - - for joint_ind in range(0, num_joints): - # Conf + frame_feat[pose_start_offset] = f_poses.scores[confident_pose_idx] + pose_kp_offset = pose_start_offset + 1 + for joint_ind in range(NUM_POSE_JOINTS): + joint_offset = pose_kp_offset + (joint_ind * 3) if self._use_joint_confs: - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.joint_scores[confident_pose_idx][joint_ind]) - # X - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.joint_positions[confident_pose_idx][joint_ind][0] / W) - # Y - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, f_poses.joint_positions[confident_pose_idx][joint_ind][1] / H) - else: - if self._use_joint_confs: - rows_per_joint = 3 - else: - rows_per_joint = 2 - for _ in range(NUM_POSE_JOINTS * rows_per_joint + 1): - frame_feat, vector_ind = self.append_vector(frame_feat, vector_ind, 0) - - assert vector_ind == vector_len + frame_feat[joint_offset] = f_poses.joint_scores[ + confident_pose_idx, joint_ind + ] + frame_feat[joint_offset + 1] = ( + f_poses.joint_positions[confident_pose_idx, joint_ind, 0] / w + ) + frame_feat[joint_offset + 2] = ( + f_poses.joint_positions[confident_pose_idx, joint_ind, 1] / h + ) return frame_feat