Skip to content

Commit

Permalink
Update inception_resnet_v2.py (tensorflow#1555)
Browse files Browse the repository at this point in the history
- Add inception_resnet_v2_base
- Provide option to use SAME padding for all inception resnet v2
  layers. This is to align feature maps sizes.
  • Loading branch information
derekjchow authored and sguada committed Jun 14, 2017
1 parent cb31aef commit 7e0016c
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 125 deletions.
1 change: 1 addition & 0 deletions slim/nets/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# pylint: disable=unused-import
from nets.inception_resnet_v2 import inception_resnet_v2
from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope
from nets.inception_resnet_v2 import inception_resnet_v2_base
from nets.inception_v1 import inception_v1
from nets.inception_v1 import inception_v1_arg_scope
from nets.inception_v1 import inception_v1_base
Expand Down
321 changes: 199 additions & 122 deletions slim/nets/inception_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,187 @@ def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
return net


def inception_resnet_v2_base(inputs,
final_endpoint='Conv2d_7b_1x1',
output_stride=16,
align_feature_maps=False,
scope=None):
"""Inception model from http://arxiv.org/abs/1602.07261.
Constructs an Inception Resnet v2 network from inputs to the given final
endpoint. This method can construct the network up to the final inception
block Conv2d_7b_1x1.
Args:
inputs: a tensor of size [batch_size, height, width, channels].
final_endpoint: specifies the endpoint to construct the network up to. It
can be one of ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3',
'Mixed_5b', 'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1']
output_stride: A scalar that specifies the requested ratio of input to
output spatial resolution. Only supports 8 and 16.
align_feature_maps: When true, changes all the VALID paddings in the network
to SAME padding so that the feature maps are aligned.
scope: Optional variable_scope.
Returns:
tensor_out: output tensor corresponding to the final_endpoint.
end_points: a set of activations for external use, for example summaries or
losses.
Raises:
ValueError: if final_endpoint is not set to one of the predefined values,
or if the output_stride is not 8 or 16, or if the output_stride is 8 and
we request an end point after 'PreAuxLogits'.
"""
if output_stride != 8 and output_stride != 16:
raise ValueError('output_stride must be 8 or 16.')

padding = 'SAME' if align_feature_maps else 'VALID'

end_points = {}

def add_and_check_final(name, net):
end_points[name] = net
return name == final_endpoint

with tf.variable_scope(scope, 'InceptionResnetV2', [inputs]):
with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
stride=1, padding='SAME'):
# 149 x 149 x 32
net = slim.conv2d(inputs, 32, 3, stride=2, padding=padding,
scope='Conv2d_1a_3x3')
if add_and_check_final('Conv2d_1a_3x3', net): return net, end_points

# 147 x 147 x 32
net = slim.conv2d(net, 32, 3, padding=padding,
scope='Conv2d_2a_3x3')
if add_and_check_final('Conv2d_2a_3x3', net): return net, end_points
# 147 x 147 x 64
net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
if add_and_check_final('Conv2d_2b_3x3', net): return net, end_points
# 73 x 73 x 64
net = slim.max_pool2d(net, 3, stride=2, padding=padding,
scope='MaxPool_3a_3x3')
if add_and_check_final('MaxPool_3a_3x3', net): return net, end_points
# 73 x 73 x 80
net = slim.conv2d(net, 80, 1, padding=padding,
scope='Conv2d_3b_1x1')
if add_and_check_final('Conv2d_3b_1x1', net): return net, end_points
# 71 x 71 x 192
net = slim.conv2d(net, 192, 3, padding=padding,
scope='Conv2d_4a_3x3')
if add_and_check_final('Conv2d_4a_3x3', net): return net, end_points
# 35 x 35 x 192
net = slim.max_pool2d(net, 3, stride=2, padding=padding,
scope='MaxPool_5a_3x3')
if add_and_check_final('MaxPool_5a_3x3', net): return net, end_points

