From 86224bbd8365892e9e2503a86ed5c7de2d846b6a Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Thu, 21 Nov 2024 12:08:17 +0000 Subject: [PATCH] Working out how to handle test_mask in cli --- tests/test_validation.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/test_validation.py b/tests/test_validation.py index 66d3fa5..9342ecb 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -16,7 +16,10 @@ validate_from_config, ) -test_mask = create_cutout_mask((2, 6, 1, 7), (9, 8)) + +@pytest.fixture() +def test_mask(): + return create_cutout_mask((2, 6, 1, 7), (9, 8)) @pytest.fixture() @@ -25,7 +28,7 @@ def model(): @pytest.mark.parametrize("nan_to_num", [True, False]) -def test_score_model_on_all_metrics(model, val_sat_zarr_path, nan_to_num): +def test_score_model_on_all_metrics(model, val_sat_zarr_path, test_mask, nan_to_num): # Create valid dataset valid_dataset = ValidationSatelliteDataset( zarr_path=val_sat_zarr_path, @@ -84,7 +87,7 @@ def test_calc_mean_metrics(): assert mean_metrics_dict["mse"] == 5 -def test_validate(model, val_sat_zarr_path, mocker): +def test_validate(model, val_sat_zarr_path, test_mask, mocker): # Mock the wandb functions so they aren't run in testing mocker.patch("wandb.login") mocker.patch("wandb.init") @@ -106,7 +109,7 @@ def test_validate(model, val_sat_zarr_path, mocker): ) -def test_validate_cli(val_sat_zarr_path, mocker): +def test_validate_cli(val_sat_zarr_path, test_mask, mocker): # write out an example model.py file with open("model.py", "w") as f: f.write( @@ -139,14 +142,14 @@ def hyperparameters_dict(self): history_steps: 1 sigma: 0.1 validation: - data_path: {val_sat_zarr_path} + data_path: "{val_sat_zarr_path}" wandb_project_name: cloudcasting-pytest wandb_run_name: test_validate nan_to_num: False batch_size: 2 num_workers: 0 batch_limit: 4 - mask: {test_mask} + mask: "{test_mask}" """ )