Skip to content

Commit

Permalink
Working out how to handle test_mask in cli
Browse files Browse the repository at this point in the history
  • Loading branch information
IFenton committed Nov 21, 2024
1 parent 727b349 commit 86224bb
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
validate_from_config,
)

test_mask = create_cutout_mask((2, 6, 1, 7), (9, 8))

@pytest.fixture()
def test_mask():
return create_cutout_mask((2, 6, 1, 7), (9, 8))


@pytest.fixture()
Expand All @@ -25,7 +28,7 @@ def model():


@pytest.mark.parametrize("nan_to_num", [True, False])
def test_score_model_on_all_metrics(model, val_sat_zarr_path, nan_to_num):
def test_score_model_on_all_metrics(model, val_sat_zarr_path, test_mask, nan_to_num):
# Create valid dataset
valid_dataset = ValidationSatelliteDataset(
zarr_path=val_sat_zarr_path,
Expand Down Expand Up @@ -84,7 +87,7 @@ def test_calc_mean_metrics():
assert mean_metrics_dict["mse"] == 5


def test_validate(model, val_sat_zarr_path, mocker):
def test_validate(model, val_sat_zarr_path, test_mask, mocker):
# Mock the wandb functions so they aren't run in testing
mocker.patch("wandb.login")
mocker.patch("wandb.init")
Expand All @@ -106,7 +109,7 @@ def test_validate(model, val_sat_zarr_path, mocker):
)


def test_validate_cli(val_sat_zarr_path, mocker):
def test_validate_cli(val_sat_zarr_path, test_mask, mocker):
# write out an example model.py file
with open("model.py", "w") as f:
f.write(
Expand Down Expand Up @@ -139,14 +142,14 @@ def hyperparameters_dict(self):
history_steps: 1
sigma: 0.1
validation:
data_path: {val_sat_zarr_path}
data_path: "{val_sat_zarr_path}"
wandb_project_name: cloudcasting-pytest
wandb_run_name: test_validate
nan_to_num: False
batch_size: 2
num_workers: 0
batch_limit: 4
mask: {test_mask}
mask: "{test_mask}"
"""
)

Expand Down

0 comments on commit 86224bb

Please sign in to comment.