Skip to content

Commit

Permalink
feat: 17 rework yolox data augmentation (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
picsalex authored Jun 10, 2024
1 parent 3f29608 commit 284cd0d
Show file tree
Hide file tree
Showing 10 changed files with 600 additions and 63 deletions.
3 changes: 3 additions & 0 deletions yolox-detection/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
FROM picsellia/cuda:11.7.1-cudnn8-ubuntu20.04-python3.10

RUN apt-get -y update
RUN apt-get -y install git

COPY ./yolox-detection/requirements.txt .

ARG REBUILD_ALL
Expand Down
42 changes: 17 additions & 25 deletions yolox-detection/experiment/YOLOX/yolox/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, exp: Exp, args):

# Picsellia
self.picsellia_experiment = args.picsellia_experiment
self.metrics_dict = {}

if self.rank == 0:
os.makedirs(self.file_name, exist_ok=True)
Expand Down Expand Up @@ -130,6 +131,10 @@ def train_one_iter(self):
**outputs,
)

metrics_to_record = ["lr", "total_loss", "iou_loss", "conf_loss", "cls_loss"]
for metric in metrics_to_record:
self.metrics_dict.setdefault(metric, []).append(self.meter[metric].latest)

def before_train(self):
logger.info("args: {}".format(self.args))
logger.info("exp value:\n{}".format(self.exp))
Expand Down Expand Up @@ -208,17 +213,15 @@ def after_train(self):
def before_epoch(self):
logger.info("---> start train epoch{}".format(self.epoch + 1))

if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug:
logger.info("--->No mosaic aug now!")
self.train_loader.close_mosaic()
logger.info("--->Add additional L1 loss now!")
if self.is_distributed:
self.model.module.head.use_l1 = True
else:
self.model.head.use_l1 = True
self.exp.eval_interval = 1
if not self.no_aug:
self.save_ckpt(ckpt_name="last_mosaic_epoch")
# if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs:
# logger.info("--->No aug now!")
# os.environ["disable_aug"] = "1"
#
# logger.info("--->Add additional L1 loss now!")
# if self.is_distributed:
# self.model.module.head.use_l1 = True
# else:
# self.model.head.use_l1 = True

def after_epoch(self):
self.save_ckpt(ckpt_name="latest")
Expand All @@ -227,28 +230,17 @@ def after_epoch(self):
all_reduce_norm(self.model)
self.evaluate_and_save_model()

loss_meter = self.meter.get_filtered_meter("loss")
for k, v in loss_meter.items():
for k, v in self.metrics_dict.items():
try:
self.picsellia_experiment.log(
name="train/" + k, type=LogType.LINE, data=float(v.latest)
name="train/" + k, type=LogType.LINE, data=float(v[0])
)
except Exception as e:
logger.info(
f"Couldn't log metric {'train/' + k} to Picsellia because: {str(e)}"
)
try:
self.picsellia_experiment.log(
name="train/lr",
type=LogType.LINE,
data=float(self.meter["lr"].latest),
)
except Exception as e:
logger.info(
f"Couldn't log metric 'train/lr' to Picsellia because: {str(e)}"
)

self.meter.clear_meters()
self.metrics_dict = {}

def before_iter(self):
pass
Expand Down
12 changes: 7 additions & 5 deletions yolox-detection/experiment/YOLOX/yolox/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

from .data_augment import TrainTransform, ValTransform
from .data_prefetcher import DataPrefetcher
from .dataloading import DataLoader, get_yolox_datadir, worker_init_reset_seed
from .datasets import *
from .samplers import InfiniteSampler, YoloBatchSampler
from .data_augment import TrainTransform, ValTransform # noqa
from .data_augment_v2 import TrainTransformV2, ValTransformV2 # noqa
from .data_augment_v3 import TrainTransformV3, ValTransformV3 # noqa
from .data_prefetcher import DataPrefetcher # noqa
from .dataloading import DataLoader, get_yolox_datadir, worker_init_reset_seed # noqa
from .datasets import * # noqa
from .samplers import InfiniteSampler, YoloBatchSampler # noqa
21 changes: 20 additions & 1 deletion yolox-detection/experiment/YOLOX/yolox/data/data_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import random

