From ae69f3f32d0add5a47db1edefaf8de36129ff58d Mon Sep 17 00:00:00 2001 From: Kasper Hintz Date: Wed, 11 Dec 2024 11:02:28 +0000 Subject: [PATCH] remove artifacts from earlier merging/rebase --- neural_lam/models/ar_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index b63e1c4..1bc706c 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -95,7 +95,6 @@ def __init__( # Store constant per-variable std.-dev. weighting # NOTE that this is the inverse of the multiplicative weighting # in wMSE/wMAE - # TODO: Do we need param_weights for this? self.register_buffer( "per_var_std", self.diff_std / torch.sqrt(self.feature_weights), @@ -262,7 +261,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states): pred_std_list, dim=1 ) # (B, pred_steps, num_grid_nodes, d_f) else: - pred_std = self.diff_std # (d_f,) + pred_std = self.per_var_std # (d_f,) return prediction, pred_std