diff --git a/src/ml/neural_net/tf_compute_context.cpp b/src/ml/neural_net/tf_compute_context.cpp index 1b0e290a74..094d04229f 100644 --- a/src/ml/neural_net/tf_compute_context.cpp +++ b/src/ml/neural_net/tf_compute_context.cpp @@ -194,39 +194,29 @@ tf_model_backend::~tf_model_backend() { class tf_image_augmenter : public float_array_image_augmenter { public: - tf_image_augmenter(const options& opts); - ~tf_image_augmenter() override = default; + tf_image_augmenter(const options& opts, pybind11::object augmenter); + + ~tf_image_augmenter(); float_array_result prepare_augmented_images( labeled_float_image data_to_augment) override; + private: + pybind11::object augmenter_; + }; -tf_image_augmenter::tf_image_augmenter(const options& opts) : float_array_image_augmenter(opts) {} +tf_image_augmenter::tf_image_augmenter(const options& opts, pybind11::object augmenter) : float_array_image_augmenter(opts), augmenter_(augmenter) {} float_array_image_augmenter::float_array_result tf_image_augmenter::prepare_augmented_images( float_array_image_augmenter::labeled_float_image data_to_augment) { - options opts = get_options(); float_array_image_augmenter::float_array_result image_annotations; call_pybind_function([&]() { - // Import the module from python that does data augmentation - pybind11::module tf_aug = pybind11::module::import( - "turicreate.toolkits.object_detector._tf_image_augmenter"); - - const size_t output_height = opts.output_height; - const size_t output_width = opts.output_width; - - // TODO: Remove resize_only by passing all the augmentation options - bool resize_only = false; - if (opts.crop_prob == 0.f) { - resize_only = true; - } // Get augmented images and annotations from tensorflow - pybind11::object augmented_data = tf_aug.attr("get_augmented_data")( - data_to_augment.images, data_to_augment.annotations, output_height, - output_width, resize_only); + pybind11::object augmented_data = augmenter_.attr("get_augmented_data")( + data_to_augment.images, data_to_augment.annotations); std::pair> aug_data = augmented_data .cast>>(); @@ -257,6 +247,10 @@ tf_image_augmenter::prepare_augmented_images( return image_annotations; } +tf_image_augmenter::~tf_image_augmenter() { + call_pybind_function([&]() { augmenter_ = pybind11::object(); }); +} + namespace { std::unique_ptr create_tf_compute_context() { @@ -338,7 +332,29 @@ std::unique_ptr tf_compute_context::create_activity_classifier( std::unique_ptr tf_compute_context::create_image_augmenter( const image_augmenter::options& opts) { - return std::unique_ptr(new tf_image_augmenter(opts)); + std::unique_ptr result; + + call_pybind_function([&]() { + + const size_t output_height = opts.output_height; + const size_t output_width = opts.output_width; + const size_t batch_size = opts.batch_size; + + // TODO: Remove resize_only by passing all the augmentation options + bool resize_only = false; + if (opts.crop_prob == 0.f) { + resize_only = true; + } + + pybind11::module tf_aug = pybind11::module::import( + "turicreate.toolkits.object_detector._tf_image_augmenter"); + + // Make an instance of python object + pybind11::object image_augmenter = + tf_aug.attr("DataAugmenter")(output_height, output_width, batch_size, resize_only); + result.reset(new tf_image_augmenter(opts, image_augmenter)); + }); + return result; } std::unique_ptr tf_compute_context::create_style_transfer( diff --git a/src/python/turicreate/toolkits/object_detector/_tf_image_augmenter.py b/src/python/turicreate/toolkits/object_detector/_tf_image_augmenter.py index e0b3b4b5d2..09dba756f6 100644 --- a/src/python/turicreate/toolkits/object_detector/_tf_image_augmenter.py +++ b/src/python/turicreate/toolkits/object_detector/_tf_image_augmenter.py @@ -22,351 +22,405 @@ tf.disable_v2_behavior() -def get_augmented_data(images, annotations, output_height, output_width, resize_only): +_DEFAULT_AUG_PARAMS = { + 'max_hue_adjust' : 0.05, + 'max_brightness' : 0.05, + 'max_contrast' : 1.25, + 'max_saturation' : 1.25, + 'skip_probability_flip' : 0.5, + 'min_aspect_ratio' : 0.8, + 'max_aspect_ratio' : 1.25, + 'min_area_fraction_crop' : 0.15, + 'max_area_fraction_crop' : 1.0, + 'min_area_fraction_pad' : 1.0, + 'max_area_fraction_pad' : 2.0, + 'max_attempts' : 50, + 'skip_probability_pad' : 0.1, + 'skip_probability_crop' : 0.1, + 'min_object_covered': 0.0, + 'min_eject_coverage': 0.5 +} + +def hue_augmenter(image, annotation, + max_hue_adjust=_DEFAULT_AUG_PARAMS["max_hue_adjust"]): + + # Sample a random rotation around the color wheel. + hue_adjust = 0.0 + if (max_hue_adjust is not None) and (max_hue_adjust > 0.0): + hue_adjust += np.pi * np.random.uniform(-max_hue_adjust, max_hue_adjust) + + # Apply the rotation to the hue + image = tf.image.random_hue(image, max_delta=max_hue_adjust) + image = tf.clip_by_value(image, 0, 1) + + # No geometry changes, so just copy the annotations. + return image, annotation + +def color_augmenter(image, annotation, + max_brightness=_DEFAULT_AUG_PARAMS["max_brightness"], + max_contrast=_DEFAULT_AUG_PARAMS["max_contrast"], + max_saturation=_DEFAULT_AUG_PARAMS["max_saturation"]): + + # Sample a random adjustment to brightness. + if max_brightness is not None and max_brightness > 0: + image = tf.image.random_brightness(image, max_delta=max_brightness) + + # Sample a random adjustment to contrast. + if max_saturation is not None and max_saturation > 1.0: + log_sat = np.log(max_saturation) + image = tf.image.random_saturation(image, lower=np.exp(-log_sat), upper=np.exp(log_sat)) + + # Sample a random adjustment to saturation. + if max_contrast is not None and max_contrast > 1.0: + log_con = np.log(max_contrast) + image = tf.image.random_contrast(image, lower=np.exp(-log_con), upper=np.exp(log_con)) + + image = tf.clip_by_value(image, 0, 1) + + # No geometry changes, so just copy the annotations. + return image, annotation + +def resize_augmenter(image, annotation, + output_shape): + - # Suppresses verbosity to only errors - tf.logging.set_verbosity(tf.logging.ERROR) + new_height = tf.cast(output_shape[0], dtype=tf.int32) + new_width = tf.cast(output_shape[1], dtype=tf.int32) - graph = tf.Graph() - with graph.as_default(): - with tf.Session() as session: - output_shape = (output_height, output_width) + # Determine the affine transform to apply and apply to the image itself. + image_scaled = tf.squeeze(tf.image.resize_bilinear( + tf.expand_dims(image, 0), [new_height, new_width]), [0]) + image_clipped = tf.clip_by_value(image_scaled, 0.0, 1.0) + annotation = tf.clip_by_value(annotation, 0.0, 1.0) - if resize_only: - images = get_resized_images(images, output_shape) - resized_images = session.run(images) - resized_images = np.array(resized_images, dtype=np.float32) - return tuple((resized_images, len(resized_images)*[np.zeros(6)])) - else: - imgs, transformations = get_augmented_images(images, output_shape) - augmented_images, trans = session.run([imgs, transformations]) - augmented_annotations = apply_bounding_box_transformation(images, annotations, trans, output_shape) - augmented_images = np.array(augmented_images, dtype=np.float32) - return tuple((augmented_images, augmented_annotations)) - -def is_tensor(x): - # Checks if `x` is a symbolic tensor-like object. - return isinstance(x, (ops.Tensor, variables.Variable)) - -def image_dimensions(images, static_only=True): - # Returns the dimensions of an image tensor. - if static_only or images.get_shape().is_fully_defined(): - return images.get_shape().as_list() - else: - return tf.unstack(tf.shape(images)) - -def get_image_dimensions(image, rank): - # Returns the dimensions of an image tensor. - if image.get_shape().is_fully_defined(): - return image.get_shape().as_list() - else: - static_shape = image.get_shape().with_rank(rank).as_list() - dynamic_shape = array_ops.unstack(array_ops.shape(image), rank) - return [s if s is not None else d for s, d in zip(static_shape, dynamic_shape)] - -def check_three_dim_image(image, require_static=True): - # Assert image is three dimensional - try: - image_shape = image.get_shape().with_rank(3) - except ValueError: - raise ValueError("'image' must be three-dimensional.") - if require_static and not image_shape.is_fully_defined(): - raise ValueError("'image' must be fully defined.") - if any(x == 0 for x in image_shape): - raise ValueError("all dims of 'image.shape' must be > 0: %s" % - image_shape) - if not image_shape.is_fully_defined(): - return [check_ops.assert_positive(array_ops.shape(image), - ["all dims of 'image.shape' must be > 0."])] - else: - return [] - -def check_atlease_three_dim_image(image, require_static=True): - # Assert image is atleast three dimensional - try: - if image.get_shape().ndims is None: - image_shape = image.get_shape().with_rank(3) - else: - image_shape = image.get_shape().with_rank_at_least(3) - except ValueError: - raise ValueError("'image' must be at least three-dimensional.") - if require_static and not image_shape.is_fully_defined(): - raise ValueError('\'image\' must be fully defined.') - if any(x == 0 for x in image_shape): - raise ValueError('all dims of \'image.shape\' must be > 0: %s' % - image_shape) - if not image_shape.is_fully_defined(): - return [check_ops.assert_positive(array_ops.shape(image),["all dims of 'image.shape' " - "must be > 0."])] - else: - return [] - -def pad_to_ensure_size(image, target_height, target_width, random=True): - image = ops.convert_to_tensor(image, name='image') - - assert_ops = [] - assert_ops += check_three_dim_image(image, require_static=False) + # No geometry changes (because of relative co-ordinate system) + return image_clipped, annotation + + +def horizontal_flip_augmenter(image, annotation, skip_probability=_DEFAULT_AUG_PARAMS["skip_probability_flip"]): + + if np.random.uniform(0.0, 1.0) < skip_probability: + return image, annotation - image = control_flow_ops.with_dependencies(assert_ops, image) - # `crop_to_bounding_box` and `pad_to_bounding_box` have their own checks. - # Make sure our checks come first, so that error messages are clearer. - if is_tensor(target_height): - target_height = control_flow_ops.with_dependencies(assert_ops, target_height) - if is_tensor(target_width): - target_width = control_flow_ops.with_dependencies(assert_ops, target_width) - - def max_(x, y): - if is_tensor(x) or is_tensor(y): - return math_ops.maximum(x, y) - else: - return max(x, y) - - height, width, _ = image_dimensions(image, static_only=False) - width_diff = target_width - width - offset_crop_width = max_(-width_diff // 2, 0) - if random: - offset_pad_width = tf.random_uniform([], minval=0, maxval=max_(width_diff, 1), dtype=tf.int32) - else: - offset_pad_width = max_(width_diff // 2, 0) - - height_diff = target_height - height - offset_crop_height = max_(-height_diff // 2, 0) - if random: - offset_pad_height = tf.random_uniform([], minval=0, maxval=max_(height_diff, 1), dtype=tf.int32) - else: - offset_pad_height = max_(height_diff // 2, 0) - - # Maybe pad if needed. - resized = pad_to_bounding_box(image, offset_pad_height, offset_pad_width, - max_(target_height, height), max_(target_width, width)) - - # In theory all the checks below are redundant. - if resized.get_shape().ndims is None: - raise ValueError('resized contains no shape.') - - resized_height, resized_width, _ = \ - image_dimensions(resized, static_only=False) - return resized, (offset_pad_height, offset_pad_width) - -def pad_to_bounding_box(image, offset_height, offset_width, target_height, - target_width): - image = ops.convert_to_tensor(image, name='image') - - is_batch = True - image_shape = image.get_shape() - if image_shape.ndims == 3: - is_batch = False - image = array_ops.expand_dims(image, 0) - elif image_shape.ndims is None: - is_batch = False - image = array_ops.expand_dims(image, 0) - image.set_shape([None] * 4) - elif image_shape.ndims != 4: - raise ValueError('\'image\' must have either 3 or 4 dimensions.') - - assert_ops = check_atlease_three_dim_image(image, require_static=False) - - batch, height, width, depth = get_image_dimensions(image, rank=4) - - after_padding_width = target_width - offset_width - width - after_padding_height = target_height - offset_height - height - - assert_ops += _assert(offset_height >= 0, ValueError, - 'offset_height must be >= 0') - assert_ops += _assert(offset_width >= 0, ValueError, - 'offset_width must be >= 0') - assert_ops += _assert(after_padding_width >= 0, ValueError, - 'width must be <= target - offset') - assert_ops += _assert(after_padding_height >= 0, ValueError, - 'height must be <= target - offset') - image = control_flow_ops.with_dependencies(assert_ops, image) - - # Do not pad on the depth dimensions. - paddings = array_ops.reshape(array_ops.stack([ - 0, 0, - offset_height, after_padding_height, - offset_width, after_padding_width, - 0, 0]), [4, 2]) - padded = array_ops.pad(image, paddings, constant_values=0.5) - - padded_shape = [None if is_tensor(i) else i - for i in [batch, target_height, target_width, depth]] - padded.set_shape(padded_shape) - - if not is_batch: - padded = array_ops.squeeze(padded, squeeze_dims=[0]) - - return padded - -def apply_bounding_box_transformation(images, annotations, transformations, clip_to_shape=None): - - aug_anns = [] - for i in range(len(annotations)): - image = _utils.convert_shared_float_array_to_numpy(images[i]) - height = image.shape[0] - width = image.shape[0] - ann = annotations[i] - annotation = _utils.convert_shared_float_array_to_numpy(ann) - identifier = np.expand_dims(annotation[:, 0], axis=1) - box = np.zeros(annotation[:, 1:5].shape) - for j in range(len(annotation)): - box[j][0] = annotation[j][2]*float(height) - box[j][1] = annotation[j][1]*float(width) - box[j][2] = (annotation[j][4]+annotation[j][2])*float(height) - box[j][3] = (annotation[j][3]+annotation[j][1])*float(width) - - confidence = np.expand_dims(annotation[:, 5], axis=1) - - # The bounding box is [n, 4] reshaped and ones added to multiply to tranformation matrix - v = np.concatenate([box.reshape(-1, 2), np.ones((box.shape[0]*2, 1), dtype=np.float32)], axis=1) - # Transform - v = np.dot(v, np.transpose(transformations[i])) - # Reverse shape - bbox_out = v[:, :2].reshape(-1, 4) - - # Make points correctly ordered (lower < upper) - # Can probably be made much nicer (numpy-ified?) - for i in range(len(bbox_out)): - if bbox_out[i][0] > bbox_out[i][2]: - bbox_out[i][0], bbox_out[i][2] = bbox_out[i][2], bbox_out[i][0] - if bbox_out[i][1] > bbox_out[i][3]: - bbox_out[i][1], bbox_out[i][3] = bbox_out[i][3], bbox_out[i][1] - - if clip_to_shape is not None: - bbox_out[:, 0::2] = np.clip(bbox_out[:, 0::2], 0, clip_to_shape[0]) - bbox_out[:, 1::2] = np.clip(bbox_out[:, 1::2], 0, clip_to_shape[1]) - - bbox = np.zeros(bbox_out.shape) - for k in range(len(bbox_out)): - bbox[k][0] = bbox_out[k][1]/float(clip_to_shape[0]) - bbox[k][1] = bbox_out[k][0]/float(clip_to_shape[1]) - bbox[k][2] = (bbox_out[k][3] - bbox_out[k][1])/float(clip_to_shape[0]) - bbox[k][3] = (bbox_out[k][2] - bbox_out[k][0])/float(clip_to_shape[1]) - - an = np.hstack((np.hstack((identifier, bbox)), confidence)) - an = np.ascontiguousarray(an, dtype=np.float32) - aug_anns.append(an) - return aug_anns - -def _assert(cond, ex_type, msg): - # A polymorphic assert, works with tensors and boolean expressions. - # If `cond` is not a tensor, behave like an ordinary assert statement, except - # that a empty list is returned. If `cond` is a tensor, return a list - # containing a single TensorFlow assert op. - - if is_tensor(cond): - return [control_flow_ops.Assert(cond, [msg])] - else: - if not cond: - raise ex_type(msg) - else: - return [] - -def get_augmented_images(images, output_shape): - - # Store transformations and augmented_images for the input batch - transformations = [] - augmented_images = [] - - # Augmentation option - min_scale = 1/1.5 - max_scale = 1.5 - max_aspect_ratio=1.5 - max_hue=0.05 - max_brightness=0.05 - max_saturation=1.25 - max_contrast=1.25 - horizontal_flip=True - - for i in range(len(images)): - - image = images[i] - image = _utils.convert_shared_float_array_to_numpy(image) + image_height, image_width, _ = image.shape + flipped_image = np.flip(image, 1) + + # One image can have more than one annotation. so loop through annotations and flip across the horizontal axis + for i in range(len(annotation)): + # Only flip if the annotation is not an empty annotation. + if np.any(annotation[i][1:5]): + annotation[i][1] = 1 - annotation[i][1] - annotation[i][3] + + return flipped_image, annotation + +def padding_augmenter(image, + annotation, + skip_probability=_DEFAULT_AUG_PARAMS["skip_probability_pad"], + min_aspect_ratio=_DEFAULT_AUG_PARAMS["min_aspect_ratio"], + max_aspect_ratio=_DEFAULT_AUG_PARAMS["max_aspect_ratio"], + min_area_fraction=_DEFAULT_AUG_PARAMS["min_area_fraction_pad"], + max_area_fraction=_DEFAULT_AUG_PARAMS["max_area_fraction_pad"], + max_attempts=_DEFAULT_AUG_PARAMS["max_attempts"]): + if np.random.uniform(0.0, 1.0) < skip_probability: + return np.array(image), annotation + + image_height, image_width, _ = image.shape + + # Randomly sample aspect ratios until one derives a non-empty range of + # compatible heights, or until reaching the upper limit on attempts. + for i in range(max_attempts): + # Randomly sample an aspect ratio. + aspect_ratio = np.random.uniform(min_aspect_ratio, max_aspect_ratio) + + # The padded height must be at least as large as the original height. + # h' >= h + min_height = float(image_height) - height, width, _ = tf.unstack(tf.shape(image)) - scale_h = tf.random_uniform([], minval=min_scale, maxval=max_scale) - scale_w = scale_h * tf.exp(tf.random_uniform([], minval=-np.log(max_aspect_ratio), maxval=np.log(max_aspect_ratio))) - new_height = tf.to_int32(tf.to_float(height) * scale_h) - new_width = tf.to_int32(tf.to_float(width) * scale_w) - - image_scaled = tf.squeeze(tf.image.resize_bilinear(tf.expand_dims(image, 0), [new_height, new_width]), [0]) - # Image padding - pad_image, pad_offset = pad_to_ensure_size(image_scaled, output_shape[0], output_shape[1]) - - new_height = tf.maximum(output_shape[0], new_height) - new_width = tf.maximum(output_shape[1], new_width) - - slice_offset = (tf.random_uniform([], minval=0, maxval=new_height - output_shape[0] + 1, dtype=tf.int32), - tf.random_uniform([], minval=0, maxval=new_width - output_shape[1] + 1, dtype=tf.int32)) - augmented_image = array_ops.slice(pad_image, [slice_offset[0], slice_offset[1], 0], [output_shape[0], output_shape[1], 3]) - - if horizontal_flip: - uniform_random = random_ops.random_uniform([], 0, 1.0) - did_horiz_flip = math_ops.less(uniform_random, .5) - augmented_image = control_flow_ops.cond(did_horiz_flip, - lambda: array_ops.reverse(augmented_image, [1]), - lambda: augmented_image) - flip_sign = 1 - tf.to_float(did_horiz_flip) * 2 - else: - flip_sign = 1 - did_horiz_flip = tf.constant(False) - - ty = tf.to_float(pad_offset[0] - slice_offset[0] ) - tx = flip_sign * tf.to_float(pad_offset[1] - slice_offset[1] ) + tf.to_float(did_horiz_flip) * output_shape[1] - - # Make the transformation matrix - transformation = tf.reshape(tf.stack([ - scale_h, 0.0, ty, - 0.0, flip_sign * scale_w, tx, - 0.0, 0.0, 1.0] - ), (3, 3)) - - if max_hue is not None and max_hue > 0: - image = tf.image.random_hue(augmented_image, max_delta=max_hue) - - if max_brightness is not None and max_brightness > 0: - image = tf.image.random_brightness(augmented_image, max_delta=max_brightness) - - if max_saturation is not None and max_saturation > 1.0: - log_sat = np.log(max_saturation) - image = tf.image.random_saturation(augmented_image, lower=np.exp(-log_sat), upper=np.exp(log_sat)) - - if max_contrast is not None and max_contrast > 1.0: - log_con = np.log(max_contrast) - image = tf.image.random_contrast(augmented_image, lower=np.exp(-log_con), upper=np.exp(log_con)) - - augmented_image = tf.clip_by_value(augmented_image, 0, 1) - augmented_images.append(augmented_image) - transformations.append(transformation) + # The padded width must be at least as large as the original width. + # w' >= w IMPLIES ah' >= w IMPLIES h' >= w / a + min_height_from_width = float(image_width) / aspect_ratio + if min_height < min_height_from_width: + min_height = min_height_from_width - return augmented_images, transformations - -def get_resized_images(images, output_shape): + # The padded area must attain the minimum area fraction. + # w'h' >= fhw IMPLIES ah'h' >= fhw IMPLIES h' >= sqrt(fhw/a) + min_height_from_area = np.sqrt(image_height * image_width * min_area_fraction / aspect_ratio) + if min_height < min_height_from_area: + min_height = min_height_from_area + + # The padded area must not exceed the maximum area fraction. + max_height = np.sqrt(image_height * image_width * max_area_fraction / aspect_ratio) + + if min_height >= max_height: + break + + # We did not find a compatible aspect ratio. Just return the original data. + if (min_height > max_height): + return np.array(image), annotation + + # Sample a final size, given the sampled aspect ratio and range of heights. + padded_height = np.random.uniform(min_height, max_height) + padded_width = padded_height * aspect_ratio; + + # Sample the offset of the source image inside the padded image. + x_offset = np.random.uniform(0.0, (padded_width - image_width)) + y_offset = np.random.uniform(0.0, (padded_height - image_height)) + + # Compute padding needed on the image + after_padding_width = padded_width - image_width - x_offset + after_padding_height = padded_height - image_height - y_offset + + # Pad the image + npad = ((int(y_offset), int(after_padding_height)), (int(x_offset), int(after_padding_width)), (0, 0)) + padded_image = np.pad(image, pad_width=npad, mode='constant', constant_values=0.5) + + ty = float(y_offset) + tx = float(x_offset) + + # Transformation matrix for the annotations + transformation_matrix = np.array([ + [1.0, 0.0, ty], + [0.0, 1.0, tx], + [0.0, 0.0, 1.0] + ]) - resized_images = [] - for i in range(len(images)): + # Use transformation matrix to augment annotations + formatted_annotation = [] + for aug in annotation: + identifier = aug[0:1] + bounds = aug[1:5] + confidence = aug[5:6] + + if not np.any(bounds): + formatted_annotation.append(np.concatenate([identifier, np.array([0, 0, 0, 0]), confidence])) + continue + + width = bounds[2] + height = bounds[3] + + x1 = bounds[0] * image_width + y1 = bounds[1] * image_height + x2 = (bounds[0] + width) * image_width + y2 = (bounds[1] + height) * image_height - image = images[i] - image = _utils.convert_shared_float_array_to_numpy(image) - height, width, _ = tf.unstack(tf.shape(image)) - orig_shape = (height, width) - scale_h = tf.constant(output_shape[0], dtype=tf.float32) / tf.to_float(height) - scale_w = tf.constant(output_shape[1], dtype=tf.float32) / tf.to_float(width) - new_height = tf.to_int32(tf.to_float(height) * scale_h) - new_width = tf.to_int32(tf.to_float(width) * scale_w) - - image_scaled = tf.squeeze(tf.image.resize_bilinear(tf.expand_dims(image, 0), [new_height, new_width]), [0]) - - pad_image, pad_offset = pad_to_ensure_size(image_scaled, output_shape[0], output_shape[1], - random=False) - - new_height = tf.maximum(output_shape[0], new_height) - new_width = tf.maximum(output_shape[1], new_width) - - slice_offset = (tf.random_uniform([], minval=0, maxval=new_height - output_shape[0] + 1, dtype=tf.int32), - tf.random_uniform([], minval=0, maxval=new_width - output_shape[1] + 1, dtype=tf.int32)) - image = array_ops.slice(pad_image, [slice_offset[0], slice_offset[1], 0], [output_shape[0], output_shape[1], 3]) - image = tf.clip_by_value(image, 0, 1) - resized_images.append(image) - - return resized_images + augmentation_coordinates = np.array([y1, x1, y2, x2], dtype=np.float32) + + v = np.concatenate([augmentation_coordinates.reshape((2, 2)), np.ones((2, 1), dtype=np.float32)], axis=1) + transposed_v = np.dot(v, np.transpose(transformation_matrix)) + t_intersection = np.squeeze(transposed_v[:, :2].reshape(-1, 4)) + + # Sort the points top, left, bottom, right + if t_intersection[0] > t_intersection[2]: + t_intersection[0], t_intersection[2] = t_intersection[2], t_intersection[0] + if t_intersection[1] > t_intersection[3]: + t_intersection[1], t_intersection[3] = t_intersection[3], t_intersection[1] + + # Normalize the elements to the cropped width and height + ele_1 = t_intersection[1] / padded_width + ele_2 = t_intersection[0] / padded_height + ele_3 = (t_intersection[3] - t_intersection[1]) /padded_width + ele_4 = (t_intersection[2] - t_intersection[0]) / padded_height + + formatted_annotation.append(np.concatenate([identifier, np.array([ele_1, ele_2, ele_3, ele_4]), confidence])) + + return np.array(padded_image), np.array(formatted_annotation, dtype=np.float32) + +def crop_augmenter(image, + annotation, + skip_probability=_DEFAULT_AUG_PARAMS["skip_probability_crop"], + min_aspect_ratio=_DEFAULT_AUG_PARAMS["min_aspect_ratio"], + max_aspect_ratio=_DEFAULT_AUG_PARAMS["max_aspect_ratio"], + min_area_fraction=_DEFAULT_AUG_PARAMS["min_area_fraction_crop"], + max_area_fraction=_DEFAULT_AUG_PARAMS["max_area_fraction_crop"], + min_object_covered=_DEFAULT_AUG_PARAMS["min_object_covered"], + max_attempts=_DEFAULT_AUG_PARAMS["max_attempts"], + min_eject_coverage=_DEFAULT_AUG_PARAMS["min_eject_coverage"]): + + if np.random.uniform(0.0, 1.0) < skip_probability: + return np.array(image), annotation + + image_height, image_width, _ = image.shape + + # Sample crop rects until one satisfies our constraints (by yielding a valid + # list of cropped annotations), or reaching the limit on attempts. + for i in range(max_attempts): + # Randomly sample an aspect ratio. + aspect_ratio = np.random.uniform(min_aspect_ratio, max_aspect_ratio) + + # Next we'll sample a height (which combined with the now known aspect + # ratio, determines the size and area). But first we must compute the range + # of valid heights. The crop cannot be taller than the original image, + # cannot be wider than the original image, and must have an area in the + # specified range. + + # The cropped height must be no larger the original height. + # h' <= h + max_height = float(image_height) + + # The cropped width must be no larger than the original width. + # w' <= w IMPLIES ah' <= w IMPLIES h' <= w / a + max_height_from_width = float(image_width) / aspect_ratio + if max_height > max_height_from_width: + max_height = max_height_from_width + + # The cropped area must not exceed the maximum area fraction. + max_height_from_area = np.sqrt(image_height * image_width * max_area_fraction / aspect_ratio) + if max_height > max_height_from_area: + max_height = max_height_from_area + + # The padded area must attain the minimum area fraction. + min_height = np.sqrt(image_height * image_width * min_area_fraction / aspect_ratio) + + # If the range is empty, then crops with the sampled aspect ratio cannot + # satisfy the area constraint. + if min_height > max_height: + continue + + + # Sample a position for the crop, constrained to lie within the image. + cropped_height = np.random.uniform(min_height, max_height) + cropped_width = cropped_height * aspect_ratio; + + x_offset = np.random.uniform(0.0, (image_width - cropped_width)) + y_offset = np.random.uniform(0.0, (image_height - cropped_height)) + + crop_bounds_x1 = x_offset + crop_bounds_y1 = y_offset + crop_bounds_x2 = x_offset + cropped_width + crop_bounds_y2 = y_offset + cropped_height + + formatted_annotation = [] + is_min_object_covered = True + for aug in annotation: + identifier = aug[0:1] + bounds = aug[1:5] + confidence = aug[5:6] + + width = bounds[2] + height = bounds[3] + + x1 = bounds[0] * image_width + y1 = bounds[1] * image_height + + x2 = (bounds[0] + width) * image_width + y2 = (bounds[1] + height) * image_height + + # This tests whether the crop bounds are out of the annotated bounds, if not it returns an empty annotation + if crop_bounds_x1 < x2 and crop_bounds_y1 < y2 and x1 < crop_bounds_x2 and y1 < crop_bounds_y2: + x_bounds = [x1, x2, x_offset, x_offset + cropped_width] + y_bounds = [y1, y2, y_offset, y_offset + cropped_height] + + x_bounds.sort() + y_bounds.sort() + + x_pairs = x_bounds[1:3] + y_pairs = y_bounds[1:3] + + intersection = np.array([y_pairs[0], x_pairs[0], y_pairs[1], x_pairs[1]]) + + intersection_area = (intersection[3] - intersection[1]) * (intersection[2] - intersection[0]) + annotation_area = (y2 - y1) * (x2 - x1) + + area_coverage = intersection_area / annotation_area + + # Invalidate the crop if it did not sufficiently overlap each annotation and try again. + if area_coverage < min_object_covered: + is_min_object_covered = False + break + + + # If the area coverage is greater the min_eject_coverage, then actually keep the annotation + if area_coverage >= min_eject_coverage: + # Transformation matrix for the annotations + transformation_matrix = np.array([ + [1.0, 0.0, -y_offset], + [0.0, 1.0, -x_offset], + [0.0, 0.0, 1.0] + ]) + + v = np.concatenate([intersection.reshape((2, 2)), np.ones((2, 1), dtype=np.float32)], axis=1) + transposed_v = np.dot(v, np.transpose(transformation_matrix)) + t_intersection = np.squeeze(transposed_v[:, :2].reshape(-1, 4)) + + # Sort the points top, left, bottom, right + if t_intersection[0] > t_intersection[2]: + t_intersection[0], t_intersection[2] = t_intersection[2], t_intersection[0] + if t_intersection[1] > t_intersection[3]: + t_intersection[1], t_intersection[3] = t_intersection[3], t_intersection[1] + + # Normalize the elements to the cropped width and height + ele_1 = t_intersection[1] / cropped_width + ele_2 = t_intersection[0] / cropped_height + ele_3 = (t_intersection[3] - t_intersection[1]) /cropped_width + ele_4 = (t_intersection[2] - t_intersection[0]) / cropped_height + + formatted_annotation.append(np.concatenate([identifier, np.array([ele_1, ele_2, ele_3, ele_4]), confidence])) + else: + formatted_annotation.append(np.concatenate([identifier, np.array([0.0, 0.0, 0.0, 0.0]), confidence])) + else: + formatted_annotation.append(np.concatenate([identifier, np.array([0.0, 0.0, 0.0, 0.0]), confidence])) + + + if not is_min_object_covered: + continue + + y_offset = int(y_offset) + x_offset = int(x_offset) + end_y = int(cropped_height + y_offset) + end_x = int(cropped_width + x_offset) + + image_cropped = image[y_offset:end_y, x_offset:end_x] + + return np.array(image_cropped), np.array(formatted_annotation, dtype=np.float32) + + return np.array(image), annotation + +def complete_augmenter(img_tf, ann_tf, output_height, output_width): + img_tf, ann_tf = tf.numpy_function(func=crop_augmenter, inp=[img_tf, ann_tf], Tout=[tf.float32, tf.float32]) + img_tf, ann_tf = tf.numpy_function(func=padding_augmenter, inp=[img_tf, ann_tf], Tout=[tf.float32, tf.float32]) + img_tf, ann_tf = tf.numpy_function(func=horizontal_flip_augmenter, inp=[img_tf, ann_tf], Tout=[tf.float32, tf.float32]) + img_tf, ann_tf = color_augmenter(img_tf, ann_tf) + img_tf, ann_tf = hue_augmenter(img_tf, ann_tf) + img_tf, ann_tf = resize_augmenter(img_tf, ann_tf, (output_height, output_width)) + return img_tf, ann_tf + + + +class DataAugmenter(object): + def __init__(self, output_height, output_width, batch_size, resize_only): + self.batch_size = batch_size + self.graph = tf.Graph() + self.resize_only = resize_only + with self.graph.as_default(): + self.img_tf = [tf.placeholder(tf.float32, [None, None, 3]) for x in range(0, self.batch_size )] + self.ann_tf = [tf.placeholder(tf.float32, [None, 6]) for x in range(0, self.batch_size )] + self.resize_op_batch = [] + for i in range(0, self.batch_size): + if resize_only: + aug_img_tf, aug_ann_tf = resize_augmenter(self.img_tf[i], self.ann_tf[i], (output_height, output_width)) + self.resize_op_batch.append([aug_img_tf, aug_ann_tf]) + else: + aug_img_tf, aug_ann_tf = complete_augmenter(self.img_tf[i], self.ann_tf[i], output_height, output_width) + self.resize_op_batch.append([aug_img_tf, aug_ann_tf]) + + def get_augmented_data(self, images, annotations): + with tf.Session(graph=self.graph) as session: + feed_dict = dict() + graph_op = self.resize_op_batch[0:len(images)] + for i in range(0, len(images)): + feed_dict[self.img_tf[i]] = _utils.convert_shared_float_array_to_numpy(images[i]) + if self.resize_only: + feed_dict[self.ann_tf[i]] = self.batch_size * [np.zeros(6)] + else: + feed_dict[self.ann_tf[i]] = _utils.convert_shared_float_array_to_numpy(annotations[i]) + aug_output = session.run(graph_op, feed_dict=feed_dict) + processed_images = [] + processed_annotations = [] + for o in aug_output: + processed_images.append(o[0]) + processed_annotations.append(np.ascontiguousarray(o[1], dtype=np.float32)) + processed_images = np.array(processed_images, dtype=np.float32) + processed_images = np.ascontiguousarray(processed_images, dtype=np.float32) + return (processed_images, processed_annotations) + +