Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of perceptual metrics #291

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion piq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from .isc import IS, inception_score
from .vif import VIFLoss, vif_p
from .brisque import BRISQUELoss, brisque
from .perceptual import StyleLoss, ContentLoss, LPIPS, DISTS
from .perceptual import StyleLoss, ContentLoss
from .lpips import LPIPS
from .dists import DISTS
from .psnr import psnr
from .fsim import fsim, FSIMLoss
from .vsi import vsi, VSILoss
Expand Down
153 changes: 153 additions & 0 deletions piq/dists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@

from typing import List, Union

import torch
from torch.nn.modules.loss import _Loss
import torchvision

from piq.utils import _validate_input, _reduce
from piq.functional import similarity_map, L2Pool2d
from piq.perceptual import VGG16_LAYERS, IMAGENET_MEAN, IMAGENET_STD


class DISTS(_Loss):
r"""Deep Image Structure and Texture Similarity metric.

By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
If no normalisation is required, change `mean` and `std` values accordingly.

Args:
reduction: Specifies the reduction type:
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
data_range: Maximum value range of images (usually 1.0 or 255).
mean: List of float values used for data standardization. Default: ImageNet mean.
If there is no need to normalize data, use [0., 0., 0.].
std: List of float values used for data standardization. Default: ImageNet std.
If there is no need to normalize data, use [1., 1., 1.].
enable_grad: Enable gradient computation. Default: ``False``

Examples:
>>> loss = DISTS()
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
>>> y = torch.rand(3, 3, 256, 256)
>>> output = loss(x, y)
>>> output.backward()

References:
Keyan Ding, Kede Ma, Shiqi Wang, Eero P. Simoncelli (2020).
Image Quality Assessment: Unifying Structure and Texture Similarity.
https://arxiv.org/abs/2004.07728
https://github.com/dingkeyan93/DISTS
"""
_weights_url = "https://github.com/photosynthesis-team/piq/releases/download/v0.4.1/dists_weights.pt"

def __init__(self, reduction: str = "mean", data_range: Union[int, float] = 1.0,
mean: List[float] = IMAGENET_MEAN, std: List[float] = IMAGENET_STD,
enable_grad: bool = False) -> None:
super().__init__()

dists_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
channels = [3, 64, 128, 256, 512, 512]

weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
dists_weights = list(torch.split(weights['alpha'], channels, dim=1))
dists_weights.extend(torch.split(weights['beta'], channels, dim=1))

self.model = torchvision.models.vgg16(pretrained=True, progress=False).features
self.layers = [VGG16_LAYERS[l] for l in dists_layers]

self.model = self.replace_pooling(self.model)

# Disable gradients
for param in self.model.parameters():
param.requires_grad_(False)

self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in dists_weights]

self.mean = torch.tensor(mean).view(1, -1, 1, 1)
self.std = torch.tensor(std).view(1, -1, 1, 1)
self.reduction = reduction
self.data_range = data_range
self.enable_grad = enable_grad

# normalize_features=False, allow_layers_weights_mismatch=True)

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""
Args:
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A target tensor. Shape :math:`(N, C, H, W)`.

Returns:
Deep Image Structure and Texture Similarity loss, i.e. ``1-DISTS`` in range [0, 1].
"""
_validate_input([x, y], dim_range=(4, 4), data_range=(0, -1))

# Rescale to [0, 1] range
x = x / float(self.data_range)
y = y / float(self.data_range)

# Downsample if necessary
_, _, H, W = x.shape
if min(H, W) > 256:
x = torch.nn.functional.interpolate(
x, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
y = torch.nn.functional.interpolate(
y, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')

_validate_input([x, y], dim_range=(4, 4), data_range=(0, -1))

self.model.to(x)
self.mean, self.std = self.mean.to(x), self.std.to(x)

# Normalize
x, y = (x - self.mean) / self.std, (y - self.mean) / self.std

with torch.autograd.set_grad_enabled(self.enable_grad):
# Add input tensor as an additional feature
x_features, y_features = [x, ], [y, ]
for name, module in self.model._modules.items():
x = module(x)
y = module(y)
if name in self.layers:
x_features.append(x)
y_features.append(y)

# Compute structure similarity between feature maps
EPS = 1e-6 # Small constant for numerical stability

structure_distance, texture_distance = [], []
for x, y in zip(x_features, y_features):
x_mean = x.mean([2, 3], keepdim=True)
y_mean = y.mean([2, 3], keepdim=True)
structure_distance.append(similarity_map(x_mean, y_mean, constant=EPS))

x_var = ((x - x_mean) ** 2).mean([2, 3], keepdim=True)
y_var = ((y - y_mean) ** 2).mean([2, 3], keepdim=True)
xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean
texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS))

distances = structure_distance + texture_distance

# Scale distances, then average in spatial dimensions, then stack and sum in channels dimension
loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1)

return 1 - _reduce(loss, self.reduction)

def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
r"""Turn All MaxPool layers into L2Pool

Args:
module: Module to change MaxPool into L2Pool

