Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonlessons committed Mar 15, 2024
2 parents 506c06a + 6ba4798 commit 4d085bc
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 7 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
## [1.2.2] - 2024-03-15
### Changed
- Bug fixed with `loss_info` local variable in `mltu.torch.model.Model` object

### Added
- Added `RandomColorMode` and `RandomZoom` into `mltu.augmentors`


## [1.2.1] - 2024-03-12
### Changed
- Fixed many minor bugs
Expand Down
2 changes: 1 addition & 1 deletion mltu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.2.1"
__version__ = "1.2.2"

from .annotations.images import Image
from .annotations.images import CVImage
Expand Down
131 changes: 127 additions & 4 deletions mltu/augmentors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from . import Image
from mltu.annotations.audio import Audio
from mltu.annotations.detections import Detections, Detection
from mltu.annotations.detections import Detections, Detection, BboxType

"""
Implemented image augmentors:
Expand All @@ -20,6 +20,8 @@
- RandomFlip
- RandomDropBlock
- RandomMosaic
- RandomZoom
- RandomColorMode
Implemented audio augmentors:
- RandomAudioNoise
Expand Down Expand Up @@ -379,21 +381,22 @@ def __init__(
self,
random_chance: float = 0.5,
log_level: int = logging.INFO,
sigma: typing.Union[int, float] = 0.5,
sigma: typing.Union[int, float] = 1.5,
augment_annotation: bool = False,
) -> None:
""" Randomly erode and dilate image
Args:
random_chance (float): Float between 0.0 and 1.0 setting bounds for random probability. Defaults to 0.5.
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
sigma (int, float): standard deviation of the Gaussian kernel
sigma (int, float): maximum sigma value for Gaussian blur. Defaults to 1.5.
"""
super(RandomGaussianBlur, self).__init__(random_chance, log_level, augment_annotation)
self.sigma = sigma

def augment(self, image: Image) -> Image:
img = cv2.GaussianBlur(image.numpy(), (0, 0), self.sigma)
sigma = np.random.uniform(0, self.sigma)
img = cv2.GaussianBlur(image.numpy(), (0, 0), sigma)

image.update(img)

Expand Down Expand Up @@ -716,6 +719,126 @@ def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image,
return image, annotation


class RandomZoom(Augmentor):
def __init__(
self,
random_chance: float = 0.5,
log_level: int = logging.INFO,
augment_annotation: bool = True,
object_crop_percentage: float = 0.5,
) -> None:
""" Randomly zoom into an image
Args:
random_chance (float): Float between 0.0 and 1.0 setting bounds for random probability. Defaults to 0.5.
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
augment_annotation (bool): Whether to augment the annotation. Defaults to False.
object_crop_percentage (float): Percentage of the object allowed to be cropped. Defaults to 0.5.
"""
super(RandomZoom, self).__init__(random_chance, log_level, augment_annotation)
self.object_crop_percentage = object_crop_percentage

@randomness_decorator
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
""" Randomly zoom an image
Args:
image (Image): Image to be used for zoom
annotation (typing.Any): Annotation to be used for zoom
Returns:
image (Image): Zoomed image
annotation (typing.Any): Zoomed annotation if necessary
"""
if isinstance(annotation, Detections) and self._augment_annotation:

dets = np.array([detection.xyxy for detection in annotation])
min_left = np.min(dets[:, 0])
min_top = np.min(dets[:, 1])
max_right = np.max(dets[:, 2])
max_bottom = np.max(dets[:, 3])

# Calculate the size of the object
object_width = max_right - min_left
object_height = max_bottom - min_top

crop_xmin = np.random.uniform(0, min_left + 0.25 * object_width * self.object_crop_percentage)
crop_ymin = np.random.uniform(0, min_top + 0.25 * object_height * self.object_crop_percentage)
crop_xmax = np.random.uniform(max_right - 0.25 * object_width * self.object_crop_percentage, 1)
crop_ymax = np.random.uniform(max_bottom - 0.25 * object_height * self.object_crop_percentage, 1)

crop_min_max = np.array([crop_xmin, crop_ymin, crop_xmax, crop_ymax])
new_xyxy = (crop_min_max * np.array([image.width, image.height, image.width, image.height])).astype(int)
new_image = image.numpy()[new_xyxy[1]:new_xyxy[3], new_xyxy[0]:new_xyxy[2]]
image.update(new_image)

crop_min_ratio = np.array([crop_xmin, crop_ymin, crop_xmin, crop_ymin])
crop_max_ratio = np.array([crop_xmax, crop_ymax, crop_xmax, crop_ymax])
new_dets = (dets - crop_min_ratio) / (crop_max_ratio - crop_min_ratio)

detections = []
for detection, new_det in zip(annotation, new_dets):
new_detection = Detection(
new_det,
label=detection.label,
labels=detection.labels,
confidence=detection.confidence,
image_path=detection.image_path,
width=image.width,
height=image.height,
relative=True,
bbox_type = BboxType.XYXY
)

detections.append(new_detection)

annotation = Detections(
labels=annotation.labels,
width=image.width,
height=image.height,
detections=detections
)

return image, annotation


class RandomColorMode(Augmentor):
def __init__(
self,
random_chance: float = 0.5,
log_level: int = logging.INFO,
augment_annotation: bool = False,
) -> None:
""" Randomly change color mode of an image
Args:
random_chance (float): Float between 0.0 and 1.0 setting bounds for random probability. Defaults to 0.5.
log_level (int): Log level for the augmentor. Defaults to logging.INFO.
augment_annotation (bool): Whether to augment the annotation. Defaults to False.
"""
super(RandomColorMode, self).__init__(random_chance, log_level, augment_annotation)

@randomness_decorator
def __call__(self, image: Image, annotation: typing.Any) -> typing.Tuple[Image, typing.Any]:
""" Randomly change color mode of an image
Args:
image (Image): Image to be used for color mode change
annotation (typing.Any): Annotation to be used for color mode change
Returns:
image (Image): Color mode changed image
annotation (typing.Any): Color mode changed annotation if necessary
"""
color_mode = np.random.choice([cv2.COLOR_BGR2GRAY, cv2.COLOR_BGR2HSV, cv2.COLOR_BGR2LAB, cv2.COLOR_BGR2YCrCb, cv2.COLOR_BGR2RGB])
new_image = cv2.cvtColor(image.numpy(), color_mode)
if color_mode == cv2.COLOR_BGR2GRAY:
new_image = cv2.cvtColor(new_image, cv2.COLOR_GRAY2BGR)
image.update(new_image)

return image, annotation


class RandomAudioNoise(Augmentor):
""" Randomly add noise to audio
Expand Down
8 changes: 6 additions & 2 deletions mltu/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,15 @@ def toDevice(self, data: np.ndarray, target: np.ndarray) -> typing.Tuple[torch.T
def train_step(
self,
data: typing.Union[np.ndarray, torch.Tensor],
target: typing.Union[np.ndarray, torch.Tensor]
target: typing.Union[np.ndarray, torch.Tensor],
loss_info: dict = {}
) -> torch.Tensor:
""" Perform one training step
Args:
data (typing.Union[np.ndarray, torch.Tensor]): training data
target (typing.Union[np.ndarray, torch.Tensor]): training target
loss_info (dict, optional): additional loss information. Defaults to {}.
Returns:
torch.Tensor: loss
Expand Down Expand Up @@ -228,13 +230,15 @@ def train_step(
def test_step(
self,
data: typing.Union[np.ndarray, torch.Tensor],
target: typing.Union[np.ndarray, torch.Tensor]
target: typing.Union[np.ndarray, torch.Tensor],
loss_info: dict = {}
) -> torch.Tensor:
""" Perform one validation step
Args:
data (typing.Union[np.ndarray, torch.Tensor]): validation data
target (typing.Union[np.ndarray, torch.Tensor]): validation target
loss_info (dict, optional): additional loss information. Defaults to {}.
Returns:
torch.Tensor: loss
Expand Down

0 comments on commit 4d085bc

Please sign in to comment.