Skip to content

Commit

Permalink
Remove syncing for MultiLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Beitner committed Mar 24, 2022
1 parent 1caf6fe commit 022147c
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pytorch_forecasting/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Implementation of metrics for (mulit-horizon) timeseries forecasting.
"""
from typing import Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings

import scipy.stats
Expand All @@ -11,7 +11,6 @@
import torch.nn.functional as F
from torch.nn.utils import rnn
from torchmetrics import Metric as LightningMetric
from torchmetrics.metric import CompositionalMetric

from pytorch_forecasting.utils import create_mask, unpack_sequence, unsqueeze_like

Expand Down Expand Up @@ -152,6 +151,9 @@ def _sync_dist(self, dist_sync_fn=None, process_group=None) -> None:
# No syncing required here. syncing will be done in metric_a and metric_b
pass

def _wrap_compute(self, compute: Callable) -> Callable:
return compute

def reset(self) -> None:
self.torchmetric.reset()

Expand Down Expand Up @@ -340,6 +342,10 @@ def forward(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs):
def _wrap_compute(self, compute: Callable) -> Callable:
return compute

def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None:
# No syncing required here. syncing will be done in metrics
pass

def reset(self) -> None:
for metric in self.metrics:
metric.reset()
Expand Down

0 comments on commit 022147c

Please sign in to comment.