Skip to content

Commit

Permalink
iput pipeline on CPU. 1700 images/sec to 8000 on GTX 1080
Browse files Browse the repository at this point in the history
  • Loading branch information
tfboyd committed Jun 8, 2017
1 parent 68a18b7 commit 082e65c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
13 changes: 9 additions & 4 deletions tutorials/image/cifar10/cifar10_multi_gpu_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"""Whether to log device placement.""")


def tower_loss(scope):
def tower_loss(scope, images, labels):
"""Calculate the total loss on a single tower running the CIFAR model.
Args:
Expand All @@ -71,8 +71,7 @@ def tower_loss(scope):
Returns:
Tensor of shape [] containing the total loss for a batch of data
"""
# Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs()


# Build inference Graph.
logits = cifar10.inference(images)
Expand Down Expand Up @@ -160,6 +159,12 @@ def train():
# Create an optimizer that performs gradient descent.
opt = tf.train.GradientDescentOptimizer(lr)

# Get images and labels for CIFAR-10.
# Force input pipeline to CPU:0 to avoid opertaios sometimes ending up on GPU
# and resulting in a slow down.
with tf.device('/CPU:0'):
images, labels = cifar10.distorted_inputs()

# Calculate the gradients for each model tower.
tower_grads = []
with tf.variable_scope(tf.get_variable_scope()):
Expand All @@ -169,7 +174,7 @@ def train():
# Calculate the loss for one tower of the CIFAR model. This function
# constructs the entire CIFAR model but shares the variables across
# all towers.
loss = tower_loss(scope)
loss = tower_loss(scope, images, labels)

# Reuse variables for the next tower.
tf.get_variable_scope().reuse_variables()
Expand Down
5 changes: 4 additions & 1 deletion tutorials/image/cifar10/cifar10_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def train():
global_step = tf.contrib.framework.get_or_create_global_step()

# Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs()
# Force input pipeline to CPU:0 to avoid opertaios sometimes ending up
# on GPU and resulting in a slow down.
with tf.device('/CPU:0'):
images, labels = cifar10.distorted_inputs()

# Build a Graph that computes the logits predictions from the
# inference model.
Expand Down

0 comments on commit 082e65c

Please sign in to comment.