From fd29f1d7ca5fc24027a7174f25cd0596b1929764 Mon Sep 17 00:00:00 2001 From: Kushagra Dwivedi <100153737+pearguacamole@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:39:03 +0530 Subject: [PATCH] Add PhaseShift Transform (#1714) --- .gitignore | 2 +- lightly/transforms/__init__.py | 1 + lightly/transforms/phase_shift_transform.py | 60 +++++++++++++++++++ .../transforms/test_phase_shift_transform.py | 27 +++++++++ 4 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 lightly/transforms/phase_shift_transform.py create mode 100644 tests/transforms/test_phase_shift_transform.py diff --git a/.gitignore b/.gitignore index 83d4d1fa2..07753b4fa 100644 --- a/.gitignore +++ b/.gitignore @@ -22,7 +22,7 @@ docs/source/getting_started/resources #ignore venv venv - +.venv #ignore pycharm IDE .idea diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index c6401ee53..422b64cac 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -26,6 +26,7 @@ from lightly.transforms.mmcr_transform import MMCRTransform from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform +from lightly.transforms.phase_shift_transform import PhaseShiftTransform from lightly.transforms.pirl_transform import PIRLTransform from lightly.transforms.random_frequency_mask_transform import ( RandomFrequencyMaskTransform, diff --git a/lightly/transforms/phase_shift_transform.py b/lightly/transforms/phase_shift_transform.py new file mode 100644 index 000000000..e2c37864c --- /dev/null +++ b/lightly/transforms/phase_shift_transform.py @@ -0,0 +1,60 @@ +from typing import Tuple + +import torch +from torch import Tensor +from torch.distributions import Uniform +from torch.distributions.bernoulli import Bernoulli + + +class PhaseShiftTransform: + """Implementation of phase shifting transformation. + + + Applies a random phase shift `theta` (positive or negative) to the Fourier spectrum (`freq_image`) of the image and returns the transformed spectrum. + + Attributes: + dist: + A uniform distribution in the range `[p, q)` from which the magnitude of the + phase shift `theta` is selected. + include_negatives: + A flag indicating whether negative values of `theta` should be included. + If `True`, both positive and negative shifts are applied. + sign_dist: + A Bernoulli distribution used to decide the sign of `theta`, based on a + given probability `sign_probability`, if negative values are included. + """ + + def __init__( + self, + range: Tuple[float, float] = (0.4, 0.7), + include_negatives: bool = False, + sign_probability: float = 0.5, + ) -> None: + self.dist = Uniform(range[0], range[1]) + self.include_negatives = include_negatives + if include_negatives: + self.sign_dist = Bernoulli(sign_probability) + + def __call__(self, freq_image: Tensor) -> Tensor: + # Calculate amplitude and phase + amplitude = torch.sqrt(freq_image.real**2 + freq_image.imag**2) + phase = torch.atan2(freq_image.imag, freq_image.real) + + # Sample a random phase shift θ + theta = self.dist.sample().to(freq_image.device) + + if self.include_negatives: + # Determine sign for shift: +θ or -θ + sign = self.sign_dist.sample().to(freq_image.device) + # Apply random sign directly to theta + theta = torch.where(sign == 1, theta, -theta) + + # Adjust the phase + phase_shifted = phase + theta + + # Recreate the complex spectrum with adjusted phase + real = amplitude * torch.cos(phase_shifted) + imag = amplitude * torch.sin(phase_shifted) + output = torch.complex(real, imag) + + return output diff --git a/tests/transforms/test_phase_shift_transform.py b/tests/transforms/test_phase_shift_transform.py new file mode 100644 index 000000000..5c3ae1f17 --- /dev/null +++ b/tests/transforms/test_phase_shift_transform.py @@ -0,0 +1,27 @@ +import torch + +from lightly.transforms import IRFFT2DTransform, PhaseShiftTransform, RFFT2DTransform + + +# Testing function image -> RFFT -> PhaseShift. +# Compare shapes of source and result. +def test() -> None: + image = torch.randn(3, 64, 64) + + rfftTransform = RFFT2DTransform() + rfft = rfftTransform(image) + + phaseShiftTf_1 = PhaseShiftTransform() + rescaled_rfft_1 = phaseShiftTf_1(rfft) + + phaseShiftTf_2 = PhaseShiftTransform(range=(1.0, 2.0)) + rescaled_rfft_2 = phaseShiftTf_2(rfft) + + phaseShiftTf_3 = PhaseShiftTransform( + range=(1.0, 2.0), include_negatives=True, sign_probability=0.8 + ) + rescaled_rfft_3 = phaseShiftTf_3(rfft) + + assert rescaled_rfft_1.shape == rfft.shape + assert rescaled_rfft_2.shape == rfft.shape + assert rescaled_rfft_3.shape == rfft.shape