Skip to content

Commit

Permalink
Updating TF pets training example (#179)
Browse files Browse the repository at this point in the history
* Add tf_pets training example

* Adding dockerfile and readme changes

* Resolving review comments

* Resolving review comments

---------

Co-authored-by: SundarRajan98 <[email protected]>
  • Loading branch information
SundarRajan28 and SundarRajan98 authored Jul 22, 2024
1 parent e4a57b7 commit bda170d
Show file tree
Hide file tree
Showing 8 changed files with 583 additions and 156 deletions.
17 changes: 1 addition & 16 deletions docker/rocal-with-tensorflow.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ RUN DEBIAN_FRONTEND=noninteractive apt-get -y install python3 python3-pip git g+
git clone https://github.com/Tencent/rapidjson.git && cd rapidjson && mkdir build && cd build && \
cmake ../ && make -j4 && sudo make install && cd ../../ && \
pip install pytest==7.3.1 && git clone -b v2.11.1 https://github.com/pybind/pybind11 && cd pybind11 && mkdir build && cd build && \
cmake -DDOWNLOAD_CATCH=ON -DDOWNLOAD_EIGEN=ON ../ && make -j4 && sudo make install && cd ../../ && \
pip install numpy==1.24.2 scipy==1.9.3 cython==0.29.* git+https://github.com/ROCm/hipify_torch.git && \
env CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py && \
git clone -b rocm6.1_internal_testing https://github.com/ROCm/cupy.git && cd cupy && git submodule update --init && \
pip install -e . --no-cache-dir -vvvv
cmake -DDOWNLOAD_CATCH=ON -DDOWNLOAD_EIGEN=ON ../ && make -j4 && sudo make install && cd ../../

# install MIVisionX
RUN git clone https://github.com/ROCm/MIVisionX.git && cd MIVisionX && \
Expand All @@ -56,14 +52,3 @@ WORKDIR $ROCAL_WORKSPACE
RUN pip install --upgrade pip
RUN git clone -b develop https://github.com/ROCm/rocAL && \
mkdir build && cd build && cmake -D PYTHON_VERSION_SUGGESTED=${ROCAL_PYTHON_VERSION_SUGGESTED} ../rocAL && make -j8 && cmake --build . --target PyPackageInstall && make install

WORKDIR /workspace

#install tensorflow dependencies - Compile protos and Install TensorFlow Object Detection API.
RUN cd ~/models/research/ && \
protoc object_detection/protos/*.proto --python_out=. && \
cp object_detection/packages/tf2/setup.py . && \
python -m pip install .

#upgrade pip
RUN pip3 install --upgrade pip
9 changes: 6 additions & 3 deletions docs/examples/tf/pets_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

### Running the training

* For first run, to setup dataset, edit "train.py" and set "DATASET_DOWNLOAD_AND_PREPROCESS = True"
* For subsequent runs, after the dataset has already been downloaded and preprocessed, set "DATASET_DOWNLOAD_AND_PREPROCESS = False"
* To setup dataset, run

To run this example for the first run or subsequent runs, just execute:
```shell
bash download_and_preprocess_dataset.sh
```

* To run this example, just execute:

```shell
python3 train.py
Expand Down
314 changes: 314 additions & 0 deletions docs/examples/tf/pets_training/create_pet_tf_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""Convert the Oxford pet dataset to TFRecord for object_detection.
See: O. M. Parkhi, A. Vedaldi, A. Zisserman, C. V. Jawahar
Cats and Dogs
IEEE Conference on Computer Vision and Pattern Recognition, 2012
http://www.robots.ox.ac.uk/~vgg/data/pets/
Example usage:
python object_detection/dataset_tools/create_pet_tf_record.py \
--data_dir=/home/user/pet \
--output_dir=/home/user/pet/output
"""

import hashlib
import io
import logging
import os
import random
import re

import contextlib2
from lxml import etree
import PIL.Image
from six.moves import range
import tensorflow.compat.v1 as tf

flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw pet dataset.')
flags.DEFINE_string('output_dir', '', 'Path to directory to output TFRecords.')
flags.DEFINE_string('label_map_path', 'pet_label_map.pbtxt',
'Path to label map proto')
flags.DEFINE_integer('num_shards', 10, 'Number of TFRecord shards')

FLAGS = flags.FLAGS


def open_sharded_output_tfrecords(exit_stack, base_path, num_shards):
"""Opens all TFRecord shards for writing and adds them to an exit stack.
Args:
exit_stack: A context2.ExitStack used to automatically closed the TFRecords
opened in this function.
base_path: The base path for all shards
num_shards: The number of shards
Returns:
The list of opened TFRecords. Position k in the list corresponds to shard k.
"""
tf_record_output_filenames = [
'{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards)
for idx in range(num_shards)
]

tfrecords = [
exit_stack.enter_context(tf.python_io.TFRecordWriter(file_name))
for file_name in tf_record_output_filenames
]

return tfrecords


def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def get_class_name_from_filename(file_name):
"""Gets the class name from a file.
Args:
file_name: The file name to get the class name from.
ie. "american_pit_bull_terrier_105.jpg"
Returns:
A string of the class name.
"""
match = re.match(r'([A-Za-z_]+)(_[0-9]+\.jpg)', file_name, re.I)
return match.groups()[0]


def read_examples_list(path):
"""Read list of training or validation examples.
The file is assumed to contain a single example per line where the first
token in the line is an identifier that allows us to find the image and
annotation xml for that example.
For example, the line:
xyz 3
would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored).
Args:
path: absolute path to examples list file.
Returns:
list of example identifiers (strings).
"""
with tf.gfile.GFile(path) as fid:
lines = fid.readlines()
return [line.strip().split(' ')[0] for line in lines]


def recursive_parse_xml_to_dict(xml):
"""Recursively parses XML contents to python dict.
We assume that `object` tags are the only ones that can appear
multiple times at the same level of a tree.
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree
Returns:
Python dictionary holding XML contents.
"""
if not xml:
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = recursive_parse_xml_to_dict(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}


def dict_to_tf_example(data,
mask_path,
label_map_dict,
image_subdirectory):
"""Convert XML derived dict to tf.Example proto.
Notice that this function normalizes the bounding box coordinates provided
by the raw data.
Args:
data: dict holding PASCAL XML fields for a single image (obtained by
running recursive_parse_xml_to_dict)
mask_path: String path to PNG encoded mask.
label_map_dict: A map from string label names to integers ids.
image_subdirectory: String specifying subdirectory within the
Pascal dataset directory holding the actual image data.
Returns:
example: The converted tf.Example.
Raises:
ValueError: if the image pointed to by data['filename'] is not a valid JPEG
"""
img_path = os.path.join(image_subdirectory, data['filename'])
with tf.gfile.GFile(img_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
key = hashlib.sha256(encoded_jpg).hexdigest()

width = int(data['size']['width'])
height = int(data['size']['height'])

classes = []
classes_text = []

if 'object' in data:
class_name = get_class_name_from_filename(data['filename'])
classes_text.append(class_name.encode('utf8'))
classes.append(label_map_dict[class_name])

feature_dict = {
'image/height': int64_feature(height),
'image/width': int64_feature(width),
'image/filename': bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': bytes_feature(key.encode('utf8')),
'image/encoded': bytes_feature(encoded_jpg),
'image/format': bytes_feature('jpeg'.encode('utf8')),
'image/object/class/text': bytes_list_feature(classes_text),
'image/object/class/label': int64_list_feature(classes)
}
example = tf.train.Example(
features=tf.train.Features(feature=feature_dict))
return example


def create_tf_record(output_filename,
num_shards,
label_map_dict,
annotations_dir,
image_dir,
examples):
"""Creates a TFRecord file from examples.
Args:
output_filename: Path to where output file is saved.
num_shards: Number of shards for output file.
label_map_dict: The label map dictionary.
annotations_dir: Directory where annotation files are stored.
image_dir: Directory where image files are stored.
examples: Examples to parse and save to tf record.
"""
with contextlib2.ExitStack() as tf_record_close_stack:
output_tfrecords = open_sharded_output_tfrecords(
tf_record_close_stack, output_filename, num_shards)
for idx, example in enumerate(examples):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(examples))
xml_path = os.path.join(annotations_dir, 'xmls', example + '.xml')
mask_path = os.path.join(
annotations_dir, 'trimaps', example + '.png')

if not os.path.exists(xml_path):
logging.warning(
'Could not find %s, ignoring example.', xml_path)
continue
with tf.gfile.GFile(xml_path, 'r') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = recursive_parse_xml_to_dict(xml)['annotation']

try:
tf_example = dict_to_tf_example(
data,
mask_path,
label_map_dict,
image_dir)
if tf_example:
shard_idx = idx % num_shards
output_tfrecords[shard_idx].write(
tf_example.SerializeToString())
except ValueError:
logging.warning('Invalid example: %s, ignoring.', xml_path)


def main(_):
data_dir = FLAGS.data_dir
label_map_dict = {'Abyssinian': 1, 'american_bulldog': 2, 'american_pit_bull_terrier': 3, 'basset_hound': 4, 'beagle': 5, 'Bengal': 6, 'Birman': 7, 'Bombay': 8, 'boxer': 9, 'British_Shorthair': 10, 'chihuahua': 11, 'Egyptian_Mau': 12, 'english_cocker_spaniel': 13, 'english_setter': 14, 'german_shorthaired': 15, 'great_pyrenees': 16, 'havanese': 17, 'japanese_chin': 18,
'keeshond': 19, 'leonberger': 20, 'Maine_Coon': 21, 'miniature_pinscher': 22, 'newfoundland': 23, 'Persian': 24, 'pomeranian': 25, 'pug': 26, 'Ragdoll': 27, 'Russian_Blue': 28, 'saint_bernard': 29, 'samoyed': 30, 'scottish_terrier': 31, 'shiba_inu': 32, 'Siamese': 33, 'Sphynx': 34, 'staffordshire_bull_terrier': 35, 'wheaten_terrier': 36, 'yorkshire_terrier': 37}
logging.info('Reading from Pet dataset.')
image_dir = os.path.join(data_dir, 'images')
annotations_dir = os.path.join(data_dir, 'annotations')
examples_path = os.path.join(annotations_dir, 'trainval.txt')
examples_list = read_examples_list(examples_path)

# Test images are not included in the downloaded data set, so we shall perform
# our own split.
random.seed(42)
random.shuffle(examples_list)
num_examples = len(examples_list)
num_train = int(0.7 * num_examples)
train_examples = examples_list[:num_train]
val_examples = examples_list[num_train:]
logging.info('%d training and %d validation examples.',
len(train_examples), len(val_examples))

train_output_path = os.path.join(
FLAGS.output_dir, 'pet_faces_train.record')
val_output_path = os.path.join(FLAGS.output_dir, 'pet_faces_val.record')
create_tf_record(
train_output_path,
FLAGS.num_shards,
label_map_dict,
annotations_dir,
image_dir,
train_examples)
create_tf_record(
val_output_path,
FLAGS.num_shards,
label_map_dict,
annotations_dir,
image_dir,
val_examples)


if __name__ == '__main__':
tf.app.run()
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ DATASET_URL="http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz"
GROUNDTRUTH_URL="http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz"

apt-get install wget
pip install contextlib2 pillow
printf "\nDownloading Oxford-IIIT-Pet dataset from $DATASET_URL..."
wget $DATASET_URL
printf "\nDownloading Oxford-IIIT-Pet ground truth from $GROUNDTRUTH_URL..."
Expand All @@ -12,11 +13,7 @@ printf "\nExtracting..."
tar xzvf images.tar.gz
tar xzvf annotations.tar.gz
mkdir tf_pets_records
git clone https://github.com/swetha097/rocALmodels.git
cd rocALmodels
git checkout TF_V2
cd ../
python3 rocALmodels/models/research/object_detection/dataset_tools/create_pet_tf_record.py --data_dir=./ --output_dir=tf_pets_records/ --label_map_path=rocALmodels/models/research/object_detection/data/pet_label_map.pbtxt
python3 create_pet_tf_record.py --data_dir=./ --output_dir=tf_pets_records/ --label_map_path=pet_label_map.pbtxt
cd tf_pets_records
mkdir train
mv pet_faces_train.record-0000* train
Expand Down
Loading

0 comments on commit bda170d

Please sign in to comment.