# 35 x 35 x 320
with tf.variable_scope('Mixed_5b'):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
with tf.variable_scope('Branch_1'):
tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
scope='Conv2d_0b_5x5')
with tf.variable_scope('Branch_2'):
tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
scope='Conv2d_0b_3x3')
tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
scope='Conv2d_0c_3x3')
with tf.variable_scope('Branch_3'):
tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
scope='AvgPool_0a_3x3')
tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
scope='Conv2d_0b_1x1')
net = tf.concat(
[tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1], 3)

if add_and_check_final('Mixed_5b', net): return net, end_points
# TODO(alemi): Register intermediate endpoints
net = slim.repeat(net, 10, block35, scale=0.17)

# 17 x 17 x 1088 if output_stride == 8,
# 33 x 33 x 1088 if output_stride == 16
use_atrous = output_stride == 8

with tf.variable_scope('Mixed_6a'):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 384, 3, stride=1 if use_atrous else 2,
padding=padding,
scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_1'):
tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3,
scope='Conv2d_0b_3x3')
tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3,
stride=1 if use_atrous else 2,
padding=padding,
scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_2'):
tower_pool = slim.max_pool2d(net, 3, stride=1 if use_atrous else 2,
padding=padding,
scope='MaxPool_1a_3x3')
net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3)

if add_and_check_final('Mixed_6a', net): return net, end_points

# TODO(alemi): register intermediate endpoints
with slim.arg_scope([slim.conv2d], rate=2 if use_atrous else 1):
net = slim.repeat(net, 20, block17, scale=0.10)
if add_and_check_final('PreAuxLogits', net): return net, end_points

if output_stride == 8:
# TODO(gpapan): Properly support output_stride for the rest of the net.
raise ValueError('output_stride==8 is only supported up to the '
'PreAuxlogits end_point for now.')

# 8 x 8 x 2080
with tf.variable_scope('Mixed_7a'):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
padding=padding,
scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_1'):
tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2,
padding=padding,
scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_2'):
tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,
scope='Conv2d_0b_3x3')
tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2,
padding=padding,
scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_3'):
tower_pool = slim.max_pool2d(net, 3, stride=2,
padding=padding,
scope='MaxPool_1a_3x3')
net = tf.concat(
[tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool], 3)

if add_and_check_final('Mixed_7a', net): return net, end_points

# TODO(alemi): register intermediate endpoints
net = slim.repeat(net, 9, block8, scale=0.20)
net = block8(net, activation_fn=None)

# 8 x 8 x 1536
net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1')
if add_and_check_final('Conv2d_7b_1x1', net): return net, end_points

raise ValueError('final_endpoint (%s) not recognized', final_endpoint)


def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
dropout_keep_prob=0.8,
reuse=None,
scope='InceptionResnetV2'):
scope='InceptionResnetV2',
create_aux_logits=True):
"""Creates the Inception Resnet V2 model.
Args:
Expand All @@ -105,95 +282,25 @@ def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
create_aux_logits: Whether to include the auxilliary logits.
Returns:
logits: the logits outputs of the model.
end_points: the set of end_points from the inception model.
"""
end_points = {}

with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse):
with tf.variable_scope(scope, 'InceptionResnetV2', [inputs, num_classes],
reuse=reuse) as scope:
with slim.arg_scope([slim.batch_norm, slim.dropout],
is_training=is_training):
with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
stride=1, padding='SAME'):

