Skip to content

Commit

Permalink
Fix GT/Pred plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
Purg committed Nov 6, 2024
1 parent 740fa2a commit 497434f
Showing 1 changed file with 78 additions and 62 deletions.
140 changes: 78 additions & 62 deletions tcn_hpl/callbacks/plot_metrics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from collections import defaultdict
from pathlib import Path
import typing as ty
from typing import Any, Dict, Optional, Tuple

import kwcoco
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
from debugpy.common.timestamp import current
from pytorch_lightning.callbacks import Callback
import seaborn as sns
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -20,19 +17,50 @@
Image = None


def create_video_frame_gt_preds(
all_targets: torch.Tensor,
all_preds: torch.Tensor,
all_source_vids: torch.Tensor,
all_source_frames: torch.Tensor,
) -> Dict[int, Dict[int, Tuple[int, int]]]:
"""
Create a two-layer mapping from video ID to frame ID to pair of (gt, pred)
class IDs.
:param all_targets: Tensor of all target window class IDs.
:param all_preds: Tensor of all predicted window class IDs.
:param all_source_vids: Tensor of video IDs for the window.
:param all_source_frames: Tensor of video frame number for the final
frame of windows.
:return: New mapping.
"""
per_video_frame_gt_preds = defaultdict(dict)
for (gt, pred, source_vid, source_frame) in zip(
all_targets, all_preds, all_source_vids, all_source_frames
):
per_video_frame_gt_preds[source_vid.item()][source_frame.item()] = (gt.item(), pred.item())
return per_video_frame_gt_preds


