From a7f6491a5b4940c99d8b0f3631733286867e195f Mon Sep 17 00:00:00 2001 From: hirwa Date: Tue, 10 Dec 2024 00:22:14 +0530 Subject: [PATCH] fix issue when there is mps fallback enabled during training --- pytorch_forecasting/models/base_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 3068f558..e40e7cd7 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -828,6 +828,9 @@ def step( loss = self.loss(prediction, y) else: loss = None + # ensure that loss has require_grad + if loss is not None and loss.device.type == "mps": + loss.requires_grad_(True) self.log( f"{self.current_stage}_loss", loss,