# 149 x 149 x 32
net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID',
scope='Conv2d_1a_3x3')
end_points['Conv2d_1a_3x3'] = net
# 147 x 147 x 32
net = slim.conv2d(net, 32, 3, padding='VALID',
scope='Conv2d_2a_3x3')
end_points['Conv2d_2a_3x3'] = net
# 147 x 147 x 64
net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
end_points['Conv2d_2b_3x3'] = net
# 73 x 73 x 64
net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
scope='MaxPool_3a_3x3')
end_points['MaxPool_3a_3x3'] = net
# 73 x 73 x 80
net = slim.conv2d(net, 80, 1, padding='VALID',
scope='Conv2d_3b_1x1')
end_points['Conv2d_3b_1x1'] = net
# 71 x 71 x 192
net = slim.conv2d(net, 192, 3, padding='VALID',
scope='Conv2d_4a_3x3')
end_points['Conv2d_4a_3x3'] = net
# 35 x 35 x 192
net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
scope='MaxPool_5a_3x3')
end_points['MaxPool_5a_3x3'] = net

# 35 x 35 x 320
with tf.variable_scope('Mixed_5b'):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
with tf.variable_scope('Branch_1'):
tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
scope='Conv2d_0b_5x5')
with tf.variable_scope('Branch_2'):
tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
scope='Conv2d_0b_3x3')
tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
scope='Conv2d_0c_3x3')
with tf.variable_scope('Branch_3'):
tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
scope='AvgPool_0a_3x3')
tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
scope='Conv2d_0b_1x1')
net = tf.concat(axis=3, values=[tower_conv, tower_conv1_1,
tower_conv2_2, tower_pool_1])

end_points['Mixed_5b'] = net
net = slim.repeat(net, 10, block35, scale=0.17)

# 17 x 17 x 1088
with tf.variable_scope('Mixed_6a'):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID',
scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_1'):
tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3,
scope='Conv2d_0b_3x3')
tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3,
stride=2, padding='VALID',
scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_2'):
tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
scope='MaxPool_1a_3x3')
net = tf.concat(axis=3, values=[tower_conv, tower_conv1_2, tower_pool])

end_points['Mixed_6a'] = net
net = slim.repeat(net, 20, block17, scale=0.10)

# Auxiliary tower
net, end_points = inception_resnet_v2_base(inputs, scope=scope)

if create_aux_logits:
with tf.variable_scope('AuxLogits'):
aux = slim.avg_pool2d(net, 5, stride=3, padding='VALID',
aux = end_points['PreAuxLogits']
aux = slim.avg_pool2d(aux, 5, stride=3, padding='VALID',
scope='Conv2d_1a_3x3')
aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1')
aux = slim.conv2d(aux, 768, aux.get_shape()[1:3],
Expand All @@ -203,49 +310,19 @@ def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
scope='Logits')
end_points['AuxLogits'] = aux

with tf.variable_scope('Mixed_7a'):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
padding='VALID', scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_1'):
tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2,
padding='VALID', scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_2'):
tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,
scope='Conv2d_0b_3x3')
tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2,
padding='VALID', scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_3'):
tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
scope='MaxPool_1a_3x3')
net = tf.concat(axis=3, values=[tower_conv_1, tower_conv1_1,
tower_conv2_2, tower_pool])

end_points['Mixed_7a'] = net

net = slim.repeat(net, 9, block8, scale=0.20)
net = block8(net, activation_fn=None)

net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1')
end_points['Conv2d_7b_1x1'] = net

with tf.variable_scope('Logits'):
end_points['PrePool'] = net
net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
scope='AvgPool_1a_8x8')
net = slim.flatten(net)

net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
scope='Dropout')

end_points['PreLogitsFlatten'] = net
logits = slim.fully_connected(net, num_classes, activation_fn=None,
scope='Logits')
end_points['Logits'] = logits
end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')
with tf.variable_scope('Logits'):
net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
scope='AvgPool_1a_8x8')
net = slim.flatten(net)

net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
scope='Dropout')

end_points['PreLogitsFlatten'] = net
logits = slim.fully_connected(net, num_classes, activation_fn=None,
scope='Logits')
end_points['Logits'] = logits
end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')

return logits, end_points
inception_resnet_v2.default_image_size = 299
Expand Down
Loading

0 comments on commit 7e0016c

Please sign in to comment.