From eaeb38b99c8eeb4e2f6b6f0e9957e709dfec6a89 Mon Sep 17 00:00:00 2001 From: Denis Prokopenko <22414094+denproc@users.noreply.github.com> Date: Tue, 14 Jul 2020 17:25:41 +0300 Subject: [PATCH] Enhancement of SSIM/MS-SSIM and BRISQUE, Refactoring of tests for SSIM/MS-SSIM and BRISQUE (#134) * tests(ssim): fix the randomly failing test Signed-off-by: Sergey Kastryulin * refactoring(ssim): changes to simplify the ssim/ms-ssim * refactoring(test): changes ssim/ms-ssim to check values on real images * refactoring(ssim): minor * refactoring(tests): changes (ms)ssim tests to for better readability * refactoring(tests): fix memory consumption * refactoring(tests): docs * refactoring(tests): fix BRISQUE tests on real images * refactoring(tests): small fix for better utility * refact(ssim/brisque): changes proposed by @snk4tr and @zakajd * refact(tests): Hit 100% coverage for `TVLoss`, 'utls.py'. * docs: groomed to the single format * minor(all): changes after merge * release_commit: v0.5.0 Co-authored-by: Sergey Kastryulin --- piq/__init__.py | 2 +- piq/brisque.py | 231 ++++----- piq/functional/__init__.py | 4 +- piq/functional/colour_conversion.py | 19 +- piq/functional/filters.py | 18 + piq/ssim.py | 389 ++++++-------- piq/utils/common.py | 10 +- tests/test_brisque.py | 95 ++-- tests/test_ssim.py | 751 ++++++++++++---------------- tests/test_tv.py | 28 +- tests/test_utils.py | 9 +- 11 files changed, 658 insertions(+), 898 deletions(-) diff --git a/piq/__init__.py b/piq/__init__.py index ba3b6daa..d2962e13 100644 --- a/piq/__init__.py +++ b/piq/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.1" +__version__ = "0.5.0" from .ssim import ssim, multi_scale_ssim, SSIMLoss, MultiScaleSSIMLoss from .msid import MSID diff --git a/piq/brisque.py b/piq/brisque.py index c768dea4..d1c2a27c 100644 --- a/piq/brisque.py +++ b/piq/brisque.py @@ -13,7 +13,109 @@ from torch.utils.model_zoo import load_url import torch.nn.functional as F from piq.utils import _adjust_dimensions, _validate_input -from piq.functional import rgb2yiq +from piq.functional import rgb2yiq, gaussian_filter + + +def brisque(x: torch.Tensor, + kernel_size: int = 7, kernel_sigma: float = 7 / 6, + data_range: Union[int, float] = 1., reduction: str = 'mean', + interpolation: str = 'nearest') -> torch.Tensor: + r"""Interface of BRISQUE index. + + Args: + x: Batch of images. Required to be 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first. + kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. + kernel_sigma: Sigma of normal distribution. + data_range: Value range of input images (usually 1.0 or 255). + reduction: Reduction over samples in batch: "mean"|"sum"|"none". + interpolation: Interpolation to be used for scaling. + + Returns: + Value of BRISQUE index. + + References: + .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain", + https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf + """ + _validate_input(input_tensors=x, allow_5d=False) + x = _adjust_dimensions(input_tensors=x) + + assert data_range >= x.max(), f'Expected data range greater or equal maximum value, got {data_range} and {x.max()}.' + x = x * 255. / data_range + + if x.size(1) == 3: + x = rgb2yiq(x)[:, :1] + features = [] + num_of_scales = 2 + for _ in range(num_of_scales): + features.append(_natural_scene_statistics(x, kernel_size, kernel_sigma)) + x = F.interpolate(x, size=(x.size(2) // 2, x.size(3) // 2), mode=interpolation) + + features = torch.cat(features, dim=-1) + scaled_features = _scale_features(features) + score = _score_svr(scaled_features) + if reduction == 'none': + return score + + return {'mean': score.mean, + 'sum': score.sum + }[reduction](dim=0) + + +class BRISQUELoss(_Loss): + r"""Creates a criterion that measures the BRISQUE score for input :math:`x`. + :math:`x` is tensor of 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first. + The sum operation still operates over all the elements, and divides by :math:`n`. + The division by :math:`n` can be avoided by setting ``reduction = 'sum'``. + + Args: + kernel_size: By default, the mean and covariance of a pixel is obtained + by convolution with given filter_size. + kernel_sigma: Standard deviation for Gaussian kernel. + data_range: The difference between the maximum and minimum of the pixel value, + i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1. + The pixel value interval of both input and output should remain the same. + reduction: Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``. + interpolation: Interpolation to be used for scaling. + + Shape: + - Input: Required to be 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first. + + Examples:: + >>> loss = BRISQUELoss() + >>> prediction = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> target = torch.rand(3, 3, 256, 256) + >>> output = loss(prediction) + >>> output.backward() + + References: + .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain", + https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf + """ + def __init__(self, kernel_size: int = 7, kernel_sigma: float = 7 / 6, + data_range: Union[int, float] = 1., reduction: str = 'mean', + interpolation: str = 'nearest') -> None: + super().__init__() + self.reduction = reduction + self.kernel_size = kernel_size + self.kernel_sigma = kernel_sigma + self.data_range = data_range + self.interpolation = interpolation + + def forward(self, prediction: torch.Tensor) -> torch.Tensor: + r"""Computation of BRISQUE score as a loss function. + + Args: + prediction: Tensor of prediction of the network. + + Returns: + Value of BRISQUE loss to be minimized. + """ + return brisque(prediction, reduction=self.reduction, kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, data_range=self.data_range, interpolation=self.interpolation) def _ggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -65,24 +167,8 @@ def _aggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch return solution, left_sigma.squeeze(dim=-1), right_sigma.squeeze(dim=-1) -def _gaussian_kernel2d(kernel_size: int = 7, sigma: float = 7 / 6) -> torch.Tensor: - r"""Returns 2D Gaussian kernel N(0,`sigma`) - Args: - kernel_size: Size - sigma: Sigma - Returns: - gaussian_kernel: 2D kernel with shape (kernel_size x kernel_size) - - """ - x = torch.arange(- (kernel_size // 2), kernel_size // 2 + 1).view(1, kernel_size) - y = torch.arange(- (kernel_size // 2), kernel_size // 2 + 1).view(kernel_size, 1) - kernel = torch.exp(-(x * x + y * y) / (2.0 * sigma ** 2)) - kernel = kernel / torch.sum(kernel) - return kernel - - def _natural_scene_statistics(luma: torch.Tensor, kernel_size: int = 7, sigma: float = 7. / 6) -> torch.Tensor: - kernel = _gaussian_kernel2d(kernel_size=kernel_size, sigma=sigma).view(1, 1, kernel_size, kernel_size).to(luma) + kernel = gaussian_filter(size=kernel_size, sigma=sigma).view(1, 1, kernel_size, kernel_size).to(luma) C = 1 mu = F.conv2d(luma, kernel, padding=kernel_size // 2) mu_sq = mu ** 2 @@ -132,9 +218,7 @@ def _scale_features(features: torch.Tensor) -> torch.Tensor: def _rbf_kernel(features: torch.Tensor, sv: torch.Tensor, gamma: float = 0.05) -> torch.Tensor: - features.unsqueeze_(dim=-1) - sv.unsqueeze_(dim=0) - dist = (features - sv).pow(2).sum(dim=1) + dist = (features.unsqueeze(dim=-1) - sv.unsqueeze(dim=0)).pow(2).sum(dim=1) return torch.exp(- dist * gamma) @@ -151,108 +235,3 @@ def _score_svr(features: torch.Tensor) -> torch.Tensor: kernel_features = _rbf_kernel(features=features, sv=sv, gamma=gamma) score = kernel_features @ sv_coef return score - rho - - -def brisque(x: torch.Tensor, - kernel_size: int = 7, kernel_sigma: float = 7 / 6, - data_range: Union[int, float] = 1., reduction: str = 'mean', - interpolation: str = 'nearest') -> torch.Tensor: - r"""Interface of SBRISQUE index. - Args: - x: Batch of images. Required to be 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first. - kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. - kernel_sigma: Sigma of normal distribution. - data_range: Value range of input images (usually 1.0 or 255). - reduction: Reduction over samples in batch: "mean"|"sum"|"none". - interpolation: Interpolation to be used for scaling. - Returns: - Value of BRISQUE index. - References: - .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain", - https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf - """ - _validate_input(input_tensors=x, allow_5d=False) - x = _adjust_dimensions(input_tensors=x) - - assert data_range >= x.max(), f'Expected data range greater or equal maximum value, got {data_range} and {x.max()}.' - x = x * 255. / data_range - - if x.size(1) == 3: - x = rgb2yiq(x)[:, :1] - features = [] - num_of_scales = 2 - for _ in range(num_of_scales): - features.append(_natural_scene_statistics(x, kernel_size, kernel_sigma)) - x = F.interpolate(x, scale_factor=0.5, mode=interpolation) - - features = torch.cat(features, dim=-1) - scaled_features = _scale_features(features) - score = _score_svr(scaled_features) - if reduction == 'none': - return score - - return {'mean': score.mean, - 'sum': score.sum - }[reduction](dim=0) - - -class BRISQUELoss(_Loss): - r"""Creates a criterion that measures the BRISQUE score for input :math:`x`. - - :math:`x` is tensor of 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first. - - The sum operation still operates over all the elements, and divides by :math:`n`. - - The division by :math:`n` can be avoided by setting ``reduction = 'sum'``. - - - Args: - kernel_size: By default, the mean and covariance of a pixel is obtained - by convolution with given filter_size. - kernel_sigma: Standard deviation for Gaussian kernel. - data_range: The difference between the maximum and minimum of the pixel value, - i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1. - The pixel value interval of both input and output should remain the same. - reduction: Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of - elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``. - interpolation: Interpolation to be used for scaling. - - Shape: - - Input: Required to be 2D (H, W), 3D (C,H,W) or 4D (N,C,H,W), channels first. - - Examples:: - - >>> loss = BRISQUELoss() - >>> prediction = torch.rand(3, 3, 256, 256, requires_grad=True) - >>> target = torch.rand(3, 3, 256, 256) - >>> output = loss(prediction) - >>> output.backward() - - References: - .. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain", - https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf - """ - def __init__(self, kernel_size: int = 7, kernel_sigma: float = 7 / 6, - data_range: Union[int, float] = 1., reduction: str = 'mean', - interpolation: str = 'nearest') -> None: - super().__init__() - self.reduction = reduction - self.kernel_size = kernel_size - self.kernel_sigma = kernel_sigma - self.data_range = data_range - self.interpolation = interpolation - - def forward(self, prediction: torch.Tensor) -> torch.Tensor: - r"""Computation of BRISQUE score as a loss function. - - Args: - prediction: Tensor of prediction of the network. - - Returns: - Value of BRISQUE loss to be minimized. - """ - - return brisque(prediction, reduction=self.reduction, kernel_size=self.kernel_size, - kernel_sigma=self.kernel_sigma, data_range=self.data_range, interpolation=self.interpolation) diff --git a/piq/functional/__init__.py b/piq/functional/__init__.py index 844f74d4..64b60e8d 100644 --- a/piq/functional/__init__.py +++ b/piq/functional/__init__.py @@ -1,12 +1,12 @@ from piq.functional.base import ifftshift, get_meshgrid, similarity_map, gradient_map from piq.functional.colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq -from piq.functional.filters import hann_filter, scharr_filter, prewitt_filter +from piq.functional.filters import hann_filter, scharr_filter, prewitt_filter, gaussian_filter from piq.functional.layers import L2Pool2d __all__ = [ 'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map', 'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq', - 'hann_filter', 'scharr_filter', 'prewitt_filter', + 'hann_filter', 'scharr_filter', 'prewitt_filter', 'gaussian_filter', 'L2Pool2d', ] diff --git a/piq/functional/colour_conversion.py b/piq/functional/colour_conversion.py index 0f144de6..d61ef820 100644 --- a/piq/functional/colour_conversion.py +++ b/piq/functional/colour_conversion.py @@ -4,8 +4,7 @@ def rgb2lmn(x: torch.Tensor) -> torch.Tensor: - r""" - Convert a batch of RGB images to a batch of LMN images + r"""Convert a batch of RGB images to a batch of LMN images Args: x: Batch of 4D (N x 3 x H x W) images in RGB colour space. @@ -21,8 +20,7 @@ def rgb2lmn(x: torch.Tensor) -> torch.Tensor: def rgb2xyz(x: torch.Tensor) -> torch.Tensor: - r""" - Convert a batch of RGB images to a batch of XYZ images + r"""Convert a batch of RGB images to a batch of XYZ images Args: x: Batch of 4D (N x 3 x H x W) images in RGB colour space. @@ -43,14 +41,14 @@ def rgb2xyz(x: torch.Tensor) -> torch.Tensor: return x_xyz -def xyz2lab(x: torch.Tensor, illuminant='D50', observer='2') -> torch.Tensor: - r""" - Convert a batch of XYZ images to a batch of LAB images +def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor: + r"""Convert a batch of XYZ images to a batch of LAB images Args: x: Batch of 4D (N x 3 x H x W) images in XYZ colour space. illuminant: {“A”, “D50”, “D55”, “D65”, “D75”, “E”}, optional. The name of the illuminant. observer: {“2”, “10”}, optional. The aperture angle of the observer. + Returns: Batch of 4D (N x 3 x H x W) images in LAB colour space. """ @@ -88,12 +86,12 @@ def xyz2lab(x: torch.Tensor, illuminant='D50', observer='2') -> torch.Tensor: def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor: - r""" - Convert a batch of RGB images to a batch of LAB images + r"""Convert a batch of RGB images to a batch of LAB images Args: x: Batch of 4D (N x 3 x H x W) images in RGB colour space. data_range: dynamic range of the input image. + Returns: Batch of 4D (N x 3 x H x W) images in LAB colour space. """ @@ -101,8 +99,7 @@ def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tenso def rgb2yiq(x: torch.Tensor) -> torch.Tensor: - r""" - Convert a batch of RGB images to a batch of YIQ images + r"""Convert a batch of RGB images to a batch of YIQ images Args: x: Batch of 4D (N x 3 x H x W) images in RGB colour space. diff --git a/piq/functional/filters.py b/piq/functional/filters.py index 98de5cb1..74aa57b1 100644 --- a/piq/functional/filters.py +++ b/piq/functional/filters.py @@ -13,6 +13,24 @@ def hann_filter(kernel_size) -> torch.Tensor: return kernel.view(1, kernel_size, kernel_size) / kernel.sum() +def gaussian_filter(size: int, sigma: float) -> torch.Tensor: + r"""Returns 2D Gaussian kernel N(0,`sigma`^2) + Args: + size: Size of the lernel + sigma: Std of the distribution + Returns: + gaussian_kernel: 2D kernel with shape (1 x kernel_size x kernel_size) + """ + coords = torch.arange(size).to(dtype=torch.float32) + coords -= (size - 1) / 2. + + g = coords ** 2 + g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp() + + g /= g.sum() + return g.unsqueeze(0) + + # Gradient operator kernels def scharr_filter() -> torch.Tensor: r"""Utility function that returns a normalized 3x3 Scharr kernel in X direction diff --git a/piq/ssim.py b/piq/ssim.py index 72c14622..3850e837 100644 --- a/piq/ssim.py +++ b/piq/ssim.py @@ -9,30 +9,35 @@ from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as f +import torch.nn.functional as F from torch.nn.modules.loss import _Loss from piq.utils import _adjust_dimensions, _validate_input +from piq.functional import gaussian_filter def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5, - data_range: Union[int, float] = 255, reduction: str = 'mean', full: bool = False, + data_range: Union[int, float] = 1., reduction: str = 'mean', full: bool = False, k1: float = 0.01, k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Interface of Structural Similarity (SSIM) index. + Args: x: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. y: Batch of images. Required to be 2D (H, W), 3D (C,H,W) 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution. data_range: Value range of input images (usually 1.0 or 255). - reduction: Reduction over samples in batch: "mean"|"sum"|"none" - full: Return sc or not. + reduction: Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + full: Return cs map or not. k1: Algorithm parameter, K1 (small constant, see [1]). k2: Algorithm parameter, K2 (small constant, see [1]). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + Returns: Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned as a tensor of size 2. + References: .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality assessment: From error visibility to @@ -43,20 +48,21 @@ def ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: """ _validate_input(input_tensors=(x, y), allow_5d=True, kernel_size=kernel_size, scale_weights=None) x, y = _adjust_dimensions(input_tensors=(x, y)) - kernel = _fspecial_gauss_1d(kernel_size, kernel_sigma) - kernel = kernel.repeat(x.shape[1], 1, 1, 1) + if isinstance(x, torch.ByteTensor) or isinstance(y, torch.ByteTensor): + x = x.type(torch.float32) + y = y.type(torch.float32) - _compute_ssim = _ssim_complex if x.dim() == 5 else _ssim - ssim_val, cs = _compute_ssim(x=x, y=y, kernel=kernel, data_range=data_range, full=True, k1=k1, k2=k2) + kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y) + _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel + ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2) + ssim_val = ssim_map.mean(1) + cs = cs_map.mean(1) if reduction != 'none': - ssim_val = {'mean': ssim_val.mean, - 'sum': ssim_val.sum - }[reduction](dim=0) - - cs = {'mean': cs.mean, - 'sum': cs.sum - }[reduction](dim=0) + reduction_operation = {'mean': torch.mean, + 'sum': torch.sum} + ssim_val = reduction_operation[reduction](ssim_val, dim=0) + cs = reduction_operation[reduction](cs, dim=0) if full: return ssim_val, cs @@ -89,9 +95,7 @@ class SSIMLoss(_Loss): of :math:`n` elements each. The sum operation still operates over all the elements, and divides by :math:`n`. - The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. - In case of 5D input tensors, complex value is returned as a tensor of size 2. Args: @@ -113,7 +117,6 @@ class SSIMLoss(_Loss): - Target: Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. Examples:: - >>> loss = SSIMLoss() >>> prediction = torch.rand(3, 3, 256, 256, requires_grad=True) >>> target = torch.rand(3, 3, 256, 256) @@ -128,7 +131,7 @@ class SSIMLoss(_Loss): https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, :DOI:`10.1109/TIP.2003.819861` """ - __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction'] + __constants__ = ['kernel_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction'] def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03, reduction: str = 'mean', data_range: Union[int, float] = 1.) -> None: @@ -144,75 +147,54 @@ def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = self.k2 = k2 self.data_range = data_range - # Cash kernel between calls. - self.kernel = _fspecial_gauss_1d(kernel_size, kernel_sigma) - def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: r"""Computation of Structural Similarity (SSIM) index as a loss function. Args: - prediction: Tensor of prediction of the network. - target: Reference tensor. + prediction: Tensor of prediction of the network. Required to be + 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. + target: Reference tensor. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), + channels first. Returns: - Value of SSIM loss to be minimized. 0 <= SSIM loss <= 1. In case of 5D input tensors, + Value of SSIM loss to be minimized, i.e 1 - `ssim`. 0 <= SSIM loss <= 1. In case of 5D input tensors, complex value is returned as a tensor of size 2. """ - _validate_input(input_tensors=(prediction, target), allow_5d=True, - kernel_size=self.kernel_size, scale_weights=None) - prediction, target = _adjust_dimensions(input_tensors=(prediction, target)) - - return self.compute_metric(prediction, target) - - def compute_metric(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - - kernel = self.kernel.repeat(prediction.shape[1], 1, 1, 1) - kernel = kernel.to(device=prediction.device) - _compute_ssim = _ssim_complex if prediction.dim() == 5 else _ssim - ssim_val = _compute_ssim( - x=prediction, - y=target, - kernel=kernel, - data_range=self.data_range, - full=False, - k1=self.k1, - k2=self.k2 - ) - - loss = 1 - ssim_val - - if self.reduction == 'none': - return loss - - return {'mean': loss.mean, - 'sum': loss.sum - }[self.reduction](dim=0) + score = ssim(x=prediction, y=target, kernel_size=self.kernel_size, kernel_sigma=self.kernel_sigma, + data_range=self.data_range, reduction=self.reduction, full=False, k1=self.k1, k2=self.k2) + return torch.ones_like(score) - score def multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, kernel_sigma: float = 1.5, - data_range: Union[int, float] = 255, reduction: str = 'mean', - scale_weights: Optional[Union[Tuple[float], List[float], torch.Tensor]] = None, k1=0.01, - k2=0.03) -> torch.Tensor: + data_range: Union[int, float] = 1., reduction: str = 'mean', + scale_weights: Optional[Union[Tuple[float], List[float], torch.Tensor]] = None, + k1: float = 0.01, k2: float = 0.03) -> torch.Tensor: r""" Interface of Multi-scale Structural Similarity (MS-SSIM) index. + Args: x: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. + The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1. y: Batch of images. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. + The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1. kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. kernel_sigma: Sigma of normal distribution. data_range: Value range of input images (usually 1.0 or 255). - reduction: Reduction over samples in batch: "mean"|"sum"|"none". + reduction: Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, scale_weights: Weights for different scales. If None, default weights from the paper [1] will be used. Default weights: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). k1: Algorithm parameter, K1 (small constant, see [2]). k2: Algorithm parameter, K2 (small constant, see [2]). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + Returns: Value of Multi-scale Structural Similarity (MS-SSIM) index. In case of 5D input tensors, complex value is returned as a tensor of size 2. + References: .. [1] Wang, Z., Simoncelli, E. P., Bovik, A. C. (2003). Multi-scale Structural Similarity for Image Quality Assessment. @@ -228,15 +210,18 @@ def multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, ke """ _validate_input(input_tensors=(x, y), allow_5d=True, kernel_size=kernel_size, scale_weights=scale_weights) x, y = _adjust_dimensions(input_tensors=(x, y)) + if isinstance(x, torch.ByteTensor) or isinstance(y, torch.ByteTensor): + x = x.type(torch.float32) + y = y.type(torch.float32) if scale_weights is None: scale_weights_from_ms_ssim_paper = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] scale_weights = scale_weights_from_ms_ssim_paper - scale_weights_tensor = torch.tensor(scale_weights).to(x.device, dtype=x.dtype) - kernel = _fspecial_gauss_1d(kernel_size, kernel_sigma) - kernel = kernel.repeat(x.shape[1], 1, 1, 1) - + scale_weights_tensor = scale_weights if isinstance(scale_weights, torch.Tensor) else torch.tensor(scale_weights) + scale_weights_tensor = scale_weights_tensor.to(y) + kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y) + _compute_msssim = _multi_scale_ssim_complex if x.dim() == 5 else _multi_scale_ssim msssim_val = _compute_msssim( x=x, @@ -251,9 +236,8 @@ def multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, ke if reduction == 'none': return msssim_val - return {'mean': msssim_val.mean, - 'sum': msssim_val.sum - }[reduction](dim=0) + return {'mean': torch.mean, + 'sum': torch.sum}[reduction](msssim_val, dim=0) class MultiScaleSSIMLoss(_Loss): @@ -280,14 +264,11 @@ class MultiScaleSSIMLoss(_Loss): :math:`x` and :math:`y` are tensors of arbitrary shapes with a total of :math:`n` elements each. - The sum operation still operates over all the elements, and divides by :math:`n`. - The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. - In case of 5D input tensors, complex value is returned as a tensor of size 2. - Args: + Args: kernel_size: By default, the mean and covariance of a pixel is obtained by convolution with given filter_size. kernel_sigma: Standard deviation for Gaussian kernel. @@ -304,13 +285,13 @@ class MultiScaleSSIMLoss(_Loss): i.e., if for image x it holds min(x) = 0 and max(x) = 1, then data_range = 1. The pixel value interval of both input and output should remain the same. - Shape: - Input: Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. + The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1. - Target: Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. + The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1. Examples:: - >>> loss = MultiScaleSSIMLoss() >>> input = torch.rand(3, 3, 256, 256, requires_grad=True) >>> target = torch.rand(3, 3, 256, 256) @@ -330,7 +311,7 @@ class MultiScaleSSIMLoss(_Loss): https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, :DOI:`10.1109/TIP.2003.819861` """ - __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction'] + __constants__ = ['kernel_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction'] def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03, scale_weights: Optional[Union[Tuple[float], List[float], torch.Tensor]] = None, @@ -344,115 +325,70 @@ def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float = if scale_weights is None: scale_weights_from_ms_ssim_paper = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] scale_weights = scale_weights_from_ms_ssim_paper - self.scale_weights = torch.tensor(scale_weights) + self.scale_weights = scale_weights if isinstance(scale_weights, torch.Tensor) else torch.tensor(scale_weights) self.kernel_size = kernel_size self.kernel_sigma = kernel_sigma self.k1 = k1 self.k2 = k2 self.data_range = data_range - # Cash kernel between calls. - self.kernel = _fspecial_gauss_1d(kernel_size, kernel_sigma) - def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: r"""Computation of Multi-scale Structural Similarity (MS-SSIM) index as a loss function. - Args: - prediction: Tensor of prediction of the network. - target: Reference tensor. + prediction: Tensor of prediction of the network. Required to be + 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), channels first. The size of the image + should be (kernel_size - 1) * 2 ** (levels - 1) + 1. + target: Reference tensor. Required to be 2D (H, W), 3D (C,H,W), 4D (N,C,H,W) or 5D (N,C,H,W,2), + channels first. The size of the image should be (kernel_size - 1) * 2 ** (levels - 1) + 1. Returns: - Value of MS-SSIM loss to be minimized. 0 <= MS-SSIM loss <= 1. In case of 5D tensor, + Value of MS-SSIM loss to be minimized, i.e. 1-`ms_sim`. 0 <= MS-SSIM loss <= 1. In case of 5D tensor, complex value is returned as a tensor of size 2. """ - _validate_input(input_tensors=(prediction, target), allow_5d=True, - kernel_size=self.kernel_size, scale_weights=self.scale_weights) - prediction, target = _adjust_dimensions(input_tensors=(prediction, target)) - - score = self.compute_metric(prediction, target) - return score - - def compute_metric(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - kernel = self.kernel.repeat(prediction.shape[1], 1, 1, 1) - scale_weights_tensor = self.scale_weights.to(device=prediction.device, dtype=prediction.dtype) - - _compute_msssim = _multi_scale_ssim_complex if prediction.dim() == 5 else _multi_scale_ssim - msssim_val = _compute_msssim( - x=prediction, - y=target, - data_range=self.data_range, - kernel=kernel, - scale_weights_tensor=scale_weights_tensor, - k1=self.k1, - k2=self.k2) - - loss = 1 - msssim_val - if self.reduction == 'none': - return loss - - return {'mean': loss.mean, - 'sum': loss.sum - }[self.reduction](dim=0) - - -def _fspecial_gauss_1d(size: int, sigma: float) -> torch.Tensor: - r""" Creates a 1-D gauss kernel. - - Args: - size: The size of gauss kernel. - sigma: Sigma of normal distribution. - - Returns: - 1D Gauss kernel. - """ - coords = torch.arange(size).to(dtype=torch.float) - coords -= size // 2 - g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) - g /= g.sum() + score = multi_scale_ssim(x=prediction, y=target, kernel_size=self.kernel_size, kernel_sigma=self.kernel_sigma, + data_range=self.data_range, reduction=self.reduction, scale_weights=self.scale_weights, + k1=self.k1, k2=self.k2) + return torch.ones_like(score) - score - return g.unsqueeze(0).unsqueeze(0) - -def _ssim_per_channel(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, data_range: Union[float, int] = 255, - k1: float = 0.01, k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: +def _ssim_per_channel(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, + data_range: Union[float, int] = 1., k1: float = 0.01, + k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Calculate Structural Similarity (SSIM) index for X and Y per channel. - Args: - x: Batch of images, (N,C,H,W). - y: Batch of images, (N,C,H,W). - kernel: 1-D gauss kernel. - data_range: Value range of input images (usually 1.0 or 255). - k1: Algorithm parameter, K1 (small constant, see [1]). - k2: Algorithm parameter, K2 (small constant, see [1]). - Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. - + Args: + x: Batch of images, (N,C,H,W). + y: Batch of images, (N,C,H,W). + kernel: 2D Gaussian kernel. + data_range: Value range of input images (usually 1.0 or 255). + k1: Algorithm parameter, K1 (small constant, see [1]). + k2: Algorithm parameter, K2 (small constant, see [1]). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: Full Value of Structural Similarity (SSIM) index. """ - if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-1): + if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2): raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. ' f'Kernel size: {kernel.size()}') c1 = (k1 * data_range) ** 2 c2 = (k2 * data_range) ** 2 - - kernel = kernel.to(x.device, dtype=x.dtype) - - mu1 = _gaussian_filter(x, kernel) - mu2 = _gaussian_filter(y, kernel) + n_channels = x.size(1) + mu1 = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels) + mu2 = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 compensation = 1.0 - sigma1_sq = compensation * (_gaussian_filter(x * x, kernel) - mu1_sq) - sigma2_sq = compensation * (_gaussian_filter(y * y, kernel) - mu2_sq) - sigma12 = compensation * (_gaussian_filter(x * y, kernel) - mu1_mu2) + sigma1_sq = compensation * (F.conv2d(x * x, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_sq) + sigma2_sq = compensation * (F.conv2d(y * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu2_sq) + sigma12 = compensation * (F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2) # Set alpha = beta = gamma = 1. cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2) @@ -460,42 +396,26 @@ def _ssim_per_channel(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, da ssim_val = ssim_map.mean(dim=(-1, -2)) cs = cs_map.mean(dim=(-1, -2)) - return ssim_val, cs -def _ssim(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, data_range: Union[float, int] = 255, - full: bool = False, k1: float = 0.01, k2: float = 0.03) \ - -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Calculate Structural Similarity (SSIM) index for X and Y. +def _multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float], kernel: torch.Tensor, + scale_weights_tensor: torch.Tensor, k1: float, k2: float) -> torch.Tensor: + r"""Calculates Multi scale Structural Similarity (MS-SSIM) index for X and Y. Args: x: Batch of images, (N,C,H,W). y: Batch of images, (N,C,H,W). - kernel: 1-D gauss kernel. data_range: Value range of input images (usually 1.0 or 255). - full: Return sc or not. + kernel: 2D Gaussian kernel. + scale_weights_tensor: Weights for scaled SSIM k1: Algorithm parameter, K1 (small constant, see [1]). k2: Algorithm parameter, K2 (small constant, see [1]). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: - Value of Structural Similarity (SSIM) index. + Value of Multi scale Structural Similarity (MS-SSIM) index. """ - - ssim_map, cs_map = _ssim_per_channel(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2) - - ssim_val = ssim_map.mean(1) - cs = cs_map.mean(1) - - if full: - return ssim_val, cs - - return ssim_val - - -def _multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float], kernel: torch.Tensor, - scale_weights_tensor: torch.Tensor, k1: float, k2: float) -> torch.Tensor: levels = scale_weights_tensor.size(0) min_size = (kernel.size(-1) - 1) * 2 ** (levels - 1) + 1 if x.size(-1) < min_size or x.size(-2) < min_size: @@ -503,14 +423,17 @@ def _multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, f mcs = [] ssim_val = None - for _ in range(levels): + for iteration in range(levels): + if iteration > 0: + padding = (x.shape[2] % 2, x.shape[3] % 2) + x = F.pad(x, pad=[padding[0], 0, padding[1], 0], mode='replicate') + y = F.pad(y, pad=[padding[0], 0, padding[1], 0], mode='replicate') + x = F.avg_pool2d(x, kernel_size=2, padding=0) + y = F.avg_pool2d(y, kernel_size=2, padding=0) + ssim_val, cs = _ssim_per_channel(x, y, kernel=kernel, data_range=data_range, k1=k1, k2=k2) mcs.append(cs) - padding = (x.shape[2] % 2, x.shape[3] % 2) - x = f.avg_pool2d(x, kernel_size=2, padding=padding) - y = f.avg_pool2d(y, kernel_size=2, padding=padding) - # mcs, (level, batch) mcs_ssim = torch.relu(torch.stack(mcs[:-1] + [ssim_val], dim=0)) @@ -520,59 +443,40 @@ def _multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, f return msssim_val -def _gaussian_filter(to_blur: torch.Tensor, window: torch.Tensor) -> torch.Tensor: - r""" Blur input with 1-D kernel. - - Args: - to_blur: A batch of tensors to be blured. - window: 1-D gauss kernel. - - Returns: - A batch of blurred tensors. - """ - _, n_channels, _, _ = to_blur.shape - out = f.conv2d(to_blur, window, stride=1, padding=0, groups=n_channels) - out = f.conv2d(out, window.transpose(2, 3), stride=1, padding=0, groups=n_channels) - return out - - def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, - data_range: Union[float, int] = 255, k1: float = 0.01, + data_range: Union[float, int] = 1., k1: float = 0.01, k2: float = 0.03) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel. - Args: - x: Batch of complex images, (N,C,H,W,2). - y: Batch of complex images, (N,C,H,W,2). - kernel: 1-D gauss kernel. - data_range: Value range of input images (usually 1.0 or 255). - k1: Algorithm parameter, K1 (small constant, see [1]). - k2: Algorithm parameter, K2 (small constant, see [1]). - Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. - + Args: + x: Batch of complex images, (N,C,H,W,2). + y: Batch of complex images, (N,C,H,W,2). + kernel: 2-D gauss kernel. + data_range: Value range of input images (usually 1.0 or 255). + k1: Algorithm parameter, K1 (small constant, see [1]). + k2: Algorithm parameter, K2 (small constant, see [1]). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: Full Value of Complex Structural Similarity (SSIM) index. """ - - if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-1): + n_channels = x.size(1) + if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2): raise ValueError(f'Kernel size can\'t be greater than actual input size. Input size: {x.size()}. ' f'Kernel size: {kernel.size()}') c1 = (k1 * data_range) ** 2 c2 = (k2 * data_range) ** 2 - kernel = kernel.to(x.device, dtype=x.dtype) - x_real = x[..., 0] x_imag = x[..., 1] y_real = y[..., 0] y_imag = y[..., 1] - mu1_real = _gaussian_filter(x_real, kernel) - mu1_imag = _gaussian_filter(x_imag, kernel) - mu2_real = _gaussian_filter(y_real, kernel) - mu2_imag = _gaussian_filter(y_imag, kernel) + mu1_real = F.conv2d(x_real, weight=kernel, stride=1, padding=0, groups=n_channels) + mu1_imag = F.conv2d(x_imag, weight=kernel, stride=1, padding=0, groups=n_channels) + mu2_real = F.conv2d(y_real, weight=kernel, stride=1, padding=0, groups=n_channels) + mu2_imag = F.conv2d(y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) mu1_sq = mu1_real.pow(2) + mu1_imag.pow(2) mu2_sq = mu2_real.pow(2) + mu2_imag.pow(2) @@ -586,15 +490,16 @@ def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Te x_y_real = x_real * y_real - x_imag * y_imag x_y_imag = x_real * y_imag + x_imag * y_real - sigma1_sq = compensation * _gaussian_filter(x_sq, kernel) - mu1_sq - sigma2_sq = compensation * _gaussian_filter(y_sq, kernel) - mu2_sq - sigma12_real = compensation * _gaussian_filter(x_y_real, kernel) - mu1_mu2_real - sigma12_imag = compensation * _gaussian_filter(x_y_imag, kernel) - mu1_mu2_imag + sigma1_sq = F.conv2d(x_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_sq + sigma2_sq = F.conv2d(y_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu2_sq + sigma12_real = F.conv2d(x_y_real, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_real + sigma12_imag = F.conv2d(x_y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_imag sigma12 = torch.stack((sigma12_imag, sigma12_real), dim=-1) mu1_mu2 = torch.stack((mu1_mu2_real, mu1_mu2_imag), dim=-1) # Set alpha = beta = gamma = 1. - cs_map = (sigma12 * 2 + c2) / (sigma1_sq.unsqueeze(-1) + sigma2_sq.unsqueeze(-1) + c2) - ssim_map = ((mu1_mu2 * 2 + c1) / (mu1_sq.unsqueeze(-1) + mu2_sq.unsqueeze(-1) + c1)) * cs_map + cs_map = (sigma12 * 2 + c2 * compensation) / (sigma1_sq.unsqueeze(-1) + sigma2_sq.unsqueeze(-1) + c2 * compensation) + ssim_map = (mu1_mu2 * 2 + c1 * compensation) / (mu1_sq.unsqueeze(-1) + mu2_sq.unsqueeze(-1) + c1 * compensation) + ssim_map = ssim_map * cs_map ssim_val = ssim_map.mean(dim=(-2, -3)) cs = cs_map.mean(dim=(-2, -3)) @@ -602,61 +507,50 @@ def _ssim_per_channel_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Te return ssim_val, cs -def _ssim_complex(x: torch.Tensor, y: torch.Tensor, kernel: torch.Tensor, data_range: Union[float, int] = 255, - full: bool = False, k1: float = 0.01, k2: float = 0.03) \ - -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Calculate Structural Similarity (SSIM) index for Complex X and Y. +def _multi_scale_ssim_complex(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float], + kernel: torch.Tensor, scale_weights_tensor: torch.Tensor, k1: float, + k2: float) -> torch.Tensor: + r"""Calculate Multi scale Structural Similarity (MS-SSIM) index for Complex X and Y. Args: x: Batch of complex images, (N,C,H,W,2). y: Batch of complex images, (N,C,H,W,2). - kernel: 1-D gauss kernel. data_range: Value range of input images (usually 1.0 or 255). - full: Return sc or not. + kernel: 2-D gauss kernel. k1: Algorithm parameter, K1 (small constant, see [1]). k2: Algorithm parameter, K2 (small constant, see [1]). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: - Value of Complex Structural Similarity (SSIM) index. + Value of Complex Multi scale Structural Similarity (MS-SSIM) index. """ - ssim_map, cs_map = _ssim_per_channel_complex(x=x, y=y, kernel=kernel, data_range=data_range, k1=k1, k2=k2) - - ssim_val = ssim_map.mean(1) - cs = cs_map.mean(1) - - if full: - return ssim_val, cs - - return ssim_val - - -def _multi_scale_ssim_complex(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, float], - kernel: torch.Tensor, scale_weights_tensor: torch.Tensor, k1: float, - k2: float) -> torch.Tensor: levels = scale_weights_tensor.size(0) min_size = (kernel.size(-1) - 1) * 2 ** (levels - 1) + 1 if x.size(-2) < min_size or x.size(-3) < min_size: raise ValueError(f'Invalid size of the input images, expected at least {min_size}x{min_size}.') - mcs = [] ssim_val = None - for _ in range(levels): - ssim_val, cs = _ssim_per_channel_complex(x, y, kernel=kernel, data_range=data_range, k1=k1, k2=k2) - + for iteration in range(levels): x_real = x[..., 0] x_imag = x[..., 1] y_real = y[..., 0] y_imag = y[..., 1] - mcs.append(cs) + if iteration > 0: + padding = (x.size(2) % 2, x.size(3) % 2) + x_real = F.pad(x_real, pad=[padding[0], 0, padding[1], 0], mode='replicate') + x_imag = F.pad(x_imag, pad=[padding[0], 0, padding[1], 0], mode='replicate') + y_real = F.pad(y_real, pad=[padding[0], 0, padding[1], 0], mode='replicate') + y_imag = F.pad(y_imag, pad=[padding[0], 0, padding[1], 0], mode='replicate') + + x_real = F.avg_pool2d(x_real, kernel_size=2, padding=0) + x_imag = F.avg_pool2d(x_imag, kernel_size=2, padding=0) + y_real = F.avg_pool2d(y_real, kernel_size=2, padding=0) + y_imag = F.avg_pool2d(y_imag, kernel_size=2, padding=0) + x = torch.stack((x_real, x_imag), dim=-1) + y = torch.stack((y_real, y_imag), dim=-1) - padding = (x.size(2) % 2, x.size(3) % 2) - x_real = f.avg_pool2d(x_real, kernel_size=2, padding=padding) - x_imag = f.avg_pool2d(x_imag, kernel_size=2, padding=padding) - y_real = f.avg_pool2d(y_real, kernel_size=2, padding=padding) - y_imag = f.avg_pool2d(y_imag, kernel_size=2, padding=padding) - x = torch.stack((x_real, x_imag), dim=-1) - y = torch.stack((y_real, y_imag), dim=-1) + ssim_val, cs = _ssim_per_channel_complex(x, y, kernel=kernel, data_range=data_range, k1=k1, k2=k2) + mcs.append(cs) # mcs, (level, batch) mcs_ssim = torch.relu(torch.stack(mcs[:-1] + [ssim_val], dim=0)) @@ -664,7 +558,7 @@ def _multi_scale_ssim_complex(x: torch.Tensor, y: torch.Tensor, data_range: Unio mcs_ssim_real = mcs_ssim[..., 0] mcs_ssim_imag = mcs_ssim[..., 1] mcs_ssim_abs = (mcs_ssim_real.pow(2) + mcs_ssim_imag.pow(2)).sqrt() - mcs_ssim_deg = torch.atan(mcs_ssim_imag / mcs_ssim_real) + mcs_ssim_deg = torch.atan2(mcs_ssim_imag, mcs_ssim_real) mcs_ssim_pow_abs = mcs_ssim_abs ** scale_weights_tensor.view(-1, 1, 1) mcs_ssim_pow_deg = mcs_ssim_deg * scale_weights_tensor.view(-1, 1, 1) @@ -674,5 +568,4 @@ def _multi_scale_ssim_complex(x: torch.Tensor, y: torch.Tensor, data_range: Unio msssim_val_real = msssim_val_abs * torch.cos(msssim_val_deg) msssim_val_imag = msssim_val_abs * torch.sin(msssim_val_deg) msssim_val = torch.stack((msssim_val_real, msssim_val_imag), dim=-1).mean(dim=1) - return msssim_val diff --git a/piq/utils/common.py b/piq/utils/common.py index e5782b80..bcac39e3 100644 --- a/piq/utils/common.py +++ b/piq/utils/common.py @@ -1,5 +1,5 @@ from typing import Optional, Union, Tuple, List - +import warnings import torch @@ -11,6 +11,8 @@ def _adjust_dimensions(input_tensors: Union[torch.Tensor, Tuple[torch.Tensor, to resized_tensors = [] for tensor in input_tensors: + if not isinstance(tensor, torch.FloatTensor): + warnings.warn(f'Expected input tensor {torch.FloatTensor}, got {tensor.type()}.') tmp = tensor.clone() if tmp.dim() == 2: tmp = tmp.unsqueeze(0) @@ -56,8 +58,10 @@ def _validate_input( if scale_weights is not None: assert isinstance(scale_weights, (list, tuple, torch.Tensor)), \ f'Scale weights must be of type list, tuple or torch.Tensor, got {type(scale_weights)}.' - assert (torch.tensor(scale_weights).dim() == 1), \ - f'Scale weights must be one dimensional, got {torch.tensor(scale_weights).dim()}.' + if isinstance(scale_weights, (list, tuple)): + scale_weights = torch.tensor(scale_weights) + assert (scale_weights.dim() == 1), \ + f'Scale weights must be one dimensional, got {scale_weights.dim()}.' def _validate_features(x: torch.Tensor, y: torch.Tensor) -> None: diff --git a/tests/test_brisque.py b/tests/test_brisque.py index 93651cb5..b05de516 100644 --- a/tests/test_brisque.py +++ b/tests/test_brisque.py @@ -3,96 +3,89 @@ from libsvm import svmutil # noqa: F401 from brisque import BRISQUE from piq import brisque, BRISQUELoss +from skimage.io import imread +from typing import Any @pytest.fixture(scope='module') def prediction_grey() -> torch.Tensor: - return torch.rand(3, 1, 256, 256) + return torch.rand(3, 1, 96, 96) @pytest.fixture(scope='module') def prediction_rgb() -> torch.Tensor: - return torch.rand(3, 3, 256, 256) + return torch.rand(3, 3, 96, 96) # ================== Test function: `brisque` ================== -def test_brisque_if_works_with_grey(prediction_grey: torch.Tensor) -> None: - brisque(prediction_grey) +def test_brisque_if_works_with_grey(prediction_grey: torch.Tensor, device: str) -> None: + brisque(prediction_grey.to(device)) -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') -def test_brisque_if_works_with_grey_on_gpu(prediction_grey: torch.Tensor) -> None: - prediction_grey = prediction_grey.cuda() - brisque(prediction_grey) +def test_brisque_if_works_with_rgb(prediction_rgb, device: str) -> None: + brisque(prediction_rgb.to(device)) -def test_brisque_if_works_with_rgb(prediction_rgb) -> None: - brisque(prediction_rgb) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') -def test_brisque_if_works_with_rgb_on_gpu(prediction_rgb) -> None: - prediction_rgb = prediction_rgb.cuda() - brisque(prediction_rgb) - - -def test_brisque_raises_if_wrong_reduction(prediction_grey: torch.Tensor) -> None: +def test_brisque_raises_if_wrong_reduction(prediction_grey: torch.Tensor, device: str) -> None: for mode in ['mean', 'sum', 'none']: - brisque(prediction_grey, reduction=mode) + brisque(prediction_grey.to(device), reduction=mode) for mode in [None, 'n', 2]: with pytest.raises(KeyError): - brisque(prediction_grey, reduction=mode) + brisque(prediction_grey.to(device), reduction=mode) -def test_brisque_values_grey(prediction_grey: torch.Tensor) -> None: - score = brisque(prediction_grey, reduction='none', data_range=1.) - score_baseline = torch.tensor([BRISQUE().get_score((img * 255).type(torch.uint8).squeeze().numpy()) - for img in prediction_grey]) - assert torch.isclose(score, score_baseline, atol=1e-1, rtol=1e-3).all(), f'Expected values to be equal to ' \ - f'baseline prediction.' \ - f'got {score} and {score_baseline}' +def test_brisque_values_grey(device: str) -> None: + img = imread('tests/assets/goldhill.gif') + prediction_grey = torch.tensor(img).unsqueeze(0).unsqueeze(0) + score = brisque(prediction_grey.to(device), reduction='none', data_range=255) + score_baseline = BRISQUE().get_score(img) + assert torch.isclose(score, torch.tensor(score_baseline).to(score), rtol=1e-3), \ + f'Expected values to be equal to baseline prediction, got {score.item()} and {score_baseline}' -def test_brisque_values_rgb(prediction_rgb) -> None: - score = brisque(prediction_rgb, reduction='none', data_range=1.) - score_baseline = [BRISQUE().get_score((img * 255).type(torch.uint8).squeeze().permute(1, 2, 0).numpy()[..., ::-1]) - for img in prediction_rgb] - assert torch.isclose(score, - torch.tensor(score_baseline), - atol=1e-1, rtol=1e-3).all(), f'Expected values to be equal to ' \ - f'baseline prediction.' \ - f'got {score} and {score_baseline}' +def test_brisque_values_rgb(device: str) -> None: + img = imread('tests/assets/I01.BMP') + prediction_rgb = (torch.tensor(img, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)) + score = brisque(prediction_rgb.to(device), reduction='none', data_range=255.) + score_baseline = BRISQUE().get_score(prediction_rgb[0].permute(1, 2, 0).numpy()[..., ::-1]) + assert torch.isclose(score, torch.tensor(score_baseline).to(score), rtol=1e-3), \ + f'Expected values to be equal to baseline prediction, got {score.item()} and {score_baseline}' -def test_brisque_all_zeros_or_ones() -> None: - size = (1, 1, 256, 256) - for tensor in [torch.zeros(size), torch.ones(size)]: - with pytest.raises(AssertionError): - brisque(tensor, reduction='mean') +@pytest.mark.parametrize( + "input,expectation", + [(torch.zeros(2, 1, 96, 96), AssertionError), + (torch.ones(2, 1, 96, 96), AssertionError)], +) +def test_brisque_for_special_cases(input: torch.Tensor, expectation: Any, device: str) -> None: + with pytest.raises(expectation): + brisque(input.to(device), reduction='mean') # ================== Test class: `BRISQUELoss` ================== -def test_brisque_loss_if_works_with_grey(prediction_grey: torch.Tensor) -> None: - prediction_grey_grad = prediction_grey.clone() +def test_brisque_loss_if_works_with_grey(prediction_grey: torch.Tensor, device: str) -> None: + prediction_grey_grad = prediction_grey.clone().to(device) prediction_grey_grad.requires_grad_() loss_value = BRISQUELoss()(prediction_grey_grad) loss_value.backward() - assert prediction_grey_grad.grad is not None, 'Expected non None gradient of leaf variable' + assert torch.isfinite(prediction_grey_grad.grad).all(), f'Expected non None gradient of leaf variable, ' \ + f'got {prediction_grey_grad.grad}' -def test_brisque_loss_if_works_with_rgb(prediction_rgb) -> None: - prediction_rgb_grad = prediction_rgb.clone() +def test_brisque_loss_if_works_with_rgb(prediction_rgb: torch.Tensor, device: str) -> None: + prediction_rgb_grad = prediction_rgb.clone().to(device) prediction_rgb_grad.requires_grad_() loss_value = BRISQUELoss()(prediction_rgb_grad) loss_value.backward() - assert prediction_rgb_grad.grad is not None, 'Expected non None gradient of leaf variable' + assert torch.isfinite(prediction_rgb_grad.grad).all(), 'Expected non None gradient of leaf variable, ' \ + f'got {prediction_rgb_grad.grad}' -def test_brisque_loss_raises_if_wrong_reduction(prediction_grey: torch.Tensor) -> None: +def test_brisque_loss_raises_if_wrong_reduction(prediction_grey: torch.Tensor, device: str) -> None: for mode in ['mean', 'sum', 'none']: - BRISQUELoss(reduction=mode)(prediction_grey) + BRISQUELoss(reduction=mode)(prediction_grey.to(device)) for mode in [None, 'n', 2]: with pytest.raises(KeyError): - BRISQUELoss(reduction=mode)(prediction_grey) + BRISQUELoss(reduction=mode)(prediction_grey.to(device)) diff --git a/tests/test_ssim.py b/tests/test_ssim.py index 6d566fb8..6542c658 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -2,96 +2,103 @@ import itertools import pytest import tensorflow as tf - from piq import SSIMLoss, MultiScaleSSIMLoss, ssim, multi_scale_ssim +from typing import Tuple, List, Any +from skimage.io import imread +from contextlib import contextmanager + + +@contextmanager +def raise_nothing(enter_result=None): + yield enter_result @pytest.fixture(scope='module') def prediction() -> torch.Tensor: - return torch.rand(3, 3, 256, 256) + return torch.rand(3, 3, 161, 161) @pytest.fixture(scope='module') def target() -> torch.Tensor: - return torch.rand(3, 3, 256, 256) + return torch.rand(3, 3, 161, 161) -@pytest.fixture(scope='module') -def prediction_5d() -> torch.Tensor: - return torch.rand(3, 3, 256, 256, 2) +@pytest.fixture(params=[(3, 3, 161, 161), (3, 3, 161, 161, 2)], scope='module') +def prediction_target_4d_5d(request: Any) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.rand(request.param), torch.rand(request.param) -@pytest.fixture(scope='module') -def target_5d() -> torch.Tensor: - return torch.rand(3, 3, 256, 256, 2) +@pytest.fixture(params=[(3, 3, 161, 161), (3, 3, 161, 161, 2)], scope='module') +def ones_zeros_4d_5d(request: Any) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ones(request.param), torch.zeros(request.param) -# ================== Test function: `ssim` ================== -def test_ssim_symmetry(prediction: torch.Tensor, target: torch.Tensor) -> None: - measure = ssim(prediction, target, data_range=1.) - reverse_measure = ssim(target, prediction, data_range=1.) - assert (measure == reverse_measure).all(), f'Expect: SSIM(a, b) == SSIM(b, a), got {measure} != {reverse_measure}' +@pytest.fixture(scope='module') +def test_images() -> List[Tuple[torch.Tensor, torch.Tensor]]: + prediction_grey = torch.tensor(imread('tests/assets/goldhill_jpeg.gif')).unsqueeze(0).unsqueeze(0) + target_grey = torch.tensor(imread('tests/assets/goldhill.gif')).unsqueeze(0).unsqueeze(0) + prediction_rgb = torch.tensor(imread('tests/assets/I01.BMP')).permute(2, 0, 1).unsqueeze(0) + target_rgb = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute(2, 0, 1).unsqueeze(0) + return [(prediction_grey, target_grey), (prediction_rgb, target_rgb)] -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_ssim_symmetry_cuda(prediction: torch.Tensor, target: torch.Tensor) -> None: - prediction = prediction.cuda() - target = target.cuda() - test_ssim_symmetry(prediction=prediction, target=target) +@pytest.fixture(params=[[0.0448, 0.2856, 0.3001, 0.2363, 0.1333], [0.0448, 0.2856, 0.3001]], scope='module') +def scale_weights(request: Any) -> List: + return request.param -def test_ssim_symmetry_5d(prediction_5d: torch.Tensor, target_5d: torch.Tensor) -> None: - test_ssim_symmetry(prediction_5d, target_5d) +# ================== Test function: `ssim` ================== +def test_ssim_symmetry(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) + measure = ssim(prediction, target, data_range=1., reduction='none') + reverse_measure = ssim(target, prediction, data_range=1., reduction='none') + assert torch.allclose(measure, reverse_measure), f'Expect: SSIM(a, b) == SSIM(b, a), ' \ + f'got {measure} != {reverse_measure}' -def test_ssim_measure_is_one_for_equal_tensors(target: torch.Tensor) -> None: +def test_ssim_measure_is_one_for_equal_tensors(target: torch.Tensor, device: str) -> None: + target = target.to(device) prediction = target.clone() - measure = ssim(prediction, target, data_range=1.) - measure -= 1. - assert (measure.abs() <= 1e-6).all(), f'If equal tensors are passed SSIM must be equal to 1 ' \ - f'(considering floating point operation error up to 1 * 10^-6), ' \ - f'got {measure + 1}' - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_ssim_measure_is_one_for_equal_tensors_cuda(target: torch.Tensor) -> None: - target = target.cuda() - test_ssim_measure_is_one_for_equal_tensors(target=target) - - -def test_ssim_measure_is_less_or_equal_to_one() -> None: + measure = ssim(prediction, target, data_range=1., reduction='none') + assert torch.allclose(measure, torch.ones_like(measure)), f'If equal tensors are passed SSIM must be equal to 1 ' \ + f'(considering floating point error up to 1 * 10^-6), '\ + f'got {measure + 1}' + + +@pytest.mark.parametrize( + "reduction,full,expectation", + [('mean', False, raise_nothing()), + ('sum', False, raise_nothing()), + ('none', False, raise_nothing()), + ('none', True, raise_nothing()), + ('reduction', False, pytest.raises(KeyError))] +) +def test_ssim_reduction_and_full(reduction: str, full: bool, expectation: Any, + prediction: torch.Tensor, target: torch.Tensor, device: str) -> None: + prediction = prediction.to(device) + target = target.to(device) + with expectation: + ssim(prediction, target, data_range=1., reduction=reduction, full=full) + + +def test_ssim_measure_is_less_or_equal_to_one(ones_zeros_4d_5d: Tuple[torch.Tensor, torch.Tensor], + device: str) -> None: # Create two maximally different tensors. - ones = torch.ones((3, 3, 256, 256)) - zeros = torch.zeros((3, 3, 256, 256)) - measure = ssim(ones, zeros, data_range=1.) - assert measure <= 1, f'SSIM must be <= 1, got {measure}' - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_ssim_measure_is_less_or_equal_to_one_cuda() -> None: - ones = torch.ones((3, 3, 256, 256)).cuda() - zeros = torch.zeros((3, 3, 256, 256)).cuda() - measure = ssim(ones, zeros, data_range=1.) - assert measure <= 1, f'SSIM must be <= 1, got {measure}' - - -def test_ssim_measure_is_less_or_equal_to_one_5d() -> None: - ones = torch.ones((3, 3, 256, 256, 2)) - zeros = torch.zeros((3, 3, 256, 256, 2)) - measure = ssim(ones, zeros, data_range=1.) + ones = ones_zeros_4d_5d[0].to(device) + zeros = ones_zeros_4d_5d[1].to(device) + measure = ssim(ones, zeros, data_range=1., reduction='none') assert (measure <= 1).all(), f'SSIM must be <= 1, got {measure}' -def test_ssim_raises_if_tensors_have_different_dimensions() -> None: - custom_prediction = torch.rand(256, 256) - with pytest.raises(AssertionError): - ssim(custom_prediction, custom_prediction.unsqueeze(0)) - - -def test_ssim_raises_if_tensors_have_different_shapes(target: torch.Tensor) -> None: - dims = [[3], [2, 3], [255, 256], [255, 256]] - for b, c, h, w in list(itertools.product(*dims)): - wrong_shape_prediction = torch.rand(b, c, h, w) +def test_ssim_raises_if_tensors_have_different_shapes(prediction_target_4d_5d: Tuple[torch.Tensor, + torch.Tensor], device) -> None: + target = prediction_target_4d_5d[1].to(device) + dims = [[3], [2, 3], [161, 162], [161, 162]] + if target.dim() == 5: + dims += [[2, 3]] + for size in list(itertools.product(*dims)): + wrong_shape_prediction = torch.rand(size).to(target) if wrong_shape_prediction.size() == target.size(): try: ssim(wrong_shape_prediction, target) @@ -102,20 +109,6 @@ def test_ssim_raises_if_tensors_have_different_shapes(target: torch.Tensor) -> N ssim(wrong_shape_prediction, target) -def test_ssim_raises_if_tensors_have_different_shapes_5d(target_5d: torch.Tensor) -> None: - dims = [[3], [2, 3], [255, 256], [255, 256], [2, 3]] - for b, c, h, w, d in list(itertools.product(*dims)): - wrong_shape_prediction = torch.rand(b, c, h, w, d) - if wrong_shape_prediction.size() == target_5d.size(): - try: - ssim(wrong_shape_prediction, target_5d) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") - else: - with pytest.raises(AssertionError): - ssim(wrong_shape_prediction, target_5d) - - def test_ssim_raises_if_tensors_have_different_types(target: torch.Tensor) -> None: wrong_type_prediction = list(range(10)) with pytest.raises(AssertionError): @@ -138,124 +131,99 @@ def test_ssim_check_available_dimensions() -> None: custom_target.unsqueeze_(0) -def test_ssim_raises_if_wrong_kernel_size_is_passed(prediction: torch.Tensor, target: torch.Tensor) -> None: - wrong_kernel_sizes = list(range(0, 50, 2)) - for kernel_size in wrong_kernel_sizes: - with pytest.raises(AssertionError): +def test_ssim_check_kernel_size_is_passed(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], + device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) + kernel_sizes = list(range(0, 50)) + for kernel_size in kernel_sizes: + if kernel_size % 2: ssim(prediction, target, kernel_size=kernel_size) - - -def test_ssim_raises_if_kernel_size_greater_than_image() -> None: - right_kernel_sizes = list(range(1, 52, 2)) - for kernel_size in right_kernel_sizes: - wrong_size_prediction = torch.rand(3, 3, kernel_size - 1, kernel_size - 1) - wrong_size_target = torch.rand(3, 3, kernel_size - 1, kernel_size - 1) - with pytest.raises(ValueError): - ssim(wrong_size_prediction, wrong_size_target, kernel_size=kernel_size) - - -def test_ssim_raise_if_wrong_value_is_estimated(prediction: torch.Tensor, target: torch.Tensor) -> None: - piq_ssim = ssim(prediction, target, kernel_size=11, kernel_sigma=1.5, data_range=1., reduction='none') - tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy()) - tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy()) - tf_ssim = torch.tensor(tf.image.ssim(tf_prediction, tf_target, max_val=1.).numpy()) - assert torch.isclose(piq_ssim, tf_ssim, atol=1e-6).all(), \ - f'The estimated value must be equal to tensorflow provided one' \ - f'(considering floating point operation error up to 1 * 10^-6), ' \ - f'got difference {(piq_ssim - tf_ssim).abs()}' + else: + with pytest.raises(AssertionError): + ssim(prediction, target, kernel_size=kernel_size) + + +def test_ssim_raises_if_kernel_size_greater_than_image(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], + device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) + kernel_size = 11 + wrong_size_prediction = prediction[:, :, :kernel_size - 1, :kernel_size - 1] + wrong_size_target = target[:, :, :kernel_size - 1, :kernel_size - 1] + with pytest.raises(ValueError): + ssim(wrong_size_prediction, wrong_size_target, kernel_size=kernel_size) + + +def test_ssim_raise_if_wrong_value_is_estimated(test_images: Tuple[torch.Tensor, torch.Tensor], + device: str) -> None: + for prediction, target in test_images: + piq_ssim = ssim(prediction.to(device), target.to(device), kernel_size=11, kernel_sigma=1.5, data_range=255, + reduction='none') + tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy()) + tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy()) + with tf.device('/CPU'): + tf_ssim = torch.tensor(tf.image.ssim(tf_prediction, tf_target, max_val=255).numpy()).to(piq_ssim) + match_accuracy = 2e-5 + 1e-8 + assert torch.allclose(piq_ssim, tf_ssim, rtol=0, atol=match_accuracy), \ + f'The estimated value must be equal to tensorflow provided one' \ + f'(considering floating point operation error up to {match_accuracy}), ' \ + f'got difference {(piq_ssim - tf_ssim).abs()}' # ================== Test class: `SSIMLoss` ================== -def test_ssim_loss_symmetry(prediction: torch.Tensor, target: torch.Tensor) -> None: +def test_ssim_loss_grad(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) + prediction.requires_grad_(True) + loss = SSIMLoss(data_range=1.)(prediction, target).mean() + loss.backward() + assert torch.isfinite(prediction.grad).all(), f'Expected finite gradient values, got {prediction.grad}' + + +def test_ssim_loss_symmetry(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) loss = SSIMLoss() loss_value = loss(prediction, target) reverse_loss_value = loss(target, prediction) - assert (loss_value == reverse_loss_value).all(), \ - f'Expect: SSIM(a, b) == SSIM(b, a), got {loss_value} != {reverse_loss_value}' - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_ssim_loss_symmetry_cuda(prediction: torch.Tensor, target: torch.Tensor) -> None: - prediction = prediction.cuda() - target = target.cuda() - test_ssim_loss_symmetry(prediction=prediction, target=target) - - -def test_ssim_loss_symmetry_5d(prediction_5d: torch.Tensor, target_5d: torch.Tensor) -> None: - test_ssim_loss_symmetry(prediction_5d, target_5d) + assert torch.allclose(loss_value, reverse_loss_value), \ + f'Expect: SSIMLoss(a, b) == SSIMLoss(b, a), got {loss_value} != {reverse_loss_value}' -def test_ssim_loss_equality(target: torch.Tensor) -> None: +def test_ssim_loss_equality(target: torch.Tensor, device: str) -> None: + target = target.to(device) prediction = target.clone() loss = SSIMLoss()(prediction, target) - assert (loss.abs() <= 1e-6).all(), f'If equal tensors are passed SSIM loss must be equal to 0 ' \ - f'(considering floating point operation error up to 1 * 10^-6), got {loss}' - + assert torch.allclose(loss, torch.zeros_like(loss)), \ + f'If equal tensors are passed SSIM loss must be equal to 0 '\ + f'(considering floating point operation error up to 1 * 10^-6), got {loss}' -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_ssim_loss_equality_cuda(target: torch.Tensor) -> None: - target = target.cuda() - test_ssim_loss_equality(target=target) - -def test_ssim_loss_is_less_or_equal_to_one() -> None: +def test_ssim_loss_is_less_or_equal_to_one(ones_zeros_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: # Create two maximally different tensors. - ones = torch.ones((3, 3, 256, 256)) - zeros = torch.zeros((3, 3, 256, 256)) - loss = SSIMLoss()(ones, zeros) - assert loss <= 1, f'SSIM loss must be <= 1, got {loss}' - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_ssim_loss_is_less_or_equal_to_one_cuda() -> None: - ones = torch.ones((3, 3, 256, 256)).cuda() - zeros = torch.zeros((3, 3, 256, 256)).cuda() - loss = SSIMLoss()(ones, zeros) - assert loss <= 1, f'SSIM loss must be <= 1, got {loss}' - - -def test_ssim_loss_is_less_or_equal_to_one_5d() -> None: - # Create two maximally different tensors. - ones = torch.ones((3, 3, 256, 256, 2)) - zeros = torch.zeros((3, 3, 256, 256, 2)) + ones = ones_zeros_4d_5d[0].to(device) + zeros = ones_zeros_4d_5d[1].to(device) loss = SSIMLoss()(ones, zeros) assert (loss <= 1).all(), f'SSIM loss must be <= 1, got {loss}' -def test_ssim_loss_raises_if_tensors_have_different_dimensions() -> None: - custom_prediction = torch.rand(256, 256) - with pytest.raises(AssertionError): - SSIMLoss()(custom_prediction, custom_prediction.unsqueeze(0)) - - -def test_ssim_loss_raises_if_tensors_have_different_shapes(target: torch.Tensor) -> None: - dims = [[3], [2, 3], [255, 256], [255, 256]] - for b, c, h, w in list(itertools.product(*dims)): - wrong_shape_prediction = torch.rand(b, c, h, w) +def test_ssim_loss_raises_if_tensors_have_different_shapes(prediction_target_4d_5d: Tuple[torch.Tensor, + torch.Tensor], + device) -> None: + target = prediction_target_4d_5d[1].to(device) + dims = [[3], [2, 3], [161, 162], [161, 162]] + if target.dim() == 5: + dims += [[2, 3]] + for size in list(itertools.product(*dims)): + wrong_shape_prediction = torch.rand(size).to(target) if wrong_shape_prediction.size() == target.size(): - try: - SSIMLoss()(wrong_shape_prediction, target) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + SSIMLoss()(wrong_shape_prediction, target) else: with pytest.raises(AssertionError): SSIMLoss()(wrong_shape_prediction, target) -def test_ssim_loss_raises_if_tensors_have_different_shapes_5d(target_5d: torch.Tensor) -> None: - dims = [[3], [2, 3], [255, 256], [255, 256], [2, 3]] - for b, c, h, w, d in list(itertools.product(*dims)): - wrong_shape_prediction = torch.rand(b, c, h, w, d) - if wrong_shape_prediction.size() == target_5d.size(): - try: - SSIMLoss()(wrong_shape_prediction, target_5d) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") - else: - with pytest.raises(AssertionError): - SSIMLoss()(wrong_shape_prediction, target_5d) - - def test_ssim_loss_check_available_dimensions() -> None: custom_prediction = torch.rand(256, 256) custom_target = torch.rand(256, 256) @@ -278,108 +246,82 @@ def test_ssim_loss_raises_if_tensors_have_different_types(target: torch.Tensor) SSIMLoss()(wrong_type_prediction, target) -def test_ssim_loss_raises_if_wrong_kernel_size_is_passed(prediction: torch.Tensor, target: torch.Tensor) -> None: - wrong_kernel_sizes = list(range(0, 50, 2)) - for kernel_size in wrong_kernel_sizes: - with pytest.raises(AssertionError): +def test_ssim_loss_check_kernel_size_is_passed(prediction: torch.Tensor, target: torch.Tensor) -> None: + kernel_sizes = list(range(0, 50)) + for kernel_size in kernel_sizes: + if kernel_size % 2: SSIMLoss(kernel_size=kernel_size)(prediction, target) - - -def test_ssim_loss_raises_if_kernel_size_greater_than_image() -> None: - right_kernel_sizes = list(range(1, 52, 2)) - for kernel_size in right_kernel_sizes: - wrong_size_prediction = torch.rand(3, 3, kernel_size - 1, kernel_size - 1) - wrong_size_target = torch.rand(3, 3, kernel_size - 1, kernel_size - 1) - with pytest.raises(ValueError): - SSIMLoss(kernel_size=kernel_size)(wrong_size_prediction, wrong_size_target) - - -def test_ssim_loss_raise_if_wrong_value_is_estimated(prediction: torch.Tensor, target: torch.Tensor) -> None: - ssim_loss = SSIMLoss(kernel_size=11, kernel_sigma=1.5, data_range=1.)(prediction, target) - tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy()) - tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy()) - tf_ssim = torch.tensor(tf.image.ssim(tf_prediction, tf_target, max_val=1.).numpy()).mean() - assert torch.isclose(ssim_loss, 1 - tf_ssim, atol=1e-6).all(), \ - f'The estimated value must be equal to tensorflow provided one' \ - f'(considering floating point operation error up to 1 * 10^-6), ' \ - f'got difference {(ssim_loss - 1 + tf_ssim).abs()}' + else: + with pytest.raises(AssertionError): + SSIMLoss(kernel_size=kernel_size)(prediction, target) + + +def test_ssim_loss_raises_if_kernel_size_greater_than_image(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], + device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) + kernel_size = 11 + wrong_size_prediction = prediction[:, :, :kernel_size - 1, :kernel_size - 1] + wrong_size_target = target[:, :, :kernel_size - 1, :kernel_size - 1] + with pytest.raises(ValueError): + SSIMLoss(kernel_size=kernel_size)(wrong_size_prediction, wrong_size_target) + + +def test_ssim_loss_raise_if_wrong_value_is_estimated(test_images: Tuple[torch.Tensor, torch.Tensor], + device: str) -> None: + for prediction, target in test_images: + ssim_loss = SSIMLoss(kernel_size=11, kernel_sigma=1.5, data_range=255, reduction='mean')(prediction.to(device), + target.to(device)) + tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy()) + tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy()) + with tf.device('/CPU'): + tf_ssim = torch.tensor(tf.image.ssim(tf_prediction, tf_target, max_val=255).numpy()).mean().to(device) + match_accuracy = 2e-5 + 1e-8 + assert torch.isclose(ssim_loss, 1. - tf_ssim, rtol=0, atol=match_accuracy), \ + f'The estimated value must be equal to tensorflow provided one' \ + f'(considering floating point operation error up to {match_accuracy}), ' \ + f'got difference {(ssim_loss - 1. + tf_ssim).abs()}' # ================== Test function: `multi_scale_ssim` ================== -def test_multi_scale_ssim_symmetry(prediction: torch.Tensor, target: torch.Tensor) -> None: - measure = multi_scale_ssim(prediction, target, data_range=1.) - reverse_measure = multi_scale_ssim(target, prediction, data_range=1.) - assert (measure == reverse_measure).all(), f'Expect: SSIM(a, b) == SSIM(b, a), got {measure} != {reverse_measure}' - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_multi_scale_ssim_symmetry_cuda(prediction: torch.Tensor, target: torch.Tensor) -> None: - prediction = prediction.cuda() - target = target.cuda() - test_multi_scale_ssim_loss_symmetry(prediction=prediction, target=target) - - -def test_multi_scale_ssim_symmetry_5d(prediction_5d: torch.Tensor, target_5d: torch.Tensor) -> None: - measure = multi_scale_ssim(prediction_5d, target_5d, data_range=1., k2=.4) - reverse_measure = multi_scale_ssim(target_5d, prediction_5d, data_range=1., k2=.4) - assert (measure == reverse_measure).all(), f'Expect: SSIM(a, b) == SSIM(b, a), got {measure} != {reverse_measure}' +def test_multi_scale_ssim_symmetry(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) + measure = multi_scale_ssim(prediction, target, data_range=1., reduction='none') + reverse_measure = multi_scale_ssim(target, prediction, data_range=1., reduction='none') + assert torch.allclose(measure, reverse_measure), f'Expect: MS-SSIM(a, b) == MSSSIM(b, a), '\ + f'got {measure} != {reverse_measure}' -def test_multi_scale_ssim_measure_is_one_for_equal_tensors(target: torch.Tensor) -> None: +def test_multi_scale_ssim_measure_is_one_for_equal_tensors(target: torch.Tensor, device: str) -> None: + target = target.to(device) prediction = target.clone() measure = multi_scale_ssim(prediction, target, data_range=1.) - measure -= 1. - assert (measure.abs() <= 1e-6).all(), f'If equal tensors are passed SSIM must be equal to 1 ' \ - f'(considering floating point operation error up to 1 * 10^-6), ' \ - f'got {measure + 1}' + assert torch.allclose(measure, torch.ones_like(measure)), \ + f'If equal tensors are passed MS-SSIM must be equal to 1 ' \ + f'(considering floating point operation error up to 1 * 10^-6), got {measure + 1}' -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_multi_scale_ssim_measure_is_one_for_equal_tensors_cuda(target: torch.Tensor) -> None: - target = target.cuda() - test_multi_scale_ssim_measure_is_one_for_equal_tensors(target=target) - - -def test_multi_scale_ssim_measure_is_less_or_equal_to_one() -> None: - # Create two maximally different tensors. - ones = torch.ones((3, 3, 256, 256)) - zeros = torch.zeros((3, 3, 256, 256)) - measure = multi_scale_ssim(ones, zeros, data_range=1.) - assert measure <= 1, f'SSIM must be <= 1, got {measure}' - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_multi_scale_ssim_measure_is_less_or_equal_to_one_cuda() -> None: - ones = torch.ones((3, 3, 256, 256)).cuda() - zeros = torch.zeros((3, 3, 256, 256)).cuda() - measure = multi_scale_ssim(ones, zeros, data_range=1.) - assert measure <= 1, f'SSIM must be <= 1, got {measure}' - - -def test_multi_scale_ssim_measure_is_less_or_equal_to_one_5d() -> None: +def test_multi_scale_ssim_measure_is_less_or_equal_to_one(ones_zeros_4d_5d: Tuple[torch.Tensor, torch.Tensor], + device: str) -> None: # Create two maximally different tensors. - ones = torch.ones((3, 3, 256, 256, 2)) - zeros = torch.zeros((3, 3, 256, 256, 2)) + ones = ones_zeros_4d_5d[0].to(device) + zeros = ones_zeros_4d_5d[1].to(device) measure = multi_scale_ssim(ones, zeros, data_range=1.) - assert (measure <= 1).all(), f'SSIM must be <= 1, got {measure}' - - -def test_multi_scale_ssim_raises_if_tensors_have_different_dimensions() -> None: - custom_prediction = torch.rand(256, 256) - with pytest.raises(AssertionError): - multi_scale_ssim(custom_prediction, custom_prediction.unsqueeze(0)) - - -def test_multi_scale_ssim_raises_if_tensors_have_different_shapes(prediction: torch.Tensor, - target: torch.Tensor) -> None: - dims = [[3], [2, 3], [255, 256], [255, 256]] - for b, c, h, w in list(itertools.product(*dims)): - wrong_shape_prediction = torch.rand(b, c, h, w) + assert (measure <= 1).all(), f'MS-SSIM must be <= 1, got {measure}' + + +def test_multi_scale_ssim_raises_if_tensors_have_different_shapes(prediction_target_4d_5d: Tuple[torch.Tensor, + torch.Tensor], + device: str) -> None: + target = prediction_target_4d_5d[1].to(device) + dims = [[3], [2, 3], [161, 162], [161, 162]] + if target.dim() == 5: + dims += [[2, 3]] + for size in list(itertools.product(*dims)): + wrong_shape_prediction = torch.rand(size).to(target) if wrong_shape_prediction.size() == target.size(): - try: - multi_scale_ssim(wrong_shape_prediction, target) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + multi_scale_ssim(wrong_shape_prediction, target) else: with pytest.raises(AssertionError): multi_scale_ssim(wrong_shape_prediction, target) @@ -388,24 +330,6 @@ def test_multi_scale_ssim_raises_if_tensors_have_different_shapes(prediction: to multi_scale_ssim(prediction, target, scale_weights=scale_weights) -def test_multi_scale_ssim_raises_if_tensors_have_different_shapes_5d(prediction_5d: torch.Tensor, - target_5d: torch.Tensor) -> None: - dims = [[3], [2, 3], [255, 256], [255, 256], [2, 3]] - for b, c, h, w, d in list(itertools.product(*dims)): - wrong_shape_prediction = torch.rand(b, c, h, w, d) - if wrong_shape_prediction.size() == target_5d.size(): - try: - multi_scale_ssim(wrong_shape_prediction, target_5d) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") - else: - with pytest.raises(AssertionError): - multi_scale_ssim(wrong_shape_prediction, target_5d) - scale_weights = torch.rand(2, 2) - with pytest.raises(AssertionError): - multi_scale_ssim(prediction_5d, target_5d, scale_weights=scale_weights) - - def test_multi_scale_ssim_check_available_dimensions() -> None: custom_prediction = torch.rand(256, 256) custom_target = torch.rand(256, 256) @@ -432,126 +356,96 @@ def test_multi_scale_ssim_raises_if_tensors_have_different_types(prediction: tor multi_scale_ssim(prediction, target, scale_weights=wrong_type_scale_weights) -def test_multi_scale_ssim_raises_if_wrong_kernel_size_is_passed(prediction: torch.Tensor, target: torch.Tensor) -> None: - wrong_kernel_sizes = list(range(0, 50, 2)) - for kernel_size in wrong_kernel_sizes: - with pytest.raises(AssertionError): +def test_multi_scale_ssim_check_kernel_size_is_passed(prediction: torch.Tensor, target: torch.Tensor) -> None: + kernel_sizes = list(range(0, 13)) + for kernel_size in kernel_sizes: + if kernel_size % 2: multi_scale_ssim(prediction, target, kernel_size=kernel_size) - - -def test_multi_scale_ssim_raises_if_kernel_size_greater_than_image() -> None: - right_kernel_sizes = list(range(1, 52, 2)) - for kernel_size in right_kernel_sizes: - wrong_size_prediction = torch.rand(3, 3, kernel_size - 1, kernel_size - 1) - wrong_size_target = torch.rand(3, 3, kernel_size - 1, kernel_size - 1) - with pytest.raises(ValueError): - multi_scale_ssim(wrong_size_prediction, wrong_size_target, kernel_size=kernel_size) - - -def test_multi_scale_ssim_raise_if_wrong_value_is_estimated(prediction: torch.Tensor, target: torch.Tensor) -> None: - piq_ms_ssim = multi_scale_ssim(prediction, target, kernel_size=11, kernel_sigma=1.5, - data_range=1., reduction='none') - tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy()) - tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy()) - tf_ms_ssim = torch.tensor(tf.image.ssim_multiscale(tf_prediction, tf_target, max_val=1.).numpy()) - assert torch.isclose(piq_ms_ssim, tf_ms_ssim, atol=1e-4).all(), \ - f'The estimated value must be equal to tensorflow provided one' \ - f'(considering floating point operation error up to 1 * 10^-4), ' \ - f'got difference {(piq_ms_ssim - tf_ms_ssim).abs()}' - - -def test_multi_scale_ssim_raise_if_wrong_value_is_estimated_custom_weights(prediction: torch.Tensor, - target: torch.Tensor) -> None: - scale_weights = [0.0448, 0.2856, 0.3001] - piq_ms_ssim = multi_scale_ssim(prediction, target, kernel_size=11, kernel_sigma=1.5, - data_range=1., reduction='none', scale_weights=scale_weights) - tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy()) - tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy()) - tf_ms_ssim = torch.tensor(tf.image.ssim_multiscale(tf_prediction, tf_target, max_val=1., - power_factors=scale_weights).numpy()) - assert torch.isclose(piq_ms_ssim, tf_ms_ssim, atol=1e-4).all(), \ - f'The estimated value must be equal to tensorflow provided one' \ - f'(considering floating point operation error up to 1 * 10^-4), ' \ - f'got difference {(piq_ms_ssim - tf_ms_ssim).abs()}' + else: + with pytest.raises(AssertionError): + multi_scale_ssim(prediction, target, kernel_size=kernel_size) + + +def test_ms_ssim_raises_if_kernel_size_greater_than_image(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], + device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) + kernel_size = 11 + levels = 5 + min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1 + wrong_size_prediction = prediction[:, :, :min_size - 1, :min_size - 1] + wrong_size_target = target[:, :, :min_size - 1, :min_size - 1] + with pytest.raises(ValueError): + multi_scale_ssim(wrong_size_prediction, wrong_size_target, kernel_size=kernel_size) + + +def test_multi_scale_ssim_raise_if_wrong_value_is_estimated(test_images: Tuple[torch.Tensor, torch.Tensor], + scale_weights: List, device: str) -> None: + for prediction, target in test_images: + piq_ms_ssim = multi_scale_ssim(prediction.to(device), target.to(device), kernel_size=11, kernel_sigma=1.5, + data_range=255, reduction='none', scale_weights=scale_weights) + tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy()) + tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy()) + with tf.device('/CPU'): + tf_ms_ssim = torch.tensor(tf.image.ssim_multiscale(tf_prediction, tf_target, max_val=255, + power_factors=scale_weights).numpy()).to(device) + number_of_weights = 5. + match_accuracy = number_of_weights * 1e-5 + 1e-8 + assert torch.allclose(piq_ms_ssim, tf_ms_ssim, rtol=0, atol=match_accuracy), \ + f'The estimated value must be equal to tensorflow provided one' \ + f'(considering floating point operation error up to {match_accuracy}), ' \ + f'got difference {(piq_ms_ssim - tf_ms_ssim).abs()}' # ================== Test class: `MultiScaleSSIMLoss` ================== -def test_multi_scale_ssim_loss_symmetry(prediction: torch.Tensor, target: torch.Tensor) -> None: +def test_multi_scale_ssim_loss_grad(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) + prediction.requires_grad_() + loss = MultiScaleSSIMLoss(data_range=1.)(prediction, target).mean() + loss.backward() + assert torch.isfinite(prediction.grad).all(), f'Expected finite gradient values, got {prediction.grad}' + + +def test_multi_scale_ssim_loss_symmetry(prediction_target_4d_5d: Tuple[torch.Tensor, torch.Tensor], + device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) loss = MultiScaleSSIMLoss() loss_value = loss(prediction, target) reverse_loss_value = loss(target, prediction) assert (loss_value == reverse_loss_value).all(), \ - f'Expect: SSIM(a, b) == SSIM(b, a), got {loss_value} != {reverse_loss_value}' - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_multi_scale_ssim_loss_symmetry_cuda(prediction: torch.Tensor, target: torch.Tensor) -> None: - prediction = prediction.cuda() - target = target.cuda() - test_multi_scale_ssim_loss_symmetry(prediction=prediction, target=target) - - -def test_multi_scale_ssim_loss_symmetry_5d(prediction_5d: torch.Tensor, target_5d: torch.Tensor) -> None: - loss = MultiScaleSSIMLoss(k2=.4) - loss_value = loss(prediction_5d, target_5d) - reverse_loss_value = loss(target_5d, prediction_5d) - assert (loss_value == reverse_loss_value).all(), \ - f'Expect: SSIM(a, b) == SSIM(b, a), got {loss_value} != {reverse_loss_value}' + f'Expect: MS-SSIM(a, b) == MS-SSIM(b, a), got {loss_value} != {reverse_loss_value}' -def test_multi_scale_ssim_loss_equality(target: torch.Tensor) -> None: +def test_multi_scale_ssim_loss_equality(target: torch.Tensor, device: str) -> None: + target = target.to(device) prediction = target.clone() loss = MultiScaleSSIMLoss()(prediction, target) assert (loss.abs() <= 1e-6).all(), f'If equal tensors are passed SSIM loss must be equal to 0 ' \ f'(considering floating point operation error up to 1 * 10^-6), got {loss}' -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_multi_scale_ssim_loss_equality_cuda(target: torch.Tensor) -> None: - target = target.cuda() - test_multi_scale_ssim_loss_equality(target=target) - - -def test_multi_scale_ssim_loss_is_less_or_equal_to_one() -> None: - # Create two maximally different tensors. - ones = torch.ones((3, 3, 256, 256)) - zeros = torch.zeros((3, 3, 256, 256)) - loss = MultiScaleSSIMLoss()(ones, zeros) - assert loss <= 1, f'SSIM loss must be <= 1, got {loss}' - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.') -def test_multi_scale_ssim_loss_is_less_or_equal_to_one_cuda() -> None: - ones = torch.ones((3, 3, 256, 256)).cuda() - zeros = torch.zeros((3, 3, 256, 256)).cuda() - loss = MultiScaleSSIMLoss()(ones, zeros) - assert loss <= 1, f'SSIM loss must be <= 1, got {loss}' - - -def test_multi_scale_ssim_loss_is_less_or_equal_to_one_5d() -> None: +def test_multi_scale_ssim_loss_is_less_or_equal_to_one(ones_zeros_4d_5d: Tuple[torch.Tensor, torch.Tensor], + device: str) -> None: # Create two maximally different tensors. - ones = torch.ones((3, 3, 256, 256, 2)) - zeros = torch.zeros((3, 3, 256, 256, 2)) + ones = ones_zeros_4d_5d[0].to(device) + zeros = ones_zeros_4d_5d[1].to(device) loss = MultiScaleSSIMLoss()(ones, zeros) - assert (loss <= 1).all(), f'SSIM loss must be <= 1, got {loss}' - - -def test_multi_scale_ssim_loss_raises_if_tensors_have_different_dimensions() -> None: - custom_prediction = torch.rand(256, 256) - with pytest.raises(AssertionError): - MultiScaleSSIMLoss()(custom_prediction, custom_prediction.unsqueeze(0)) - - -def test_multi_scale_ssim_loss_raises_if_tensors_have_different_shapes(prediction: torch.Tensor, - target: torch.Tensor) -> None: - dims = [[3], [2, 3], [255, 256], [255, 256]] - for b, c, h, w in list(itertools.product(*dims)): - wrong_shape_prediction = torch.rand(b, c, h, w) + assert (loss <= 1).all(), f'MS-SSIM loss must be <= 1, got {loss}' + + +def test_multi_scale_ssim_loss_raises_if_tensors_have_different_shapes(prediction_target_4d_5d: Tuple[torch.Tensor, + torch.Tensor], + device: str) -> None: + target = prediction_target_4d_5d[1].to(device) + dims = [[3], [2, 3], [161, 162], [161, 162]] + if target.dim() == 5: + dims += [[2, 3]] + for size in list(itertools.product(*dims)): + wrong_shape_prediction = torch.rand(size).to(target) if wrong_shape_prediction.size() == target.size(): - try: - MultiScaleSSIMLoss()(wrong_shape_prediction, target) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + MultiScaleSSIMLoss()(wrong_shape_prediction, target) else: with pytest.raises(AssertionError): MultiScaleSSIMLoss()(wrong_shape_prediction, target) @@ -560,24 +454,6 @@ def test_multi_scale_ssim_loss_raises_if_tensors_have_different_shapes(predictio MultiScaleSSIMLoss(scale_weights=scale_weights)(prediction, target) -def test_multi_scale_ssim_loss_raises_if_tensors_have_different_shapes_5d(prediction_5d: torch.Tensor, - target_5d: torch.Tensor) -> None: - dims = [[3], [2, 3], [255, 256], [255, 256], [2, 3]] - for b, c, h, w, d in list(itertools.product(*dims)): - wrong_shape_prediction = torch.rand(b, c, h, w, d) - if wrong_shape_prediction.size() == target_5d.size(): - try: - MultiScaleSSIMLoss()(wrong_shape_prediction, target_5d) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") - else: - with pytest.raises(AssertionError): - MultiScaleSSIMLoss()(wrong_shape_prediction, target_5d) - scale_weights = torch.rand(2, 2) - with pytest.raises(AssertionError): - MultiScaleSSIMLoss(scale_weights=scale_weights)(prediction_5d, target_5d) - - def test_multi_scale_ssim_loss_check_available_dimensions() -> None: custom_prediction = torch.rand(256, 256) custom_target = torch.rand(256, 256) @@ -606,44 +482,43 @@ def test_multi_scale_ssim_loss_raises_if_tensors_have_different_types(prediction def test_multi_scale_ssim_loss_raises_if_wrong_kernel_size_is_passed(prediction: torch.Tensor, target: torch.Tensor) -> None: - wrong_kernel_sizes = list(range(0, 50, 2)) - for kernel_size in wrong_kernel_sizes: - with pytest.raises(AssertionError): + kernel_sizes = list(range(0, 13)) + for kernel_size in kernel_sizes: + if kernel_size % 2: MultiScaleSSIMLoss(kernel_size=kernel_size)(prediction, target) - - -def test_multi_scale_ssim_loss_raises_if_kernel_size_greater_than_image() -> None: - right_kernel_sizes = list(range(1, 52, 2)) - for kernel_size in right_kernel_sizes: - wrong_size_prediction = torch.rand(3, 3, kernel_size - 1, kernel_size - 1) - wrong_size_target = torch.rand(3, 3, kernel_size - 1, kernel_size - 1) - with pytest.raises(ValueError): - MultiScaleSSIMLoss(kernel_size=kernel_size)(wrong_size_prediction, wrong_size_target) - - -def test_multi_scale_ssim_loss_raise_if_wrong_value_is_estimated(prediction: torch.Tensor, - target: torch.Tensor) -> None: - piq_ms_ssim_loss = MultiScaleSSIMLoss(kernel_size=11, kernel_sigma=1.5, - data_range=1.)(prediction, target) - tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy()) - tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy()) - tf_ms_ssim = torch.tensor(tf.image.ssim_multiscale(tf_prediction, tf_target, max_val=1.).numpy()).mean() - assert torch.isclose(piq_ms_ssim_loss, 1 - tf_ms_ssim, atol=1e-4).all(), \ - f'The estimated value must be equal to tensorflow provided one' \ - f'(considering floating point operation error up to 1 * 10^-4), ' \ - f'got difference {(piq_ms_ssim_loss - 1 + tf_ms_ssim).abs()}' - - -def test_multi_scale_ssim_loss_raise_if_wrong_value_is_estimated_custom_weights(prediction: torch.Tensor, - target: torch.Tensor) -> None: - scale_weights = [0.0448, 0.2856, 0.3001] - piq_ms_ssim_loss = MultiScaleSSIMLoss(kernel_size=11, kernel_sigma=1.5, - data_range=1., scale_weights=scale_weights)(prediction, target) - tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy()) - tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy()) - tf_ms_ssim = torch.tensor(tf.image.ssim_multiscale(tf_prediction, tf_target, max_val=1., - power_factors=scale_weights).numpy()).mean() - assert torch.isclose(piq_ms_ssim_loss, 1 - tf_ms_ssim, atol=1e-4).all(), \ - f'The estimated value must be equal to tensorflow provided one' \ - f'(considering floating point operation error up to 1 * 10^-4), ' \ - f'got difference {(piq_ms_ssim_loss - 1 + tf_ms_ssim).abs()}' + else: + with pytest.raises(AssertionError): + MultiScaleSSIMLoss(kernel_size=kernel_size)(prediction, target) + + +def test_ms_ssim_loss_raises_if_kernel_size_greater_than_image(prediction_target_4d_5d: Tuple[torch.Tensor, + torch.Tensor], + device: str) -> None: + prediction = prediction_target_4d_5d[0].to(device) + target = prediction_target_4d_5d[1].to(device) + kernel_size = 11 + levels = 5 + min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1 + wrong_size_prediction = prediction[:, :, :min_size - 1, :min_size - 1] + wrong_size_target = target[:, :, :min_size - 1, :min_size - 1] + with pytest.raises(ValueError): + MultiScaleSSIMLoss(kernel_size=kernel_size)(wrong_size_prediction, wrong_size_target) + + +def test_multi_scale_ssim_loss_raise_if_wrong_value_is_estimated(test_images: List, scale_weights: List, + device: str) -> None: + for prediction, target in test_images: + piq_loss = MultiScaleSSIMLoss(kernel_size=11, kernel_sigma=1.5, data_range=255, scale_weights=scale_weights) + piq_ms_ssim_loss = piq_loss(prediction.to(device), target.to(device)) + tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy()) + tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy()) + with tf.device('/CPU'): + tf_ms_ssim = torch.tensor(tf.image.ssim_multiscale(tf_prediction, tf_target, + power_factors=scale_weights, + max_val=255).numpy()).mean().to(device) + number_of_weights = len(piq_loss.scale_weights) + match_accuracy = number_of_weights * 1e-5 + 1e-8 + assert torch.isclose(piq_ms_ssim_loss, 1. - tf_ms_ssim, rtol=0, atol=match_accuracy), \ + f'The estimated value must be equal to tensorflow provided one' \ + f'(considering floating point operation error up to {match_accuracy}), ' \ + f'got difference {(piq_ms_ssim_loss - 1. + tf_ms_ssim).abs()}' diff --git a/tests/test_tv.py b/tests/test_tv.py index 15996ac6..8a9bc794 100644 --- a/tests/test_tv.py +++ b/tests/test_tv.py @@ -12,27 +12,21 @@ def prediction() -> torch.Tensor: # ================== Test method: `total_variation` ================== def test_tv_works(prediction: torch.Tensor) -> None: for mode in ['l2', 'l1', 'l2_squared']: - try: - measure = total_variation(prediction, norm_type=mode) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") - assert measure > 0 + measure = total_variation(prediction, norm_type=mode, reduction='none') + assert (measure > 0).all() + with pytest.raises(ValueError): + wrong_mode = 'DEADBEEF' + total_variation(prediction, norm_type=wrong_mode) # ================== Test class: `TVLoss` ================== def test_tv_loss_init() -> None: - try: - TVLoss() - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + TVLoss() def test_tv_loss_greater_than_zero(prediction: torch.Tensor) -> None: for mode in ['l2', 'l1', 'l2_squared']: - try: - res = TVLoss(norm_type=mode)(prediction) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + res = TVLoss(norm_type=mode)(prediction) assert res > 0 @@ -46,10 +40,7 @@ def test_tv_loss_check_available_dimensions() -> None: custom_prediction = torch.rand(256, 256) for _ in range(10): if custom_prediction.dim() < 5: - try: - TVLoss()(custom_prediction) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + TVLoss()(custom_prediction) else: with pytest.raises(AssertionError): TVLoss()(custom_prediction) @@ -59,6 +50,9 @@ def test_tv_loss_check_available_dimensions() -> None: def test_tv_loss_for_known_answer(): # Tensor with `l1` TV = (10 - 1) * 2 * 2 = 36 prediction = torch.eye(10).reshape((1, 1, 10, 10)) + prediction.requires_grad_() loss = TVLoss(norm_type='l1') measure = loss(prediction) + measure.backward() assert measure == 36., f'TV for this tensors must be 36., got {measure}' + assert torch.isfinite(prediction.grad).all(), f'Expected finite gradient values, got {prediction.grad}' diff --git a/tests/test_utils.py b/tests/test_utils.py index 6b915a23..95d05aa2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ import numpy as np -from piq.utils import _validate_input +from piq.utils import _validate_input, _adjust_dimensions @pytest.fixture(scope='module') @@ -129,3 +129,10 @@ def test_breaks_if_scale_weight_wrong_n_dims_provided(tensor_2d: torch.Tensor) - wrong_scale_weights = tensor_2d.clone() with pytest.raises(AssertionError): _validate_input(tensor_2d, allow_5d=False, scale_weights=wrong_scale_weights) + + +# ================== Test function: `_adjust_dimensions` ================== +def test_breaks_if_number_of_dim_greater_five() -> None: + tensor_6d = torch.rand(1, 1, 1, 1, 1, 1) + with pytest.raises(ValueError): + _adjust_dimensions(tensor_6d)