From 497434fca6335a1a9d685051d3c07702328c8d1e Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Wed, 6 Nov 2024 16:06:57 -0500 Subject: [PATCH] Fix GT/Pred plotting --- tcn_hpl/callbacks/plot_metrics.py | 140 +++++++++++++++++------------- 1 file changed, 78 insertions(+), 62 deletions(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 894b129c3..7fdd99309 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -1,13 +1,10 @@ from collections import defaultdict from pathlib import Path -import typing as ty from typing import Any, Dict, Optional, Tuple -import kwcoco import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl -from debugpy.common.timestamp import current from pytorch_lightning.callbacks import Callback import seaborn as sns from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -20,19 +17,50 @@ Image = None +def create_video_frame_gt_preds( + all_targets: torch.Tensor, + all_preds: torch.Tensor, + all_source_vids: torch.Tensor, + all_source_frames: torch.Tensor, +) -> Dict[int, Dict[int, Tuple[int, int]]]: + """ + Create a two-layer mapping from video ID to frame ID to pair of (gt, pred) + class IDs. + + :param all_targets: Tensor of all target window class IDs. + :param all_preds: Tensor of all predicted window class IDs. + :param all_source_vids: Tensor of video IDs for the window. + :param all_source_frames: Tensor of video frame number for the final + frame of windows. + + :return: New mapping. + """ + per_video_frame_gt_preds = defaultdict(dict) + for (gt, pred, source_vid, source_frame) in zip( + all_targets, all_preds, all_source_vids, all_source_frames + ): + per_video_frame_gt_preds[source_vid.item()][source_frame.item()] = (gt.item(), pred.item()) + return per_video_frame_gt_preds + + def plot_gt_vs_preds( output_dir: Path, - per_video_frame_gt_preds: ty.Dict[ty.Any, ty.Dict[int, ty.Tuple[int, int]]], + per_video_frame_gt_preds: Dict[int, Dict[int, Tuple[int, int]]], + epoch: int, split="train", - max_items=30, + max_items=np.inf, ) -> None: """ Plot activity classification truth and predictions through the course of a video's frames as lines. + Successive calls to this function will overwrite any images in the given + output directory for the given split. + :param output_dir: Base directory into which to save plots. :param per_video_frame_gt_preds: Mapping of video-to-frames-to-tuple, where the tuple is the (gt, pred) pair of class IDs for the respective frame. + :param epoch: The current epoch number for the plot title. :param split: Which train/val/test split the input is for. This will influence the names of files generated. :param max_items: Only consider the first N videos in the given @@ -63,7 +91,7 @@ def plot_gt_vs_preds( label="Pred", ax=ax, ).set( - title=f"{split} Step Prediction Per Frame", + title=f"{split} Step Prediction Per Frame (Epoch {epoch})", xlabel="Index", ylabel="Step", ) @@ -74,7 +102,7 @@ def plot_gt_vs_preds( root_dir.mkdir(parents=True, exist_ok=True) fig.savefig(root_dir / f"{split}_vid{video:03d}.jpg", pad_inches=5) - plt.close() + plt.close(fig) class PlotMetrics(Callback): @@ -119,23 +147,6 @@ def __init__( # care for plotting outputs for. self._has_begun_training = False - # self.topic = topic - # - # # Get Action Names - # mapping_file = f"{self.hparams.data_dir}/{mapping_file_name}" - # actions_dict = dict() - # with open(mapping_file, "r") as file_ptr: - # actions = file_ptr.readlines() - # actions = [a.strip() for a in actions] # drop leading/trailing whitespace - # for a in actions: - # parts = a.split() # split on any number of whitespace - # actions_dict[parts[1]] = int(parts[0]) - # - # self.class_ids = list(actions_dict.values()) - # self.classes = list(actions_dict.keys()) - # - # self.action_id_to_str = dict(zip(self.class_ids, self.classes)) - def on_train_batch_end( self, trainer: "pl.Trainer", @@ -166,25 +177,26 @@ def on_train_epoch_end( curr_acc = pl_module.train_acc.compute() curr_f1 = pl_module.train_f1.compute() - class_ids = np.arange(all_probs.shape[-1]) - num_classes = len(class_ids) - # # Plot per-video class predictions vs. GT across progressive frames in # that video. # - # Build up mapping of truth to preds for each video - per_video_frame_gt_preds = defaultdict(dict) - for (gt, pred, prob, source_vid, source_frame) in zip( - all_targets, all_preds, all_probs, all_source_vids, all_source_frames - ): - per_video_frame_gt_preds[source_vid][source_frame] = (int(gt), int(pred)) - - plot_gt_vs_preds(self.output_dir, per_video_frame_gt_preds, split="train") + plot_gt_vs_preds( + self.output_dir, + create_video_frame_gt_preds( + all_targets, + all_preds, + all_source_vids, + all_source_frames + ), + epoch=current_epoch, + split="train", + ) # # Create confusion matrix # + class_ids = np.arange(all_probs.shape[-1]) cm = confusion_matrix( all_targets.cpu().numpy(), all_preds.cpu().numpy(), @@ -192,6 +204,7 @@ def on_train_epoch_end( normalize="true", ) + num_classes = len(class_ids) fig, ax = plt.subplots(figsize=(num_classes, num_classes)) sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) @@ -256,25 +269,10 @@ def on_validation_epoch_end( curr_f1 = pl_module.val_f1.compute() best_f1 = pl_module.val_f1_best.compute() - class_ids = np.arange(all_probs.shape[-1]) - num_classes = len(class_ids) - - # - # Plot per-video class predictions vs. GT across progressive frames in - # that video. - # - # Build up mapping of truth to preds for each video - per_video_frame_gt_preds = defaultdict(dict) - for (gt, pred, prob, source_vid, source_frame) in zip( - all_targets, all_preds, all_probs, all_source_vids, all_source_frames - ): - per_video_frame_gt_preds[source_vid][source_frame] = (int(gt), int(pred)) - - plot_gt_vs_preds(self.output_dir, per_video_frame_gt_preds, split="validation") - # # Create confusion matrix # + class_ids = np.arange(all_probs.shape[-1]) cm = confusion_matrix( all_targets.numpy(), all_preds.numpy(), @@ -282,6 +280,7 @@ def on_validation_epoch_end( normalize="true", ) + num_classes = len(class_ids) fig, ax = plt.subplots(figsize=(num_classes, num_classes)) sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) @@ -302,6 +301,21 @@ def on_validation_epoch_end( / f"confusion_mat_val_epoch{current_epoch:04d}_acc_{curr_acc:.4f}_f1_{curr_f1:.4f}.jpg", pad_inches=5, ) + # + # Plot per-video class predictions vs. GT across progressive frames in + # that video. + # + plot_gt_vs_preds( + self.output_dir, + create_video_frame_gt_preds( + all_targets, + all_preds, + all_source_vids, + all_source_frames + ), + epoch=current_epoch, + split="validation", + ) plt.close(fig) @@ -348,21 +362,21 @@ def on_test_epoch_end( test_acc = pl_module.test_acc.compute() test_f1 = pl_module.test_f1.compute() - class_ids = np.arange(all_probs.shape[-1]) - num_classes = len(class_ids) - # # Plot per-video class predictions vs. GT across progressive frames in # that video. # - # Build up mapping of truth to preds for each video - per_video_frame_gt_preds = defaultdict(dict) - for (gt, pred, prob, source_vid, source_frame) in zip( - all_targets, all_preds, all_probs, all_source_vids, all_source_frames - ): - per_video_frame_gt_preds[source_vid][source_frame] = (int(gt), int(pred)) - - plot_gt_vs_preds(self.output_dir, per_video_frame_gt_preds, split="test") + plot_gt_vs_preds( + self.output_dir, + create_video_frame_gt_preds( + all_targets, + all_preds, + all_source_vids, + all_source_frames + ), + epoch=current_epoch, + split="test", + ) # Built a COCO dataset of test results to output. # TODO: Configure activity test COCO file as input. @@ -371,6 +385,7 @@ def on_test_epoch_end( # # Create confusion matrix # + class_ids = np.arange(all_probs.shape[-1]) cm = confusion_matrix( all_targets.cpu().numpy(), all_preds.cpu().numpy(), @@ -378,6 +393,7 @@ def on_test_epoch_end( normalize="true", ) + num_classes = len(class_ids) fig, ax = plt.subplots(figsize=(num_classes, num_classes)) sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1)