import cv2
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from YOLOX.yolox.utils import xyxy2cxcywh

Expand Down Expand Up @@ -185,12 +187,14 @@ def __call__(self, image, targets, input_dim):

if random.random() < self.hsv_prob:
augment_hsv(image)

image_t, boxes = _mirror(image, boxes, self.flip_prob)

height, width, _ = image_t.shape
image_t, r_ = preproc(image_t, input_dim)
# boxes [xyxy] 2 [cx,cy,w,h]
boxes = xyxy2cxcywh(boxes)
boxes *= r_
boxes = boxes * r_

mask_b = np.minimum(boxes[:, 2], boxes[:, 3]) > 1
boxes_t = boxes[mask_b]
Expand All @@ -212,6 +216,21 @@ def __call__(self, image, targets, input_dim):
padded_labels = np.ascontiguousarray(padded_labels, dtype=np.float32)
return image_t, padded_labels

def visualize_image_with_boxes(self, image, boxes):
fig, ax = plt.subplots(1)
ax.imshow(image)
for box in boxes:
rect = patches.Rectangle(
(box[0] - box[2] / 2, box[1] - box[3] / 2),
box[2],
box[3],
linewidth=2,
edgecolor="r",
facecolor="none",
)
ax.add_patch(rect)
plt.show()