Returns:
Module with L2Pool instead of MaxPool
"""
module_output = module
if isinstance(module, torch.nn.MaxPool2d):
module_output = L2Pool2d(kernel_size=3, stride=2, padding=1)

for name, child in module.named_children():
module_output.add_module(name, self.replace_pooling(child))

return module_output
156 changes: 156 additions & 0 deletions piq/lpips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Implementation of Learned Perceptual Image Patch Similarity (LPIPS) metric
References:
Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
IEEE/CVF Conference on Computer Vision and Pattern Recognition
https://arxiv.org/abs/1801.03924
https://github.com/richzhang/PerceptualSimilarity
"""

from typing import List, Union

import torch
import torchvision
import torch.nn as nn
from torch.nn.modules.loss import _Loss

from piq.utils import _validate_input, _reduce
from piq.perceptual import VGG16_LAYERS, IMAGENET_MEAN, IMAGENET_STD, EPS


class LPIPS(_Loss):
r"""Learned Perceptual Image Patch Similarity metric. Only VGG16 learned weights are supported.

By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
If no normalisation is required, change `mean` and `std` values accordingly.

Args:
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
reduction: Specifies the reduction type:
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
data_range: Maximum value range of images (usually 1.0 or 255).
mean: List of float values used for data standardization. Default: ImageNet mean.
If there is no need to normalize data, use [0., 0., 0.].
std: List of float values used for data standardization. Default: ImageNet std.
If there is no need to normalize data, use [1., 1., 1.].
enable_grad: Flag to compute gradients. Useful when LPIPS used as a loss. Default: False.

Examples:
>>> loss = LPIPS()
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
>>> y = torch.rand(3, 3, 256, 256)
>>> output = loss(x, y)
>>> output.backward()

References:
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
A Neural Algorithm of Artistic Style
Association for Research in Vision and Ophthalmology (ARVO)
https://arxiv.org/abs/1508.06576

Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
IEEE/CVF Conference on Computer Vision and Pattern Recognition
https://arxiv.org/abs/1801.03924
https://github.com/richzhang/PerceptualSimilarity
"""
_weights_url = "https://github.com/photosynthesis-team/" + \
"photosynthesis.metrics/releases/download/v0.4.0/lpips_weights.pt"

def __init__(self, replace_pooling: bool = False, distance: str = "mse", reduction: str = "mean",
data_range: Union[int, float] = 1.0, mean: List[float] = IMAGENET_MEAN,
std: List[float] = IMAGENET_STD, enable_grad: bool = False) -> None:
super().__init__()

lpips_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
lpips_weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)

self.model = torchvision.models.vgg16(pretrained=True, progress=False).features
self.layers = [VGG16_LAYERS[l] for l in lpips_layers]

if replace_pooling:
self.model = self.replace_pooling(self.model)

# Disable gradients
for param in self.model.parameters():
param.requires_grad_(False)

self.distance = {
"mse": nn.MSELoss,
"mae": nn.L1Loss,
}[distance](reduction='none')

self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in lpips_weights]

assert len(self.layers) == len(self.weights), \
(f'Lengths of provided layers and weighs mismatch ({len(self.weights)} weights and '
f'{len(self.layers)} layers), which will cause incorrect results. '
f'Please provide weight for each layer.')

self.mean = torch.tensor(mean).view(1, -1, 1, 1)
self.std = torch.tensor(std).view(1, -1, 1, 1)
self.reduction = reduction
self.data_range = data_range
self.enable_grad = enable_grad

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""LPIPS computation between :math:`x` and :math:`y` tensors.

Args:
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A target tensor. Shape :math:`(N, C, H, W)`.

Returns:
LPIPS value between inputs.
"""
_validate_input([x, y], dim_range=(4, 4), data_range=(0, -1))

# Rescale to [0, 1] range
x = x / float(self.data_range)
y = y / float(self.data_range)

self.model.to(x)
self.mean, self.std = self.mean.to(x), self.std.to(x)

# Normalize
x, y = (x - self.mean) / self.std, (y - self.mean) / self.std

x_features, y_features = [], []
with torch.autograd.set_grad_enabled(self.enable_grad):
for name, module in self.model._modules.items():
x = module(x)
y = module(y)
if name in self.layers:
# Normalize feature maps in channel direction to unit length.
x_norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
y_norm_factor = torch.sqrt(torch.sum(y ** 2, dim=1, keepdim=True))

x_features.append(x / (x_norm_factor + EPS))
y_features.append(y / (y_norm_factor + EPS))

distances = [self.distance(x, y) for x, y in zip(x_features, y_features)]

# Scale distances, then average in spatial dimensions, then stack and sum in channels dimension
loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1)

return _reduce(loss, self.reduction)

def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
r"""Turn all MaxPool layers into AveragePool

Args:
module: Module to change MaxPool int AveragePool

Returns:
Module with AveragePool instead MaxPool

"""
module_output = module
if isinstance(module, torch.nn.MaxPool2d):
module_output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

for name, child in module.named_children():
module_output.add_module(name, self.replace_pooling(child))
return module_output
Loading