Skip to content
This repository has been archived by the owner on Dec 9, 2024. It is now read-only.

Add option support for native NCHW input data format. #268

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions scripts/tf_cnn_benchmarks/benchmark_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,11 @@
flags.DEFINE_enum('device', 'gpu', ('cpu', 'gpu', 'CPU', 'GPU'),
'Device to use for computation: cpu or gpu')
flags.DEFINE_enum('data_format', 'NCHW', ('NHWC', 'NCHW'),
'Data layout to use: NHWC (TF native) or NCHW (cuDNN '
'native, requires GPU).')
'Data layout to use for model layers: NHWC (TF native) or NCHW '
'(NVIDIA cuDNN native / Intel MKL native).')
flags.DEFINE_enum('input_data_format', 'NHWC', ('NHWC', 'NCHW'),
'Data format of input image: the image format of input data'
'NHWC by default (tf.image generated data format).')
flags.DEFINE_integer('num_intra_threads', None,
'Number of threads to use for intra-op parallelism. If '
'set to 0, the system will pick an appropriate number.')
Expand Down Expand Up @@ -1755,6 +1758,7 @@ def print_info(self):
log_fn('Num epochs: %.2f' % self.num_epochs)
log_fn('Devices: %s' % benchmark_info['device_list'])
log_fn('Data format: %s' % self.params.data_format)
log_fn('Input image: %s' % self.params.input_data_format)
if self.rewriter_config:
log_fn('RewriterConfig: %s' % self.rewriter_config)
log_fn('Optimizer: %s' % self.params.optimizer)
Expand Down Expand Up @@ -1813,6 +1817,7 @@ def _log_benchmark_run(self):
'num_batches': self.num_batches,
'num_epochs': self.num_epochs,
'data_format': self.params.data_format,
'input_data_format': self.params.input_data_format,
'rewrite_config': self.rewriter_config,
'optimizer': self.params.optimizer,
'session_config': create_config_proto(self.params),
Expand Down Expand Up @@ -3170,6 +3175,18 @@ def device_aware_reshape(tensor, shape):

subset = 'validation' if self._doing_eval else 'train'
input_shapes = self.model.get_input_shapes(subset)

# Till now, existing pre-defined input is of NHWC format.
# Could extend branch here in the future for the case of feeding native NCHW images.
if self.model.input_data_format == 'NCHW':
# Temporarily format above NHWC data to NCHW if expecting NCHW data as input
images_, labels_ = input_list
images_ = tf.transpose(images_, [0, 3, 1, 2])
input_list = (images_, labels_)
input_shape_, output_shape_ = input_shapes
input_shape_ = [input_shape_[0], input_shape_[3], input_shape_[1], input_shape_[2]]
input_shapes = [input_shape_, output_shape_]

input_list = [
device_aware_reshape(input_list[i], shape=input_shapes[i])
for i in range(len(input_list))
Expand Down
9 changes: 8 additions & 1 deletion scripts/tf_cnn_benchmarks/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(self,
self.depth = 3
self.params = params
self.data_format = params.data_format if params else 'NCHW'
self.input_data_format = params.input_data_format

def get_layer_counts(self):
return self.layer_counts
Expand Down Expand Up @@ -273,8 +274,14 @@ def build_network(self,
information.
"""
images = inputs[0]
if self.data_format == 'NCHW':
if self.data_format == 'NCHW' and self.input_data_format == 'NHWC':
images = tf.transpose(images, [0, 3, 1, 2])
elif self.data_format == 'NHWC' and self.input_data_format == 'NCHW':
images = tf.transpose(images, [0, 2, 3, 1])
else:
# No need to transpose since self.data_format == self.input_data_format
pass

var_type = tf.float32
if self.data_type == tf.float16 and self.fp16_vars:
var_type = tf.float16
Expand Down