class ValTransform:
"""
Expand Down
239 changes: 239 additions & 0 deletions yolox-detection/experiment/YOLOX/yolox/data/data_augment_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import cv2
import imgaug.augmenters as iaa
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from YOLOX.yolox.utils import xyxy2cxcywh
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage


def preproc(img, input_size, swap=(2, 0, 1)):
if len(img.shape) == 3:
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
else:
padded_img = np.ones(input_size, dtype=np.uint8) * 114

r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img

padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r


def get_transformed_image_and_targets(image, targets, input_dim, max_labels, aug):
boxes = targets[:, :4].copy()
labels = targets[:, 4].copy()

if len(boxes) == 0:
targets = np.zeros((max_labels, 5), dtype=np.float32)
image, _ = preproc(image, input_dim)
return image, targets

bbs = BoundingBoxesOnImage([BoundingBox(*box) for box in boxes], shape=image.shape)

if aug is not None:
image_aug, bbs_aug = aug(image=image, bounding_boxes=bbs)
else:
image_aug = image
bbs_aug = bbs

bbs_aug = bbs_aug.clip_out_of_image()

valid_indices = [
i
for i, bbox in enumerate(bbs_aug.bounding_boxes)
if bbox.is_fully_within_image(image)
]
labels = labels[valid_indices]
bbs_aug = bbs_aug.remove_out_of_image()

boxes = np.array(
[[bbox.x1, bbox.y1, bbox.x2, bbox.y2] for bbox in bbs_aug.bounding_boxes]
)

height, width, _ = image_aug.shape
image_t, r_ = preproc(image_aug, input_dim)

# All the bbox have disappeared due to a data augmentation
if len(boxes) == 0:
targets = np.zeros((max_labels, 5), dtype=np.float32)
return image_t, targets

# boxes [xyxy] 2 [cx,cy,w,h]
boxes = xyxy2cxcywh(boxes)

mask_b = np.minimum(boxes[:, 2], boxes[:, 3]) > 1
boxes_t = boxes[mask_b]
labels_t = labels[mask_b]

if len(boxes_t) == 0:
targets = np.zeros((max_labels, 5), dtype=np.float32)
return image_t, targets

labels_t = np.expand_dims(labels_t, 1)

targets_t = np.hstack((labels_t, boxes_t))
padded_labels = np.zeros((max_labels, 5))
padded_labels[range(len(targets_t))[:max_labels]] = targets_t[:max_labels]
padded_labels = np.ascontiguousarray(padded_labels, dtype=np.float32)
return image_t, padded_labels


def visualize_image_with_boxes(image, boxes):
fig, ax = plt.subplots(1)
ax.imshow(image)
for box in boxes:
rect = patches.Rectangle(
(box[0] - box[2] / 2, box[1] - box[3] / 2),
box[2],
box[3],
linewidth=2,
edgecolor="r",
facecolor="none",
)
ax.add_patch(rect)
plt.show()


class TrainTransformV2:
def __init__(self, max_labels=50):
self.max_labels = max_labels
self.aug = self._load_custom_augmentations()

def _load_custom_augmentations(self):
"""Create a sequence of imgaug augmentations"""

def sometimes(aug):
return iaa.Sometimes(0.5, aug)

def rarely(aug):
return iaa.Sometimes(0.03, aug)

return iaa.Sequential(
[
# apply the following augmenters to most images
iaa.Fliplr(0.5), # horizontally flip 50% of all images
# crop images by -5% to 10% of their height/width
sometimes(
iaa.CropAndPad(
percent=(-0.05, 0.1), pad_mode="constant", pad_cval=0
)
),
sometimes(
iaa.Affine(
# scale images to 90-110% of their size, individually per axis
scale={"x": (0.9, 1.1), "y": (0.9, 1.1)},
# translate by -20 to +20 percent (per axis)
translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
shear=(-5, 5), # shear by -5 to +5 degrees
# use nearest neighbour or bilinear interpolation (fast)
order=[0, 1],
# if mode is constant, use a cval between 0 and 255
cval=0,
# use any of scikit-image's warping modes
mode="constant",
)
),
rarely(iaa.Cutout(nb_iterations=(1, 4), size=(0.05, 0.1))),
# execute 0 to 5 of the following (less important) augmenters per
# image don't execute all of them, as that would often be way too
# strong
iaa.SomeOf(
(0, 5),
[
# convert images into their superpixel representation
iaa.OneOf(
[
# blur images with a sigma between 10 and 14.0
iaa.GaussianBlur((2, 4)),
# blur image using local means with kernel sizes
# between 2 and 5
iaa.AverageBlur(k=(3, 8)),
# blur image using local medians with kernel sizes
# between 1 and 3
iaa.MedianBlur(k=(1, 5)),
]
),
iaa.Sharpen(
alpha=(0, 0.6), lightness=(0.9, 1.2)
), # sharpen images
# add gaussian noise to images
iaa.AdditiveGaussianNoise(scale=(0.0, 0.06 * 255)),
# change brightness of images (by -15 to 15 of original value)
iaa.Add((-15, 15)),
iaa.AddToHueAndSaturation((-10, 10)),
iaa.Add((-8, 8), per_channel=0.5),
],
random_order=True,
),
rarely(
iaa.OneOf(
[
iaa.CloudLayer(
intensity_mean=(220, 255),
intensity_freq_exponent=(-2.0, -1.5),
intensity_coarse_scale=2,
alpha_min=(0.7, 0.9),
alpha_multiplier=0.3,
alpha_size_px_max=(2, 8),
alpha_freq_exponent=(-4.0, -2.0),
sparsity=0.9,
density_multiplier=(0.3, 0.6),
),
iaa.Rain(nb_iterations=1, drop_size=0.05, speed=0.2),
iaa.MotionBlur(k=20),
]
)
),
],
random_order=True,
)

def __call__(self, image, targets, input_dim):
return get_transformed_image_and_targets(
image=image,
targets=targets,
input_dim=input_dim,
max_labels=self.max_labels,
aug=self.aug,
)


class ValTransformV2:
"""
Defines the transformations that should be applied to test PIL image
for input into the network
dimension -> tensorize -> color adj
Arguments:
resize (int): input dimension to SSD
rgb_means ((int,int,int)): average RGB of the dataset
(104,117,123)
swap ((int,int,int)): final order of channels
Returns:
transform (transform) : callable transform to be applied to test/val
data
"""

def __init__(self, swap=(2, 0, 1)):
self.swap = swap

# assume input is cv2 img for now
def __call__(self, image, targets, input_dim):
image_t, _ = get_transformed_image_and_targets(
image=image,
targets=targets,
input_dim=input_dim,
max_labels=150,
aug=None,
)

return image_t, np.zeros((1, 5))
Loading

0 comments on commit 284cd0d

Please sign in to comment.