Skip to content

Commit

Permalink
Correcting type
Browse files Browse the repository at this point in the history
  • Loading branch information
IFenton committed Nov 20, 2024
1 parent ca15eb8 commit 727b349
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/cloudcasting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def numpy_validation_collate_fn(
def create_cutout_mask(
mask_size: tuple[int, int, int, int],
image_size: tuple[int, int],
) -> NDArray[np.int8]:
) -> NDArray[np.float64]:
"""Create a mask with a cutout in the center.
Args:
x: x-coordinate of the center of the cutout
Expand All @@ -173,7 +173,7 @@ def create_cutout_mask(
height, width = image_size
min_x, max_x, min_y, max_y = mask_size

mask = np.empty((height, width), dtype=np.int8)
mask = np.empty((height, width), dtype=np.float64)
mask[:] = np.nan
mask[min_y:max_y, min_x:max_x] = 1
return mask
4 changes: 2 additions & 2 deletions src/cloudcasting/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def score_model_on_all_metrics(
batch_limit: int | None = None,
metric_names: tuple[str, ...] | list[str] = ("mae", "mse", "ssim"),
metric_kwargs: dict[str, dict[str, Any]] | None = None,
mask: NDArray[np.int8] = CUTOUT_MASK,
mask: NDArray[np.float64] = CUTOUT_MASK,
) -> tuple[dict[str, MetricArray], list[str]]:
"""Calculate the scoreboard metrics for the given model on the validation dataset.
Expand Down Expand Up @@ -368,7 +368,7 @@ def validate(
batch_size: int = 1,
num_workers: int = 0,
batch_limit: int | None = None,
mask: NDArray[np.int8] = CUTOUT_MASK,
mask: NDArray[np.float64] = CUTOUT_MASK,
) -> None:
"""Run the full validation procedure on the model and log the results to wandb.
Expand Down

0 comments on commit 727b349

Please sign in to comment.