Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Changes in TF Data Augmenter to match 5.8 performance and accuracy (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyajain17 authored Dec 13, 2019
1 parent 91570b1 commit 51f6bc2
Show file tree
Hide file tree
Showing 2 changed files with 430 additions and 360 deletions.
56 changes: 36 additions & 20 deletions src/ml/neural_net/tf_compute_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,39 +194,29 @@ tf_model_backend::~tf_model_backend() {

class tf_image_augmenter : public float_array_image_augmenter {
public:
tf_image_augmenter(const options& opts);
~tf_image_augmenter() override = default;
tf_image_augmenter(const options& opts, pybind11::object augmenter);

~tf_image_augmenter();

float_array_result prepare_augmented_images(
labeled_float_image data_to_augment) override;
private:
pybind11::object augmenter_;

};

tf_image_augmenter::tf_image_augmenter(const options& opts) : float_array_image_augmenter(opts) {}
tf_image_augmenter::tf_image_augmenter(const options& opts, pybind11::object augmenter) : float_array_image_augmenter(opts), augmenter_(augmenter) {}

float_array_image_augmenter::float_array_result
tf_image_augmenter::prepare_augmented_images(
float_array_image_augmenter::labeled_float_image data_to_augment) {
options opts = get_options();
float_array_image_augmenter::float_array_result image_annotations;

call_pybind_function([&]() {
// Import the module from python that does data augmentation
pybind11::module tf_aug = pybind11::module::import(
"turicreate.toolkits.object_detector._tf_image_augmenter");

const size_t output_height = opts.output_height;
const size_t output_width = opts.output_width;

// TODO: Remove resize_only by passing all the augmentation options
bool resize_only = false;
if (opts.crop_prob == 0.f) {
resize_only = true;
}

// Get augmented images and annotations from tensorflow
pybind11::object augmented_data = tf_aug.attr("get_augmented_data")(
data_to_augment.images, data_to_augment.annotations, output_height,
output_width, resize_only);
pybind11::object augmented_data = augmenter_.attr("get_augmented_data")(
data_to_augment.images, data_to_augment.annotations);
std::pair<pybind11::buffer, std::vector<pybind11::buffer>> aug_data =
augmented_data
.cast<std::pair<pybind11::buffer, std::vector<pybind11::buffer>>>();
Expand Down Expand Up @@ -257,6 +247,10 @@ tf_image_augmenter::prepare_augmented_images(
return image_annotations;
}

tf_image_augmenter::~tf_image_augmenter() {
call_pybind_function([&]() { augmenter_ = pybind11::object(); });
}

namespace {

std::unique_ptr<compute_context> create_tf_compute_context() {
Expand Down Expand Up @@ -338,7 +332,29 @@ std::unique_ptr<model_backend> tf_compute_context::create_activity_classifier(

std::unique_ptr<image_augmenter> tf_compute_context::create_image_augmenter(
const image_augmenter::options& opts) {
return std::unique_ptr<image_augmenter>(new tf_image_augmenter(opts));
std::unique_ptr<tf_image_augmenter> result;

call_pybind_function([&]() {

const size_t output_height = opts.output_height;
const size_t output_width = opts.output_width;
const size_t batch_size = opts.batch_size;

// TODO: Remove resize_only by passing all the augmentation options
bool resize_only = false;
if (opts.crop_prob == 0.f) {
resize_only = true;
}

pybind11::module tf_aug = pybind11::module::import(
"turicreate.toolkits.object_detector._tf_image_augmenter");

// Make an instance of python object
pybind11::object image_augmenter =
tf_aug.attr("DataAugmenter")(output_height, output_width, batch_size, resize_only);
result.reset(new tf_image_augmenter(opts, image_augmenter));
});
return result;
}

std::unique_ptr<model_backend> tf_compute_context::create_style_transfer(
Expand Down
Loading

0 comments on commit 51f6bc2

Please sign in to comment.