Skip to content

Commit

Permalink
Add PhaseShift Transform (#1714)
Browse files Browse the repository at this point in the history
  • Loading branch information
pearguacamole authored Nov 7, 2024
1 parent 2155ebb commit fd29f1d
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ docs/source/getting_started/resources

#ignore venv
venv

.venv
#ignore pycharm IDE
.idea

Expand Down
1 change: 1 addition & 0 deletions lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lightly.transforms.mmcr_transform import MMCRTransform
from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform
from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform
from lightly.transforms.phase_shift_transform import PhaseShiftTransform
from lightly.transforms.pirl_transform import PIRLTransform
from lightly.transforms.random_frequency_mask_transform import (
RandomFrequencyMaskTransform,
Expand Down
60 changes: 60 additions & 0 deletions lightly/transforms/phase_shift_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Tuple

import torch
from torch import Tensor
from torch.distributions import Uniform
from torch.distributions.bernoulli import Bernoulli


class PhaseShiftTransform:
"""Implementation of phase shifting transformation.
Applies a random phase shift `theta` (positive or negative) to the Fourier spectrum (`freq_image`) of the image and returns the transformed spectrum.
Attributes:
dist:
A uniform distribution in the range `[p, q)` from which the magnitude of the
phase shift `theta` is selected.
include_negatives:
A flag indicating whether negative values of `theta` should be included.
If `True`, both positive and negative shifts are applied.
sign_dist:
A Bernoulli distribution used to decide the sign of `theta`, based on a
given probability `sign_probability`, if negative values are included.
"""

def __init__(
self,
range: Tuple[float, float] = (0.4, 0.7),
include_negatives: bool = False,
sign_probability: float = 0.5,
) -> None:
self.dist = Uniform(range[0], range[1])
self.include_negatives = include_negatives
if include_negatives:
self.sign_dist = Bernoulli(sign_probability)

def __call__(self, freq_image: Tensor) -> Tensor:
# Calculate amplitude and phase
amplitude = torch.sqrt(freq_image.real**2 + freq_image.imag**2)
phase = torch.atan2(freq_image.imag, freq_image.real)

# Sample a random phase shift θ
theta = self.dist.sample().to(freq_image.device)

if self.include_negatives:
# Determine sign for shift: +θ or -θ
sign = self.sign_dist.sample().to(freq_image.device)
# Apply random sign directly to theta
theta = torch.where(sign == 1, theta, -theta)

# Adjust the phase
phase_shifted = phase + theta

# Recreate the complex spectrum with adjusted phase
real = amplitude * torch.cos(phase_shifted)
imag = amplitude * torch.sin(phase_shifted)
output = torch.complex(real, imag)

return output
27 changes: 27 additions & 0 deletions tests/transforms/test_phase_shift_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch

from lightly.transforms import IRFFT2DTransform, PhaseShiftTransform, RFFT2DTransform


# Testing function image -> RFFT -> PhaseShift.
# Compare shapes of source and result.
def test() -> None:
image = torch.randn(3, 64, 64)

rfftTransform = RFFT2DTransform()
rfft = rfftTransform(image)

phaseShiftTf_1 = PhaseShiftTransform()
rescaled_rfft_1 = phaseShiftTf_1(rfft)

phaseShiftTf_2 = PhaseShiftTransform(range=(1.0, 2.0))
rescaled_rfft_2 = phaseShiftTf_2(rfft)

phaseShiftTf_3 = PhaseShiftTransform(
range=(1.0, 2.0), include_negatives=True, sign_probability=0.8
)
rescaled_rfft_3 = phaseShiftTf_3(rfft)

assert rescaled_rfft_1.shape == rfft.shape
assert rescaled_rfft_2.shape == rfft.shape
assert rescaled_rfft_3.shape == rfft.shape

0 comments on commit fd29f1d

Please sign in to comment.