diff --git a/tests/test_training.py b/tests/test_training.py index 6a3f253c..408ae829 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -75,7 +75,8 @@ def test_concat_tasks_with_nans(self): tl(date, n_context, n_target) ) # Changing number of targets task = tl(date, n_context, 42) - task["Y_c"][0][:, 0] = np.nan + task["Y_c"][0][:, 0] = np.nan # Add NaN to context + task["Y_t"][0][:, 0] = np.nan # Add NaN to target tasks.append(task) multiple = 50 @@ -100,7 +101,11 @@ def test_training(self): train_tasks = [] for i in range(n_train_tasks): date = np.random.choice(self.da.time.values) - train_tasks.append(tl(date, 10, 10)) + task = tl(date, 10, 10) + task["Y_c"][0][:, 0] = np.nan # Add NaN to context + task["Y_t"][0][:, 0] = np.nan # Add NaN to target + print(task) + train_tasks.append(task) # Train trainer = Trainer(model, lr=5e-5)