Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the Amplitude Rescaling Transform #1694

Merged
merged 8 commits into from
Oct 21, 2024
1 change: 1 addition & 0 deletions lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# All Rights Reserved

from lightly.transforms.aim_transform import AIMTransform
from lightly.transforms.amplitude_rescale_transform import AmplitudeRescaleTranform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
Expand Down
34 changes: 34 additions & 0 deletions lightly/transforms/amplitude_rescale_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Tuple

import numpy as np
import torch
from torch import Tensor
from torch.distributions import Uniform


class AmplitudeRescaleTranform:
"""Implementation of amplitude rescaling transformation.

This transform will rescale the amplitude of the Fourier Spectrum (`freq_image`) of the image and return it.

Attributes:
dist:
Uniform distribution in `[m, n)` from which the scaling value will be selected.
"""

def __init__(self, range: Tuple[float, float] = (0.8, 1.75)) -> None:
self.dist = Uniform(range[0], range[1])

def __call__(self, freq_image: Tensor) -> Tensor:
amplitude = torch.sqrt(freq_image.real**2 + freq_image.imag**2)

phase = torch.atan2(freq_image.imag, freq_image.real)
# p with shape (H, W)
p = self.dist.sample(freq_image.shape[1:]).to(freq_image.device)
# Unsqueeze to add channel dimension.
amplitude *= p.unsqueeze(0)
real = amplitude * torch.cos(phase)
imag = amplitude * torch.sin(phase)
output = torch.complex(real, imag)

return output
26 changes: 26 additions & 0 deletions tests/transforms/test_amplitude_rescale_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import torch

from lightly.transforms import (
AmplitudeRescaleTranform,
IRFFT2DTransform,
RFFT2DTransform,
)


# Testing function image -> FFT -> AmplitudeRescale.
# Compare shapes of source and result.
def test() -> None:
image = torch.randn(3, 64, 64)

rfftTransform = RFFT2DTransform()
rfft = rfftTransform(image)

ampRescaleTf_1 = AmplitudeRescaleTranform()
rescaled_rfft_1 = ampRescaleTf_1(rfft)

ampRescaleTf_2 = AmplitudeRescaleTranform(range=(1.0, 2.0))
rescaled_rfft_2 = ampRescaleTf_2(rfft)

assert rescaled_rfft_1.shape == rfft.shape
assert rescaled_rfft_2.shape == rfft.shape