Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modification of the definition of networks #18

Open
wants to merge 63 commits into
base: doc_refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
28ad980
Update installation packages and installation instructions
SeguinBe Oct 6, 2018
957cd58
Revamp of the network description and architecture in a more flexible…
SeguinBe Oct 12, 2018
ff1edd8
Removing useless files
SeguinBe Oct 12, 2018
a9e0ed7
dh_segment_train as a script
SeguinBe Oct 12, 2018
e0d6c5d
Correcting the deletion of the main script, oops...
SeguinBe Oct 12, 2018
cb1d8fc
Nicer labels for the progress bars
SeguinBe Oct 12, 2018
da1258a
Nicer handling of number of threads
SeguinBe Oct 12, 2018
4de57fe
Removing code which has been made useless
SeguinBe Oct 12, 2018
7e5ccb4
mainly docstring formatting
solivr Oct 22, 2018
ce214c2
changed :param: by :ivar:
solivr Oct 22, 2018
62ec71d
Updating batchnorm training
SeguinBe Oct 26, 2018
cace550
Added MobileNetV2
SeguinBe Oct 26, 2018
ea11126
Documentation of exported model
SeguinBe Oct 29, 2018
82a5f22
Fixed refactoring
Oct 30, 2018
91540f2
Merge pull request #19 from sriak/master
solivr Oct 30, 2018
9889c7d
updated demo
solivr Nov 1, 2018
4e00913
pip install
solivr Nov 2, 2018
4f177b1
typo in attribute
solivr Nov 14, 2018
932fa3c
corrected non exported segment_ids field
solivr Nov 14, 2018
c5a1965
sorting of TextLines in a TextRegion
solivr Nov 15, 2018
346e2fb
force type to be int (for JSON export compatibility)
solivr Nov 20, 2018
7c25b56
specific to int32 and int64 type
solivr Nov 20, 2018
3eefba8
input csv file
solivr Dec 4, 2018
455a8e9
via annotation processing
e-maud Dec 11, 2018
811af9c
via annotation processing - typo
e-maud Dec 11, 2018
48efe87
type correction
solivr Dec 11, 2018
f736aaa
added doc
solivr Dec 12, 2018
7f65ad4
updated doc
solivr Jan 17, 2019
4509bc5
updated installation doc
solivr Jan 18, 2019
e61079f
packages versions
solivr Jan 18, 2019
db46c35
detected contour should have at least 3 points
solivr Jan 21, 2019
7c53e27
LatestExporter if no eval data is provided
solivr Jan 24, 2019
e07f996
update
solivr Dec 14, 2018
b090906
contour option in mask creation
solivr Jan 24, 2019
ba92f50
export regions coordinates to VIA compatible format
solivr Jan 30, 2019
fbb9350
doc and typos
solivr Feb 5, 2019
600acaa
simlified via.py and updated doc
solivr Feb 11, 2019
665af99
doc formatting
solivr Feb 11, 2019
84ec4dd
parse attributes of TextRegion and TextLines 'custom' and 'type'
solivr Dec 4, 2018
77bb4f3
remove git repo dependency
solivr Feb 11, 2019
532131a
merging
solivr Feb 11, 2019
909e8b1
corrected wrong argument names
solivr Feb 13, 2019
6717332
wrong variable name
solivr Feb 13, 2019
704087a
via example and doc formatting
solivr Feb 12, 2019
04ce8b6
Correcting typo masks creation script
alix-tz Feb 20, 2019
2264cf1
Merge pull request #26 from alix-tz/patch-1
solivr Feb 21, 2019
1262b59
Fixing instruction
alix-tz Feb 26, 2019
12d2759
Merge pull request #27 from alix-tz/patch-2
solivr Feb 28, 2019
6fdfcbd
do not export attribute 'type' if it's empty
solivr Mar 7, 2019
8fbd882
array to list of Point method
solivr Feb 25, 2019
2af56f2
update parsing + get list of tags from xml
solivr Mar 12, 2019
7100855
merge from master
SeguinBe Mar 22, 2019
8deae44
miou metric
solivr Mar 8, 2019
540eb36
to_json method for Page class
solivr Apr 4, 2019
605a930
updated via helpers
solivr Apr 9, 2019
6456a69
update packages version
solivr Apr 9, 2019
a072442
update to opencv 4.0
solivr Apr 9, 2019
fbad361
changelog
solivr Apr 9, 2019
9de5ca7
fix tensorflow-gpu version
solivr Apr 10, 2019
875c547
fixes #37
solivr May 15, 2019
7f2a348
merge
SeguinBe May 22, 2019
de461a7
working version corrected
SeguinBe May 22, 2019
1b36fca
formatting
solivr Jul 26, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 14 additions & 36 deletions dh_segment/estimator_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .utils import PredictionType, ModelParams, TrainingParams, \
class_to_label_image, multiclass_to_label_image
import numpy as np
from .network.model import inference_resnet_v1_50, inference_vgg16, inference_u_net