def plot_gt_vs_preds(
output_dir: Path,
per_video_frame_gt_preds: ty.Dict[ty.Any, ty.Dict[int, ty.Tuple[int, int]]],
per_video_frame_gt_preds: Dict[int, Dict[int, Tuple[int, int]]],
epoch: int,
split="train",
max_items=30,
max_items=np.inf,
) -> None:
"""
Plot activity classification truth and predictions through the course of a
video's frames as lines.
Successive calls to this function will overwrite any images in the given
output directory for the given split.
:param output_dir: Base directory into which to save plots.
:param per_video_frame_gt_preds: Mapping of video-to-frames-to-tuple, where
the tuple is the (gt, pred) pair of class IDs for the respective frame.
:param epoch: The current epoch number for the plot title.
:param split: Which train/val/test split the input is for. This will
influence the names of files generated.
:param max_items: Only consider the first N videos in the given
Expand Down Expand Up @@ -63,7 +91,7 @@ def plot_gt_vs_preds(
label="Pred",
ax=ax,
).set(
title=f"{split} Step Prediction Per Frame",
title=f"{split} Step Prediction Per Frame (Epoch {epoch})",
xlabel="Index",
ylabel="Step",
)
Expand All @@ -74,7 +102,7 @@ def plot_gt_vs_preds(
root_dir.mkdir(parents=True, exist_ok=True)

fig.savefig(root_dir / f"{split}_vid{video:03d}.jpg", pad_inches=5)
plt.close()
plt.close(fig)


class PlotMetrics(Callback):
Expand Down Expand Up @@ -119,23 +147,6 @@ def __init__(
# care for plotting outputs for.
self._has_begun_training = False

# self.topic = topic
#
# # Get Action Names
# mapping_file = f"{self.hparams.data_dir}/{mapping_file_name}"
# actions_dict = dict()
# with open(mapping_file, "r") as file_ptr:
# actions = file_ptr.readlines()
# actions = [a.strip() for a in actions] # drop leading/trailing whitespace
# for a in actions:
# parts = a.split() # split on any number of whitespace
# actions_dict[parts[1]] = int(parts[0])
#
# self.class_ids = list(actions_dict.values())
# self.classes = list(actions_dict.keys())
#
# self.action_id_to_str = dict(zip(self.class_ids, self.classes))

def on_train_batch_end(
self,
trainer: "pl.Trainer",
Expand Down Expand Up @@ -166,32 +177,34 @@ def on_train_epoch_end(
curr_acc = pl_module.train_acc.compute()
curr_f1 = pl_module.train_f1.compute()

class_ids = np.arange(all_probs.shape[-1])
num_classes = len(class_ids)

#
# Plot per-video class predictions vs. GT across progressive frames in
# that video.
#
# Build up mapping of truth to preds for each video
per_video_frame_gt_preds = defaultdict(dict)
for (gt, pred, prob, source_vid, source_frame) in zip(
all_targets, all_preds, all_probs, all_source_vids, all_source_frames
):
per_video_frame_gt_preds[source_vid][source_frame] = (int(gt), int(pred))

plot_gt_vs_preds(self.output_dir, per_video_frame_gt_preds, split="train")
plot_gt_vs_preds(
self.output_dir,
create_video_frame_gt_preds(
all_targets,
all_preds,
all_source_vids,
all_source_frames
),
epoch=current_epoch,
split="train",
)

#
# Create confusion matrix
#
class_ids = np.arange(all_probs.shape[-1])
cm = confusion_matrix(
all_targets.cpu().numpy(),
all_preds.cpu().numpy(),
labels=class_ids,
normalize="true",
)

num_classes = len(class_ids)
fig, ax = plt.subplots(figsize=(num_classes, num_classes))

sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1)
Expand Down Expand Up @@ -256,32 +269,18 @@ def on_validation_epoch_end(
curr_f1 = pl_module.val_f1.compute()
best_f1 = pl_module.val_f1_best.compute()

class_ids = np.arange(all_probs.shape[-1])
num_classes = len(class_ids)

#
# Plot per-video class predictions vs. GT across progressive frames in
# that video.
#
# Build up mapping of truth to preds for each video
per_video_frame_gt_preds = defaultdict(dict)
for (gt, pred, prob, source_vid, source_frame) in zip(
all_targets, all_preds, all_probs, all_source_vids, all_source_frames
):
per_video_frame_gt_preds[source_vid][source_frame] = (int(gt), int(pred))

plot_gt_vs_preds(self.output_dir, per_video_frame_gt_preds, split="validation")

#
# Create confusion matrix
#
class_ids = np.arange(all_probs.shape[-1])
cm = confusion_matrix(
all_targets.numpy(),
all_preds.numpy(),
labels=class_ids,
normalize="true",
)

num_classes = len(class_ids)
fig, ax = plt.subplots(figsize=(num_classes, num_classes))

sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1)
Expand All @@ -302,6 +301,21 @@ def on_validation_epoch_end(
/ f"confusion_mat_val_epoch{current_epoch:04d}_acc_{curr_acc:.4f}_f1_{curr_f1:.4f}.jpg",
pad_inches=5,
)
#
# Plot per-video class predictions vs. GT across progressive frames in
# that video.
#
plot_gt_vs_preds(
self.output_dir,
create_video_frame_gt_preds(
all_targets,
all_preds,
all_source_vids,
all_source_frames
),
epoch=current_epoch,
split="validation",
)

plt.close(fig)

Expand Down Expand Up @@ -348,21 +362,21 @@ def on_test_epoch_end(
test_acc = pl_module.test_acc.compute()
test_f1 = pl_module.test_f1.compute()

class_ids = np.arange(all_probs.shape[-1])
num_classes = len(class_ids)

#
# Plot per-video class predictions vs. GT across progressive frames in
# that video.
#
# Build up mapping of truth to preds for each video
per_video_frame_gt_preds = defaultdict(dict)
for (gt, pred, prob, source_vid, source_frame) in zip(
all_targets, all_preds, all_probs, all_source_vids, all_source_frames
):
per_video_frame_gt_preds[source_vid][source_frame] = (int(gt), int(pred))

plot_gt_vs_preds(self.output_dir, per_video_frame_gt_preds, split="test")
plot_gt_vs_preds(
self.output_dir,
create_video_frame_gt_preds(
all_targets,
all_preds,
all_source_vids,
all_source_frames
),
epoch=current_epoch,
split="test",
)

# Built a COCO dataset of test results to output.
# TODO: Configure activity test COCO file as input.
Expand All @@ -371,13 +385,15 @@ def on_test_epoch_end(
#
# Create confusion matrix
#
class_ids = np.arange(all_probs.shape[-1])
cm = confusion_matrix(
all_targets.cpu().numpy(),
all_preds.cpu().numpy(),
labels=class_ids,
normalize="true",
)

num_classes = len(class_ids)
fig, ax = plt.subplots(figsize=(num_classes, num_classes))

sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1)
Expand Down

0 comments on commit 497434f

Please sign in to comment.