Skip to content

Commit

Permalink
Merge pull request tensorflow#1447 from ahundt/resnet_fix
Browse files Browse the repository at this point in the history
resnet_v1 and v2 segmentation bugs from 7e2435e resolved
  • Loading branch information
nealwu authored May 6, 2017
2 parents 68dab50 + cb4a485 commit c029806
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
18 changes: 14 additions & 4 deletions slim/nets/resnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def resnet_v1(inputs,
normalizer_fn=None, scope='logits')
if spatial_squeeze:
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
else:
logits = net
# Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if num_classes is not None:
Expand All @@ -215,6 +217,7 @@ def resnet_v1_50(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v1_50'):
"""ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
Expand All @@ -230,7 +233,8 @@ def resnet_v1_50(inputs,
]
return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope)
include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v1_50.default_image_size = resnet_v1.default_image_size


Expand All @@ -239,6 +243,7 @@ def resnet_v1_101(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v1_101'):
"""ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
Expand All @@ -254,7 +259,8 @@ def resnet_v1_101(inputs,
]
return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope)
include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v1_101.default_image_size = resnet_v1.default_image_size


Expand All @@ -263,6 +269,7 @@ def resnet_v1_152(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v1_152'):
"""ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
Expand All @@ -277,7 +284,8 @@ def resnet_v1_152(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope)
include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v1_152.default_image_size = resnet_v1.default_image_size


Expand All @@ -286,6 +294,7 @@ def resnet_v1_200(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v1_200'):
"""ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
Expand All @@ -300,5 +309,6 @@ def resnet_v1_200(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v1(inputs, blocks, num_classes, is_training,
global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope)
include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v1_200.default_image_size = resnet_v1.default_image_size
18 changes: 14 additions & 4 deletions slim/nets/resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def resnet_v2(inputs,
normalizer_fn=None, scope='logits')
if spatial_squeeze:
logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
else:
logits = net
# Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if num_classes is not None:
Expand All @@ -224,6 +226,7 @@ def resnet_v2_50(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v2_50'):
"""ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
Expand All @@ -238,7 +241,8 @@ def resnet_v2_50(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope)
include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v2_50.default_image_size = resnet_v2.default_image_size


Expand All @@ -247,6 +251,7 @@ def resnet_v2_101(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v2_101'):
"""ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
Expand All @@ -261,7 +266,8 @@ def resnet_v2_101(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope)
include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v2_101.default_image_size = resnet_v2.default_image_size


Expand All @@ -270,6 +276,7 @@ def resnet_v2_152(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v2_152'):
"""ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
Expand All @@ -284,7 +291,8 @@ def resnet_v2_152(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope)
include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v2_152.default_image_size = resnet_v2.default_image_size


Expand All @@ -293,6 +301,7 @@ def resnet_v2_200(inputs,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
reuse=None,
scope='resnet_v2_200'):
"""ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
Expand All @@ -307,5 +316,6 @@ def resnet_v2_200(inputs,
'block4', bottleneck, [(2048, 512, 1)] * 3)]
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
global_pool=global_pool, output_stride=output_stride,
include_root_block=True, reuse=reuse, scope=scope)
include_root_block=True, spatial_squeeze=spatial_squeeze,
reuse=reuse, scope=scope)
resnet_v2_200.default_image_size = resnet_v2.default_image_size

0 comments on commit c029806

Please sign in to comment.