From 022147c6b3adb7f68f4035557ca14f8bf7058b8b Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Thu, 24 Mar 2022 21:18:20 +0000 Subject: [PATCH] Remove syncing for MultiLoss --- pytorch_forecasting/metrics.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/metrics.py b/pytorch_forecasting/metrics.py index f3a4b185..ec655cb8 100644 --- a/pytorch_forecasting/metrics.py +++ b/pytorch_forecasting/metrics.py @@ -1,7 +1,7 @@ """ Implementation of metrics for (mulit-horizon) timeseries forecasting. """ -from typing import Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import scipy.stats @@ -11,7 +11,6 @@ import torch.nn.functional as F from torch.nn.utils import rnn from torchmetrics import Metric as LightningMetric -from torchmetrics.metric import CompositionalMetric from pytorch_forecasting.utils import create_mask, unpack_sequence, unsqueeze_like @@ -152,6 +151,9 @@ def _sync_dist(self, dist_sync_fn=None, process_group=None) -> None: # No syncing required here. syncing will be done in metric_a and metric_b pass + def _wrap_compute(self, compute: Callable) -> Callable: + return compute + def reset(self) -> None: self.torchmetric.reset() @@ -340,6 +342,10 @@ def forward(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs): def _wrap_compute(self, compute: Callable) -> Callable: return compute + def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None: + # No syncing required here. syncing will be done in metrics + pass + def reset(self) -> None: for metric in self.metrics: metric.reset()