diff --git a/src/cloudcasting/utils.py b/src/cloudcasting/utils.py index 8a51d05..2abf46c 100644 --- a/src/cloudcasting/utils.py +++ b/src/cloudcasting/utils.py @@ -158,7 +158,7 @@ def numpy_validation_collate_fn( def create_cutout_mask( mask_size: tuple[int, int, int, int], image_size: tuple[int, int], -) -> NDArray[np.int8]: +) -> NDArray[np.float64]: """Create a mask with a cutout in the center. Args: x: x-coordinate of the center of the cutout @@ -173,7 +173,7 @@ def create_cutout_mask( height, width = image_size min_x, max_x, min_y, max_y = mask_size - mask = np.empty((height, width), dtype=np.int8) + mask = np.empty((height, width), dtype=np.float64) mask[:] = np.nan mask[min_y:max_y, min_x:max_x] = 1 return mask diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index fd4e2dd..b859b20 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -192,7 +192,7 @@ def score_model_on_all_metrics( batch_limit: int | None = None, metric_names: tuple[str, ...] | list[str] = ("mae", "mse", "ssim"), metric_kwargs: dict[str, dict[str, Any]] | None = None, - mask: NDArray[np.int8] = CUTOUT_MASK, + mask: NDArray[np.float64] = CUTOUT_MASK, ) -> tuple[dict[str, MetricArray], list[str]]: """Calculate the scoreboard metrics for the given model on the validation dataset. @@ -368,7 +368,7 @@ def validate( batch_size: int = 1, num_workers: int = 0, batch_limit: int | None = None, - mask: NDArray[np.int8] = CUTOUT_MASK, + mask: NDArray[np.float64] = CUTOUT_MASK, ) -> None: """Run the full validation procedure on the model and log the results to wandb.