Skip to content

Commit

Permalink
Add NaNs in Y_t to unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Sep 19, 2023
1 parent 95e709a commit 9207599
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def test_concat_tasks_with_nans(self):
tl(date, n_context, n_target)
) # Changing number of targets
task = tl(date, n_context, 42)
task["Y_c"][0][:, 0] = np.nan
task["Y_c"][0][:, 0] = np.nan # Add NaN to context
task["Y_t"][0][:, 0] = np.nan # Add NaN to target
tasks.append(task)

multiple = 50
Expand All @@ -100,7 +101,11 @@ def test_training(self):
train_tasks = []
for i in range(n_train_tasks):
date = np.random.choice(self.da.time.values)
train_tasks.append(tl(date, 10, 10))
task = tl(date, 10, 10)
task["Y_c"][0][:, 0] = np.nan # Add NaN to context
task["Y_t"][0][:, 0] = np.nan # Add NaN to target
print(task)
train_tasks.append(task)

# Train
trainer = Trainer(model, lr=5e-5)
Expand Down

0 comments on commit 9207599

Please sign in to comment.