Skip to content

Commit

Permalink
Merge branch 'master' into malte-lig-5576-tiling-tasks-implement-work…
Browse files Browse the repository at this point in the history
…er-part
  • Loading branch information
MalteEbner authored Nov 18, 2024
2 parents 908c366 + d73fbca commit 67a3501
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 11 deletions.
2 changes: 1 addition & 1 deletion lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# All Rights Reserved

from lightly.transforms.aim_transform import AIMTransform
from lightly.transforms.amplitude_rescale_transform import AmplitudeRescaleTranform
from lightly.transforms.amplitude_rescale_transform import AmplitudeRescaleTransform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
Expand Down
3 changes: 1 addition & 2 deletions lightly/transforms/amplitude_rescale_transform.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Tuple

import numpy as np
import torch
from torch import Tensor
from torch.distributions import Uniform


class AmplitudeRescaleTranform:
class AmplitudeRescaleTransform:
"""Implementation of amplitude rescaling transformation.
This transform will rescale the amplitude of the Fourier Spectrum (`freq_image`) of the image and return it.
Expand Down
11 changes: 3 additions & 8 deletions tests/transforms/test_amplitude_rescale_transform.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import numpy as np
import torch

from lightly.transforms import (
AmplitudeRescaleTranform,
IRFFT2DTransform,
RFFT2DTransform,
)
from lightly.transforms import AmplitudeRescaleTransform, RFFT2DTransform


# Testing function image -> FFT -> AmplitudeRescale.
Expand All @@ -16,10 +11,10 @@ def test() -> None:
rfftTransform = RFFT2DTransform()
rfft = rfftTransform(image)

ampRescaleTf_1 = AmplitudeRescaleTranform()
ampRescaleTf_1 = AmplitudeRescaleTransform()
rescaled_rfft_1 = ampRescaleTf_1(rfft)

ampRescaleTf_2 = AmplitudeRescaleTranform(range=(1.0, 2.0))
ampRescaleTf_2 = AmplitudeRescaleTransform(range=(1.0, 2.0))
rescaled_rfft_2 = ampRescaleTf_2(rfft)

assert rescaled_rfft_1.shape == rfft.shape
Expand Down

0 comments on commit 67a3501

Please sign in to comment.