From 3082e120876368eddb38019b19291aa402ff8429 Mon Sep 17 00:00:00 2001 From: payo101 <35198092+payo101@users.noreply.github.com> Date: Sat, 12 Oct 2024 20:55:50 +0530 Subject: [PATCH 1/7] Added the Amplitude Rescaling Transform --- lightly/transforms/__init__.py | 1 + .../transforms/amplitude_rescale_transform.py | 42 +++++++++++++++++++ .../test_amplitude_rescale_transform.py | 21 ++++++++++ 3 files changed, 64 insertions(+) create mode 100644 lightly/transforms/amplitude_rescale_transform.py create mode 100644 tests/transforms/test_amplitude_rescale_transform.py diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index 949fbe905..efb58da28 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -9,6 +9,7 @@ # All Rights Reserved from lightly.transforms.aim_transform import AIMTransform +from lightly.transforms.amplitude_rescale_transform import AmplitudeRescaleTranform from lightly.transforms.byol_transform import ( BYOLTransform, BYOLView1Transform, diff --git a/lightly/transforms/amplitude_rescale_transform.py b/lightly/transforms/amplitude_rescale_transform.py new file mode 100644 index 000000000..d1dc6a77a --- /dev/null +++ b/lightly/transforms/amplitude_rescale_transform.py @@ -0,0 +1,42 @@ +from typing import Tuple + +import numpy as np +import torch +from torch import Tensor + + +class AmplitudeRescaleTranform: + """ + This transform will rescale the amplitude of the Fourier Spectrum (`input`) of the image and return it. + The scaling value *p* will range within `[m, n)` + ``` + img = torch.randn(3, 64, 64) + + rfft = lightly.transforms.RFFT2DTransform() + rfft_img = rfft(img) + + art = AmplitudeRescaleTransform() + rescaled_img = art(rfft_img) + ``` + + # Intial Arguments + **range**: *Tuple of float_like* + The low `m` and high `n` values such that **p belongs to [m, n)**. + # Parameters: + **input**: _torch.Tensor_ + The 2D Discrete Fourier Tranform of an Image. + # Returns: + **output**:_torch.Tensor_ + The Fourier spectrum of the 2D Image with rescaled Amplitude. + """ + + def __init__(self, range: Tuple[float, float] = (0.8, 1.75)) -> None: + self.m = range[0] + self.n = range[1] + + def __call__(self, input: Tensor) -> Tensor: + p = np.random.uniform(self.m, self.n) + + output = input * p + + return output diff --git a/tests/transforms/test_amplitude_rescale_transform.py b/tests/transforms/test_amplitude_rescale_transform.py new file mode 100644 index 000000000..6ded2746f --- /dev/null +++ b/tests/transforms/test_amplitude_rescale_transform.py @@ -0,0 +1,21 @@ +import numpy as np +import torch + +from lightly.transforms import AmplitudeRescaleTranform, RFFT2DTransform + + +# Testing function image -> FFT -> AmplitudeRescale. +# Compare shapes of source and result. +def test() -> None: + image = torch.randn(3, 64, 64) + + rfftTransform = RFFT2DTransform() + rfft = rfftTransform(image) + + ampRescaleTf_1 = AmplitudeRescaleTranform() + rescaled_rfft_1 = ampRescaleTf_1(rfft) + + ampRescaleTf_2 = AmplitudeRescaleTranform(range=(1.0, 2.0)) + rescaled_rfft_2 = ampRescaleTf_2(rfft) + + assert rescaled_rfft_1.shape == rfft.shape and rescaled_rfft_2.shape == rfft.shape From cf24dfab270772aceccc28ee400636b1e3f1d134 Mon Sep 17 00:00:00 2001 From: Mayur Kawale <122032765+Mefisto04@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:21:55 +0530 Subject: [PATCH 2/7] Update README.md (#1691) --- docs/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/README.md b/docs/README.md index c14c8c5bf..1e8b8b713 100644 --- a/docs/README.md +++ b/docs/README.md @@ -6,7 +6,7 @@ Make sure you installed dev dependencies: pip install -r ../requirements/dev.txt ``` -You may have to set up a clean environment (e.g. with Conda) and use setuptools from the parent directory: +You may need to set up a clean environment (e.g., using Conda) and utilize setuptools from the parent directory: ``` conda create -n lightly python=3.7 conda activate lightly @@ -18,7 +18,7 @@ This isn't handled in requirements because the version you'll need depends on yo [Follow instructions](https://detectron2.readthedocs.io/en/latest/tutorials/install.html) ## Build the Docs -`sphinx` provides a Makefile, so to build the `html` documentation, simply type: +The `sphinx` documentation generator provides a Makefile. To build the `html` documentation, simply execute: ``` make html ``` @@ -28,7 +28,7 @@ To build docs without running python files (tutorials) use make html-noplot ``` -Shortcut to build the docs (with env variables for active-learning tutorial) use: +To create a shortcut for building the documentation with environment variables for the active-learning tutorial, use: ``` LIGHTLY_SERVER_LOCATION='https://api.lightly.ai' LIGHTLY_TOKEN='YOUR_TOKEN' AL_TUTORIAL_DATASET_ID='YOUR_DATASET_ID' make html && python -m http.server 1234 -d build/html ``` @@ -52,8 +52,8 @@ Only Lightly core team members will have access to deploy new docs. We build our code based on the [Google Python Styleguide](). Important notes: -- Always use three double-quotes (`"""`). -- A function must have a docstring, unless it meets all of the following criteria: not externally visible, very short, obvious. +- Always use triple double quotes (`"""`). +- A function must include a docstring unless it meets all the following criteria: it is not externally visible, is very short, and is obvious. - Always use type hints when possible. - Don't overlook the `Raises`. - Use punctuation. @@ -81,7 +81,7 @@ examples. ### Functions -Example: +Example of a function: ```python def fetch_smalltable_rows(table_handle: smalltable.Table, keys: Sequence[Union[bytes, str]], From 15e6475e426011138306cfa46052c517d7352a53 Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee <127598707+snehilchatterjee@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:54:28 +0530 Subject: [PATCH 3/7] Add GaussianMixtureMaskTransform (#1692) --- lightly/transforms/__init__.py | 1 + .../gaussian_mixture_masks_transform.py | 93 +++++++++++++++++++ .../transforms/test_gaussian_mixture_masks.py | 10 ++ 3 files changed, 104 insertions(+) create mode 100644 lightly/transforms/gaussian_mixture_masks_transform.py create mode 100644 tests/transforms/test_gaussian_mixture_masks.py diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index efb58da28..6794c146f 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -19,6 +19,7 @@ from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform from lightly.transforms.fast_siam_transform import FastSiamTransform from lightly.transforms.gaussian_blur import GaussianBlur +from lightly.transforms.gaussian_mixture_masks_transform import GaussianMixtureMask from lightly.transforms.irfft2d_transform import IRFFT2DTransform from lightly.transforms.jigsaw import Jigsaw from lightly.transforms.mae_transform import MAETransform diff --git a/lightly/transforms/gaussian_mixture_masks_transform.py b/lightly/transforms/gaussian_mixture_masks_transform.py new file mode 100644 index 000000000..6d89d41c6 --- /dev/null +++ b/lightly/transforms/gaussian_mixture_masks_transform.py @@ -0,0 +1,93 @@ +from typing import Tuple + +import torch +import torch.fft +from torch import Tensor + + +class GaussianMixtureMask: + """Applies a Gaussian Mixture Mask in the Fourier domain to an image. + + The mask is created using random Gaussian kernels, which are applied in + the frequency domain. + + Attributes: + num_gaussians: Number of Gaussian kernels to generate in the mixture mask. + std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians. + """ + + def __init__( + self, num_gaussians: int = 20, std_range: Tuple[float, float] = (10, 15) + ): + """Initializes GaussianMixtureMasks with the given parameters. + + Args: + num_gaussians: Number of Gaussian kernels to generate in the mixture mask. + std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians. + """ + self.num_gaussians = num_gaussians + self.std_range = std_range + + def gaussian_kernel( + self, size: Tuple[int, int], sigma: Tensor, center: Tensor + ) -> Tensor: + """Generates a 2D Gaussian kernel. + + Args: + size: Tuple specifying the dimensions of the Gaussian kernel (H, W). + sigma: Tensor specifying the standard deviation of the Gaussian. + center: Tensor specifying the center of the Gaussian kernel. + + Returns: + A 2D Gaussian kernel tensor. + """ + u, v = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1])) + u = u.to(sigma.device) + v = v.to(sigma.device) + u0, v0 = center + gaussian = torch.exp( + -((u - u0) ** 2 / (2 * sigma[0] ** 2) + (v - v0) ** 2 / (2 * sigma[1] ** 2)) + ) + + return gaussian + + def apply_gaussian_mixture_mask( + self, freq_image: Tensor, num_gaussians: int, std: Tuple[float, float] + ) -> Tensor: + """Applies the Gaussian mixture mask to a frequency-domain image. + + Args: + freq_image: Tensor representing the frequency-domain image of shape (C, H, W//2+1). + num_gaussians: Number of Gaussian kernels to generate in the mask. + std: Tuple specifying the standard deviation range for the Gaussians. + + Returns: + Image tensor in frequency domain after applying the Gaussian mixture mask. + """ + (C, U, V) = freq_image.shape + mask = freq_image.new_ones(freq_image.shape) + + for _ in range(num_gaussians): + u0 = torch.randint(0, U, (1,), device=freq_image.device) + v0 = torch.randint(0, V, (1,), device=freq_image.device) + center = torch.tensor((u0, v0), device=freq_image.device) + sigma = torch.rand(2, device=freq_image.device) * (std[1] - std[0]) + std[0] + + g_kernel = self.gaussian_kernel((U, V), sigma, center) + mask *= 1 - g_kernel.unsqueeze(0) + + filtered_freq_image = freq_image * mask + return filtered_freq_image + + def __call__(self, freq_image: Tensor) -> Tensor: + """Applies the Gaussian mixture mask transformation to the input frequency-domain image. + + Args: + freq_image: Tensor representing a frequency-domain image of shape (C, H, W//2+1). + + Returns: + Image tensor in frequency domain after applying the Gaussian mixture mask. + """ + return self.apply_gaussian_mixture_mask( + freq_image, self.num_gaussians, self.std_range + ) diff --git a/tests/transforms/test_gaussian_mixture_masks.py b/tests/transforms/test_gaussian_mixture_masks.py new file mode 100644 index 000000000..db687000b --- /dev/null +++ b/tests/transforms/test_gaussian_mixture_masks.py @@ -0,0 +1,10 @@ +import torch + +from lightly.transforms import GaussianMixtureMask + + +def test() -> None: + transform = GaussianMixtureMask(20, (10, 15)) + image = torch.rand(3, 32, 17) + output = transform(image) + assert output.shape == image.shape From 9a9665d0025e4bcf7fa036bf8c373da3eb771967 Mon Sep 17 00:00:00 2001 From: ayush22iitbhu Date: Thu, 17 Oct 2024 12:56:33 +0530 Subject: [PATCH 4/7] Add Documentation for lightly/loss subpackage. (#1697) --- lightly/loss/barlow_twins_loss.py | 42 +++++--- lightly/loss/dcl_loss.py | 73 +++++++++----- lightly/loss/dino_loss.py | 31 +++--- lightly/loss/emp_ssl_loss.py | 31 +++++- lightly/loss/hypersphere_loss.py | 41 ++++---- lightly/loss/ibot_loss.py | 18 +++- lightly/loss/koleo_loss.py | 42 +++++--- lightly/loss/mmcr_loss.py | 19 +++- lightly/loss/msn_loss.py | 59 +++++++---- lightly/loss/negative_cosine_similarity.py | 26 +++-- lightly/loss/ntx_ent_loss.py | 49 +++++---- lightly/loss/pmsn_loss.py | 41 ++++++-- lightly/loss/regularizer/__init__.py | 1 - lightly/loss/regularizer/co2.py | 45 +++++---- lightly/loss/swav_loss.py | 43 +++++--- lightly/loss/sym_neg_cos_sim_loss.py | 27 +++-- lightly/loss/tico_loss.py | 42 +++++--- lightly/loss/vicreg_loss.py | 35 +++++-- lightly/loss/vicregl_loss.py | 110 +++++++++++++++++---- lightly/loss/wmse_loss.py | 65 +++++++++--- 20 files changed, 604 insertions(+), 236 deletions(-) diff --git a/lightly/loss/barlow_twins_loss.py b/lightly/loss/barlow_twins_loss.py index 54ee08496..f7783805c 100644 --- a/lightly/loss/barlow_twins_loss.py +++ b/lightly/loss/barlow_twins_loss.py @@ -7,12 +7,11 @@ class BarlowTwinsLoss(torch.nn.Module): """Implementation of the Barlow Twins Loss from Barlow Twins[0] paper. - This code specifically implements the Figure Algorithm 1 from [0]. + This code specifically implements the Figure Algorithm 1 from [0]. [0] Zbontar,J. et.al, 2021, Barlow Twins... https://arxiv.org/abs/2103.03230 - Examples: - + Examples: >>> # initialize loss function >>> loss_fn = BarlowTwinsLoss() >>> @@ -25,19 +24,22 @@ class BarlowTwinsLoss(torch.nn.Module): >>> >>> # calculate loss >>> loss = loss_fn(out0, out1) - """ def __init__(self, lambda_param: float = 5e-3, gather_distributed: bool = False): """Lambda param configuration with default value like in [0] + Initializes the BarlowTwinsLoss with the specified parameters. + Args: lambda_param: Parameter for importance of redundancy reduction term. - Defaults to 5e-3 [0]. gather_distributed: - If True then the cross-correlation matrices from all gpus are + If True, the cross-correlation matrices from all GPUs are gathered and summed before the loss calculation. + + Raises: + ValueError: If gather_distributed is True but torch.distributed is not available. """ super(BarlowTwinsLoss, self).__init__() self.lambda_param = lambda_param @@ -45,22 +47,32 @@ def __init__(self, lambda_param: float = 5e-3, gather_distributed: bool = False) if gather_distributed and not dist.is_available(): raise ValueError( - "gather_distributed is True but torch.distributed is not available. " + "gather_distributed is True but torch.distributed is not available." "Please set gather_distributed=False or install a torch version with " "distributed support." ) def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: - # normalize repr. along the batch dimension + """Computes the Barlow Twins loss for the given projections. + + Args: + z_a: Output projections of the first set of transformed images. + z_b: Output projections of the second set of transformed images. + + Returns: + Computed Barlow Twins Loss. + """ + + # Normalize repr. along the batch dimension z_a_norm, z_b_norm = _normalize(z_a, z_b) N = z_a.size(0) - # cross-correlation matrix + # Compute the cross-correlation matrix c = z_a_norm.T @ z_b_norm c.div_(N) - # sum cross-correlation matrix between multiple gpus + # Aggregate and normalize the cross-correlation matrix between multiple GPUs if self.gather_distributed and dist.is_initialized(): world_size = dist.get_world_size() if world_size > 1: @@ -78,7 +90,10 @@ def _normalize( z_a: torch.Tensor, z_b: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Helper function to normalize tensors along the batch dimension.""" + # Stack tensors along a new dimension combined = torch.stack([z_a, z_b], dim=0) # Shape: 2 x N x D + + # Normalize the stacked tensors along the batch dimension normalized = F.batch_norm( combined.flatten(0, 1), running_mean=None, @@ -87,11 +102,16 @@ def _normalize( bias=None, training=True, ).view_as(combined) + return normalized[0], normalized[1] def _off_diagonal(x): - # return a flattened view of the off-diagonal elements of a square matrix + """Returns a flattened view of the off-diagonal elements of a square matrix.""" + + # Ensure the input is a square matrix n, m = x.shape assert n == m + + # Flatten the matrix and extract off-diagonal elements return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() diff --git a/lightly/loss/dcl_loss.py b/lightly/loss/dcl_loss.py index d491bfcf5..06f3cb6b4 100644 --- a/lightly/loss/dcl_loss.py +++ b/lightly/loss/dcl_loss.py @@ -12,8 +12,7 @@ def negative_mises_fisher_weights( out0: Tensor, out1: Tensor, sigma: float = 0.5 ) -> torch.Tensor: - """Negative Mises-Fisher weighting function as presented in Decoupled - Contrastive Learning [0]. + """Negative Mises-Fisher weighting function as presented in Decoupled Contrastive Learning [0]. The implementation was inspired by [1]. @@ -32,15 +31,15 @@ def negative_mises_fisher_weights( Returns: A tensor with shape (batch_size,) where each entry is the weight for one of the input images. - """ similarity = torch.einsum("nm,nm->n", out0.detach(), out1.detach()) / sigma + + # Return negative Mises-Fisher weights return 2 - out0.shape[0] * nn.functional.softmax(similarity, dim=0) class DCLLoss(nn.Module): - """Implementation of the Decoupled Contrastive Learning Loss from - Decoupled Contrastive Learning [0]. + """Implementation of the Decoupled Contrastive Learning Loss from Decoupled Contrastive Learning [0]. This code implements Equation 6 in [0], including the sum over all images `i` and views `k`. The loss is reduced to a mean loss over the mini-batch. @@ -59,11 +58,10 @@ class DCLLoss(nn.Module): passed to the forward call as input and return a weight tensor. The returned weight tensor must have the same length as the input tensors. gather_distributed: - If True then negatives from all gpus are gathered before the + If True, negatives from all GPUs are gathered before the loss calculation. Examples: - >>> loss_fn = DCLLoss(temperature=0.07) >>> >>> # generate two random transforms of images @@ -80,7 +78,6 @@ class DCLLoss(nn.Module): >>> # you can also add a custom weighting function >>> weight_fn = lambda out0, out1: torch.sum((out0 - out1) ** 2, dim=1) >>> loss_fn = DCLLoss(weight_fn=weight_fn) - """ def __init__( @@ -89,11 +86,30 @@ def __init__( weight_fn: Optional[Callable[[Tensor, Tensor], Tensor]] = None, gather_distributed: bool = False, ): + """Initialzes the DCLoss module. + + Args: + temperature: + Similarities are scaled by inverse temperature. + weight_fn: + Weighting function `w` from the paper. Scales the loss between the + positive views (views from the same image). No weighting is performed + if weight_fn is None. The function must take the two input tensors + passed to the forward call as input and return a weight tensor. The + returned weight tensor must have the same length as the input tensors. + gather_distributed: + If True, negatives from all GPUs are gathered before the + loss calculation. + + Raises: + ValuesError: If gather_distributed is True but torch.distributed is not available. + """ super().__init__() self.temperature = temperature self.weight_fn = weight_fn self.gather_distributed = gather_distributed + # Check if distributed gathering is enabled but torch.distributed is not available if gather_distributed and not torch_dist.is_available(): raise ValueError( "gather_distributed is True but torch.distributed is not available. " @@ -119,21 +135,23 @@ def forward( Returns: Mean loss over the mini-batch. """ - # normalize the output to length 1 + # Normalize the output to length 1 out0 = nn.functional.normalize(out0, dim=1) out1 = nn.functional.normalize(out1, dim=1) if self.gather_distributed and dist.world_size() > 1: - # gather representations from other processes if necessary + # Gather representations from other processes if necessary out0_all = torch.cat(dist.gather(out0), 0) out1_all = torch.cat(dist.gather(out1), 0) else: out0_all = out0 out1_all = out1 - # calculate symmetric loss + # Calculate symmetric loss loss0 = self._loss(out0, out1, out0_all, out1_all) loss1 = self._loss(out1, out0, out1_all, out0_all) + + # Return the mean loss over the mini-batch return 0.5 * (loss0 + loss1) def _loss(self, out0, out1, out0_all, out1_all): @@ -154,17 +172,17 @@ def _loss(self, out0, out1, out0_all, out1_all): Output projections of the first set of transformed images from all distributed processes/gpus. Should be equal to out0 in an undistributed setting. - Shape (batch_size * world_size, embedding_size) + Shape: (batch_size * world_size, embedding_size) out1_all: Output projections of the second set of transformed images from - all distributed processes/gpus. Should be equal to out1 in an + all distributed processes/GPUs. Should be equal to out1 in an undistributed setting. - Shape (batch_size * world_size, embedding_size) + Shape: (batch_size * world_size, embedding_size) Returns: Mean loss over the mini-batch. """ - # create diagonal mask that only selects similarities between + # Create diagonal mask that only selects similarities between # representations of the same images batch_size = out0.shape[0] if self.gather_distributed and dist.world_size() > 1: @@ -172,8 +190,7 @@ def _loss(self, out0, out1, out0_all, out1_all): else: diag_mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool) - # calculate similarities - # here n = batch_size and m = batch_size * world_size. + # Calculate similarities (n = batch_size, m = batch_size * world_size) sim_00 = torch.einsum("nc,mc->nm", out0, out0_all) / self.temperature sim_01 = torch.einsum("nc,mc->nm", out0, out1_all) / self.temperature @@ -181,10 +198,11 @@ def _loss(self, out0, out1, out0_all, out1_all): if self.weight_fn: positive_loss = positive_loss * self.weight_fn(out0, out1) - # remove simliarities between same views of the same image + # Remove simliarities between same views of the same image sim_00 = sim_00[~diag_mask].view(batch_size, -1) - # remove similarities between different views of the same images - # this is the key difference compared to NTXentLoss + + # Remove similarities between different views of the same images + # This is the key difference compared to NTXentLoss sim_01 = sim_01[~diag_mask].view(batch_size, -1) negative_loss_00 = torch.logsumexp(sim_00, dim=1) @@ -210,11 +228,10 @@ class DCLWLoss(DCLLoss): Similar to temperature but applies the inverse scaling in the weighting function. gather_distributed: - If True then negatives from all gpus are gathered before the + If True, negatives from all GPUs are gathered before the loss calculation. Examples: - >>> loss_fn = DCLWLoss(temperature=0.07) >>> >>> # generate two random transforms of images @@ -227,7 +244,6 @@ class DCLWLoss(DCLLoss): >>> >>> # calculate loss >>> loss = loss_fn(out0, out1) - """ def __init__( @@ -236,6 +252,17 @@ def __init__( sigma: float = 0.5, gather_distributed: bool = False, ): + """Initializes the DCLWLoss module. + + Args: + temperature: + Similarities are scaled by inverse temperature. + sigma: + Applies inverse scaling in the weighting function. + gather_distributed: + If True, negatives from all GPUs are gathered before the + loss calculation. + """ super().__init__( temperature=temperature, weight_fn=partial(negative_mises_fisher_weights, sigma=sigma), diff --git a/lightly/loss/dino_loss.py b/lightly/loss/dino_loss.py index ec8a60869..72b45b7b0 100644 --- a/lightly/loss/dino_loss.py +++ b/lightly/loss/dino_loss.py @@ -10,8 +10,7 @@ class DINOLoss(Module): - """ - Implementation of the loss described in 'Emerging Properties in + """Implementation of the loss described in 'Emerging Properties in Self-Supervised Vision Transformers'. [0] This implementation follows the code published by the authors. [1] @@ -31,7 +30,7 @@ class DINOLoss(Module): teacher_temp: Final value of the teacher temperature after linear warmup. Values above 0.07 result in unstable behavior in most cases. Can be - slightly increased to improve performance during finetuning. + slightly increased to improve performance during fine-tuning. warmup_teacher_temp_epochs: Number of epochs for the teacher temperature warmup. student_temp: @@ -40,7 +39,6 @@ class DINOLoss(Module): Momentum term for the center calculation. Examples: - >>> # initialize loss function >>> loss_fn = DINOLoss(128) >>> @@ -53,7 +51,6 @@ class DINOLoss(Module): >>> >>> # calculate loss >>> loss = loss_fn([teacher_out], [student_out], epoch=0) - """ def __init__( @@ -66,6 +63,11 @@ def __init__( center_momentum: float = 0.9, center_mode: str = "mean", ): + """Initializes the DINOLoss Module. + + Raises: + ValueError: If an unknown center mode is provided. + """ super().__init__() if center_mode not in CENTER_MODE_TO_FUNCTION: raise ValueError( @@ -97,8 +99,7 @@ def forward( student_out: List[Tensor], epoch: int, ) -> Tensor: - """Cross-entropy between softmax outputs of the teacher and student - networks. + """Cross-entropy between softmax outputs of the teacher and student networks. Args: teacher_out: @@ -117,9 +118,8 @@ def forward( Returns: The average cross-entropy loss. - """ - # get teacher temperature + # Get teacher temperature if epoch < self.warmup_teacher_temp_epochs: teacher_temp = self.teacher_temp_schedule[epoch] else: @@ -131,18 +131,20 @@ def forward( student_out = torch.stack(student_out) s_out = F.log_softmax(student_out / self.student_temp, dim=-1) - # calculate feature similarities where: + # Calculate feature similarities, ignoring the diagonal # b = batch_size, t = n_views_teacher, s = n_views_student, d = output_dim - # the diagonal is ignored as it contains features from the same views loss = -torch.einsum("tbd,sbd->ts", t_out, s_out) loss.fill_diagonal_(0) - # number of loss terms, ignoring the diagonal + # Number of loss terms, ignoring the diagonal n_terms = loss.numel() - loss.diagonal().numel() batch_size = teacher_out.shape[1] + loss = loss.sum() / (n_terms * batch_size) + # Update the center used for the teacher output self.update_center(teacher_out) + return loss @torch.no_grad() @@ -153,9 +155,12 @@ def update_center(self, teacher_out: Tensor) -> None: teacher_out: Tensor with shape (num_views, batch_size, output_dim) containing features from the teacher model. - """ + + # Calculate the batch center using the specified center function batch_center = self._center_fn(x=teacher_out, dim=(0, 1)) + + # Update the center with a moving average self.center = center.center_momentum( center=self.center, batch_center=batch_center, momentum=self.center_momentum ) diff --git a/lightly/loss/emp_ssl_loss.py b/lightly/loss/emp_ssl_loss.py index ec20d7190..156ed63df 100644 --- a/lightly/loss/emp_ssl_loss.py +++ b/lightly/loss/emp_ssl_loss.py @@ -9,7 +9,7 @@ def tcr_loss(z: Tensor, eps: float) -> Tensor: - """Total Coding Rate (TCR) loss. + """Computes the Total Coding Rate (TCR) loss. Args: z: @@ -22,14 +22,17 @@ def tcr_loss(z: Tensor, eps: float) -> Tensor: """ _, batch_size, dim = z.shape diag = torch.eye(dim, device=z.device).unsqueeze(0) - # matmul over batch dimension + # Matrix multiplication over the batch dimension einsum = torch.einsum("vbd,vbe->vde", z, z) + + # Calculate the log determinant logdet = torch.logdet(diag + dim / (batch_size * eps) * einsum) + return 0.5 * logdet.mean() def invariance_loss(z: Tensor) -> Tensor: - """Loss representing the similiarity between the patch embeddings and the average of + """Calculates the invariance loss, representing the similiarity between the patch embeddings and the average of the patch embeddings. Args: @@ -39,7 +42,10 @@ def invariance_loss(z: Tensor) -> Tensor: Similarity loss. """ # z has shape (num_views, batch_size, dim) + + # Calculate the mean of the patch embeddings across the batch dimension z_mean = z.mean(0, keepdim=True) + return -F.cosine_similarity(z, z_mean, dim=-1).mean() @@ -78,11 +84,30 @@ def __init__( tcr_eps: float = 0.2, inv_coef: float = 200.0, ) -> None: + """Initializes the EMPSSLoss module. + + Args: + tcr_eps: + Total coding rate (TCR) epsilon. + inv_coff: + Coefficient for the invariance loss. + """ super().__init__() self.tcr_eps = tcr_eps self.inv_coef = inv_coef def forward(self, z_views: List[Tensor]) -> Tensor: + """Computes the EMP-SSL loss, which is a combination of Total Coding Rate loss and invariance loss. + + Args: + z_views: + List of patch embeddings tensors from different views. + + Returns: + The computed EMP-SSL loss. + """ + # z has shape (num_views, batch_size, dim) z = torch.stack(z_views) + return tcr_loss(z, eps=self.tcr_eps) + self.inv_coef * invariance_loss(z) diff --git a/lightly/loss/hypersphere_loss.py b/lightly/loss/hypersphere_loss.py index 3d1d248a6..7ef2da5b2 100644 --- a/lightly/loss/hypersphere_loss.py +++ b/lightly/loss/hypersphere_loss.py @@ -8,17 +8,16 @@ class HypersphereLoss(torch.nn.Module): - """ - Implementation of the loss described in 'Understanding Contrastive Representation Learning through + """Implementation of the loss described in 'Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere.' [0] [0] Tongzhou Wang. et.al, 2020, ... https://arxiv.org/abs/2005.10242 Note: - In order for this loss to function as advertized, an l1-normalization to the hypersphere is required. - This loss function applies this l1-normalization internally in the loss-layer. + In order for this loss to function as advertized, an L1-normalization to the hypersphere is required. + This loss function applies this L1-normalization internally in the loss layer. However, it is recommended that the same normalization is also applied in your architecture, - considering that this l1-loss is also intended to be applied during inference. + considering that this L1-loss is also intended to be applied during inference. Perhaps there may be merit in leaving it out of the inferrence pathway, but this use has not been tested. Moreover it is recommended that the layers preceeding this loss function are either a linear layer without activation, @@ -43,22 +42,21 @@ class HypersphereLoss(torch.nn.Module): >>> >>> # calculate loss >>> loss = loss_fn(out0, out1) - """ def __init__(self, t=1.0, lam=1.0, alpha=2.0): - """Parameters as described in [0] + """Initializes the HypersphereLoss module with the specified parameters. + + Parameters as described in [0] Args: - t : float - Temperature parameter; - proportional to the inverse variance of the Gaussians used to measure uniformity - lam : float: + t: + Temperature parameter; proportional to the inverse variance of the Gaussians used to measure uniformity. + lam: Weight balancing the alignment and uniformity loss terms - alpha : float + alpha: Power applied to the alignment term of the loss. At its default value of 2, - distances between positive samples are penalized in an l-2 sense. - + distances between positive samples are penalized in an L2 sense. """ super(HypersphereLoss, self).__init__() self.t = t @@ -66,24 +64,29 @@ def __init__(self, t=1.0, lam=1.0, alpha=2.0): self.alpha = alpha def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: - """ + """Computes the Hypersphere loss, which combines alignment and uniformity loss terms. Args: - x (torch.Tensor, [b, d], float) - y (torch.Tensor, [b, d], float) + z_a: + Tensor of shape (batch_size, embedding_dim) for the first set of embeddings. + z_b: + Tensor of shape (batch_size, embedding_dim) for the second set of embeddings. Returns: - Loss (torch.Tensor, [], float) - + The computed loss. """ + # Normalize the input embeddings x = F.normalize(z_a) y = F.normalize(z_b) + # Calculate alignment loss def lalign(x, y): return (x - y).norm(dim=1).pow(self.alpha).mean() + # Calculate uniformity loss def lunif(x): sq_pdist = torch.pdist(x, p=2).pow(2) return sq_pdist.mul(-self.t).exp().mean().log() + # Combine alignment and uniformity loss terms return lalign(x, y) + self.lam * (lunif(x) + lunif(y)) / 2 diff --git a/lightly/loss/ibot_loss.py b/lightly/loss/ibot_loss.py index 888e0a6b1..bceaaa8d4 100644 --- a/lightly/loss/ibot_loss.py +++ b/lightly/loss/ibot_loss.py @@ -36,6 +36,20 @@ def __init__( center_mode: str = "mean", center_momentum: float = 0.9, ) -> None: + """Initializes the iBOTPatchLoss module with the specified parameters. + + Args: + output_dim: + Dimension of the model output. + teacher_temp: + Temperature for the teacher output. + student_temp: + Temperature for the student output. + center_mode: + Mode for center calculation. Only 'mean' is supported. + center_momentum: + Momentum term for the center update. + """ super().__init__() self.teacher_temp = teacher_temp self.student_temperature = student_temp @@ -66,13 +80,13 @@ def forward( True in the mask. Returns: - Loss value. + The loss value. """ # B = batch size, N = sequence length = number of masked tokens, D = embed dim # H = height (in tokens), W = width (in tokens) # Note that N <= H * W depending on how many tokens are masked. - # Calculate cross entropy loss. + # Calculate cross-entropy loss. teacher_softmax = F.softmax( (teacher_out - self.center.value) / self.teacher_temp, dim=-1 ) diff --git a/lightly/loss/koleo_loss.py b/lightly/loss/koleo_loss.py index ba68f5ee0..45a731c5c 100644 --- a/lightly/loss/koleo_loss.py +++ b/lightly/loss/koleo_loss.py @@ -4,28 +4,38 @@ class KoLeoLoss(Module): + """KoLeo loss based on [0]. + + KoLeo loss is a regularizer that encourages a uniform span of the features in a + batch by penalizing the distance between the features and their nearest + neighbors. + + Implementation is based on [1]. + + - [0]: Spreading vectors for similarity search, 2019, https://arxiv.org/abs/1806.03198 + - [1]: https://github.com/facebookresearch/dinov2/blob/main/dinov2/loss/koleo_loss.py + + Attributes: + p: + The norm degree for pairwise distance calculation. + eps: + Small value to avoid division by zero. + """ + def __init__( self, p: float = 2, eps: float = 1e-8, ): - """KoLeo loss based on [0]. - - KoLeo loss is a regularizer that encourages a unfirom span of the features in a - batch by penalizing the distance between the features and their nearest - neighbors. - - Implementation is based on [1]. - - - [0]: Spreading vectors for similarity search, 2019, https://arxiv.org/abs/1806.03198 - - [1]: https://github.com/facebookresearch/dinov2/blob/main/dinov2/loss/koleo_loss.py + """Initializes the KoLeoLoss module with the specified parameters. - Attributes: + Args: p: The norm degree for pairwise distance calculation. eps: Small value to avoid division by zero. """ + super().__init__() self.p = p self.eps = eps @@ -35,17 +45,23 @@ def forward(self, x: Tensor) -> Tensor: """Forward pass through KoLeo Loss. Args: - x: - Tensor with shape (batch_size, embedding_size). + x: Tensor with shape (batch_size, embedding_size). + Returns: Loss value. """ + # Normalize the input tensor x = functional.normalize(x, p=2, dim=-1, eps=self.eps) + # Calculate cosine similarity. cos_sim = torch.mm(x, x.t()) cos_sim.fill_diagonal_(-2) + # Get nearest neighbors. nn_idx = cos_sim.argmax(dim=1) nn_dist: Tensor = self.pairwise_distance(x, x[nn_idx]) + + # Compute the loss loss = -(nn_dist + self.eps).log().mean() + return loss diff --git a/lightly/loss/mmcr_loss.py b/lightly/loss/mmcr_loss.py index 883b11476..765bf82c3 100644 --- a/lightly/loss/mmcr_loss.py +++ b/lightly/loss/mmcr_loss.py @@ -8,11 +8,9 @@ class MMCRLoss(nn.Module): All hyperparameters are set to the default values from the paper for ImageNet. - [0]: Efficient Coding of Natural Images using Maximum Manifold Capacity - Representations, 2023, https://arxiv.org/pdf/2303.03307.pdf - - Examples: - + Representations, 2023, https://arxiv.org/pdf/2303.03307.pdf + Examples: >>> # initialize loss function >>> loss_fn = MMCRLoss() >>> transform = MMCRTransform(k=2) @@ -27,6 +25,14 @@ class MMCRLoss(nn.Module): """ def __init__(self, lmda: float = 5e-3): + """Initializes the MMCRLoss module with the specified lambda parameter. + + Args: + lmda: The regularization parameter. + + Raises: + ValueError: If lmda is less than 0. + """ super().__init__() if lmda < 0: raise ValueError("lmda must be greater than or equal to 0") @@ -34,7 +40,8 @@ def __init__(self, lmda: float = 5e-3): self.lmda = lmda def forward(self, online: torch.Tensor, momentum: torch.Tensor) -> torch.Tensor: - """ + """Computes the MMCR loss for the online and momentum network outputs. + Args: online: Output of the online network for the current batch. Expected to be @@ -45,6 +52,8 @@ def forward(self, online: torch.Tensor, momentum: torch.Tensor) -> torch.Tensor: of shape (batch_size, k, embedding_size), where k represents the number of randomly augmented views for each sample. + Returns: + The computed loss value. """ assert ( online.shape == momentum.shape diff --git a/lightly/loss/msn_loss.py b/lightly/loss/msn_loss.py index 744714304..b1d7eb31b 100644 --- a/lightly/loss/msn_loss.py +++ b/lightly/loss/msn_loss.py @@ -27,7 +27,6 @@ def prototype_probabilities( Returns: Probability tensor with shape (batch_size, num_prototypes) which sums to 1 along the num_prototypes dimension. - """ return F.softmax(torch.matmul(queries, prototypes.T) / temperature, dim=1) @@ -43,7 +42,6 @@ def sharpen(probabilities: Tensor, temperature: float) -> Tensor: output probabilities are less uniform). Returns: Probabilities tensor with shape (batch_size, dim). - """ probabilities = probabilities ** (1.0 / temperature) probabilities /= torch.sum(probabilities, dim=1, keepdim=True) @@ -69,10 +67,9 @@ def sinkhorn( iterations: Number of iterations of the sinkhorn algorithms. Set to 0 to disable. gather_distributed: - If True then features from all gpus are gathered during normalization. + If True, features from all GPUs are gathered during normalization. Returns: A normalized probabilities tensor. - """ if iterations <= 0: return probabilities @@ -89,14 +86,14 @@ def sinkhorn( probabilities = probabilities / sum_probabilities for _ in range(iterations): - # normalize rows + # Normalize rows row_sum = torch.sum(probabilities, dim=1, keepdim=True) if world_size > 1: dist.all_reduce(row_sum) probabilities /= row_sum probabilities /= num_prototypes - # normalize columns + # Normalize columns probabilities /= torch.sum(probabilities, dim=0, keepdim=True) probabilities /= num_targets @@ -129,8 +126,7 @@ class MSNLoss(nn.Module): gather_distributed: If True, then target probabilities are gathered from all GPUs. - Examples: - + Examples: >>> # initialize loss function >>> loss_fn = MSNLoss() >>> @@ -144,7 +140,6 @@ class MSNLoss(nn.Module): >>> >>> # calculate loss >>> loss = loss_fn(anchors_out, targets_out, prototypes=model.prototypes) - """ def __init__( @@ -155,6 +150,28 @@ def __init__( me_max_weight: Optional[float] = None, gather_distributed: bool = False, ): + """Initializes the MSNLoss module with the specified parameters. + + Args: + temperature: + Similarities between anchors and targets are scaled by the inverse of the temperature. Must be in (0, inf). + sinkhorn_iterations: + Number of sinkhorn normalization iterations on the targets. + regularization_weight: + Weight factor lambda by which the regularization loss is scaled. Set to 0 to disable regularization. + me_max_weight: + Deprecated, use `regularization_weight` instead. Takes precedence over + `regularization_weight` if not None. Weight factor lambda by which the mean + entropy maximization regularization loss is scaled. Set to 0 to disable mean + entropy maximization regularization. + gather_distributed: + If True, then target probabilities are gathered from all GPUs. + + Raises: + ValueError: If temperature is not in (0, inf). + ValueError: If sinkhorn_iterations is less than 0. + ValueError: If gather_distributed is True but torch.distributed is not available. + """ super().__init__() if temperature <= 0: raise ValueError(f"temperature must be in (0, inf) but is {temperature}.") @@ -172,7 +189,7 @@ def __init__( self.temperature = temperature self.sinkhorn_iterations = sinkhorn_iterations self.regularization_weight = regularization_weight - # set regularization_weight to me_max_weight for backwards compatibility + # Set regularization_weight to me_max_weight for backwards compatibility if me_max_weight is not None: warnings.warn( DeprecationWarning( @@ -190,7 +207,7 @@ def forward( prototypes: Tensor, target_sharpen_temperature: float = 0.25, ) -> Tensor: - """Computes the MSN loss for a set of anchors, targets and prototypes. + """Computes the MSN loss for a set of anchors, targets, and prototypes. Args: anchors: @@ -204,19 +221,20 @@ def forward( Returns: Mean loss over all anchors. - """ num_views = anchors.shape[0] // targets.shape[0] + + # Normalize the inputs anchors = F.normalize(anchors, dim=1) targets = F.normalize(targets, dim=1) prototypes = F.normalize(prototypes, dim=1) - # anchor predictions + # Anchor predictions anchor_probs = prototype_probabilities( anchors, prototypes, temperature=self.temperature ) - # target predictions + # Target predictions with torch.no_grad(): target_probs = prototype_probabilities( targets, prototypes, temperature=self.temperature @@ -230,10 +248,10 @@ def forward( ) target_probs = target_probs.repeat((num_views, 1)) - # cross entropy loss + # Cross entropy loss loss = torch.mean(torch.sum(torch.log(anchor_probs ** (-target_probs)), dim=1)) - # regularization loss + # Regularization loss if self.regularization_weight > 0: mean_anchor_probs = torch.mean(anchor_probs, dim=0) reg_loss = self.regularization_loss(mean_anchor_probs=mean_anchor_probs) @@ -242,7 +260,14 @@ def forward( return loss def regularization_loss(self, mean_anchor_probs: Tensor) -> Tensor: - """Calculates mean entropy regularization loss.""" + """Calculates mean entropy regularization loss. + + Args: + mean_anchor_probs: The mean anchor probabilities. + + Returns: + The calculated regularization loss. + """ loss = -torch.sum(torch.log(mean_anchor_probs ** (-mean_anchor_probs))) loss += math.log(float(len(mean_anchor_probs))) return loss diff --git a/lightly/loss/negative_cosine_similarity.py b/lightly/loss/negative_cosine_similarity.py index 08fbd32c5..433a4dbbf 100644 --- a/lightly/loss/negative_cosine_similarity.py +++ b/lightly/loss/negative_cosine_similarity.py @@ -10,10 +10,9 @@ class NegativeCosineSimilarity(torch.nn.Module): """Implementation of the Negative Cosine Simililarity used in the SimSiam[0] paper. - [0] SimSiam, 2020, https://arxiv.org/abs/2011.10566 + - [0] SimSiam, 2020, https://arxiv.org/abs/2011.10566 Examples: - >>> # initialize loss function >>> loss_fn = NegativeCosineSimilarity() >>> @@ -27,17 +26,30 @@ class NegativeCosineSimilarity(torch.nn.Module): """ def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: - """Same parameters as in torch.nn.CosineSimilarity + """Initializes the NegativeCosineSimilarity module the specified parameters. + + Same parameters as in torch.nn.CosineSimilarity Args: - dim (int, optional): - Dimension where cosine similarity is computed. Default: 1 - eps (float, optional): - Small value to avoid division by zero. Default: 1e-8 + dim: + Dimension where cosine similarity is computed. + eps: + Small value to avoid division by zero. """ super().__init__() self.dim = dim self.eps = eps def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: + """Computes the negative cosine similarity between two tensors. + + Args: + x0: + First input tensor. + x1: + Second input tensor. + + Returns: + The mean negative cosine similarity. + """ return -cosine_similarity(x0, x1, self.dim, self.eps).mean() diff --git a/lightly/loss/ntx_ent_loss.py b/lightly/loss/ntx_ent_loss.py index 7f0c1d429..338e20d86 100644 --- a/lightly/loss/ntx_ent_loss.py +++ b/lightly/loss/ntx_ent_loss.py @@ -36,7 +36,7 @@ class NTXentLoss(MemoryBankModule): batch stored in the memory bank. Leaving out the feature dimension might lead to errors in distributed training. gather_distributed: - If True then negatives from all gpus are gathered before the + If True then negatives from all GPUs are gathered before the loss calculation. If a memory bank is used and gather_distributed is True, then tensors from all gpus are gathered before the memory bank is updated. @@ -44,7 +44,6 @@ class NTXentLoss(MemoryBankModule): ValueError: If abs(temperature) < 1e-8 to prevent divide by zero. Examples: - >>> # initialize loss function without memory bank >>> loss_fn = NTXentLoss(memory_bank_size=0) >>> @@ -67,6 +66,20 @@ def __init__( memory_bank_size: Union[int, Sequence[int]] = 0, gather_distributed: bool = False, ): + """Initializes the NTXentLoss module with the specified parameters. + + Args: + temperature: + Scale logits by the inverse of the temperature. + memory_bank_size: + Size of the memory bank. + gather_distributed: + If True, negatives from all GPUs are gathered before the loss calculation. + + Raises: + ValueError: If temperature is less than 1e-8 to prevent divide by zero. + ValueError: If gather_distributed is True but torch.distributed is not available. + """ super().__init__(size=memory_bank_size, gather_distributed=gather_distributed) self.temperature = temperature self.gather_distributed = gather_distributed @@ -101,13 +114,12 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor): Returns: Contrastive Cross Entropy Loss value. - """ device = out0.device batch_size, _ = out0.shape - # normalize the output to length 1 + # Normalize the output to length 1 out0 = nn.functional.normalize(out0, dim=1) out1 = nn.functional.normalize(out1, dim=1) @@ -121,12 +133,11 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor): out1, update=out0.requires_grad ) - # We use the cosine similarity, which is a dot product (einsum) here, - # as all vectors are already normalized to unit length. + # Use cosine similarity (dot product) as all vectors are normalized to unit length # Notation in einsum: n = batch_size, c = embedding_size and k = memory_bank_size. if negatives is not None: - # use negatives from memory bank + # Use negatives from memory bank negatives = negatives.to(device) # sim_pos is of shape (batch_size, 1) and sim_pos[i] denotes the similarity @@ -137,50 +148,50 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor): # of the i-th sample to the j-th negative sample sim_neg = torch.einsum("nc,ck->nk", out0, negatives) - # set the labels to the first "class", i.e. sim_pos, - # so that it is maximized in relation to sim_neg + # Set the labels to maximize sim_pos in relation to sim_neg logits = torch.cat([sim_pos, sim_neg], dim=1) / self.temperature labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long) else: - # user other samples from batch as negatives + # Use other samples from batch as negatives # and create diagonal mask that only selects similarities between # views of the same image if self.gather_distributed and dist.world_size() > 1: - # gather hidden representations from other processes + # Gather hidden representations from other processes out0_large = torch.cat(dist.gather(out0), 0) out1_large = torch.cat(dist.gather(out1), 0) diag_mask = dist.eye_rank(batch_size, device=out0.device) else: - # single process + # Single process out0_large = out0 out1_large = out1 diag_mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool) - # calculate similiarities - # here n = batch_size and m = batch_size * world_size - # the resulting vectors have shape (n, m) + # Calculate similiarities + # Here n = batch_size and m = batch_size * world_size + # The resulting vectors have shape (n, m) logits_00 = torch.einsum("nc,mc->nm", out0, out0_large) / self.temperature logits_01 = torch.einsum("nc,mc->nm", out0, out1_large) / self.temperature logits_10 = torch.einsum("nc,mc->nm", out1, out0_large) / self.temperature logits_11 = torch.einsum("nc,mc->nm", out1, out1_large) / self.temperature - # remove simliarities between same views of the same image + # Remove simliarities between same views of the same image logits_00 = logits_00[~diag_mask].view(batch_size, -1) logits_11 = logits_11[~diag_mask].view(batch_size, -1) - # concatenate logits - # the logits tensor in the end has shape (2*n, 2*m-1) + # Concatenate logits + # The logits tensor in the end has shape (2*n, 2*m-1) logits_0100 = torch.cat([logits_01, logits_00], dim=1) logits_1011 = torch.cat([logits_10, logits_11], dim=1) logits = torch.cat([logits_0100, logits_1011], dim=0) - # create labels + # Create labels labels = torch.arange(batch_size, device=device, dtype=torch.long) if self.gather_distributed: labels = labels + dist.rank() * batch_size labels = labels.repeat(2) + # Calculate the cross-entropy loss loss = self.cross_entropy(logits, labels) return loss diff --git a/lightly/loss/pmsn_loss.py b/lightly/loss/pmsn_loss.py index 15e97e2cb..e41a0a349 100644 --- a/lightly/loss/pmsn_loss.py +++ b/lightly/loss/pmsn_loss.py @@ -28,8 +28,7 @@ class PMSNLoss(MSNLoss): gather_distributed: If True, then target probabilities are gathered from all GPUs. - Examples: - + Examples: >>> # initialize loss function >>> loss_fn = PMSNLoss() >>> @@ -53,6 +52,7 @@ def __init__( power_law_exponent: float = 0.25, gather_distributed: bool = False, ): + """Initializes the PMSNLoss module with the specified parameters.""" super().__init__( temperature=temperature, sinkhorn_iterations=sinkhorn_iterations, @@ -62,7 +62,14 @@ def __init__( self.power_law_exponent = power_law_exponent def regularization_loss(self, mean_anchor_probs: Tensor) -> Tensor: - """Calculates regularization loss with a power law target distribution.""" + """Calculates the regularization loss with a power law target distribution. + + Args: + mean_anchor_probs: The mean anchor probabilities. + + Returns: + The calculated regularization loss. + """ power_dist = _power_law_distribution( size=mean_anchor_probs.shape[0], exponent=self.power_law_exponent, @@ -98,8 +105,7 @@ class PMSNCustomLoss(MSNLoss): gather_distributed: If True, then target probabilities are gathered from all GPUs. - Examples: - + Examples: >>> # define custom target distribution >>> def my_uniform_distribution(mean_anchor_probabilities: Tensor) -> Tensor: >>> dim = mean_anchor_probabilities.shape[0] @@ -128,6 +134,7 @@ def __init__( regularization_weight: float = 1, gather_distributed: bool = False, ): + """Initializes the PMSNCustomLoss module with the specified parameters.""" super().__init__( temperature=temperature, sinkhorn_iterations=sinkhorn_iterations, @@ -137,7 +144,15 @@ def __init__( self.target_distribution = target_distribution def regularization_loss(self, mean_anchor_probs: Tensor) -> Tensor: - """Calculates regularization loss with a custom target distribution.""" + """Calculates regularization loss with a custom target distribution. + + Args: + mean_anchor_probs: + The mean anchor probabilities. + + Returns: + The calculated regularization loss. + """ target_dist = self.target_distribution(mean_anchor_probs).to( mean_anchor_probs.device ) @@ -148,7 +163,19 @@ def regularization_loss(self, mean_anchor_probs: Tensor) -> Tensor: def _power_law_distribution(size: int, exponent: float, device: torch.device) -> Tensor: - """Returns a power law distribution summing up to 1.""" + """Returns a power law distribution summing up to 1. + + Args: + size: + The size of the distribution. + exponent: + The exponent for the power law distribution. + device: + The device to create tensor on. + + Returns: + A power law distribution tensor summing up to 1. + """ k = torch.arange(1, size + 1, device=device) power_dist = k ** (-exponent) power_dist = power_dist / power_dist.sum() diff --git a/lightly/loss/regularizer/__init__.py b/lightly/loss/regularizer/__init__.py index 3ed727248..6041175cf 100644 --- a/lightly/loss/regularizer/__init__.py +++ b/lightly/loss/regularizer/__init__.py @@ -1,6 +1,5 @@ """The lightly.loss.regularizer package provides regularizers for self-supervised learning. """ - # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved diff --git a/lightly/loss/regularizer/co2.py b/lightly/loss/regularizer/co2.py index 7cb52c532..2d302d671 100644 --- a/lightly/loss/regularizer/co2.py +++ b/lightly/loss/regularizer/co2.py @@ -13,7 +13,7 @@ class CO2Regularizer(MemoryBankModule): """Implementation of the CO2 regularizer [0] for self-supervised learning. - [0] CO2, 2021, https://arxiv.org/abs/2010.02217 + - [0] CO2, 2021, https://arxiv.org/abs/2010.02217 Attributes: alpha: @@ -44,7 +44,6 @@ class CO2Regularizer(MemoryBankModule): >>> >>> # calculate loss and apply regularizer >>> loss = loss_fn(out0, out1) + co2(out0, out1) - """ def __init__( @@ -53,8 +52,18 @@ def __init__( t_consistency: float = 0.05, memory_bank_size: Union[int, Sequence[int]] = 0, ): + """Initializes the CO2Regularizer with the specified parameters. + + Args: + alpha: + Weight of the regularization term. + t_consistency: + Temperature used during softmax calculations. + memory_bank_size: + Size of the memory bank. + """ super(CO2Regularizer, self).__init__(size=memory_bank_size) - # try-catch the KLDivLoss construction for backwards compatability + # Try-catch the KLDivLoss construction for backwards compatability self.log_target = True try: self.kl_div = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) @@ -76,29 +85,25 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor): Returns: The regularization term multiplied by the weight factor alpha. - """ - # normalize the output to length 1 + # Normalize the output to length 1 out0 = torch.nn.functional.normalize(out0, dim=1) out1 = torch.nn.functional.normalize(out1, dim=1) - # ask memory bank for negative samples and extend it with out1 if - # out1 requires a gradient, otherwise keep the same vectors in the - # memory bank (this allows for keeping the memory bank constant e.g. - # for evaluating the loss on the test set) - # if the memory_bank size is 0, negatives will be None + # Update the memory bank with out1 and get negatives(if memory bank size > 0) + # If the memory_bank size is 0, negatives will be None out1, negatives = super(CO2Regularizer, self).forward(out1, update=True) - # get log probabilities + # Get log probabilities p = self._get_pseudo_labels(out0, out1, negatives) q = self._get_pseudo_labels(out1, out0, negatives) - # calculate symmetrized kullback leibler divergence + # Calculate symmetrized Kullback-Leibler divergence if self.log_target: div = self.kl_div(p, q) + self.kl_div(q, p) else: - # can't use log_target because of early torch version + # Can't use log_target because of early torch version div = self.kl_div(p, torch.exp(q)) + self.kl_div(q, torch.exp(p)) return self.alpha * 0.5 * div @@ -124,30 +129,30 @@ def _get_pseudo_labels( Log probability that a positive samples will classify each negative sample as the positive sample. Shape: bsz x (bsz - 1) or bsz x memory_bank_size - """ batch_size, _ = out0.shape if negatives is None: - # use second batch as negative samples + # Use second batch as negative samples # l_pos has shape bsz x 1 and l_neg has shape bsz x bsz l_pos = torch.einsum("nc,nc->n", [out0, out1]).unsqueeze(-1) l_neg = torch.einsum("nc,ck->nk", [out0, out1.t()]) - # remove elements on the diagonal + + # Remove elements on the diagonal # l_neg has shape bsz x (bsz - 1) l_neg = l_neg.masked_select( ~torch.eye(batch_size, dtype=bool, device=l_neg.device) ).view(batch_size, batch_size - 1) else: - # use memory bank as negative samples + # Use memory bank as negative samples # l_pos has shape bsz x 1 and l_neg has shape bsz x memory_bank_size negatives = negatives.to(out0.device) l_pos = torch.einsum("nc,nc->n", [out0, out1]).unsqueeze(-1) l_neg = torch.einsum("nc,ck->nk", [out0, negatives.clone().detach()]) - # concatenate such that positive samples are at index 0 + # Concatenate such that positive samples are at index 0 logits = torch.cat([l_pos, l_neg], dim=1) - # divide by temperature + # Divide by temperature logits = logits / self.t_consistency - # the input to kl_div is expected to be log(p) + # The input to kl_div is expected to be log(p) return torch.nn.functional.log_softmax(logits, dim=-1) diff --git a/lightly/loss/swav_loss.py b/lightly/loss/swav_loss.py index 01345bd2b..bf4294812 100644 --- a/lightly/loss/swav_loss.py +++ b/lightly/loss/swav_loss.py @@ -17,8 +17,8 @@ def sinkhorn( As outlined in [0] and implemented in [1]. - [0]: SwaV, 2020, https://arxiv.org/abs/2006.09882 - [1]: https://github.com/facebookresearch/swav/ + - [0]: SwaV, 2020, https://arxiv.org/abs/2006.09882 + - [1]: https://github.com/facebookresearch/swav/ Args: out: @@ -33,13 +33,12 @@ def sinkhorn( Returns: Soft codes Q assigning each feature to a prototype. - """ world_size = 1 if gather_distributed and dist.is_initialized(): world_size = dist.get_world_size() - # get the exponential matrix and make it sum to 1 + # Get the exponential matrix and make it sum to 1 Q = torch.exp(out / epsilon).t() sum_Q = torch.sum(Q) if world_size > 1: @@ -49,12 +48,12 @@ def sinkhorn( B = Q.shape[1] * world_size for _ in range(iterations): - # normalize rows + # Normalize rows sum_of_rows = torch.sum(Q, dim=1, keepdim=True) if world_size > 1: dist.all_reduce(sum_of_rows) Q /= sum_of_rows - # normalize columns + # Normalize columns Q /= torch.sum(Q, dim=0, keepdim=True) Q /= B @@ -73,9 +72,8 @@ class SwaVLoss(nn.Module): sinkhorn_epsilon: Temperature parameter used in the sinkhorn algorithm. sinkhorn_gather_distributed: - If True then features from all gpus are gathered to calculate the + If True, features from all GPUs are gathered to calculate the soft codes in the sinkhorn algorithm. - """ def __init__( @@ -85,6 +83,23 @@ def __init__( sinkhorn_epsilon: float = 0.05, sinkhorn_gather_distributed: bool = False, ): + """Initializes the SwaVLoss module with the specified parameters. + + Args: + temperature: + Temperature parameter used for cross-entropy calculations. + sinkhorn_iterations: + Number of iterations of the sinkhorn algorithm. + sinkhorn_epsilon: + Temperature parameter used in the sinkhorn algorithm. + sinkhorn_gather_distributed: + If True, features from all GPUs are gathered to calculate the + soft codes in the sinkhorn algorithm. + + Raises: + ValueError: If sinkhorn_gather_distributed is True but torch.distributed + is not available. + """ super(SwaVLoss, self).__init__() if sinkhorn_gather_distributed and not dist.is_available(): raise ValueError( @@ -109,7 +124,6 @@ def subloss(self, z: torch.Tensor, q: torch.Tensor): Returns: Cross entropy between predictions z and codes q. - """ return -torch.mean( torch.sum(q * F.log_softmax(z / self.temperature, dim=1), dim=1) @@ -123,6 +137,8 @@ def forward( ): """Computes the SwaV loss for a set of high and low resolution outputs. + - [0]: SwaV, 2020, https://arxiv.org/abs/2006.09882 + Args: high_resolution_outputs: List of similarities of features and SwaV prototypes for the @@ -136,16 +152,13 @@ def forward( Returns: Swapping assignments between views loss (SwaV) as described in [0]. - - [0]: SwaV, 2020, https://arxiv.org/abs/2006.09882 - """ n_crops = len(high_resolution_outputs) + len(low_resolution_outputs) - # multi-crop iterations + # Multi-crop iterations loss = 0.0 for i in range(len(high_resolution_outputs)): - # compute codes of i-th high resolution crop + # Compute codes of i-th high resolution crop with torch.no_grad(): outputs = high_resolution_outputs[i].detach() @@ -165,7 +178,7 @@ def forward( if queue_outputs is not None: q = q[: len(high_resolution_outputs[i])] - # compute subloss for each pair of crops + # Compute subloss for each pair of crops subloss = 0.0 for v in range(len(high_resolution_outputs)): if v != i: diff --git a/lightly/loss/sym_neg_cos_sim_loss.py b/lightly/loss/sym_neg_cos_sim_loss.py index 3cb20beec..74c2d960b 100644 --- a/lightly/loss/sym_neg_cos_sim_loss.py +++ b/lightly/loss/sym_neg_cos_sim_loss.py @@ -11,10 +11,9 @@ class SymNegCosineSimilarityLoss(torch.nn.Module): """Implementation of the Symmetrized Loss used in the SimSiam[0] paper. - [0] SimSiam, 2020, https://arxiv.org/abs/2011.10566 + - [0] SimSiam, 2020, https://arxiv.org/abs/2011.10566 Examples: - >>> # initialize loss function >>> loss_fn = SymNegCosineSimilarityLoss() >>> @@ -27,10 +26,14 @@ class SymNegCosineSimilarityLoss(torch.nn.Module): >>> >>> # calculate loss >>> loss = loss_fn(out0, out1) - """ def __init__(self) -> None: + """Initializes the SymNegCosineSimilarityLoss module. + + Note: + SymNegCosineSimilarityLoss will be deprecated in favor of NegativeCosineSimilarity in the future. + """ super().__init__() warnings.warn( Warning( @@ -47,19 +50,16 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor): out0: Output projections of the first set of transformed images. Expects the tuple to be of the form (z0, p0), where z0 is - the output of the backbone and projection mlp, and p0 is the + the output of the backbone and projection MLP, and p0 is the output of the prediction head. out1: Output projections of the second set of transformed images. Expects the tuple to be of the form (z1, p1), where z1 is - the output of the backbone and projection mlp, and p1 is the + the output of the backbone and projection MLP, and p1 is the output of the prediction head. Returns: - Contrastive Cross Entropy Loss value. - - Raises: - ValueError if shape of output is not multiple of batch_size. + Negative Cosine Similarity loss value. """ z0, p0 = out0 z1, p1 = out1 @@ -72,5 +72,14 @@ def forward(self, out0: torch.Tensor, out1: torch.Tensor): return loss def _neg_cosine_simililarity(self, x, y): + """Calculates the negative cosine similarity between two tensors. + + Args: + x: First input tensor. + y: Second input tensor. + + Returns: + Negative cosine similarity value. + """ v = -torch.nn.functional.cosine_similarity(x, y.detach(), dim=-1).mean() return v diff --git a/lightly/loss/tico_loss.py b/lightly/loss/tico_loss.py index b2ae3f328..9dba5d7d3 100644 --- a/lightly/loss/tico_loss.py +++ b/lightly/loss/tico_loss.py @@ -6,14 +6,14 @@ class TiCoLoss(torch.nn.Module): """Implementation of the Tico Loss from Tico[0] paper. + This implementation takes inspiration from the code published by sayannag using Lightly. [1] - [0] Jiachen Zhu et. al, 2022, Tico... https://arxiv.org/abs/2206.10698 - [1] https://github.com/sayannag/TiCo-pytorch + - [0] Jiachen Zhu et. al, 2022, Tico... https://arxiv.org/abs/2206.10698 + - [1] https://github.com/sayannag/TiCo-pytorch Attributes: - Args: beta: Coefficient for the EMA update of the covariance @@ -22,11 +22,10 @@ class TiCoLoss(torch.nn.Module): Weight for the covariance term of the loss Defaults to 8.0 [0]. gather_distributed: - If True then the cross-correlation matrices from all gpus are + If True, the cross-correlation matrices from all GPUs are gathered and summed before the loss calculation. Examples: - >>> # initialize loss function >>> loss_fn = TiCoLoss() >>> @@ -47,6 +46,20 @@ def __init__( rho: float = 8.0, gather_distributed: bool = False, ): + """Initializes the TiCoLoss module with the specified parameters. + + Args: + beta: + Coefficient for the EMA update of the covariance. + rho: + Weight for the covariance term of the loss. + gather_distributed: + If True, the cross-correlation matrices from all GPUs are gathered + and summed before the loss calculation. Default is False. + + Raises: + ValueError: If gather_distributed is True but torch.distributed is not available. + """ super(TiCoLoss, self).__init__() if gather_distributed and not dist.is_available(): raise ValueError( @@ -66,7 +79,9 @@ def forward( z_b: torch.Tensor, update_covariance_matrix: bool = True, ) -> torch.Tensor: - """Tico Loss computation. It maximize the agreement among embeddings of different distorted versions of the same image + """Computes the TiCo loss. + + It maximizes the agreement among embeddings of different distorted versions of the same image while avoiding collapse using Covariance matrix. Args: @@ -78,8 +93,11 @@ def forward( Parameter to update the covariance matrix at each iteration. Returns: - The loss. + The computed loss. + Raises: + AssertionError: If z_a or z_b have a batch size <= 1. + AssertionError: If z_a and z_b do not have the same shape. """ assert ( @@ -96,18 +114,18 @@ def forward( z_a = torch.cat(gather(z_a), dim=0) z_b = torch.cat(gather(z_b), dim=0) - # normalize image + # Normalize image z_a = torch.nn.functional.normalize(z_a, dim=1) z_b = torch.nn.functional.normalize(z_b, dim=1) - # compute auxiliary matrix B + # Compute auxiliary matrix B B = torch.mm(z_a.T, z_a).detach() / z_a.shape[0] - # init covariance matrix + # Initialize covariance matrix if self.C is None: self.C = B.new_zeros(B.shape).detach() - # compute loss + # Compute loss C = self.beta * self.C + (1 - self.beta) * B transformative_invariance_loss = 1.0 - (z_a * z_b).sum(dim=1).mean() @@ -115,7 +133,7 @@ def forward( loss = transformative_invariance_loss + covariance_contrast_loss - # update covariance matrix + # Update covariance matrix if update_covariance_matrix: self.C = C.detach() diff --git a/lightly/loss/vicreg_loss.py b/lightly/loss/vicreg_loss.py index 1b99e6f27..32996d572 100644 --- a/lightly/loss/vicreg_loss.py +++ b/lightly/loss/vicreg_loss.py @@ -22,13 +22,12 @@ class VICRegLoss(torch.nn.Module): nu_param: Scaling coefficient for the covariance term of the loss. gather_distributed: - If True then the cross-correlation matrices from all gpus are gathered and + If True, the cross-correlation matrices from all GPUs are gathered and summed before the loss calculation. eps: Epsilon for numerical stability. Examples: - >>> # initialize loss function >>> loss_fn = VICRegLoss() >>> @@ -51,6 +50,11 @@ def __init__( gather_distributed: bool = False, eps=0.0001, ): + """Initializes the VICRegLoss module with the specified parameters. + + Raises: + ValueError: If gather_distributed is True but torch.distributed is not available. + """ super(VICRegLoss, self).__init__() if gather_distributed and not dist.is_available(): raise ValueError( @@ -73,6 +77,13 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: Tensor with shape (batch_size, ..., dim). z_b: Tensor with shape (batch_size, ..., dim). + + Returns: + The computed VICReg loss. + + Raises: + AssertionError: If z_a or z_b have a batch size <= 1. + AssertionError: If z_a and z_b do not have the same shape. """ assert ( z_a.shape[0] > 1 and z_b.shape[0] > 1 @@ -81,21 +92,23 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: z_a.shape == z_b.shape ), f"z_a and z_b must have same shape but found {z_a.shape} and {z_b.shape}." - # invariance term of the loss + # Invariance term of the loss inv_loss = invariance_loss(x=z_a, y=z_b) - # gather all batches + # Gather all batches if self.gather_distributed and dist.is_initialized(): world_size = dist.get_world_size() if world_size > 1: z_a = torch.cat(gather(z_a), dim=0) z_b = torch.cat(gather(z_b), dim=0) + # Variance and covariance terms of the loss var_loss = 0.5 * ( variance_loss(x=z_a, eps=self.eps) + variance_loss(x=z_b, eps=self.eps) ) cov_loss = covariance_loss(x=z_a) + covariance_loss(x=z_b) + # Total VICReg loss loss = ( self.lambda_param * inv_loss + self.mu_param * var_loss @@ -112,6 +125,9 @@ def invariance_loss(x: Tensor, y: Tensor) -> Tensor: Tensor with shape (batch_size, ..., dim). y: Tensor with shape (batch_size, ..., dim). + + Returns: + The computed VICReg invariance loss. """ return F.mse_loss(x, y) @@ -124,6 +140,9 @@ def variance_loss(x: Tensor, eps: float = 0.0001) -> Tensor: Tensor with shape (batch_size, ..., dim). eps: Epsilon for numerical stability. + + Returns: + The computed VICReg variance loss. """ std = torch.sqrt(x.var(dim=0) + eps) loss = torch.mean(F.relu(1.0 - std)) @@ -138,15 +157,19 @@ def covariance_loss(x: Tensor) -> Tensor: https://github.com/facebookresearch/VICRegL/blob/803ae4c8cd1649a820f03afb4793763e95317620/main_vicregl.py#L299 Args: - x: - Tensor with shape (batch_size, ..., dim). + x: Tensor with shape (batch_size, ..., dim). + + Returns: + The computed VICReg covariance loss. """ x = x - x.mean(dim=0) batch_size = x.size(0) dim = x.size(-1) # nondiag_mask has shape (dim, dim) with 1s on all non-diagonal entries. nondiag_mask = ~torch.eye(dim, device=x.device, dtype=torch.bool) + # cov has shape (..., dim, dim) cov = torch.einsum("b...c,b...d->...cd", x, x) / (batch_size - 1) + loss = cov[..., nondiag_mask].pow(2).sum(-1) / dim return loss.mean() diff --git a/lightly/loss/vicregl_loss.py b/lightly/loss/vicregl_loss.py index 81466bc1b..ceb14e7ec 100644 --- a/lightly/loss/vicregl_loss.py +++ b/lightly/loss/vicregl_loss.py @@ -33,7 +33,7 @@ class VICRegLLoss(torch.nn.Module): Coefficient to weight global with local loss. The final loss is computed as (self.alpha * global_loss + (1-self.alpha) * local_loss). gather_distributed: - If True then the cross-correlation matrices from all gpus are gathered and + If True, the cross-correlation matrices from all gpus are gathered and summed before the loss calculation. eps: Epsilon for numerical stability. @@ -41,7 +41,6 @@ class VICRegLLoss(torch.nn.Module): Number of local features to match using nearest neighbors. Examples: - >>> # initialize loss function >>> criterion = VICRegLLoss() >>> transform = VICRegLTransform(n_global_views=2, n_local_views=4) @@ -61,7 +60,6 @@ class VICRegLLoss(torch.nn.Module): ... local_view_features=features[2:], ... local_view_grids=grids[2:], ... ) - """ def __init__( @@ -74,6 +72,11 @@ def __init__( eps: float = 0.0001, num_matches: Tuple[int, int] = (20, 4), ): + """Initializes the VICRegL loss module with the specified parameters. + + Raises: + ValueError: If gather_distributed is True but torch.distributed is not available. + """ super(VICRegLLoss, self).__init__() self.alpha = alpha self.num_matches = num_matches @@ -129,6 +132,11 @@ def forward( Returns: Weighted sum of the global and local loss, calculated as: `self.alpha * global_loss + (1-self.alpha) * local_loss`. + + Raises: + ValueError: If the lengths of global_view_features and global_view_grids are not the same. + ValueError: If the lengths of local_view_features and local_view_grids are not the same. + ValueError: If only one of local_view_features or local_view_grids is set. """ if len(global_view_features) != len(global_view_grids): raise ValueError( @@ -147,13 +155,13 @@ def forward( f"None but found {type(local_view_features)} and {type(local_view_grids)}." ) - # calculate loss from global features + # Calculate loss from global features global_loss = self._global_loss( global_view_features=global_view_features, local_view_features=local_view_features, ) - # calculate loss from local features + # Calculate loss from local features local_loss = self._local_loss( global_view_features=global_view_features, global_view_grids=global_view_grids, @@ -169,7 +177,19 @@ def _global_loss( global_view_features: Sequence[Tuple[Tensor, Tensor]], local_view_features: Optional[Sequence[Tuple[Tensor, Tensor]]] = None, ) -> Tensor: - """Returns global features loss.""" + """Returns global features loss. + + Args: + global_view_features: + Sequence of (global_features, local_features) + tuples from the global crop views. + local_view_features: + Sequence of (global_features,local_features) + tuples from the local crop views. + + Returns: + The computed global features loss. + """ inv_loss = self._global_invariance_loss( global_view_features=global_view_features, local_view_features=local_view_features, @@ -189,21 +209,35 @@ def _global_invariance_loss( global_view_features: Sequence[Tuple[Tensor, Tensor]], local_view_features: Optional[Sequence[Tuple[Tensor, Tensor]]] = None, ) -> Tensor: - """Returns invariance loss from global features.""" + """Returns invariance loss from global features. + + Args: + global_view_features: + Sequence of (global_features, local_features) + tuples from the global crop views. + local_view_features: + Sequence of (global_features,local_features) + tuples from the local crop views. + + Returns: + The computed invariance loss from global features. + """ loss = 0 loss_count = 0 + + # Compute invariance loss between global views for global_features_a, _ in global_view_features: - # global views for global_features_b, _ in global_view_features: if global_features_a is not global_features_b: loss += invariance_loss(global_features_a, global_features_b) loss_count += 1 - # local views + # Compute invariance loss between global and local views if local_view_features is not None: for global_features_b, _ in local_view_features: loss += invariance_loss(global_features_a, global_features_b) loss_count += 1 + return loss / loss_count def _global_variance_and_covariance_loss( @@ -211,7 +245,17 @@ def _global_variance_and_covariance_loss( global_view_features: Sequence[Tuple[Tensor, Tensor]], local_view_features: Optional[Sequence[Tuple[Tensor, Tensor]]] = None, ) -> Tuple[Tensor, Tensor]: - """Returns variance and covariance loss from global features.""" + """Returns variance and covariance loss from global features. + + Args: + global_view_features: Sequence of (global_features, local_features) + tuples from the global crop views. + local_view_features: Sequence of (global_features,local_features) + tuples from the local crop views. + + Returns: + The computed variance and covariance loss from global features. + """ view_features = list(global_view_features) if local_view_features is not None: view_features = view_features + list(local_view_features) @@ -252,13 +296,27 @@ def _local_loss( regardless of global or local views: https://github.com/facebookresearch/VICRegL/blob/803ae4c8cd1649a820f03afb4793763e95317620/main_vicregl.py#L329-L334 Our implementation follows the original code and ignores view type. + + Args: + global_view_features: + Sequence of (global_features, local_features) tuples from the global crop views. + global_view_grids: + Sequence of grid tensors from the global crop views. + local_view_features: + Sequence of (global_features,local_features) tuples from the local crop views. + local_view_grids: + Sequence of grid tensors from the local crop views. + + Returns: + The computed loss from local features based on nearest neighbor matching. """ loss = 0 loss_count = 0 + + # Compute the loss for global views for (_, z_a_local_features), grid_a in zip( global_view_features, global_view_grids ): - # global views for (_, z_b_local_features), grid_b in zip( global_view_features, global_view_grids ): @@ -275,7 +333,7 @@ def _local_loss( ) loss_count += 1 - # local views + # Compute the loss for local views if local_view_features is not None and local_view_grids is not None: for (_, z_b_local_features), grid_b in zip( local_view_features, local_view_grids @@ -303,22 +361,29 @@ def _local_l2_loss( Args: z_a: - Local feature tensor with shape (batch_size, heigh, width, dim). + Local feature tensor with shape (batch_size, height, width, dim). z_b: - Local feature tensor with shape (batch_size, heigh, width, dim). + Local feature tensor with shape (batch_size, height, width, dim). + + Returns: + The computed loss for local features. """ - # (batch_size, heigh, width, dim) -> (batch_size, heigh * width, dim) + # (batch_size, height, width, dim) -> (batch_size, height * width, dim) z_a = z_a.flatten(start_dim=1, end_dim=2) z_b = z_b.flatten(start_dim=1, end_dim=2) + # Find nearest neighbours using L2 distance z_a_filtered, z_a_nn = self._nearest_neighbors_on_l2( input_features=z_a, candidate_features=z_b, num_matches=self.num_matches[0] ) z_b_filtered, z_b_nn = self._nearest_neighbors_on_l2( input_features=z_b, candidate_features=z_a, num_matches=self.num_matches[1] ) + + # Compute VICReg losses loss_a = self.vicreg_loss.forward(z_a=z_a_filtered, z_b=z_a_nn) loss_b = self.vicreg_loss.forward(z_a=z_b_filtered, z_b=z_b_nn) + return 0.5 * (loss_a + loss_b) def _local_location_loss( @@ -333,22 +398,28 @@ def _local_location_loss( Args: z_a: - Local feature tensor with shape (batch_size, heigh, width, dim). + Local feature tensor with shape (batch_size, height, width, dim). z_b: - Local feature tensor with shape (batch_size, heigh, width, dim). + Local feature tensor with shape (batch_size, height, width, dim). Note that height and width can be different from z_a. grid_a: Grid tensor with shape (batch_size, height, width, 2). grid_b: Grid tensor with shape (batch_size, height, width, 2). Note that height and width can be different from grid_a. + + Returns: + The computed loss for local features based on nearest neighbour matching. """ - # (batch_size, heigh, width, dim) -> (batch_size, heigh * width, dim) + # (batch_size, height, width, dim) -> (batch_size, height * width, dim) z_a = z_a.flatten(start_dim=1, end_dim=2) z_b = z_b.flatten(start_dim=1, end_dim=2) - # (batch_size, heigh, width, 2) -> (batch_size, heigh * width, 2) + + # (batch_size, height, width, 2) -> (batch_size, height * width, 2) grid_a = grid_a.flatten(start_dim=1, end_dim=2) grid_b = grid_b.flatten(start_dim=1, end_dim=2) + + # Find nearest neighbours based on grid location z_a_filtered, z_a_nn = self._nearest_neighbors_on_grid( input_features=z_a, candidate_features=z_b, @@ -364,6 +435,7 @@ def _local_location_loss( num_matches=self.num_matches[1], ) + # Compute VICReg losses loss_a = self.vicreg_loss.forward(z_a=z_a_filtered, z_b=z_a_nn) loss_b = self.vicreg_loss.forward(z_a=z_b_filtered, z_b=z_b_nn) return 0.5 * (loss_a + loss_b) diff --git a/lightly/loss/wmse_loss.py b/lightly/loss/wmse_loss.py index 3fb458bd5..8d65ed8aa 100644 --- a/lightly/loss/wmse_loss.py +++ b/lightly/loss/wmse_loss.py @@ -16,17 +16,24 @@ def norm_mse_loss(x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: - """Normalized MSE Loss as implemented in https://github.com/htdt/self-supervised.""" + """Normalized MSE Loss as implemented in https://github.com/htdt/self-supervised. + + Args: + x0: First input tensor. + x1: Second input tensor. + + Returns: + The computed normalized MSE loss. + """ x0 = F.normalize(x0) x1 = F.normalize(x1) return torch.sub(input=2, other=(x0 * x1).sum(dim=-1).mean(), alpha=2) class Whitening2d(nn.Module): - """ - Implementation of the whitening layer as described in [0]. + """Implementation of the whitening layer as described in [0]. - [0] W-MSE, 2021, https://arxiv.org/pdf/2007.06346.pdf + - [0] W-MSE, 2021, https://arxiv.org/pdf/2007.06346.pdf """ def __init__( @@ -36,6 +43,22 @@ def __init__( track_running_stats: bool = True, eps: float = 0, ): + """Initializes the Whitening2d module with the specified parameters. + + Args: + num_features: + Number of features in the input. + momentum: + Momentum for the running mean and variance. + track_running_stats: + If True, tracks the running mean and variance. + eps: + Epsilon for numerical stability. + + Raises: + RuntimeError: If torch.linalg.solve_triangular is not available in the PyTorch installation. + """ + super(Whitening2d, self).__init__() if not _SOLVE_TRIANGULAR_AVAILABLE: @@ -59,13 +82,25 @@ def __init__( self.register_buffer("running_variance", torch.eye(self.num_features)) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the Whitening2d layer. + + Args: + x: Input tensor. + + Returns: + Decorrelated output tensor. + + """ x = x.unsqueeze(2).unsqueeze(3) m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1) if not self.training and self.track_running_stats: # for inference m = self.running_mean xn = x - m + # Reshape for covariance computation T = xn.permute(1, 0, 2, 3).contiguous().view(self.num_features, -1) + + # Compute covariance matrix f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) eye = torch.eye(self.num_features).type(f_cov.type()) @@ -83,6 +118,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.num_features, self.num_features, 1, 1 ) + # Decorrelate the features decorrelated = F.conv2d(xn, inv_sqrt) if self.training and self.track_running_stats: @@ -101,8 +137,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class WMSELoss(torch.nn.Module): - """ - Implementation of the loss described in 'Whitening for + """Implementation of the loss described in 'Whitening for Self-Supervised Representation Learning' [0]. - [0] W-MSE, 2021, https://arxiv.org/pdf/2007.06346.pdf @@ -134,7 +169,9 @@ def __init__( loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = norm_mse_loss, num_samples: int = 2, ): - """Parameters as described in [0] + """Initializes the WMSELoss module with the specified parameters. + + Parameters as described in [0]. Args: embedding_dim: @@ -154,7 +191,8 @@ def __init__( num_samples: Number of samples generated by the transforms for each image. - + Raises: + ValueError: If w_size is less than twice the size of embedding_dim. """ super().__init__() self.whitening = Whitening2d( @@ -176,18 +214,15 @@ def __init__( def forward(self, input: torch.Tensor) -> torch.Tensor: """Calculates the W-MSE loss. - Args: - input: - Tensor with shape (batch_size * num_samples, embedding_dim). + Args: + input: Tensor with shape (batch_size * num_samples, embedding_dim). Returns: Aggregate W-MSE loss over all sub-batches. Raises: - RuntimeError: - If the batch size is not divisible by num_samples. - ValueError: - If the batch size is smaller than w_size. + RuntimeError: If the batch size is not divisible by num_samples. + ValueError: If the batch size is smaller than w_size. """ if input.shape[0] % self.num_samples != 0: raise RuntimeError("input batch size must be divisible by num_samples") From 37186b36a0ac147c086bf993f7068e93dfc31e15 Mon Sep 17 00:00:00 2001 From: Harshit Vashisht <120767685+HarshitVashisht11@users.noreply.github.com> Date: Fri, 18 Oct 2024 19:17:06 +0530 Subject: [PATCH 5/7] Cleanup docstrings in lightly/utils subpackage (#1698) --- lightly/utils/bounding_box.py | 60 ++++++++++++++++++++++++++-------- lightly/utils/debug.py | 33 +++++++++++-------- lightly/utils/dependency.py | 34 ++++++++++++++----- lightly/utils/dist.py | 13 +++++--- lightly/utils/embeddings_2d.py | 23 ++++++++----- lightly/utils/hipify.py | 35 +++++++++++++++++++- lightly/utils/lars.py | 11 +++---- lightly/utils/scheduler.py | 14 ++++---- 8 files changed, 161 insertions(+), 62 deletions(-) diff --git a/lightly/utils/bounding_box.py b/lightly/utils/bounding_box.py index 322b13b16..695380b3b 100644 --- a/lightly/utils/bounding_box.py +++ b/lightly/utils/bounding_box.py @@ -1,4 +1,4 @@ -""" Bounding Box Utils """ +"""Bounding Box Utils""" from __future__ import annotations @@ -31,17 +31,26 @@ class BoundingBox: >>> # (x0, y0, x1, y1) = (10, 20, 30, 40) >>> W, H = 100, 100 # get image shape >>> bbox = BoundingBox(10 / W, 20 / H, 30 / W, 40 / H) - """ def __init__( self, x0: float, y0: float, x1: float, y1: float, clip_values: bool = True ): - """ - clip_values: - Set to true to clip the values into [0, 1] instead of raising an error if they lie outside. - """ + """Initializes a BoundingBox object. + + Args: + x0: + x0 coordinate relative to image width. + y0: + y0 coordinate relative to image height. + x1: + x1 coordinate relative to image width. + y1: + y1 coordinate relative to image height. + clip_values: + If True, clips the coordinates to [0, 1]. + """ if clip_values: def clip_to_0_1(value: float) -> float: @@ -60,14 +69,12 @@ def clip_to_0_1(value: float) -> float: if x0 >= x1: raise ValueError( - f"x0 must be smaller than x1 for bounding box " - f"[{x0}, {y0}, {x1}, {y1}]" + f"x0 must be smaller than x1 for bounding box [{x0}, {y0}, {x1}, {y1}]" ) if y0 >= y1: raise ValueError( - "y0 must be smaller than y1 for bounding box " - f"[{x0}, {y0}, {x1}, {y1}]" + f"y0 must be smaller than y1 for bounding box [{x0}, {y0}, {x1}, {y1}]" ) self.x0 = x0 @@ -77,7 +84,20 @@ def clip_to_0_1(value: float) -> float: @classmethod def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox: - """Helper to convert from bounding box format with width and height. + """Creates a BoundingBox from x, y, width, and height. + + Args: + x: + x coordinate of the top-left corner relative to image width. + y: + y coordinate of the top-left corner relative to image height. + w: + Width of the bounding box relative to image width. + h: + Height of the bounding box relative to image height. + + Returns: + BoundingBox: A BoundingBox instance. Examples: >>> bbox = BoundingBox.from_x_y_w_h(0.1, 0.2, 0.2, 0.2) @@ -89,11 +109,23 @@ def from_x_y_w_h(cls, x: float, y: float, w: float, h: float) -> BoundingBox: def from_yolo_label( cls, x_center: float, y_center: float, w: float, h: float ) -> BoundingBox: - """Helper to convert from yolo label format - x_center, y_center, w, h --> x0, y0, x1, y1 + """Creates a BoundingBox from YOLO label format. + + Args: + x_center: + x coordinate of the center relative to image width. + y_center: + y coordinate of the center relative to image height. + w: + Width of the bounding box relative to image width. + h: + Height of the bounding box relative to image height. + + Returns: + BoundingBox: A BoundingBox instance. Examples: - >>> bbox = BoundingBox.from_yolo(0.5, 0.4, 0.2, 0.3) + >>> bbox = BoundingBox.from_yolo_label(0.5, 0.4, 0.2, 0.3) """ return cls( diff --git a/lightly/utils/debug.py b/lightly/utils/debug.py index 7d74d1f68..a2281ab6c 100644 --- a/lightly/utils/debug.py +++ b/lightly/utils/debug.py @@ -15,7 +15,6 @@ "'pip install lightly[matplotlib]'." ) except ImportError as ex: - # Matplotlib import can fail if an incompatible dateutil version is installed. plt = ex @@ -24,9 +23,9 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor: """Calculates the mean of the standard deviation of z along each dimension. This measure was used by [0] to determine the level of collapse of the - learned representations. If the returned number is 0., the outputs z have - collapsed to a constant vector. "If the output z has a zero-mean isotropic - Gaussian distribution" [0], the returned number should be close to 1/sqrt(d) + learned representations. If the returned value is 0., the outputs z have + collapsed to a constant vector. If the output z has a zero-mean isotropic + Gaussian distribution [0], the returned value should be close to 1/sqrt(d), where d is the dimensionality of the output. [0]: https://arxiv.org/abs/2011.10566 @@ -38,9 +37,7 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor: Returns: The mean of the standard deviation of the l2 normalized tensor z along each dimension. - """ - if len(z.shape) != 2: raise ValueError( f"Input tensor must have two dimensions but has {len(z.shape)}!" @@ -53,8 +50,18 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor: def apply_transform_without_normalize( image: Image.Image, transform, -): - """Applies the transform to the image but skips ToTensor and Normalize.""" +) -> Image.Image: + """Applies the transform to the image but skips ToTensor and Normalize. + + Args: + image: + The input PIL image. + transform: + The transformation to apply, excluding ToTensor and Normalize. + + Returns: + The transformed image. + """ skippable_transforms = ( torchvision.transforms.ToTensor, torchvision.transforms.Normalize, @@ -70,10 +77,10 @@ def apply_transform_without_normalize( def generate_grid_of_augmented_images( input_images: List[Image.Image], collate_function: Union[BaseCollateFunction, MultiViewCollateFunction], -): +) -> List[List[Image.Image]]: """Returns a grid of augmented images. Images in a column belong together. - This function ignores the transforms ToTensor and Normalize for visualization purposes. + This function ignores the ToTensor and Normalize transforms for visualization purposes. Args: input_images: @@ -116,9 +123,9 @@ def plot_augmented_images( input_images: List[Image.Image], collate_function: Union[BaseCollateFunction, MultiViewCollateFunction], ): - """Returns a figure showing original images in the left column and augmented images to their right. + """Plots original images and augmented images in a figure. - This function ignores the transforms ToTensor and Normalize for visualization purposes. + This function ignores the ToTensor and Normalize transforms for visualization purposes. Args: input_images: @@ -134,7 +141,6 @@ def plot_augmented_images( MultiViewCollateFunctions all the generated views are shown. """ - _check_matplotlib_available() if len(input_images) == 0: @@ -166,5 +172,6 @@ def plot_augmented_images( def _check_matplotlib_available() -> None: + """Checks if matplotlib is available. Raises an error if not.""" if isinstance(plt, Exception): raise plt diff --git a/lightly/utils/dependency.py b/lightly/utils/dependency.py index bdd36186c..8baeb2a77 100644 --- a/lightly/utils/dependency.py +++ b/lightly/utils/dependency.py @@ -3,24 +3,42 @@ @functools.lru_cache(maxsize=1) def torchvision_vit_available() -> bool: + """Checks if Vision Transformer (ViT) models are available in torchvision. + + This function checks if the `vision_transformer` module is available in torchvision, + which requires torchvision version >= 0.12. It also handles exceptions related to + CUDA version mismatches and installation issues. + + Returns: + True if the Vision Transformer (ViT) models are available in torchvision, + otherwise False. + """ try: - import torchvision.models.vision_transformer # Requires torchvision >=0.12 + import torchvision.models.vision_transformer # Requires torchvision >=0.12. except ( - RuntimeError, # Different CUDA versions for torch and torchvision - OSError, # Different CUDA versions for torch and torchvision (old) - ImportError, # No installation or old version of torchvision + RuntimeError, # Different CUDA versions for torch and torchvision. + OSError, # Different CUDA versions for torch and torchvision (old). + ImportError, # No installation or old version of torchvision. ): return False - else: - return True + return True @functools.lru_cache(maxsize=1) def timm_vit_available() -> bool: + """Checks if Vision Transformer (ViT) models are available in the timm library. + + This function checks if the `vision_transformer` module and `LayerType` from timm + are available, which requires timm version >= 0.3.3 and >= 0.9.9, respectively. + + Returns: + True if the Vision Transformer (ViT) models are available in timm, + otherwise False. + + """ try: import timm.models.vision_transformer # Requires timm >= 0.3.3 from timm.layers import LayerType # Requires timm >= 0.9.9 except ImportError: return False - else: - return True + return True diff --git a/lightly/utils/dist.py b/lightly/utils/dist.py index 7292afaca..5143f3597 100644 --- a/lightly/utils/dist.py +++ b/lightly/utils/dist.py @@ -8,19 +8,19 @@ class GatherLayer(torch.autograd.Function): """Gather tensors from all processes, supporting backward propagation. - This code was taken and adapted from here: + Adapted from the Solo-Learn project: https://github.com/vturrisi/solo-learn/blob/b69b4bd27472593919956d9ac58902a301537a4d/solo/utils/misc.py#L187 """ @staticmethod - def forward(ctx, input: torch.Tensor) -> Tuple[torch.Tensor, ...]: # type: ignore + def forward(ctx: FunctionCtx, input: torch.Tensor) -> Tuple[torch.Tensor, ...]: # type: ignore output = [torch.empty_like(input) for _ in range(dist.get_world_size())] dist.all_gather(output, input) return tuple(output) @staticmethod - def backward(ctx, *grads) -> torch.Tensor: # type: ignore + def backward(ctx: FunctionCtx, *grads: torch.Tensor) -> torch.Tensor: # type: ignore all_gradients = torch.stack(grads) dist.all_reduce(all_gradients) grad_out = all_gradients[dist.get_rank()] @@ -38,7 +38,7 @@ def world_size() -> int: def gather(input: torch.Tensor) -> Tuple[torch.Tensor]: - """Gathers this tensor from all processes. Supports backprop.""" + """Gathers a tensor from all processes and supports backpropagation.""" return GatherLayer.apply(input) # type: ignore[no-any-return] @@ -62,6 +62,9 @@ def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor: device: Device on which the matrix should be created. + Returns: + A tensor with the appropriate diagonal filled for this rank. + """ rows = torch.arange(n, device=device, dtype=torch.long) cols = rows + rank() * n @@ -74,7 +77,7 @@ def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor: def rank_zero_only(fn: Callable[..., R]) -> Callable[..., Optional[R]]: - """Decorator that only runs the function on the process with rank 0. + """Decorator to ensure the function only runs on the process with rank 0. Example: >>> @rank_zero_only diff --git a/lightly/utils/embeddings_2d.py b/lightly/utils/embeddings_2d.py index f3e277a71..c10d4bc3a 100644 --- a/lightly/utils/embeddings_2d.py +++ b/lightly/utils/embeddings_2d.py @@ -1,4 +1,4 @@ -""" Transform embeddings to two-dimensional space for visualization. """ +"""Transforms embeddings to two-dimensional space for visualization.""" # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved @@ -21,13 +21,18 @@ class PCA(object): Number of principal components to keep. eps: Epsilon for numerical stability. + mean: + Mean of the data. + w: + Eigenvectors of the covariance matrix. + """ def __init__(self, n_components: int = 2, eps: float = 1e-10): self.n_components = n_components + self.eps = eps self.mean: Optional[NDArray[np.float32]] = None self.w: Optional[NDArray[np.float32]] = None - self.eps = eps def fit(self, X: NDArray[np.float32]) -> PCA: """Fits PCA to data in X. @@ -37,7 +42,7 @@ def fit(self, X: NDArray[np.float32]) -> PCA: Datapoints stored in numpy array of size n x d. Returns: - PCA object to transform datapoints. + PCA: The fitted PCA object to transform data points. """ X = X.astype(np.float32) @@ -46,7 +51,7 @@ def fit(self, X: NDArray[np.float32]) -> PCA: X = X - self.mean + self.eps cov = np.cov(X.T) / X.shape[0] v, w = np.linalg.eig(cov) - idx = v.argsort()[::-1] + idx = v.argsort()[::-1] # Sort eigenvalues in descending order v, w = v[idx], w[:, idx] self.w = w return self @@ -62,10 +67,13 @@ def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]: Numpy array of n x p datapoints where p <= d. Raises: - ValueError: If PCA was not fitted before. + ValueError: + If PCA is not fitted before calling this method. + """ if self.mean is None or self.w is None: raise ValueError("PCA not fitted yet. Call fit() before transform().") + X = X.astype(np.float32) X = X - self.mean + self.eps transformed: NDArray[np.float32] = X.dot(self.w)[:, : self.n_components] @@ -77,7 +85,7 @@ def fit_pca( n_components: int = 2, fraction: Optional[float] = None, ) -> PCA: - """Fits PCA to randomly selected subset of embeddings. + """Fits PCA to a randomly selected subset of embeddings. For large datasets, it can be unfeasible to perform PCA on the whole data. This method can fit a PCA on a fraction of the embeddings in order to save @@ -101,8 +109,7 @@ def fit_pca( """ if fraction is not None: if fraction < 0.0 or fraction > 1.0: - msg = f"fraction must be in [0, 1] but was {fraction}." - raise ValueError(msg) + raise ValueError(f"fraction must be in [0, 1] but was {fraction}.") N = embeddings.shape[0] n = N if fraction is None else min(N, int(N * fraction)) diff --git a/lightly/utils/hipify.py b/lightly/utils/hipify.py index 37fbaf8a1..389295d0f 100644 --- a/lightly/utils/hipify.py +++ b/lightly/utils/hipify.py @@ -4,6 +4,8 @@ class bcolors: + """ANSI escape sequences for colored terminal output.""" + HEADER = "\033[95m" OKBLUE = "\033[94m" OKGREEN = "\033[92m" @@ -15,6 +17,18 @@ class bcolors: def print_as_warning(message: str, warning_class: Type[Warning] = UserWarning) -> None: + """Prints a warning message with custom formatting. + + Temporarily overrides the default warning format to apply custom styling, then + restores the original formatting after the warning is printed. + + Args: + message: + The warning message to print. + warning_class: + The type of warning to raise. + + """ old_format = copy.copy(warnings.formatwarning) warnings.formatwarning = _custom_formatwarning warnings.warn(message, warning_class) @@ -28,5 +42,24 @@ def _custom_formatwarning( lineno: int, line: Optional[str] = None, ) -> str: - # ignore everything except the message + """Custom format for warning messages. + + Only the warning message is printed, with additional styling applied. + + Args: + message: + The warning message or warning object. + category: + The warning class. + filename: + The file where the warning originated. + lineno: + The line number where the warning occurred. + line: + The line of code that triggered the warning (if available). + + Returns: + str: The formatted warning message. + + """ return f"{bcolors.WARNING}{message}{bcolors.WARNING}\n" diff --git a/lightly/utils/lars.py b/lightly/utils/lars.py index 315f14559..036178977 100644 --- a/lightly/utils/lars.py +++ b/lightly/utils/lars.py @@ -36,7 +36,6 @@ class LARS(Optimizer): >>> input = torch.Tensor(10) >>> target = torch.Tensor([1.]) >>> loss_fn = lambda input, target: (input - target) ** 2 - >>> # >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() @@ -99,11 +98,10 @@ def __init__( def __setstate__(self, state: Dict[str, Any]) -> None: super().__setstate__(state) - for group in self.param_groups: group.setdefault("nesterov", False) - # Type ignore for overloads is required for Python 3.7 + # Type ignore for overloads is required for Python 3.7. @overload # type: ignore[override] def step(self, closure: None = None) -> None: ... @@ -125,7 +123,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] with torch.enable_grad(): loss = closure() - # exclude scaling for params with 0 weight decay + # Exclude scaling for params with 0 weight decay. for group in self.param_groups: weight_decay = group["weight_decay"] momentum = group["momentum"] @@ -140,7 +138,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] p_norm = torch.norm(p.data) g_norm = torch.norm(p.grad.data) - # lars scaling + weight decay part + # Apply Lars scaling and weight decay. if weight_decay != 0: if p_norm != 0 and g_norm != 0: lars_lr = p_norm / ( @@ -151,7 +149,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] d_p = d_p.add(p, alpha=weight_decay) d_p *= lars_lr - # sgd part + # Apply momentum. if momentum != 0: param_state = self.state[p] if "momentum_buffer" not in param_state: @@ -159,6 +157,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] else: buf = param_state["momentum_buffer"] buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: d_p = d_p.add(buf, alpha=momentum) else: diff --git a/lightly/utils/scheduler.py b/lightly/utils/scheduler.py index f262245da..ef913a78b 100644 --- a/lightly/utils/scheduler.py +++ b/lightly/utils/scheduler.py @@ -32,9 +32,9 @@ def cosine_schedule( """ if step < 0: - raise ValueError(f"Current step number {step} can't be negative") + raise ValueError(f"Current step number {step} can't be negative.") if max_steps < 1: - raise ValueError(f"Total step number {max_steps} must be >= 1") + raise ValueError(f"Total step number {max_steps} must be >= 1.") if period is None and step > max_steps: warnings.warn( f"Current step number {step} exceeds max_steps {max_steps}.", @@ -102,9 +102,9 @@ def cosine_warmup_schedule( Cosine decay value. """ if warmup_steps < 0: - raise ValueError(f"Warmup steps {warmup_steps} can't be negative") + raise ValueError(f"Warmup steps {warmup_steps} can't be negative.") if warmup_steps > max_steps: - raise ValueError(f"Warmup steps {warmup_steps} must be <= max_steps") + raise ValueError(f"Warmup steps {warmup_steps} must be <= max_steps.") if step > max_steps: warnings.warn( f"Current step number {step} exceeds max_steps {max_steps}.", @@ -157,7 +157,7 @@ class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR): Target learning rate for warmup. Defaults to start_value. Note: The `epoch` arguments do not necessarily have to be epochs. Any step or index - can be used. The naming follows the Pytorch convention to use `epoch` for the steps + can be used. The naming follows the PyTorch convention to use `epoch` for the steps in the scheduler. """ @@ -181,6 +181,7 @@ def __init__( self.period = period self.warmup_start_value = warmup_start_value self.warmup_end_value = warmup_end_value + super().__init__( optimizer=optimizer, lr_lambda=self.scale_lr, @@ -189,8 +190,7 @@ def __init__( ) def scale_lr(self, epoch: int) -> float: - """ - Scale learning rate according to the current epoch number. + """Scale learning rate according to the current epoch number. Args: epoch: From f196ad9648a560c0ea14fce2b9596fbce35884d9 Mon Sep 17 00:00:00 2001 From: ayush22iitbhu Date: Fri, 18 Oct 2024 19:38:31 +0530 Subject: [PATCH 6/7] Add Documentation for lightly/models/modules (#1700) --- lightly/models/modules/center.py | 23 ++- lightly/models/modules/heads.py | 214 +++++++++++++++++++-------- lightly/models/modules/heads_timm.py | 60 ++++++++ 3 files changed, 236 insertions(+), 61 deletions(-) diff --git a/lightly/models/modules/center.py b/lightly/models/modules/center.py index 55eee220d..21be7b9bf 100644 --- a/lightly/models/modules/center.py +++ b/lightly/models/modules/center.py @@ -31,6 +31,11 @@ def __init__( mode: str = "mean", momentum: float = 0.9, ) -> None: + """Initializes the Center module with the specified parameters. + + Raises: + ValueError: If an unknown mode is provided. + """ super().__init__() center_fn = CENTER_MODE_TO_FUNCTION.get(mode) @@ -49,8 +54,10 @@ def __init__( @property def value(self) -> Tensor: - """The current value of the center. Use this property to do any operations based - on the center.""" + """The current value of the center. + + Use this property to do any operations based on the center. + """ return self.center @torch.no_grad() @@ -75,7 +82,17 @@ def _center_mean(self, x: Tensor) -> Tensor: @torch.no_grad() def center_mean(x: Tensor, dim: Tuple[int, ...]) -> Tensor: - """Returns the center of the input tensor by calculating the mean.""" + """Returns the center of the input tensor by calculating the mean. + + Args: + x: + Input tensor. + dim: + Dimensions along which the mean is calculated. + + Returns: + The center of the input tensor. + """ batch_center = torch.mean(x, dim=dim, keepdim=True) if dist.is_available() and dist.is_initialized(): dist.all_reduce(batch_center) diff --git a/lightly/models/modules/heads.py b/lightly/models/modules/heads.py index d9dcb6989..1bfb7ccbb 100644 --- a/lightly/models/modules/heads.py +++ b/lightly/models/modules/heads.py @@ -29,7 +29,6 @@ class ProjectionHead(nn.Module): >>> (256, 256, nn.BatchNorm1d(256), nn.ReLU()), >>> (256, 128, None, None) >>> ]) - """ def __init__( @@ -41,6 +40,7 @@ def __init__( ], ], ) -> None: + """Initializes the ProjectionHead module with the specified blocks.""" super().__init__() layers: List[nn.Module] = [] @@ -60,7 +60,6 @@ def forward(self, x: Tensor) -> Tensor: Args: x: Input of shape bsz x num_ftrs. - """ projection: Tensor = self.layers(x) return projection @@ -73,13 +72,22 @@ class BarlowTwinsProjectionHead(ProjectionHead): units. The first two layers of the projector are followed by a batch normalization layer and rectified linear units." [0] - [0]: 2021, Barlow Twins, https://arxiv.org/abs/2103.03230 - + - [0]: 2021, Barlow Twins, https://arxiv.org/abs/2103.03230 """ def __init__( self, input_dim: int = 2048, hidden_dim: int = 8192, output_dim: int = 8192 ): + """Initializes the BarlowTwinsProjectionHead with the specified dimensions. + + Args: + input_dim: + Dimensionality of the input features. + hidden_dim: + Dimensionality of the hidden layers. + output_dim: + Dimensionality of the output features. + """ super(BarlowTwinsProjectionHead, self).__init__( [ (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), @@ -96,13 +104,13 @@ class BYOLProjectionHead(ProjectionHead): batch normalization, rectified linear units (ReLU), and a final linear layer with output dimension 256." [0] - [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733 - + - [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733 """ def __init__( self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256 ): + """Initializes the BYOLProjectionHead with the specified dimensions.""" super(BYOLProjectionHead, self).__init__( [ (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), @@ -118,8 +126,7 @@ class BYOLPredictionHead(ProjectionHead): batch normalization, rectified linear units (ReLU), and a final linear layer with output dimension 256." [0] - [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733 - + - [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733 """ def __init__( @@ -143,9 +150,9 @@ class MoCoProjectionHead(ProjectionHead): hidden layers of both MLPs are 4096-d and are with ReLU; the output layers of both MLPs are 256-d, without ReLU. In MoCo v3, all layers in both MLPs have BN" [2] - [0]: MoCo v1, 2020, https://arxiv.org/abs/1911.05722 - [1]: MoCo v2, 2020, https://arxiv.org/abs/2003.04297 - [2]: MoCo v3, 2021, https://arxiv.org/abs/2104.02057 + - [0]: MoCo v1, 2020, https://arxiv.org/abs/1911.05722 + - [1]: MoCo v2, 2020, https://arxiv.org/abs/2003.04297 + - [2]: MoCo v3, 2021, https://arxiv.org/abs/2104.02057 """ def __init__( @@ -159,12 +166,16 @@ def __init__( """Initialize a new MoCoProjectionHead instance. Args: - input_dim: Number of input dimensions. - hidden_dim: Number of hidden dimensions (2048 for v2, 4096 for v3). - output_dim: Number of output dimensions (128 for v2, 256 for v3). - num_layers: Number of hidden layers (2 for v2, 3 for v3). - batch_norm: Whether or not to use batch norms. - (False for v2, True for v3) + input_dim: + Number of input dimensions. + hidden_dim: + Number of hidden dimensions (2048 for v2, 4096 for v3). + output_dim: + Number of output dimensions (128 for v2, 256 for v3). + num_layers: + Number of hidden layers (2 for v2, 3 for v3). + batch_norm: + Whether or not to use batch norms. (False for v2, True for v3). """ layers: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]] = [] layers.append( @@ -204,13 +215,22 @@ class NNCLRProjectionHead(ProjectionHead): layers are followed by batch-normalization [36]. All the batch-norm layers except the last layer are followed by ReLU activation." [0] - [0]: NNCLR, 2021, https://arxiv.org/abs/2104.14548 - + - [0]: NNCLR, 2021, https://arxiv.org/abs/2104.14548 """ def __init__( self, input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 256 ): + """Initializes the NNCLRProjectionHead with the specified dimensions. + + Args: + input_dim: + Dimensionality of the input features. + hidden_dim: + Dimensionality of the hidden layers. + output_dim: + Dimensionality of the output features. + """ super(NNCLRProjectionHead, self).__init__( [ (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), @@ -227,8 +247,7 @@ class NNCLRPredictionHead(ProjectionHead): of size [4096,d]. The hidden layer of the prediction MLP is followed by batch-norm and ReLU. The last layer has no batch-norm or activation." [0] - [0]: NNCLR, 2021, https://arxiv.org/abs/2104.14548 - + - [0]: NNCLR, 2021, https://arxiv.org/abs/2104.14548 """ def __init__( @@ -265,11 +284,16 @@ def __init__( """Initialize a new SimCLRProjectionHead instance. Args: - input_dim: Number of input dimensions. - hidden_dim: Number of hidden dimensions. - output_dim: Number of output dimensions. - num_layers: Number of hidden layers (2 for v1, 3+ for v2). - batch_norm: Whether or not to use batch norms. + input_dim: + Number of input dimensions. + hidden_dim: + Number of hidden dimensions. + output_dim: + Number of output dimensions. + num_layers: + Number of hidden layers (2 for v1, 3+ for v2). + batch_norm: + Whether or not to use batch norms. """ layers: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]] = [] layers.append( @@ -307,8 +331,7 @@ class SimSiamProjectionHead(ProjectionHead): layer, including its output fc. Its output fc has no ReLU. The hidden fc is 2048-d. This MLP has 3 layers." [0] - [0]: SimSiam, 2020, https://arxiv.org/abs/2011.10566 - + - [0]: SimSiam, 2020, https://arxiv.org/abs/2011.10566 """ def __init__( @@ -329,13 +352,21 @@ def __init__( class SMoGPrototypes(nn.Module): - """SMoG prototypes module for synchronous momentum grouping.""" + """SMoG prototypes module for synchronous momentum grouping. + + Args: + group_features: + Tensor containing the group features. + beta: + Beta parameter for momentum updating. + """ def __init__( self, group_features: Tensor, beta: float, ): + """Initializes the SMoGPrototypes module with the specified parameter.""" super(SMoGPrototypes, self).__init__() self.group_features = nn.Parameter(group_features, requires_grad=False) self.beta = beta @@ -354,8 +385,7 @@ def forward( Temperature parameter for calculating the logits. Returns: - The logits. - + The computed logits. """ x = torch.nn.functional.normalize(x, dim=1) group_features = torch.nn.functional.normalize(group_features, dim=1) @@ -371,7 +401,6 @@ def get_updated_group_features(self, x: Tensor) -> Tensor: Returns: The updated group features. - """ assignments = self.assign_groups(x) group_features = torch.clone(self.group_features.data) @@ -392,11 +421,11 @@ def assign_groups(self, x: Tensor) -> Tensor: """Assigns each representation in x to a group based on cosine similarity. Args: - Tensor of shape bsz x dim. + x: + Tensor of shape (bsz, dim). Returns: - Tensor of shape bsz indicating group assignments. - + Tensor of shape (bsz,) indicating group assignments. """ return torch.argmax(self.forward(x, self.group_features), dim=-1) @@ -408,13 +437,22 @@ class SMoGProjectionHead(ProjectionHead): followed by a BatchNorm [28] and an activation function. (...) The output layer of projection head also has BN" [0] - [0]: SMoG, 2022, https://arxiv.org/pdf/2207.06167.pdf - + - [0]: SMoG, 2022, https://arxiv.org/pdf/2207.06167.pdf """ def __init__( self, input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128 ): + """Initializes the SMoGProjectionHead with the specified dimensions. + + Args: + input_dim: + Dimensionality of the input features. + hidden_dim: + Dimensionality of the hidden layers. + output_dim: + Dimensionality of the output features. + """ super(SMoGProjectionHead, self).__init__( [ (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), @@ -435,13 +473,23 @@ class SMoGPredictionHead(ProjectionHead): followed by a BatchNorm [28] and an activation function. (...) The output layer of projection head also has BN" [0] - [0]: SMoG, 2022, https://arxiv.org/pdf/2207.06167.pdf - + - [0]: SMoG, 2022, https://arxiv.org/pdf/2207.06167.pdf """ def __init__( self, input_dim: int = 128, hidden_dim: int = 2048, output_dim: int = 128 ): + """Initializes the SMoGPredictionHead with the specified dimensions. + + Args: + input_dim: + Dimensionality of the input features. + hidden_dim: + Dimensionality of the hidden layers. + output_dim: + Dimensionality of the output features. + """ + super(SMoGPredictionHead, self).__init__( [ (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), @@ -456,13 +504,22 @@ class SimSiamPredictionHead(ProjectionHead): "The prediction MLP (h) has BN applied to its hidden fc layers. Its output fc does not have BN (...) or ReLU. This MLP has 2 layers." [0] - [0]: SimSiam, 2020, https://arxiv.org/abs/2011.10566 - + - [0]: SimSiam, 2020, https://arxiv.org/abs/2011.10566 """ def __init__( self, input_dim: int = 2048, hidden_dim: int = 512, output_dim: int = 2048 ): + """Initializes the SimSiamPredictionHead with the specified dimensions. + + Args: + input_dim: + Dimensionality of the input features. + hidden_dim: + Dimensionality of the hidden layers. + output_dim: + Dimensionality of the output features. + """ super(SimSiamPredictionHead, self).__init__( [ (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), @@ -474,12 +531,13 @@ def __init__( class SwaVProjectionHead(ProjectionHead): """Projection head used for SwaV. - [0]: SwAV, 2020, https://arxiv.org/abs/2006.09882 + - [0]: SwAV, 2020, https://arxiv.org/abs/2006.09882 """ def __init__( self, input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128 ): + """Initializes the SwaVProjectionHead with the specified dimensions.""" super(SwaVProjectionHead, self).__init__( [ (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), @@ -513,7 +571,6 @@ class SwaVPrototypes(nn.Module): >>> >>> # logits has shape bsz x 512 >>> logits = prototypes(features) - """ def __init__( @@ -522,7 +579,9 @@ def __init__( n_prototypes: Union[List[int], int] = 3000, n_steps_frozen_prototypes: int = 0, ): + """Intializes the SwaVPrototypes module with the specified parameters""" super(SwaVPrototypes, self).__init__() + # Default to a list of 1 if n_prototypes is an int. self.n_prototypes = ( n_prototypes if isinstance(n_prototypes, list) else [n_prototypes] @@ -536,6 +595,18 @@ def __init__( def forward( self, x: Tensor, step: Optional[int] = None ) -> Union[Tensor, List[Tensor]]: + """Forward pass of the SwaVPrototypes module. + + Args: + x: + Input tensor. + step: + Current training step. + + Returns: + The logits after passing through the prototype heads. Returns a single tensor + if there's one prototype head, otherwise returns a list of tensors. + """ self._freeze_prototypes_if_required(step) out = [] for layer in self.heads: @@ -548,6 +619,7 @@ def normalize(self) -> None: utils.normalize_weight(layer.weight) def _freeze_prototypes_if_required(self, step: Optional[int] = None) -> None: + """Freezes the prototypes if the specified number of steps has been reached.""" if self.n_steps_frozen_prototypes > 0: if step is None: raise ValueError( @@ -588,7 +660,6 @@ class DINOProjectionHead(ProjectionHead): Whether or not to weight normalize the last layer of the DINO head. Not normalizing leads to better performance but can make the training unstable. - """ def __init__( @@ -601,6 +672,7 @@ def __init__( freeze_last_layer: int = -1, norm_last_layer: bool = True, ): + """Initializes the DINOProjectionHead with the specified dimensions.""" bn = nn.BatchNorm1d(hidden_dim) if batch_norm else None super().__init__( @@ -672,16 +744,24 @@ def __init__( """Initialize a new MMCRProjectionHead instance. Args: - input_dim: Number of input dimensions. - hidden_dim: Number of hidden dimensions. - output_dim: Number of output dimensions. - num_layers: Number of hidden layers. - batch_norm: Whether or not to use batch norms. - use_bias: Whether or not to use bias in the linear layers. + input_dim: + Number of input dimensions. + hidden_dim: + Number of hidden dimensions. + output_dim: + Number of output dimensions. + num_layers: + Number of hidden layers. + batch_norm: + Whether or not to use batch norms. + use_bias: + Whether or not to use bias in the linear layers. """ layers: List[ Tuple[int, int, Optional[nn.Module], Optional[nn.Module], bool] ] = [] + + # Add the first layer layers.append( ( input_dim, @@ -691,6 +771,8 @@ def __init__( use_bias, ) ) + + # Add the hidden layers for _ in range(num_layers - 1): layers.append( ( @@ -701,6 +783,8 @@ def __init__( use_bias, ) ) + + # Add the output layer layers.append((hidden_dim, output_dim, None, None, use_bias)) super().__init__(layers) @@ -710,6 +794,7 @@ class MSNProjectionHead(ProjectionHead): "We train with a 3-layer projection head with output dimension 256 and batch-normalization at the input and hidden layers.." [0] + Code inspired by [1]. - [0]: Masked Siamese Networks, 2022, https://arxiv.org/abs/2204.07141 @@ -730,6 +815,7 @@ def __init__( hidden_dim: int = 2048, output_dim: int = 256, ): + """Initializes the MSNProjectionHead with the specified dimensions.""" super().__init__( blocks=[ (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.GELU()), @@ -746,13 +832,13 @@ class TiCoProjectionHead(ProjectionHead): batch normalization, rectified linear units (ReLU), and a final linear layer with output dimension 256." [0] - [0]: TiCo, 2022, https://arxiv.org/pdf/2206.10698.pdf - + - [0]: TiCo, 2022, https://arxiv.org/pdf/2206.10698.pdf """ def __init__( self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256 ): + """Initializes the TiCoProjectionHead with the specified dimensions.""" super(TiCoProjectionHead, self).__init__( [ (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), @@ -768,8 +854,7 @@ class VICRegProjectionHead(ProjectionHead): units. The first two layers of the projector are followed by a batch normalization layer and rectified linear units." [0] - [0]: 2022, VICReg, https://arxiv.org/pdf/2105.04906.pdf - + - [0]: 2022, VICReg, https://arxiv.org/pdf/2105.04906.pdf """ def __init__( @@ -779,6 +864,18 @@ def __init__( output_dim: int = 8192, num_layers: int = 3, ): + """Initializes the VICRegProjectionHead with the specified dimensions. + + Args: + input_dim: + Dimensionality of the input features. + hidden_dim: + Dimensionality of the hidden layers. + output_dim: + Dimensionality of the output features. + num_layers: + Number of layers in the projection head. + """ hidden_layers = [ (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()) for _ in range(num_layers - 2) # Exclude first and last layer. @@ -795,16 +892,16 @@ def __init__( class VicRegLLocalProjectionHead(ProjectionHead): """Projection head used for the local head of VICRegL. - The projector network has three linear layers. The first two layers of the projector - are followed by a batch normalization layer and rectified linear units. - - 2022, VICRegL, https://arxiv.org/abs/2210.01571 + "The projector network has three linear layers. The first two layers of the projector + are followed by a batch normalization layer and rectified linear units." [0] + - [0]: 2022, VICRegL, https://arxiv.org/abs/2210.01571 """ def __init__( self, input_dim: int = 2048, hidden_dim: int = 8192, output_dim: int = 8192 ): + """Initializes the VicRegLLocalProjectionHead with the specified dimensions.""" super(VicRegLLocalProjectionHead, self).__init__( [ (input_dim, hidden_dim, nn.LayerNorm(hidden_dim), nn.ReLU()), @@ -826,6 +923,7 @@ class DenseCLProjectionHead(ProjectionHead): def __init__( self, input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128 ): + """Initializes the DenseCLProjectionHead with the specified dimensions.""" super().__init__( [ (input_dim, hidden_dim, None, nn.ReLU()), diff --git a/lightly/models/modules/heads_timm.py b/lightly/models/modules/heads_timm.py index 7ea20de21..44ae6a4ca 100644 --- a/lightly/models/modules/heads_timm.py +++ b/lightly/models/modules/heads_timm.py @@ -9,6 +9,22 @@ class AIMPredictionHeadBlock(Module): """Prediction head block for AIM [0]. - [0]: AIM, 2024, https://arxiv.org/abs/2401.08541 + + Args: + input_dim: + Dimensionality of the input features. + output_dim: + Dimensionality of the output features. + mlp_ratio: + Ratio used to determine the hidden layer size in the MLP. + proj_drop: + Dropout rate for the projection layer. + act_layer: + Activation layer to use. + norm_layer: + Normalization layer to use. + mlp_layer: + MLP layer to use. """ def __init__( @@ -21,6 +37,8 @@ def __init__( norm_layer: Type[Module] = LayerNorm, mlp_layer: Type[Module] = Mlp, ) -> None: + """Initializes the AIMPredictionHeadBlock module with the specified parameters.""" + super().__init__() self.norm = norm_layer(input_dim) # type: ignore[call-arg] self.mlp = mlp_layer( # type: ignore[call-arg] @@ -33,6 +51,15 @@ def __init__( ) def forward(self, x: Tensor) -> Tensor: + """Forward pass of the AIMPredictionHeadBlock. + + Args: + x: + Input tensor. + + Returns: + Output tensor after applying the MLP and normalization. + """ x = x + self.mlp(self.norm(x)) return x @@ -41,6 +68,28 @@ class AIMPredictionHead(Module): """Prediction head for AIM [0]. - [0]: AIM, 2024, https://arxiv.org/abs/2401.08541 + + Args: + input_dim: + Dimensionality of the input features. + output_dim: + Dimensionality of the output features. + hidden_dim: + Dimensionality of the hidden layer. + num_blocks: + Number of blocks in the prediction head. + mlp_ratio: + Ratio used to determine the hidden layer size in the MLP. + proj_drop: + Dropout rate for the projection layer. + act_layer: + Activation layer to use. + norm_layer: + Normalization layer to use. + mlp_layer: + MLP layer to use. + block_fn: + Block function to use for the prediction head. """ def __init__( @@ -56,6 +105,8 @@ def __init__( mlp_layer: Type[Module] = Mlp, block_fn: Type[Module] = AIMPredictionHeadBlock, ) -> None: + """Initializes the AIMPredictionHead module with the specified parameters.""" + super().__init__() self.blocks = Sequential( # Linear layer to project the input dimension to the hidden dimension. @@ -79,5 +130,14 @@ def __init__( ) def forward(self, x: Tensor) -> Tensor: + """Forward pass of the AIMPredictionHead. + + Args: + x: + Input tensor. + + Returns: + Output tensor after processing through the prediction head blocks. + """ x = self.blocks(x) return x From e82cb9136dc9978cfc7002cfafaad7d387a9f764 Mon Sep 17 00:00:00 2001 From: payo101 <35198092+payo101@users.noreply.github.com> Date: Sun, 20 Oct 2024 18:59:27 +0530 Subject: [PATCH 7/7] Made requested changes to Amplitude rescale transform and test --- .../transforms/amplitude_rescale_transform.py | 50 ++++++++----------- .../test_amplitude_rescale_transform.py | 9 +++- 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/lightly/transforms/amplitude_rescale_transform.py b/lightly/transforms/amplitude_rescale_transform.py index d1dc6a77a..e09128ced 100644 --- a/lightly/transforms/amplitude_rescale_transform.py +++ b/lightly/transforms/amplitude_rescale_transform.py @@ -3,40 +3,32 @@ import numpy as np import torch from torch import Tensor +from torch.distributions import Uniform class AmplitudeRescaleTranform: - """ - This transform will rescale the amplitude of the Fourier Spectrum (`input`) of the image and return it. - The scaling value *p* will range within `[m, n)` - ``` - img = torch.randn(3, 64, 64) - - rfft = lightly.transforms.RFFT2DTransform() - rfft_img = rfft(img) - - art = AmplitudeRescaleTransform() - rescaled_img = art(rfft_img) - ``` - - # Intial Arguments - **range**: *Tuple of float_like* - The low `m` and high `n` values such that **p belongs to [m, n)**. - # Parameters: - **input**: _torch.Tensor_ - The 2D Discrete Fourier Tranform of an Image. - # Returns: - **output**:_torch.Tensor_ - The Fourier spectrum of the 2D Image with rescaled Amplitude. - """ + """Implementation of amplitude rescaling transformation. - def __init__(self, range: Tuple[float, float] = (0.8, 1.75)) -> None: - self.m = range[0] - self.n = range[1] + This transform will rescale the amplitude of the Fourier Spectrum (`freq_image`) of the image and return it. - def __call__(self, input: Tensor) -> Tensor: - p = np.random.uniform(self.m, self.n) + Attributes: + dist: + Uniform distribution in `[m, n)` from which the scaling value will be selected. + """ - output = input * p + def __init__(self, range: Tuple[float, float] = (0.8, 1.75)) -> None: + self.dist = Uniform(range[0], range[1]) + + def __call__(self, freq_image: Tensor) -> Tensor: + amplitude = torch.sqrt(freq_image.real**2 + freq_image.imag**2) + + phase = torch.atan2(freq_image.imag, freq_image.real) + # p with shape (H, W) + p = self.dist.sample(freq_image.shape[1:]).to(freq_image.device) + # Unsqueeze to add channel dimension. + amplitude *= p.unsqueeze(0) + real = amplitude * torch.cos(phase) + imag = amplitude * torch.sin(phase) + output = torch.complex(real, imag) return output diff --git a/tests/transforms/test_amplitude_rescale_transform.py b/tests/transforms/test_amplitude_rescale_transform.py index 6ded2746f..0b8d38eaa 100644 --- a/tests/transforms/test_amplitude_rescale_transform.py +++ b/tests/transforms/test_amplitude_rescale_transform.py @@ -1,7 +1,11 @@ import numpy as np import torch -from lightly.transforms import AmplitudeRescaleTranform, RFFT2DTransform +from lightly.transforms import ( + AmplitudeRescaleTranform, + IRFFT2DTransform, + RFFT2DTransform, +) # Testing function image -> FFT -> AmplitudeRescale. @@ -18,4 +22,5 @@ def test() -> None: ampRescaleTf_2 = AmplitudeRescaleTranform(range=(1.0, 2.0)) rescaled_rfft_2 = ampRescaleTf_2(rfft) - assert rescaled_rfft_1.shape == rfft.shape and rescaled_rfft_2.shape == rfft.shape + assert rescaled_rfft_1.shape == rfft.shape + assert rescaled_rfft_2.shape == rfft.shape