diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index a82196cf..a12f5da6 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -950,6 +950,18 @@ def log_interval(self) -> float: else: return self.hparams.log_val_interval + def _logger_supports(self, method: str) -> bool: + """Whether logger supports method. + + Returns + ------- + supports_method : bool + True if attribute self.logger.experiment.method exists, False otherwise. + """ + if not hasattr(self, "logger") or not hasattr(self.logger, "experiment"): + return False + return hasattr(self.logger.experiment, method) + def log_prediction( self, x: Dict[str, torch.Tensor], out: Dict[str, torch.Tensor], batch_idx: int, **kwargs ) -> None: @@ -976,6 +988,10 @@ def log_prediction( if not mpl_available: return None # don't log matplotlib plots if not available + # Don't log figures if add_figure is not available + if not self._logger_supports("add_figure"): + return None + for idx in log_indices: fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs) tag = f"{self.current_stage} prediction" @@ -1146,7 +1162,8 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None: mpl_available = _check_matplotlib("log_gradient_flow", raise_error=False) - if not mpl_available: + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): return None import matplotlib.pyplot as plt diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index 149b4fbc..6a4b5caf 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -265,7 +265,8 @@ def log_interpretation(self, x, out, batch_idx): """ mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - if not mpl_available: + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): return None label = ["val", "train"][self.training] diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index 68816f22..94f85cfa 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -541,7 +541,8 @@ def log_interpretation(self, x, out, batch_idx): """ mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - if not mpl_available: + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): return None label = ["val", "train"][self.training] diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index cc506612..983b5a72 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -818,7 +818,8 @@ def log_interpretation(self, outputs): mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - if not mpl_available: + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): return None import matplotlib.pyplot as plt @@ -857,6 +858,11 @@ def log_embeddings(self): """ Log embeddings to tensorboard """ + + # Don't log embeddings if add_embedding is not available + if not self._logger_supports("add_embedding"): + return None + for name, emb in self.input_embeddings.items(): labels = self.hparams.embedding_labels[name] self.logger.experiment.add_embedding(