Skip to content

Commit

Permalink
Use best F1 instead of accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
Purg committed Nov 6, 2024
1 parent 6c3cfff commit 740fa2a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
6 changes: 3 additions & 3 deletions tcn_hpl/callbacks/plot_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ def on_validation_epoch_end(

current_epoch = pl_module.current_epoch
curr_acc = pl_module.val_acc.compute()
best_acc = pl_module.val_acc_best.compute()
curr_f1 = pl_module.val_f1.compute()
best_f1 = pl_module.val_f1_best.compute()

class_ids = np.arange(all_probs.shape[-1])
num_classes = len(class_ids)
Expand Down Expand Up @@ -296,7 +296,7 @@ def on_validation_epoch_end(
if Image is not None:
pl_module.logger.experiment.track(Image(fig), name=f"CM Validation Epoch")

if curr_acc >= best_acc:
if curr_f1 >= best_f1:
fig.savefig(
self.output_dir
/ f"confusion_mat_val_epoch{current_epoch:04d}_acc_{curr_acc:.4f}_f1_{curr_f1:.4f}.jpg",
Expand Down Expand Up @@ -380,7 +380,7 @@ def on_test_epoch_end(

fig, ax = plt.subplots(figsize=(num_classes, num_classes))

sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", vmin=0, vmax=1)
sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1)

# labels, title and ticks
ax.set_xlabel("Predicted labels")
Expand Down
41 changes: 24 additions & 17 deletions tcn_hpl/models/ptg_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def __init__(
self.test_loss = MeanMetric()

# for tracking best so far validation accuracy
self.val_acc_best = MaxMetric()
self.train_acc_best = MaxMetric()
self.val_f1_best = MaxMetric()

def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
"""Perform a forward pass through the model `self.net`.
Expand All @@ -141,13 +140,16 @@ def on_train_start(self) -> None:
# so it's worth to make sure validation metrics don't store results from these checks
self.val_loss.reset()
self.val_acc.reset()
self.val_acc_best.reset()
self.val_f1.reset()
self.val_recall.reset()
self.val_precision.reset()
self.val_f1_best.reset()

def compute_loss(self, p, y, mask):
"""Compute the total loss for a batch
:param p: The prediction
:param batch_target: The target labels
:param y: The target labels
:param mask: Marks valid input data
:return: The loss
Expand Down Expand Up @@ -325,20 +327,19 @@ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]])
all_preds = torch.cat([o['preds'] for o in outputs])
all_targets = torch.cat([o['targets'] for o in outputs])

acc = self.val_acc.compute()
# log `val_acc_best` as a value through `.compute()` return, instead of
# as a metric object otherwise metric would be reset by lightning after
# each epoch.
best_val_acc = self.val_acc_best(acc) # update best so far val acc
self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True)

self.val_f1(all_preds, all_targets)
self.val_recall(all_preds, all_targets)
self.val_precision(all_preds, all_targets)
self.log("val/f1", self.val_f1, prog_bar=True, on_epoch=True)
self.log("val/recall", self.val_recall, prog_bar=True, on_epoch=True)
self.log("val/precision", self.val_precision, prog_bar=True, on_epoch=True)

# log `val_f1_best` as a value through `.compute()` return, instead of
# as a metric object otherwise metric would be reset by lightning after
# each epoch.
self.val_f1_best(self.val_f1.compute())
self.log("val/f1_best", self.val_f1_best.compute(), prog_bar=True, on_epoch=True)

def test_step(
self,
batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
Expand All @@ -357,16 +358,10 @@ def test_step(
# update and log metrics
self.test_loss(loss)
self.test_acc(preds, targets[:, -1])
self.test_f1(preds, targets[:, -1])
self.test_recall(preds, targets[:, -1])
self.test_precision(preds, targets[:, -1])
self.log(
"test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True
)
self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/recall", self.test_recall, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/precision", self.test_precision, on_step=False, on_epoch=True, prog_bar=True)

# Only retain the truth and source vid/frame IDs for the final window
# frame as this is the ultimately relevant result.
Expand All @@ -379,6 +374,18 @@ def test_step(
"source_frame": source_frame[:, -1],
}

def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
all_preds = torch.cat([o['preds'] for o in outputs])
all_targets = torch.cat([o['targets'] for o in outputs])

# update and log metrics
self.test_f1(all_preds, all_targets)
self.test_recall(all_preds, all_targets)
self.test_precision(all_preds, all_targets)
self.log("test/f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/recall", self.test_recall, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/precision", self.test_precision, on_step=False, on_epoch=True, prog_bar=True)

def setup(self, stage: Optional[str] = None) -> None:
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
test, or predict.
Expand Down

0 comments on commit 740fa2a

Please sign in to comment.