From c9fe57a5db911239ec6fbe95859bcfe66503e234 Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee Date: Mon, 14 Oct 2024 19:17:45 +0530 Subject: [PATCH] Implementing requested changes on GMM --- lightly/transforms/__init__.py | 2 +- .../gaussian_mixture_masks_transform.py | 54 +++++++++++-------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index 2cdb45963..338def43a 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -18,7 +18,7 @@ from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform from lightly.transforms.fast_siam_transform import FastSiamTransform from lightly.transforms.gaussian_blur import GaussianBlur -from lightly.transforms.gaussian_mixture_masks_transform import GaussianMixtureMasks +from lightly.transforms.gaussian_mixture_masks_transform import GaussianMixtureMask from lightly.transforms.irfft2d_transform import IRFFT2DTransform from lightly.transforms.jigsaw import Jigsaw from lightly.transforms.mae_transform import MAETransform diff --git a/lightly/transforms/gaussian_mixture_masks_transform.py b/lightly/transforms/gaussian_mixture_masks_transform.py index 4ab98aeaf..e911ed19f 100644 --- a/lightly/transforms/gaussian_mixture_masks_transform.py +++ b/lightly/transforms/gaussian_mixture_masks_transform.py @@ -8,19 +8,21 @@ from lightly.transforms.rfft2d_transform import RFFT2DTransform -class GaussianMixtureMasks: - """Applies a Gaussian Mixture Mask in the Fourier domain to RGB images. +class GaussianMixtureMask: + """Applies a Gaussian Mixture Mask in the Fourier domain to a single-channel image. The mask is created using random Gaussian kernels, which are applied in the frequency domain via RFFT2D, and then the IRFFT2D is used to return - to the spatial domain. The transformation is applied to each RGB channel separately. + to the spatial domain. The transformation is applied to each image channel separately. Attributes: num_gaussians: Number of Gaussian kernels to generate in the mixture mask. std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians. """ - def __init__(self, num_gaussians: int = 20, std_range: Tuple[int, int] = (10, 15)): + def __init__( + self, num_gaussians: int = 20, std_range: Tuple[float, float] = (10, 15) + ): """Initializes GaussianMixtureMasks with the given parameters. Args: @@ -29,6 +31,7 @@ def __init__(self, num_gaussians: int = 20, std_range: Tuple[int, int] = (10, 15 """ self.rfft2d_transform = RFFT2DTransform() + self.num_gaussians = num_gaussians self.std_range = std_range @@ -38,7 +41,7 @@ def gaussian_kernel( """Generates a 2D Gaussian kernel. Args: - size: Tuple specifying the dimensions of the Gaussian kernel (C, H, W). + size: Tuple specifying the dimensions of the Gaussian kernel (H, W). sigma: Tensor specifying the standard deviation of the Gaussian. center: Tensor specifying the center of the Gaussian kernel. @@ -46,6 +49,8 @@ def gaussian_kernel( Tensor: A 2D Gaussian kernel. """ u, v = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1])) + u = u.to(sigma.device) + v = v.to(sigma.device) u0, v0 = center gaussian = torch.exp( -((u - u0) ** 2 / (2 * sigma[0] ** 2) + (v - v0) ** 2 / (2 * sigma[1] ** 2)) @@ -54,50 +59,53 @@ def gaussian_kernel( return gaussian def apply_gaussian_mixture_mask( - self, image_channel: Tensor, num_gaussians: int, std: Tuple[int, int] + self, freq_image: Tensor, num_gaussians: int, std: Tuple[int, int] ) -> Tensor: - """Applies the Gaussian mixture mask to a single channel in the frequency domain. + """Applies the Gaussian mixture mask to a frequency-domain image. Args: - image_channel: Tensor representing a single channel of the image. + freq_image: Tensor representing the frequency-domain image of shape (C, H, W//2+1). num_gaussians: Number of Gaussian kernels to generate in the mask. std: Tuple specifying the standard deviation range for the Gaussians. Returns: Tensor: Image after applying the Gaussian mixture mask. """ - image_size = image_channel[0].shape + image_size = freq_image.shape[1:] + original_height = image_size[0] + original_width = 2 * (image_size[1] - 1) + + original_shape = (original_height, original_width) - self.irfft2d_transform = IRFFT2DTransform((image_size[0], image_size[1])) - f_transform = self.rfft2d_transform(image_channel) + self.irfft2d_transform = IRFFT2DTransform(original_shape) - size = f_transform[0].shape + size = freq_image[0].shape - mask = torch.ones(size) + mask = freq_image.new_ones(freq_image.shape) for _ in range(num_gaussians): - u0 = torch.randint(0, size[0], (1,)) - v0 = torch.randint(0, size[1], (1,)) - center = torch.tensor((u0, v0)) - sigma = torch.rand(2) * 5 + 10 + u0 = torch.randint(0, size[0], (1,), device=freq_image.device) + v0 = torch.randint(0, size[1], (1,), device=freq_image.device) + center = torch.tensor((u0, v0), device=freq_image.device) + sigma = torch.rand(2, device=freq_image.device) * (std[1] - std[0]) + std[0] g_kernel = self.gaussian_kernel((size[0], size[1]), sigma, center) mask -= g_kernel - filtered_f_transform = f_transform * mask - filtered_image = self.irfft2d_transform(filtered_f_transform).abs() + filtered_freq_image = freq_image * mask + filtered_image = self.irfft2d_transform(filtered_freq_image).abs() return filtered_image - def __call__(self, image_tensor: Tensor) -> Tensor: - """Applies the Gaussian mixture mask transformation to the input image. + def __call__(self, freq_image: Tensor) -> Tensor: + """Applies the Gaussian mixture mask transformation to the input frequency-domain image. Args: - image_tensor: Tensor representing an RGB image of shape (C, H, W). + freq_image: Tensor representing a frequency-domain image of shape (C, H, W//2+1). Returns: Tensor: The transformed image after applying the Gaussian mixture mask. """ transformed_channel: Tensor = self.apply_gaussian_mixture_mask( - image_tensor, self.num_gaussians, self.std_range + freq_image, self.num_gaussians, self.std_range ) return transformed_channel