def model_fn(mode, features, labels, params):
Expand All @@ -18,45 +17,22 @@ def model_fn(mode, features, labels, params):
input_images = tf.pad(input_images, [[0, 0], [margin, margin], [margin, margin], [0, 0]],
mode='SYMMETRIC', name='mirror_padding')

if model_params.pretrained_model_name == 'vgg16':
network_output = inference_vgg16(input_images,
model_params,
model_params.n_classes,
use_batch_norm=model_params.batch_norm,
weight_decay=model_params.weight_decay,
is_training=(mode == tf.estimator.ModeKeys.TRAIN)
)
key_restore_model = 'vgg_16'
encoder_class = model_params.get_encoder()
encoder = encoder_class(**model_params.encoder_params)
decoder_class = model_params.get_decoder()
decoder = decoder_class(**model_params.decoder_params)

elif model_params.pretrained_model_name == 'resnet50':
network_output = inference_resnet_v1_50(input_images,
model_params,
model_params.n_classes,
use_batch_norm=model_params.batch_norm,
weight_decay=model_params.weight_decay,
is_training=(mode == tf.estimator.ModeKeys.TRAIN)
)
key_restore_model = 'resnet_v1_50'
elif model_params.pretrained_model_name == 'unet':
network_output = inference_u_net(input_images,
model_params,
model_params.n_classes,
use_batch_norm=model_params.batch_norm,
weight_decay=model_params.weight_decay,
is_training=(mode == tf.estimator.ModeKeys.TRAIN)
)
key_restore_model = None
else:
raise NotImplementedError
feature_maps = encoder(input_images)
network_output = decoder(feature_maps, num_classes=model_params.n_classes)

if mode == tf.estimator.ModeKeys.TRAIN:
if key_restore_model is not None:
pretrained_file, pretrained_vars = encoder.pretrained_information()
if pretrained_file:
# Pretrained weights as initialization
pretrained_restorer = tf.train.Saver(var_list=[v for v in tf.global_variables()
if key_restore_model in v.name])
pretrained_restorer = tf.train.Saver(var_list=pretrained_vars)

def init_fn(scaffold, session):
pretrained_restorer.restore(session, model_params.pretrained_model_file)
pretrained_restorer.restore(session, pretrained_file)
else:
init_fn = None
else:
Expand Down Expand Up @@ -92,8 +68,10 @@ def init_fn(scaffold, session):
if prediction_type == PredictionType.CLASSIFICATION:
onehot_labels = tf.one_hot(indices=labels, depth=model_params.n_classes)
with tf.name_scope("loss"):
per_pixel_loss = tf.nn.softmax_cross_entropy_with_logits(logits=network_output,
labels=onehot_labels, name='per_pixel_loss')
#per_pixel_loss = tf.nn.softmax_cross_entropy_with_logits(logits=network_output,
# labels=onehot_labels, name='per_pixel_loss')
per_pixel_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=network_output,
labels=onehot_labels, name='per_pixel_loss')
if training_params.focal_loss_gamma > 0.0:
# Probability per pixel of getting the correct label
probs_correct_label = tf.reduce_max(tf.multiply(prediction_probs, onehot_labels))
Expand Down
27 changes: 16 additions & 11 deletions dh_segment/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class InputCase(Enum):
INPUT_CSV = 'INPUT_CSV'


