Skip to content

Commit

Permalink
Update inception model based on tf API changes: replace tf.op_scope w…
Browse files Browse the repository at this point in the history
…ith tf.name_scope and tf.variable_op_scope with tf.variable_scope; fix the order of arguments for tf.concat; replace tf.mul with tf.multiply.
  • Loading branch information
lilao committed Jan 20, 2017
1 parent e9e470d commit e5079c8
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 73 deletions.
4 changes: 2 additions & 2 deletions inception/inception/inception_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def _activation_summary(x):
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
# session. This helps the clarity of presentation on tensorboard.
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
tf.histogram_summary(tensor_name + '/activations', x)
tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
tf.contrib.deprecated.histogram_summary(tensor_name + '/activations', x)
tf.contrib.deprecated.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x))


def _activation_summaries(endpoints):
Expand Down
3 changes: 0 additions & 3 deletions inception/inception/slim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,8 @@ number. More concretely, the scopes in the example above would be 'conv3_1',

In addition to the types of scope mechanisms in TensorFlow ([name_scope]
(https://www.tensorflow.org/api_docs/python/framework.html#name_scope),
[op_scope](https://www.tensorflow.org/api_docs/python/framework.html#op_scope),
[variable_scope]
(https://www.tensorflow.org/api_docs/python/state_ops.html#variable_scope),
[variable_op_scope]
(https://www.tensorflow.org/api_docs/python/state_ops.html#variable_op_scope)),
TF-Slim adds a new scoping mechanism called "argument scope" or [arg_scope]
(scopes.py). This new scope allows a user to specify one or more operations and
a set of arguments which will be passed to each of the operations defined in the
Expand Down
42 changes: 21 additions & 21 deletions inception/inception/slim/inception_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ def inception_v3(inputs,
is_training: whether is training or not.
restore_logits: whether or not the logits layers should be restored.
Useful for fine-tuning a model with different num_classes.
scope: Optional scope for op_scope.
scope: Optional scope for name_scope.
Returns:
a list containing 'logits', 'aux_logits' Tensors.
"""
# end_points will collect relevant activations for external use, for example
# summaries or losses.
end_points = {}
with tf.op_scope([inputs], scope, 'inception_v3'):
with tf.name_scope(scope, 'inception_v3', [inputs]):
with scopes.arg_scope([ops.conv2d, ops.fc, ops.batch_norm, ops.dropout],
is_training=is_training):
with scopes.arg_scope([ops.conv2d, ops.max_pool, ops.avg_pool],
Expand Down Expand Up @@ -122,7 +122,7 @@ def inception_v3(inputs,
with tf.variable_scope('branch_pool'):
branch_pool = ops.avg_pool(net, [3, 3])
branch_pool = ops.conv2d(branch_pool, 32, [1, 1])
net = tf.concat(3, [branch1x1, branch5x5, branch3x3dbl, branch_pool])
net = tf.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], 3)
end_points['mixed_35x35x256a'] = net
# mixed_1: 35 x 35 x 288.
with tf.variable_scope('mixed_35x35x288a'):
Expand All @@ -138,7 +138,7 @@ def inception_v3(inputs,
with tf.variable_scope('branch_pool'):
branch_pool = ops.avg_pool(net, [3, 3])
branch_pool = ops.conv2d(branch_pool, 64, [1, 1])
net = tf.concat(3, [branch1x1, branch5x5, branch3x3dbl, branch_pool])
net = tf.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], 3)
end_points['mixed_35x35x288a'] = net
# mixed_2: 35 x 35 x 288.
with tf.variable_scope('mixed_35x35x288b'):
Expand All @@ -154,7 +154,7 @@ def inception_v3(inputs,
with tf.variable_scope('branch_pool'):
branch_pool = ops.avg_pool(net, [3, 3])
branch_pool = ops.conv2d(branch_pool, 64, [1, 1])
net = tf.concat(3, [branch1x1, branch5x5, branch3x3dbl, branch_pool])
net = tf.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], 3)
end_points['mixed_35x35x288b'] = net
# mixed_3: 17 x 17 x 768.
with tf.variable_scope('mixed_17x17x768a'):
Expand All @@ -167,7 +167,7 @@ def inception_v3(inputs,
stride=2, padding='VALID')
with tf.variable_scope('branch_pool'):
branch_pool = ops.max_pool(net, [3, 3], stride=2, padding='VALID')
net = tf.concat(3, [branch3x3, branch3x3dbl, branch_pool])
net = tf.concat([branch3x3, branch3x3dbl, branch_pool], 3)
end_points['mixed_17x17x768a'] = net
# mixed4: 17 x 17 x 768.
with tf.variable_scope('mixed_17x17x768b'):
Expand All @@ -186,7 +186,7 @@ def inception_v3(inputs,
with tf.variable_scope('branch_pool'):
branch_pool = ops.avg_pool(net, [3, 3])
branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
net = tf.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], 3)
end_points['mixed_17x17x768b'] = net
# mixed_5: 17 x 17 x 768.
with tf.variable_scope('mixed_17x17x768c'):
Expand All @@ -205,7 +205,7 @@ def inception_v3(inputs,
with tf.variable_scope('branch_pool'):
branch_pool = ops.avg_pool(net, [3, 3])
branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
net = tf.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], 3)
end_points['mixed_17x17x768c'] = net
# mixed_6: 17 x 17 x 768.
with tf.variable_scope('mixed_17x17x768d'):
Expand All @@ -224,7 +224,7 @@ def inception_v3(inputs,
with tf.variable_scope('branch_pool'):
branch_pool = ops.avg_pool(net, [3, 3])
branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
net = tf.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], 3)
end_points['mixed_17x17x768d'] = net
# mixed_7: 17 x 17 x 768.
with tf.variable_scope('mixed_17x17x768e'):
Expand All @@ -243,7 +243,7 @@ def inception_v3(inputs,
with tf.variable_scope('branch_pool'):
branch_pool = ops.avg_pool(net, [3, 3])
branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
net = tf.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], 3)
end_points['mixed_17x17x768e'] = net
# Auxiliary Head logits
aux_logits = tf.identity(end_points['mixed_17x17x768e'])
Expand Down Expand Up @@ -276,43 +276,43 @@ def inception_v3(inputs,
stride=2, padding='VALID')
with tf.variable_scope('branch_pool'):
branch_pool = ops.max_pool(net, [3, 3], stride=2, padding='VALID')
net = tf.concat(3, [branch3x3, branch7x7x3, branch_pool])
net = tf.concat([branch3x3, branch7x7x3, branch_pool], 3)
end_points['mixed_17x17x1280a'] = net
# mixed_9: 8 x 8 x 2048.
with tf.variable_scope('mixed_8x8x2048a'):
with tf.variable_scope('branch1x1'):
branch1x1 = ops.conv2d(net, 320, [1, 1])
with tf.variable_scope('branch3x3'):
branch3x3 = ops.conv2d(net, 384, [1, 1])
branch3x3 = tf.concat(3, [ops.conv2d(branch3x3, 384, [1, 3]),
ops.conv2d(branch3x3, 384, [3, 1])])
branch3x3 = tf.concat([ops.conv2d(branch3x3, 384, [1, 3]),
ops.conv2d(branch3x3, 384, [3, 1])], 3)
with tf.variable_scope('branch3x3dbl'):
branch3x3dbl = ops.conv2d(net, 448, [1, 1])
branch3x3dbl = ops.conv2d(branch3x3dbl, 384, [3, 3])
branch3x3dbl = tf.concat(3, [ops.conv2d(branch3x3dbl, 384, [1, 3]),
ops.conv2d(branch3x3dbl, 384, [3, 1])])
branch3x3dbl = tf.concat([ops.conv2d(branch3x3dbl, 384, [1, 3]),
ops.conv2d(branch3x3dbl, 384, [3, 1])], 3)
with tf.variable_scope('branch_pool'):
branch_pool = ops.avg_pool(net, [3, 3])
branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
net = tf.concat(3, [branch1x1, branch3x3, branch3x3dbl, branch_pool])
net = tf.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], 3)
end_points['mixed_8x8x2048a'] = net
# mixed_10: 8 x 8 x 2048.
with tf.variable_scope('mixed_8x8x2048b'):
with tf.variable_scope('branch1x1'):
branch1x1 = ops.conv2d(net, 320, [1, 1])
with tf.variable_scope('branch3x3'):
branch3x3 = ops.conv2d(net, 384, [1, 1])
branch3x3 = tf.concat(3, [ops.conv2d(branch3x3, 384, [1, 3]),
ops.conv2d(branch3x3, 384, [3, 1])])
branch3x3 = tf.concat([ops.conv2d(branch3x3, 384, [1, 3]),
ops.conv2d(branch3x3, 384, [3, 1])], 3)
with tf.variable_scope('branch3x3dbl'):
branch3x3dbl = ops.conv2d(net, 448, [1, 1])
branch3x3dbl = ops.conv2d(branch3x3dbl, 384, [3, 3])
branch3x3dbl = tf.concat(3, [ops.conv2d(branch3x3dbl, 384, [1, 3]),
ops.conv2d(branch3x3dbl, 384, [3, 1])])
branch3x3dbl = tf.concat([ops.conv2d(branch3x3dbl, 384, [1, 3]),
ops.conv2d(branch3x3dbl, 384, [3, 1])], 3)
with tf.variable_scope('branch_pool'):
branch_pool = ops.avg_pool(net, [3, 3])
branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
net = tf.concat(3, [branch1x1, branch3x3, branch3x3dbl, branch_pool])
net = tf.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], 3)
end_points['mixed_8x8x2048b'] = net
# Final pooling and prediction
with tf.variable_scope('logits'):
Expand Down
4 changes: 2 additions & 2 deletions inception/inception/slim/inception_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def testVariablesSetDevice(self):
inception.inception_v3(inputs, num_classes)
with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
inception.inception_v3(inputs, num_classes)
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
self.assertDeviceEqual(v.device, '/cpu:0')
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
self.assertDeviceEqual(v.device, '/gpu:0')

def testHalfSizeImages(self):
Expand Down
44 changes: 22 additions & 22 deletions inception/inception/slim/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ def l1_regularizer(weight=1.0, scope=None):
Args:
weight: scale the loss by this factor.
scope: Optional scope for op_scope.
scope: Optional scope for name_scope.
Returns:
a regularizer function.
"""
def regularizer(tensor):
with tf.op_scope([tensor], scope, 'L1Regularizer'):
with tf.name_scope(scope, 'L1Regularizer', [tensor]):
l1_weight = tf.convert_to_tensor(weight,
dtype=tensor.dtype.base_dtype,
name='weight')
return tf.mul(l1_weight, tf.reduce_sum(tf.abs(tensor)), name='value')
return tf.multiply(l1_weight, tf.reduce_sum(tf.abs(tensor)), name='value')
return regularizer


Expand All @@ -58,17 +58,17 @@ def l2_regularizer(weight=1.0, scope=None):
Args:
weight: scale the loss by this factor.
scope: Optional scope for op_scope.
scope: Optional scope for name_scope.
Returns:
a regularizer function.
"""
def regularizer(tensor):
with tf.op_scope([tensor], scope, 'L2Regularizer'):
with tf.name_scope(scope, 'L2Regularizer', [tensor]):
l2_weight = tf.convert_to_tensor(weight,
dtype=tensor.dtype.base_dtype,
name='weight')
return tf.mul(l2_weight, tf.nn.l2_loss(tensor), name='value')
return tf.multiply(l2_weight, tf.nn.l2_loss(tensor), name='value')
return regularizer


Expand All @@ -78,22 +78,22 @@ def l1_l2_regularizer(weight_l1=1.0, weight_l2=1.0, scope=None):
Args:
weight_l1: scale the L1 loss by this factor.
weight_l2: scale the L2 loss by this factor.
scope: Optional scope for op_scope.
scope: Optional scope for name_scope.
Returns:
a regularizer function.
"""
def regularizer(tensor):
with tf.op_scope([tensor], scope, 'L1L2Regularizer'):
with tf.name_scope(scope, 'L1L2Regularizer', [tensor]):
weight_l1_t = tf.convert_to_tensor(weight_l1,
dtype=tensor.dtype.base_dtype,
name='weight_l1')
weight_l2_t = tf.convert_to_tensor(weight_l2,
dtype=tensor.dtype.base_dtype,
name='weight_l2')
reg_l1 = tf.mul(weight_l1_t, tf.reduce_sum(tf.abs(tensor)),
reg_l1 = tf.multiply(weight_l1_t, tf.reduce_sum(tf.abs(tensor)),
name='value_l1')
reg_l2 = tf.mul(weight_l2_t, tf.nn.l2_loss(tensor),
reg_l2 = tf.multiply(weight_l2_t, tf.nn.l2_loss(tensor),
name='value_l2')
return tf.add(reg_l1, reg_l2, name='value')
return regularizer
Expand All @@ -105,16 +105,16 @@ def l1_loss(tensor, weight=1.0, scope=None):
Args:
tensor: tensor to regularize.
weight: scale the loss by this factor.
scope: Optional scope for op_scope.
scope: Optional scope for name_scope.
Returns:
the L1 loss op.
"""
with tf.op_scope([tensor], scope, 'L1Loss'):
with tf.name_scope(scope, 'L1Loss', [tensor]):
weight = tf.convert_to_tensor(weight,
dtype=tensor.dtype.base_dtype,
name='loss_weight')
loss = tf.mul(weight, tf.reduce_sum(tf.abs(tensor)), name='value')
loss = tf.multiply(weight, tf.reduce_sum(tf.abs(tensor)), name='value')
tf.add_to_collection(LOSSES_COLLECTION, loss)
return loss

Expand All @@ -125,16 +125,16 @@ def l2_loss(tensor, weight=1.0, scope=None):
Args:
tensor: tensor to regularize.
weight: an optional weight to modulate the loss.
scope: Optional scope for op_scope.
scope: Optional scope for name_scope.
Returns:
the L2 loss op.
"""
with tf.op_scope([tensor], scope, 'L2Loss'):
with tf.name_scope(scope, 'L2Loss', [tensor]):
weight = tf.convert_to_tensor(weight,
dtype=tensor.dtype.base_dtype,
name='loss_weight')
loss = tf.mul(weight, tf.nn.l2_loss(tensor), name='value')
loss = tf.multiply(weight, tf.nn.l2_loss(tensor), name='value')
tf.add_to_collection(LOSSES_COLLECTION, loss)
return loss

Expand All @@ -150,25 +150,25 @@ def cross_entropy_loss(logits, one_hot_labels, label_smoothing=0,
one_hot_labels: [batch_size, num_classes] target one_hot_encoded labels.
label_smoothing: if greater than 0 then smooth the labels.
weight: scale the loss by this factor.
scope: Optional scope for op_scope.
scope: Optional scope for name_scope.
Returns:
A tensor with the softmax_cross_entropy loss.
"""
logits.get_shape().assert_is_compatible_with(one_hot_labels.get_shape())
with tf.op_scope([logits, one_hot_labels], scope, 'CrossEntropyLoss'):
with tf.name_scope(scope, 'CrossEntropyLoss', [logits, one_hot_labels]):
num_classes = one_hot_labels.get_shape()[-1].value
one_hot_labels = tf.cast(one_hot_labels, logits.dtype)
if label_smoothing > 0:
smooth_positives = 1.0 - label_smoothing
smooth_negatives = label_smoothing / num_classes
one_hot_labels = one_hot_labels * smooth_positives + smooth_negatives
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
labels=one_hot_labels,
name='xentropy')
cross_entropy = tf.contrib.nn.deprecated_flipped_softmax_cross_entropy_with_logits(
logits, one_hot_labels, name='xentropy')

weight = tf.convert_to_tensor(weight,
dtype=logits.dtype.base_dtype,
name='loss_weight')
loss = tf.mul(weight, tf.reduce_mean(cross_entropy), name='value')
loss = tf.multiply(weight, tf.reduce_mean(cross_entropy), name='value')
tf.add_to_collection(LOSSES_COLLECTION, loss)
return loss
Loading

0 comments on commit e5079c8

Please sign in to comment.