From 284cd0d46ec08cb8671a6b9eac3c002f0f4bc2fa Mon Sep 17 00:00:00 2001 From: Alexis S <132259399+picsalex@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:05:52 +0200 Subject: [PATCH] feat: 17 rework yolox data augmentation (#151) --- yolox-detection/Dockerfile | 3 + .../experiment/YOLOX/yolox/core/trainer.py | 42 ++- .../experiment/YOLOX/yolox/data/__init__.py | 12 +- .../YOLOX/yolox/data/data_augment.py | 21 +- .../YOLOX/yolox/data/data_augment_v2.py | 239 ++++++++++++++++++ .../YOLOX/yolox/data/data_augment_v3.py | 222 ++++++++++++++++ .../experiment/YOLOX/yolox/exp/yolox_base.py | 48 ++-- yolox-detection/experiment/main.py | 9 +- .../experiment/test_augmentations.py | 64 +++++ yolox-detection/requirements.txt | 3 +- 10 files changed, 600 insertions(+), 63 deletions(-) create mode 100644 yolox-detection/experiment/YOLOX/yolox/data/data_augment_v2.py create mode 100644 yolox-detection/experiment/YOLOX/yolox/data/data_augment_v3.py create mode 100644 yolox-detection/experiment/test_augmentations.py diff --git a/yolox-detection/Dockerfile b/yolox-detection/Dockerfile index 02d4f6fc..7e5b915f 100644 --- a/yolox-detection/Dockerfile +++ b/yolox-detection/Dockerfile @@ -1,5 +1,8 @@ FROM picsellia/cuda:11.7.1-cudnn8-ubuntu20.04-python3.10 +RUN apt-get -y update +RUN apt-get -y install git + COPY ./yolox-detection/requirements.txt . ARG REBUILD_ALL diff --git a/yolox-detection/experiment/YOLOX/yolox/core/trainer.py b/yolox-detection/experiment/YOLOX/yolox/core/trainer.py index cfc7fb36..db886ef0 100755 --- a/yolox-detection/experiment/YOLOX/yolox/core/trainer.py +++ b/yolox-detection/experiment/YOLOX/yolox/core/trainer.py @@ -63,6 +63,7 @@ def __init__(self, exp: Exp, args): # Picsellia self.picsellia_experiment = args.picsellia_experiment + self.metrics_dict = {} if self.rank == 0: os.makedirs(self.file_name, exist_ok=True) @@ -130,6 +131,10 @@ def train_one_iter(self): **outputs, ) + metrics_to_record = ["lr", "total_loss", "iou_loss", "conf_loss", "cls_loss"] + for metric in metrics_to_record: + self.metrics_dict.setdefault(metric, []).append(self.meter[metric].latest) + def before_train(self): logger.info("args: {}".format(self.args)) logger.info("exp value:\n{}".format(self.exp)) @@ -208,17 +213,15 @@ def after_train(self): def before_epoch(self): logger.info("---> start train epoch{}".format(self.epoch + 1)) - if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug: - logger.info("--->No mosaic aug now!") - self.train_loader.close_mosaic() - logger.info("--->Add additional L1 loss now!") - if self.is_distributed: - self.model.module.head.use_l1 = True - else: - self.model.head.use_l1 = True - self.exp.eval_interval = 1 - if not self.no_aug: - self.save_ckpt(ckpt_name="last_mosaic_epoch") + # if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs: + # logger.info("--->No aug now!") + # os.environ["disable_aug"] = "1" + # + # logger.info("--->Add additional L1 loss now!") + # if self.is_distributed: + # self.model.module.head.use_l1 = True + # else: + # self.model.head.use_l1 = True def after_epoch(self): self.save_ckpt(ckpt_name="latest") @@ -227,28 +230,17 @@ def after_epoch(self): all_reduce_norm(self.model) self.evaluate_and_save_model() - loss_meter = self.meter.get_filtered_meter("loss") - for k, v in loss_meter.items(): + for k, v in self.metrics_dict.items(): try: self.picsellia_experiment.log( - name="train/" + k, type=LogType.LINE, data=float(v.latest) + name="train/" + k, type=LogType.LINE, data=float(v[0]) ) except Exception as e: logger.info( f"Couldn't log metric {'train/' + k} to Picsellia because: {str(e)}" ) - try: - self.picsellia_experiment.log( - name="train/lr", - type=LogType.LINE, - data=float(self.meter["lr"].latest), - ) - except Exception as e: - logger.info( - f"Couldn't log metric 'train/lr' to Picsellia because: {str(e)}" - ) - self.meter.clear_meters() + self.metrics_dict = {} def before_iter(self): pass diff --git a/yolox-detection/experiment/YOLOX/yolox/data/__init__.py b/yolox-detection/experiment/YOLOX/yolox/data/__init__.py index aeaf4f93..f3dc5da7 100644 --- a/yolox-detection/experiment/YOLOX/yolox/data/__init__.py +++ b/yolox-detection/experiment/YOLOX/yolox/data/__init__.py @@ -2,8 +2,10 @@ # -*- coding:utf-8 -*- # Copyright (c) Megvii, Inc. and its affiliates. -from .data_augment import TrainTransform, ValTransform -from .data_prefetcher import DataPrefetcher -from .dataloading import DataLoader, get_yolox_datadir, worker_init_reset_seed -from .datasets import * -from .samplers import InfiniteSampler, YoloBatchSampler +from .data_augment import TrainTransform, ValTransform # noqa +from .data_augment_v2 import TrainTransformV2, ValTransformV2 # noqa +from .data_augment_v3 import TrainTransformV3, ValTransformV3 # noqa +from .data_prefetcher import DataPrefetcher # noqa +from .dataloading import DataLoader, get_yolox_datadir, worker_init_reset_seed # noqa +from .datasets import * # noqa +from .samplers import InfiniteSampler, YoloBatchSampler # noqa diff --git a/yolox-detection/experiment/YOLOX/yolox/data/data_augment.py b/yolox-detection/experiment/YOLOX/yolox/data/data_augment.py index c1fdde69..598e98e4 100644 --- a/yolox-detection/experiment/YOLOX/yolox/data/data_augment.py +++ b/yolox-detection/experiment/YOLOX/yolox/data/data_augment.py @@ -13,6 +13,8 @@ import random import cv2 +import matplotlib.patches as patches +import matplotlib.pyplot as plt import numpy as np from YOLOX.yolox.utils import xyxy2cxcywh @@ -185,12 +187,14 @@ def __call__(self, image, targets, input_dim): if random.random() < self.hsv_prob: augment_hsv(image) + image_t, boxes = _mirror(image, boxes, self.flip_prob) + height, width, _ = image_t.shape image_t, r_ = preproc(image_t, input_dim) # boxes [xyxy] 2 [cx,cy,w,h] boxes = xyxy2cxcywh(boxes) - boxes *= r_ + boxes = boxes * r_ mask_b = np.minimum(boxes[:, 2], boxes[:, 3]) > 1 boxes_t = boxes[mask_b] @@ -212,6 +216,21 @@ def __call__(self, image, targets, input_dim): padded_labels = np.ascontiguousarray(padded_labels, dtype=np.float32) return image_t, padded_labels + def visualize_image_with_boxes(self, image, boxes): + fig, ax = plt.subplots(1) + ax.imshow(image) + for box in boxes: + rect = patches.Rectangle( + (box[0] - box[2] / 2, box[1] - box[3] / 2), + box[2], + box[3], + linewidth=2, + edgecolor="r", + facecolor="none", + ) + ax.add_patch(rect) + plt.show() + class ValTransform: """ diff --git a/yolox-detection/experiment/YOLOX/yolox/data/data_augment_v2.py b/yolox-detection/experiment/YOLOX/yolox/data/data_augment_v2.py new file mode 100644 index 00000000..945c14f6 --- /dev/null +++ b/yolox-detection/experiment/YOLOX/yolox/data/data_augment_v2.py @@ -0,0 +1,239 @@ +import cv2 +import imgaug.augmenters as iaa +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +from YOLOX.yolox.utils import xyxy2cxcywh +from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage + + +def preproc(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + + +def get_transformed_image_and_targets(image, targets, input_dim, max_labels, aug): + boxes = targets[:, :4].copy() + labels = targets[:, 4].copy() + + if len(boxes) == 0: + targets = np.zeros((max_labels, 5), dtype=np.float32) + image, _ = preproc(image, input_dim) + return image, targets + + bbs = BoundingBoxesOnImage([BoundingBox(*box) for box in boxes], shape=image.shape) + + if aug is not None: + image_aug, bbs_aug = aug(image=image, bounding_boxes=bbs) + else: + image_aug = image + bbs_aug = bbs + + bbs_aug = bbs_aug.clip_out_of_image() + + valid_indices = [ + i + for i, bbox in enumerate(bbs_aug.bounding_boxes) + if bbox.is_fully_within_image(image) + ] + labels = labels[valid_indices] + bbs_aug = bbs_aug.remove_out_of_image() + + boxes = np.array( + [[bbox.x1, bbox.y1, bbox.x2, bbox.y2] for bbox in bbs_aug.bounding_boxes] + ) + + height, width, _ = image_aug.shape + image_t, r_ = preproc(image_aug, input_dim) + + # All the bbox have disappeared due to a data augmentation + if len(boxes) == 0: + targets = np.zeros((max_labels, 5), dtype=np.float32) + return image_t, targets + + # boxes [xyxy] 2 [cx,cy,w,h] + boxes = xyxy2cxcywh(boxes) + + mask_b = np.minimum(boxes[:, 2], boxes[:, 3]) > 1 + boxes_t = boxes[mask_b] + labels_t = labels[mask_b] + + if len(boxes_t) == 0: + targets = np.zeros((max_labels, 5), dtype=np.float32) + return image_t, targets + + labels_t = np.expand_dims(labels_t, 1) + + targets_t = np.hstack((labels_t, boxes_t)) + padded_labels = np.zeros((max_labels, 5)) + padded_labels[range(len(targets_t))[:max_labels]] = targets_t[:max_labels] + padded_labels = np.ascontiguousarray(padded_labels, dtype=np.float32) + return image_t, padded_labels + + +def visualize_image_with_boxes(image, boxes): + fig, ax = plt.subplots(1) + ax.imshow(image) + for box in boxes: + rect = patches.Rectangle( + (box[0] - box[2] / 2, box[1] - box[3] / 2), + box[2], + box[3], + linewidth=2, + edgecolor="r", + facecolor="none", + ) + ax.add_patch(rect) + plt.show() + + +class TrainTransformV2: + def __init__(self, max_labels=50): + self.max_labels = max_labels + self.aug = self._load_custom_augmentations() + + def _load_custom_augmentations(self): + """Create a sequence of imgaug augmentations""" + + def sometimes(aug): + return iaa.Sometimes(0.5, aug) + + def rarely(aug): + return iaa.Sometimes(0.03, aug) + + return iaa.Sequential( + [ + # apply the following augmenters to most images + iaa.Fliplr(0.5), # horizontally flip 50% of all images + # crop images by -5% to 10% of their height/width + sometimes( + iaa.CropAndPad( + percent=(-0.05, 0.1), pad_mode="constant", pad_cval=0 + ) + ), + sometimes( + iaa.Affine( + # scale images to 90-110% of their size, individually per axis + scale={"x": (0.9, 1.1), "y": (0.9, 1.1)}, + # translate by -20 to +20 percent (per axis) + translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, + shear=(-5, 5), # shear by -5 to +5 degrees + # use nearest neighbour or bilinear interpolation (fast) + order=[0, 1], + # if mode is constant, use a cval between 0 and 255 + cval=0, + # use any of scikit-image's warping modes + mode="constant", + ) + ), + rarely(iaa.Cutout(nb_iterations=(1, 4), size=(0.05, 0.1))), + # execute 0 to 5 of the following (less important) augmenters per + # image don't execute all of them, as that would often be way too + # strong + iaa.SomeOf( + (0, 5), + [ + # convert images into their superpixel representation + iaa.OneOf( + [ + # blur images with a sigma between 10 and 14.0 + iaa.GaussianBlur((2, 4)), + # blur image using local means with kernel sizes + # between 2 and 5 + iaa.AverageBlur(k=(3, 8)), + # blur image using local medians with kernel sizes + # between 1 and 3 + iaa.MedianBlur(k=(1, 5)), + ] + ), + iaa.Sharpen( + alpha=(0, 0.6), lightness=(0.9, 1.2) + ), # sharpen images + # add gaussian noise to images + iaa.AdditiveGaussianNoise(scale=(0.0, 0.06 * 255)), + # change brightness of images (by -15 to 15 of original value) + iaa.Add((-15, 15)), + iaa.AddToHueAndSaturation((-10, 10)), + iaa.Add((-8, 8), per_channel=0.5), + ], + random_order=True, + ), + rarely( + iaa.OneOf( + [ + iaa.CloudLayer( + intensity_mean=(220, 255), + intensity_freq_exponent=(-2.0, -1.5), + intensity_coarse_scale=2, + alpha_min=(0.7, 0.9), + alpha_multiplier=0.3, + alpha_size_px_max=(2, 8), + alpha_freq_exponent=(-4.0, -2.0), + sparsity=0.9, + density_multiplier=(0.3, 0.6), + ), + iaa.Rain(nb_iterations=1, drop_size=0.05, speed=0.2), + iaa.MotionBlur(k=20), + ] + ) + ), + ], + random_order=True, + ) + + def __call__(self, image, targets, input_dim): + return get_transformed_image_and_targets( + image=image, + targets=targets, + input_dim=input_dim, + max_labels=self.max_labels, + aug=self.aug, + ) + + +class ValTransformV2: + """ + Defines the transformations that should be applied to test PIL image + for input into the network + + dimension -> tensorize -> color adj + + Arguments: + resize (int): input dimension to SSD + rgb_means ((int,int,int)): average RGB of the dataset + (104,117,123) + swap ((int,int,int)): final order of channels + + Returns: + transform (transform) : callable transform to be applied to test/val + data + """ + + def __init__(self, swap=(2, 0, 1)): + self.swap = swap + + # assume input is cv2 img for now + def __call__(self, image, targets, input_dim): + image_t, _ = get_transformed_image_and_targets( + image=image, + targets=targets, + input_dim=input_dim, + max_labels=150, + aug=None, + ) + + return image_t, np.zeros((1, 5)) diff --git a/yolox-detection/experiment/YOLOX/yolox/data/data_augment_v3.py b/yolox-detection/experiment/YOLOX/yolox/data/data_augment_v3.py new file mode 100644 index 00000000..cf527b81 --- /dev/null +++ b/yolox-detection/experiment/YOLOX/yolox/data/data_augment_v3.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +# Copyright (c) Megvii, Inc. and its affiliates. +""" +Data augmentation functionality. Passed as callable transformations to +Dataset classes. + +The data augmentation procedures were interpreted from @weiliu89's SSD paper +http://arxiv.org/abs/1512.02325 +""" + +import random + +import cv2 +import imgaug.augmenters as iaa +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +from YOLOX.yolox.utils import xyxy2cxcywh +from loguru import logger + + +def mirror(image, boxes, prob=0.5): + _, width, _ = image.shape + if random.random() < prob: + image = image[:, ::-1] + boxes[:, 0::2] = width - boxes[:, 2::-2] + return image, boxes + + +def preproc(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + + +class TrainTransformV3: + def __init__(self, enable_weather_transform: bool, max_labels: int = 50): + self.max_labels = max_labels + self.flip_prob = 0.5 + self.enable_weather_transform = enable_weather_transform + self.aug = self._load_custom_augmentations() + + def _load_custom_augmentations(self): + """Create a sequence of imgaug augmentations""" + + def sometimes(aug): + return iaa.Sometimes(0.5, aug) + + def rarely(aug): + return iaa.Sometimes(0.03, aug) + + def get_rare_transforms(): + transformations = [ + iaa.MotionBlur(k=15), + iaa.OneOf( + [ + # blur images with a sigma between 10 and 14.0 + iaa.GaussianBlur((2, 4)), + # blur image using local means with kernel sizes + # between 2 and 5 + iaa.AverageBlur(k=(2, 5)), + # blur image using local medians with kernel sizes + # between 1 and 3 + iaa.MedianBlur(k=(1, 3)), + ] + ), + ] + if self.enable_weather_transform: + logger.info( + "Weather data augmentations (Cloud & Rain) have been enabled!" + ) + transformations += [ + iaa.CloudLayer( + intensity_mean=(220, 255), + intensity_freq_exponent=(-2.0, -1.5), + intensity_coarse_scale=2, + alpha_min=(0.7, 0.9), + alpha_multiplier=0.3, + alpha_size_px_max=(2, 8), + alpha_freq_exponent=(-4.0, -2.0), + sparsity=0.9, + density_multiplier=(0.3, 0.6), + ), + iaa.Rain(nb_iterations=1, drop_size=0.05, speed=0.2), + ] + + return transformations + + return iaa.Sequential( + [ + # execute 0 to 5 of the following (less important) augmenters per + # image don't execute all of them, as that would often be way too + # strong + sometimes( + iaa.SomeOf( + (0, 5), + [ + # convert images into their superpixel representation + iaa.Sharpen( + alpha=(0, 0.6), lightness=(0.9, 1.2) + ), # sharpen images + # add gaussian noise to images + iaa.AdditiveGaussianNoise(scale=(0.0, 0.05 * 255)), + # change brightness of images (by -15 to 15 of original value) + iaa.AddToBrightness((-12, 12)), + iaa.AddToHueAndSaturation((-15, 15)), + iaa.Add((-8, 8), per_channel=0.5), + ], + random_order=True, + ), + ), + rarely(iaa.OneOf(get_rare_transforms())), + ], + random_order=True, + ) + + def __call__(self, image, targets, input_dim): + boxes = targets[:, :4].copy() + labels = targets[:, 4].copy() + if len(boxes) == 0: + targets = np.zeros((self.max_labels, 5), dtype=np.float32) + image, r_o = preproc(image, input_dim) + return image, targets + + image_o = image.copy() + targets_o = targets.copy() + height_o, width_o, _ = image_o.shape + boxes_o = targets_o[:, :4] + labels_o = targets_o[:, 4] + # bbox_o: [xyxy] to [c_x,c_y,w,h] + boxes_o = xyxy2cxcywh(boxes_o) + + image_aug = self.aug(image=image) + image_t, boxes = mirror(image_aug, boxes, self.flip_prob) + + height, width, _ = image_t.shape + image_t, r_ = preproc(image_t, input_dim) + # boxes [xyxy] 2 [cx,cy,w,h] + boxes = xyxy2cxcywh(boxes) + boxes = boxes * r_ + + mask_b = np.minimum(boxes[:, 2], boxes[:, 3]) > 1 + boxes_t = boxes[mask_b] + labels_t = labels[mask_b] + + if len(boxes_t) == 0: + image_t, r_o = preproc(image_o, input_dim) + boxes_o *= r_o + boxes_t = boxes_o + labels_t = labels_o + + labels_t = np.expand_dims(labels_t, 1) + + targets_t = np.hstack((labels_t, boxes_t)) + padded_labels = np.zeros((self.max_labels, 5)) + padded_labels[range(len(targets_t))[: self.max_labels]] = targets_t[ + : self.max_labels + ] + padded_labels = np.ascontiguousarray(padded_labels, dtype=np.float32) + return image_t, padded_labels + + def visualize_image_with_boxes(self, image, boxes): + fig, ax = plt.subplots(1) + ax.imshow(image) + for box in boxes: + rect = patches.Rectangle( + (box[0] - box[2] / 2, box[1] - box[3] / 2), + box[2], + box[3], + linewidth=2, + edgecolor="r", + facecolor="none", + ) + ax.add_patch(rect) + plt.show() + + +class ValTransformV3: + """ + Defines the transformations that should be applied to test PIL image + for input into the network + + dimension -> tensorize -> color adj + + Arguments: + resize (int): input dimension to SSD + rgb_means ((int,int,int)): average RGB of the dataset + (104,117,123) + swap ((int,int,int)): final order of channels + + Returns: + transform (transform) : callable transform to be applied to test/val + data + """ + + def __init__(self, swap=(2, 0, 1), legacy=False): + self.swap = swap + self.legacy = legacy + + # assume input is cv2 img for now + def __call__(self, img, res, input_size): + img, _ = preproc(img, input_size, self.swap) + if self.legacy: + img = img[::-1, :, :].copy() + img /= 255.0 + img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1) + img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) + return img, np.zeros((1, 5)) diff --git a/yolox-detection/experiment/YOLOX/yolox/exp/yolox_base.py b/yolox-detection/experiment/YOLOX/yolox/exp/yolox_base.py index 33796060..bde4b999 100644 --- a/yolox-detection/experiment/YOLOX/yolox/exp/yolox_base.py +++ b/yolox-detection/experiment/YOLOX/yolox/exp/yolox_base.py @@ -48,23 +48,25 @@ def __init__(self, args): # --------------- transform config ----------------- # # prob of applying mosaic aug - self.mosaic_prob = 1.0 + self.mosaic_prob = 0 # prob of applying mixup aug self.mixup_prob = 1.0 # prob of applying hsv aug self.hsv_prob = 1.0 # prob of applying flip aug - self.flip_prob = 0.5 + self.flip_prob = 1.0 # rotation angle range, for example, if set to 2, the true range is (-2, 2) self.degrees = 10.0 # translate range, for example, if set to 0.1, the true range is (-0.1, 0.1) self.translate = 0.1 - self.mosaic_scale = (0.1, 2) + self.mosaic_scale = (0, 0) # apply mixup aug or not - self.enable_mixup = True - self.mixup_scale = (0.5, 1.5) + self.enable_mixup = False + self.mixup_scale = (0, 0) # shear angle range, for example, if set to 2, the true range is (-2, 2) - self.shear = 2.0 + self.shear = 0 + # enable the weather data augmentations (rain and cloud) + self.enable_weather_transform = args.enable_weather_transform # -------------- training config --------------------- # # epoch number used for warmup @@ -148,7 +150,7 @@ def get_dataset(self, cache: bool = False, cache_type: str = "ram"): json_file=self.train_ann, img_size=self.input_size, preproc=TrainTransform( - max_labels=50, flip_prob=self.flip_prob, hsv_prob=self.hsv_prob + max_labels=120, flip_prob=self.flip_prob, hsv_prob=self.hsv_prob ), cache=cache, cache_type=cache_type, @@ -167,12 +169,12 @@ def get_data_loader( None: Do not use cache, in this case cache_data is also None. """ from YOLOX.yolox.data import ( - TrainTransform, YoloBatchSampler, DataLoader, InfiniteSampler, - MosaicDetection, + COCODataset, worker_init_reset_seed, + TrainTransformV3, ) from YOLOX.yolox.utils import wait_for_the_master @@ -185,21 +187,14 @@ def get_data_loader( ), "cache_img must be None if you didn't create self.dataset before launch" self.dataset = self.get_dataset(cache=False, cache_type=cache_img) - self.dataset = MosaicDetection( - dataset=self.dataset, - mosaic=not no_aug, + self.dataset = COCODataset( + data_dir=self.data_dir, + json_file=self.train_ann, + name="train2017", img_size=self.input_size, - preproc=TrainTransform( - max_labels=120, flip_prob=self.flip_prob, hsv_prob=self.hsv_prob + preproc=TrainTransformV3( + enable_weather_transform=self.enable_weather_transform, max_labels=120 ), - degrees=self.degrees, - translate=self.translate, - mosaic_scale=self.mosaic_scale, - mixup_scale=self.mixup_scale, - shear=self.shear, - enable_mixup=self.enable_mixup, - mosaic_prob=self.mosaic_prob, - mixup_prob=self.mixup_prob, ) if is_distributed: @@ -301,17 +296,16 @@ def get_lr_scheduler(self, lr, iters_per_epoch): return scheduler def get_eval_dataset(self, **kwargs): - from YOLOX.yolox.data import COCODataset, ValTransform + from YOLOX.yolox.data import COCODataset, ValTransformV3 testdev = kwargs.get("testdev", False) - legacy = kwargs.get("legacy", False) return COCODataset( data_dir=self.data_dir, json_file=self.val_ann if not testdev else self.test_ann, name="val2017" if not testdev else "test2017", img_size=self.test_size, - preproc=ValTransform(legacy=legacy), + preproc=ValTransformV3(), ) def get_eval_loader(self, batch_size, is_distributed, **kwargs): @@ -365,3 +359,7 @@ def eval(self, model, evaluator, is_distributed, half=False, return_outputs=Fals def check_exp_value(exp: Exp): h, w = exp.input_size assert h % 32 == 0 and w % 32 == 0, "input size must be multiples of 32" + + +def my_collate_fn(batch): + return batch diff --git a/yolox-detection/experiment/main.py b/yolox-detection/experiment/main.py index e1489431..5d744771 100644 --- a/yolox-detection/experiment/main.py +++ b/yolox-detection/experiment/main.py @@ -105,6 +105,7 @@ epochs = int(parameters.get("epochs", 100)) image_size = int(parameters.get("image_size", 640)) eval_interval = int(parameters.get("eval_interval", 5)) +enable_weather_transform = bool(parameters.get("enable_weather_transform", False)) # 6 - Launch the training # 6A - Args @@ -126,6 +127,7 @@ args.image_size = (image_size, image_size) args.picsellia_experiment = experiment args.eval_interval = eval_interval +args.enable_weather_transform = enable_weather_transform # 6B - Get model architecture exp = get_exp_by_name(args) @@ -188,12 +190,7 @@ model_path = os.path.join(exp.output_dir, args.experiment_name, "best.onnx") -torch.onnx.export( - model, - dummy_input, - model_path, - output_names=["output_yolox"] -) +torch.onnx.export(model, dummy_input, model_path, output_names=["output_yolox"]) experiment.store("model-latest", model_path) print("Exported the model best.onnx as model-latest") diff --git a/yolox-detection/experiment/test_augmentations.py b/yolox-detection/experiment/test_augmentations.py new file mode 100644 index 00000000..0d4ec7dc --- /dev/null +++ b/yolox-detection/experiment/test_augmentations.py @@ -0,0 +1,64 @@ +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import tqdm + +from YOLOX.yolox.data import TrainTransformV3 + + +# Assuming TrainTransformV2 and other necessary imports are already included here + + +def load_image(image_path): + """Load an image from a file path.""" + image = cv2.imread(image_path) + image = cv2.cvtColor( + image, cv2.COLOR_BGR2RGB + ) # Convert from BGR to RGB for display purposes + return image + + +def visualize_image_with_boxes(ax, image, boxes, title="Image"): + """Visualize an image with bounding boxes on a matplotlib axis.""" + # Check if the image is in channel-first format and adjust if necessary + if image.shape[0] == 3: + image = image.transpose(1, 2, 0) # CHW to HWC + ax.imshow(image.astype(np.uint8)) + ax.set_title(title) + + +def main(image_path): + image = load_image(image_path) + # Example bounding boxes [x1, y1, x2, y2] + boxes = np.array([[50, 50, 200, 200], [150, 150, 300, 300]]) + labels = np.array([1, 2]) # Dummy labels + + transform = TrainTransformV3(enable_weather_transform=True, max_labels=50) + + fig, axes = plt.subplots( + 2, 4, figsize=(15, 10) + ) # Adjust subplot grid for 10 images + axes = axes.ravel() + + for i in tqdm.tqdm(range(8)): + # Repeatedly apply transformation to visualize different results + transformed_image, transformed_data = transform( + image.copy(), np.hstack((boxes, labels[:, None])), (640, 640) + ) + transformed_boxes = transformed_data[:, 1:5] # Extract transformed boxes + + # Visualize each transformed image and its boxes + visualize_image_with_boxes( + axes[i], + transformed_image, + transformed_boxes, + title=f"Transformed Image {i + 1}", + ) + + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + image_path = "../../test-yoloxv2_validation-step/images/train2017/images/018fe72d-a345-7fae-99ce-fb8cc94bd5bb.JPG" + main(image_path) diff --git a/yolox-detection/requirements.txt b/yolox-detection/requirements.txt index b9efe0de..ca2b535f 100644 --- a/yolox-detection/requirements.txt +++ b/yolox-detection/requirements.txt @@ -11,5 +11,6 @@ pycocotools==2.0.7 tqdm==4.66.1 thop==0.1.1.post2209072238 ninja==1.11.1.1 -picsellia==6.12.0 +picsellia==6.16.0 onnx==1.15.0 +git+https://github.com/marcown/imgaug.git@5eb7adda6aa2ea1628e7e3a7d64d32a3335d38f5 \ No newline at end of file