def input_fn(input_data: Union[str, List[str]], params: dict, input_label_dir: str=None,
data_augmentation: bool=False, batch_size: int=5, make_patches: bool=False, num_epochs: int=1,
num_threads: int=4, image_summaries: bool=False):
def input_fn(input_data: Union[str, List[str]], params: dict, input_label_dir: str = None,
data_augmentation: bool = False, batch_size: int = 5, make_patches: bool = False, num_epochs: int = 1,
num_threads: int = 4, image_summaries: bool = False, progressbar_description: str = 'Dataset'):
"""
Input_fn for estimator

Expand All @@ -33,6 +33,7 @@ def input_fn(input_data: Union[str, List[str]], params: dict, input_label_dir: s
:param num_epochs: number of epochs to cycle trough data (set it to None for infinite repeat)
:param num_threads: number of thread to use in parallele when usin tf.data.Dataset.map
:param image_summaries: boolean, whether to make tf.Summary to watch on tensorboard
:param progressbar_description: what will appear in the progressbar showing the number of files read
:return: fn
"""
training_params = utils.TrainingParams.from_dict(params['training_params'])
Expand Down Expand Up @@ -96,8 +97,9 @@ def _scaling_and_patch_fn(input_image, label_image):

# Data augmentation
def _augment_data_fn(input_image, label_image): \
return data_augmentation_fn(input_image, label_image, training_params.data_augmentation_flip_lr,
training_params.data_augmentation_flip_ud, training_params.data_augmentation_color)
return data_augmentation_fn(input_image, label_image, training_params.data_augmentation_flip_lr,
training_params.data_augmentation_flip_ud,
training_params.data_augmentation_color)

# Assign color to class id
def _assign_color_to_class_id(input_image, label_image):
Expand All @@ -112,27 +114,29 @@ def _assign_color_to_class_id(input_image, label_image):
output['weight_maps'] = local_entropy(tf.equal(label_image, 1),
sigma=training_params.local_entropy_sigma)
return output

# ---

# Finding the list of images to be used
if isinstance(input_data, list):
input_case = InputCase.INPUT_LIST
input_image_filenames = input_data
print('Found {} images'.format(len(input_image_filenames)))
#print('Found {} images'.format(len(input_image_filenames)))

elif os.path.isdir(input_data):
input_case = InputCase.INPUT_DIR
input_image_filenames = glob(os.path.join(input_data, '**', '*.jpg'),
recursive=True) + \
glob(os.path.join(input_data, '**', '*.png'),
recursive=True)
print('Found {} images'.format(len(input_image_filenames)))
#print('Found {} images'.format(len(input_image_filenames)))

elif os.path.isfile(input_data) and \
input_data.endswith('.csv'):
input_case = InputCase.INPUT_CSV
else:
raise NotImplementedError('Input data should be a directory, a csv file or a list of filenames but got {}'.format(input_data))
raise NotImplementedError(
'Input data should be a directory, a csv file or a list of filenames but got {}'.format(input_data))

# Finding the list of labelled images if available
has_labelled_data = False
Expand Down Expand Up @@ -169,15 +173,16 @@ def _assign_color_to_class_id(input_image, label_image):
def fn():
if not has_labelled_data:
encoded_filenames = [f.encode() for f in input_image_filenames]
dataset = tf.data.Dataset.from_generator(lambda: tqdm(encoded_filenames, desc='Dataset'),
dataset = tf.data.Dataset.from_generator(lambda: tqdm(encoded_filenames, desc=progressbar_description),
tf.string, tf.TensorShape([]))
dataset = dataset.repeat(count=num_epochs)
dataset = dataset.map(lambda filename: {'images': load_and_resize_image(filename, 3,
training_params.input_resized_size)})
else:
encoded_filenames = [(i.encode(), l.encode()) for i, l in zip(input_image_filenames, label_image_filenames)]
dataset = tf.data.Dataset.from_generator(lambda: tqdm(utils.shuffled(encoded_filenames), desc='Dataset'),
(tf.string, tf.string), (tf.TensorShape([]), tf.TensorShape([])))
dataset = tf.data.Dataset.from_generator(lambda: tqdm(utils.shuffled(encoded_filenames),
desc=progressbar_description),
(tf.string, tf.string), (tf.TensorShape([]), tf.TensorShape([])))

dataset = dataset.repeat(count=num_epochs)
dataset = dataset.map(_load_image_fn, num_threads).flat_map(_scaling_and_patch_fn)
Expand Down
Loading