From b155768dd4878922b232a5d9a978f8f1342114ef Mon Sep 17 00:00:00 2001 From: James Fulton Date: Wed, 4 Sep 2024 12:01:08 +0000 Subject: [PATCH 1/5] Make SSIM work with NaNs --- src/cloudcasting/metrics.py | 68 ++++++++++++++++++++++++------------- tests/test_metrics.py | 4 +-- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/src/cloudcasting/metrics.py b/src/cloudcasting/metrics.py index 4d23c3e..a4494d7 100644 --- a/src/cloudcasting/metrics.py +++ b/src/cloudcasting/metrics.py @@ -1,12 +1,19 @@ """Metrics for model output evaluation""" import numpy as np -from skimage.metrics import structural_similarity # type: ignore[import-not-found] +from jaxtyping import Float +from skimage.metrics import structural_similarity +from torch import Tensor -from cloudcasting.types import BatchOutputArray, OutputArray, SampleOutputArray, TimeArray +# Type aliases for clarity + reuse +Array = np.ndarray | Tensor # type: ignore[type-arg] +SingleArray = Float[Array, "channels time height width"] +BatchArray = Float[Array, "batch channels time height width"] +InputArray = SingleArray | BatchArray +TimeArray = Float[Array, "time"] -def mae_single(input: SampleOutputArray, target: SampleOutputArray) -> TimeArray: +def mae_single(input: SingleArray, target: SingleArray) -> TimeArray: """Mean absolute error for single (non-batched) image sequences. Args: @@ -21,7 +28,7 @@ def mae_single(input: SampleOutputArray, target: SampleOutputArray) -> TimeArray return arr -def mae_batch(input: BatchOutputArray, target: BatchOutputArray) -> TimeArray: +def mae_batch(input: BatchArray, target: BatchArray) -> TimeArray: """Mean absolute error for batched image sequences. Args: @@ -36,7 +43,7 @@ def mae_batch(input: BatchOutputArray, target: BatchOutputArray) -> TimeArray: return arr -def mse_single(input: SampleOutputArray, target: SampleOutputArray) -> TimeArray: +def mse_single(input: SingleArray, target: SingleArray) -> TimeArray: """Mean squared error for single (non-batched) image sequences. Args: @@ -51,7 +58,7 @@ def mse_single(input: SampleOutputArray, target: SampleOutputArray) -> TimeArray return arr -def mse_batch(input: BatchOutputArray, target: BatchOutputArray) -> TimeArray: +def mse_batch(input: BatchArray, target: BatchArray) -> TimeArray: """Mean squared error for batched image sequences. Args: @@ -66,48 +73,63 @@ def mse_batch(input: BatchOutputArray, target: BatchOutputArray) -> TimeArray: return arr -def ssim_single( - input: SampleOutputArray, target: SampleOutputArray, win_size: int | None = None -) -> TimeArray: - """Structural similarity for single (non-batched) image sequences. +def ssim_single(input: SingleArray, target: SingleArray) -> TimeArray: + """Computes the Structural Similarity (SSIM) index for single (non-batched) image sequences. Args: input: Array of shape [channels, time, height, width] target: Array of shape [channels, time, height, width] - win_size: Side-length of the sliding window for comparison (must be odd) Returns: Array of SSIM values along the time dimension + + References: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). + Image quality assessment: From error visibility to structural similarity. + IEEE Transactions on Image Processing, 13, 600-612. + https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, + DOI: 10.1109/TIP.2003.819861 """ + # This function assumes the data will be in the range 0-1 and will give invalid results if not _check_input_target_ranges(input, target) + + # The following param setting match Wang et. al. 2004 + gaussian_weights = True + use_sample_covariance = False + sigma = 1.5 + win_size = 11 + ssim_seq = [] for i_t in range(input.shape[1]): - # Calculate the SSIM array for this time step _, ssim_array = structural_similarity( - input[:, i_t, :, :], - target[:, i_t, :, :], + input[:, i_t], + target[:, i_t], data_range=1, channel_axis=0, full=True, + gaussian_weights=gaussian_weights, + use_sample_covariance=use_sample_covariance, + sigma=sigma, win_size=win_size, - ) - # Take the mean of the SSIM array over channels, height, and width - ssim_seq.append(np.nanmean(ssim_array, axis=(0, 1, 2))) + ) # type: ignore[no-untyped-call] + + # To avoid edge effects from the Gaussian filter we trim off the border + trim_width = (win_size - 1) // 2 + ssim_array = ssim_array[:, trim_width:-trim_width, trim_width:-trim_width] + + ssim_seq.append(np.nanmean(ssim_array)) arr: TimeArray = np.stack(ssim_seq, axis=0) return arr -def ssim_batch( - input: BatchOutputArray, target: BatchOutputArray, win_size: int | None = None -) -> TimeArray: +def ssim_batch(input: BatchArray, target: BatchArray) -> TimeArray: """Structural similarity for batched image sequences. Args: input: Array of shape [batch, channels, time, height, width] target: Array of shape [batch, channels, time, height, width] - win_size: Side-length of the sliding window for comparison (must be odd) Returns: Array of SSIM values along the time dimension @@ -117,12 +139,12 @@ def ssim_batch( ssim_samples = [] for i_b in range(input.shape[0]): - ssim_samples.append(ssim_single(input[i_b], target[i_b], win_size=win_size)) + ssim_samples.append(ssim_single(input[i_b], target[i_b])) arr: TimeArray = np.stack(ssim_samples, axis=0).mean(axis=0) return arr -def _check_input_target_ranges(input: OutputArray, target: OutputArray) -> None: +def _check_input_target_ranges(input: InputArray, target: InputArray) -> None: """Validate input and target arrays are within the 0-1 range. Args: diff --git a/tests/test_metrics.py b/tests/test_metrics.py index cfd72b1..54d9b9e 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -88,7 +88,6 @@ def test_calc_mse_batch(zeros_batch, ones_batch): assert (result == 4).all() -@pytest.mark.skip(reason="Currently unstable with NaNs") def test_calc_ssim_sample(zeros_sample, ones_sample, zeros_missing_sample): result = ssim_single(zeros_sample, zeros_sample) np.testing.assert_almost_equal(result, 1, decimal=4) @@ -99,11 +98,10 @@ def test_calc_ssim_sample(zeros_sample, ones_sample, zeros_missing_sample): result = ssim_single(zeros_sample, ones_sample) np.testing.assert_almost_equal(result, 0, decimal=4) - result = ssim_single(zeros_sample, zeros_missing_sample, win_size=3) + result = ssim_single(zeros_sample, zeros_missing_sample) np.testing.assert_almost_equal(result, 1, decimal=4) -@pytest.mark.skip(reason="Currently unstable with NaNs") def test_calc_ssim_batch(zeros_batch, ones_batch): result = ssim_batch(zeros_batch, zeros_batch) np.testing.assert_almost_equal(result, 1, decimal=4) From 0ab82c116e88bda8f2977bb7d2d31668376f9044 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Wed, 4 Sep 2024 12:19:38 +0000 Subject: [PATCH 2/5] fix overwrite --- src/cloudcasting/metrics.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/cloudcasting/metrics.py b/src/cloudcasting/metrics.py index a4494d7..fbac8b6 100644 --- a/src/cloudcasting/metrics.py +++ b/src/cloudcasting/metrics.py @@ -1,19 +1,12 @@ """Metrics for model output evaluation""" import numpy as np -from jaxtyping import Float -from skimage.metrics import structural_similarity -from torch import Tensor +from skimage.metrics import structural_similarity # type: ignore[import-not-found] -# Type aliases for clarity + reuse -Array = np.ndarray | Tensor # type: ignore[type-arg] -SingleArray = Float[Array, "channels time height width"] -BatchArray = Float[Array, "batch channels time height width"] -InputArray = SingleArray | BatchArray -TimeArray = Float[Array, "time"] +from cloudcasting.types import BatchOutputArray, OutputArray, SampleOutputArray, TimeArray -def mae_single(input: SingleArray, target: SingleArray) -> TimeArray: +def mae_single(input: SampleOutputArray, target: SampleOutputArray) -> TimeArray: """Mean absolute error for single (non-batched) image sequences. Args: @@ -28,7 +21,7 @@ def mae_single(input: SingleArray, target: SingleArray) -> TimeArray: return arr -def mae_batch(input: BatchArray, target: BatchArray) -> TimeArray: +def mae_batch(input: BatchOutputArray, target: BatchOutputArray) -> TimeArray: """Mean absolute error for batched image sequences. Args: @@ -43,7 +36,7 @@ def mae_batch(input: BatchArray, target: BatchArray) -> TimeArray: return arr -def mse_single(input: SingleArray, target: SingleArray) -> TimeArray: +def mse_single(input: SampleOutputArray, target: SampleOutputArray) -> TimeArray: """Mean squared error for single (non-batched) image sequences. Args: @@ -58,7 +51,7 @@ def mse_single(input: SingleArray, target: SingleArray) -> TimeArray: return arr -def mse_batch(input: BatchArray, target: BatchArray) -> TimeArray: +def mse_batch(input: BatchOutputArray, target: BatchOutputArray) -> TimeArray: """Mean squared error for batched image sequences. Args: @@ -73,7 +66,7 @@ def mse_batch(input: BatchArray, target: BatchArray) -> TimeArray: return arr -def ssim_single(input: SingleArray, target: SingleArray) -> TimeArray: +def ssim_single(input: SampleOutputArray, target: SampleOutputArray) -> TimeArray: """Computes the Structural Similarity (SSIM) index for single (non-batched) image sequences. Args: @@ -124,12 +117,13 @@ def ssim_single(input: SingleArray, target: SingleArray) -> TimeArray: return arr -def ssim_batch(input: BatchArray, target: BatchArray) -> TimeArray: +def ssim_batch(input: BatchOutputArray, target: BatchOutputArray) -> TimeArray: """Structural similarity for batched image sequences. Args: input: Array of shape [batch, channels, time, height, width] target: Array of shape [batch, channels, time, height, width] + win_size: Side-length of the sliding window for comparison (must be odd) Returns: Array of SSIM values along the time dimension @@ -144,7 +138,7 @@ def ssim_batch(input: BatchArray, target: BatchArray) -> TimeArray: return arr -def _check_input_target_ranges(input: InputArray, target: InputArray) -> None: +def _check_input_target_ranges(input: OutputArray, target: OutputArray) -> None: """Validate input and target arrays are within the 0-1 range. Args: From 17109b70348c870ff1c3884ed3ae9c6a60461fec Mon Sep 17 00:00:00 2001 From: James Fulton Date: Wed, 4 Sep 2024 12:23:02 +0000 Subject: [PATCH 3/5] remove nused 'type: ignore' comment --- src/cloudcasting/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cloudcasting/metrics.py b/src/cloudcasting/metrics.py index fbac8b6..eeb2a90 100644 --- a/src/cloudcasting/metrics.py +++ b/src/cloudcasting/metrics.py @@ -105,7 +105,7 @@ def ssim_single(input: SampleOutputArray, target: SampleOutputArray) -> TimeArra use_sample_covariance=use_sample_covariance, sigma=sigma, win_size=win_size, - ) # type: ignore[no-untyped-call] + ) # To avoid edge effects from the Gaussian filter we trim off the border trim_width = (win_size - 1) // 2 From f027018de37539537d4baab5d0fa6c2388b095fb Mon Sep 17 00:00:00 2001 From: James Fulton Date: Wed, 4 Sep 2024 12:24:28 +0000 Subject: [PATCH 4/5] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b9e845b..9191f99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "cloudcasting" -version = "0.2.0" +version = "0.2.1" authors = [ { name = "cloudcasting Maintainers", email = "nsimpson@turing.ac.uk" }, ] From 6942c2e4c16d74a800cfafe250e01900ec352840 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Wed, 11 Sep 2024 11:26:31 +0100 Subject: [PATCH 5/5] formatting --- src/cloudcasting/metrics.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/cloudcasting/metrics.py b/src/cloudcasting/metrics.py index 29814f0..c846890 100644 --- a/src/cloudcasting/metrics.py +++ b/src/cloudcasting/metrics.py @@ -66,9 +66,7 @@ def mse_batch(input: BatchOutputArray, target: BatchOutputArray) -> MetricArray: return arr - -def ssim_single( - input: SampleOutputArray, target: SampleOutputArray) -> MetricArray: +def ssim_single(input: SampleOutputArray, target: SampleOutputArray) -> MetricArray: """Computes the Structural Similarity (SSIM) index for single (non-batched) image sequences. Args: @@ -83,7 +81,7 @@ def ssim_single( Image quality assessment: From error visibility to structural similarity. IEEE Transactions on Image Processing, 13, 600-612. https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, - DOI: 10.1109/TIP.2003.819861 + DOI: 10.1109/TIP.2003.819861 """ # This function assumes the data will be in the range 0-1 and will give invalid results if not