diff --git a/classy_vision/hooks/tensorboard_plot_hook.py b/classy_vision/hooks/tensorboard_plot_hook.py index 63700bcba..2dada2c66 100644 --- a/classy_vision/hooks/tensorboard_plot_hook.py +++ b/classy_vision/hooks/tensorboard_plot_hook.py @@ -164,16 +164,16 @@ def on_phase_end(self, task) -> None: f"Parameters/{name}", parameter, global_step=phase_type_idx ) - if torch.cuda.is_available() and task.train: + if torch.cuda.is_available(): self.tb_writer.add_scalar( - "Memory/peak_allocated", + f"Memory/{phase_type}/peak_allocated", torch.cuda.max_memory_allocated(), global_step=phase_type_idx, ) loss_avg = sum(task.losses) / batches - loss_key = "Losses/{phase_type}".format(phase_type=task.phase_type) + loss_key = f"Losses/{phase_type}" self.tb_writer.add_scalar(loss_key, loss_avg, global_step=phase_type_idx) # plot meters which return a dict diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 1a894e3d3..d95cb7826 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -1296,7 +1296,7 @@ def on_phase_start(self): self.phase_start_time_train = time.perf_counter() def on_phase_end(self): - self.log_phase_end("train") + self.log_phase_end(self.phase_type) if self.train: self.optimizer.on_epoch(where=self.where) @@ -1315,7 +1315,7 @@ def on_phase_end(self): hook.on_phase_end(self) self.perf_log = [] - self.log_phase_end("total") + self.log_phase_end(f"{self.phase_type}_total") if hasattr(self.datasets[self.phase_type], "on_phase_end"): self.datasets[self.phase_type].on_phase_end() @@ -1325,12 +1325,9 @@ def on_end(self): hook.on_end(self) def log_phase_end(self, tag): - if not self.train: - return - start_time = ( self.phase_start_time_train - if tag == "train" + if tag == self.phase_type else self.phase_start_time_total ) phase_duration = time.perf_counter() - start_time @@ -1341,7 +1338,6 @@ def log_phase_end(self, tag): { "tag": tag, "phase_idx": self.train_phase_idx, - "epoch_duration": phase_duration, "im_per_sec": im_per_sec, } )