From 42f7bcd755b627f3b1df68931230c85977f12d12 Mon Sep 17 00:00:00 2001 From: CUI Wei Date: Sat, 10 Nov 2018 12:14:49 +0800 Subject: [PATCH] add option support for native NCHW input data format. --- scripts/tf_cnn_benchmarks/benchmark_cnn.py | 21 +++++++++++++++++++-- scripts/tf_cnn_benchmarks/models/model.py | 9 ++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/scripts/tf_cnn_benchmarks/benchmark_cnn.py b/scripts/tf_cnn_benchmarks/benchmark_cnn.py index df55540b..a0bf12b0 100644 --- a/scripts/tf_cnn_benchmarks/benchmark_cnn.py +++ b/scripts/tf_cnn_benchmarks/benchmark_cnn.py @@ -274,8 +274,11 @@ flags.DEFINE_enum('device', 'gpu', ('cpu', 'gpu', 'CPU', 'GPU'), 'Device to use for computation: cpu or gpu') flags.DEFINE_enum('data_format', 'NCHW', ('NHWC', 'NCHW'), - 'Data layout to use: NHWC (TF native) or NCHW (cuDNN ' - 'native, requires GPU).') + 'Data layout to use for model layers: NHWC (TF native) or NCHW ' + '(NVIDIA cuDNN native / Intel MKL native).') +flags.DEFINE_enum('input_data_format', 'NHWC', ('NHWC', 'NCHW'), + 'Data format of input image: the image format of input data' + 'NHWC by default (tf.image generated data format).') flags.DEFINE_integer('num_intra_threads', None, 'Number of threads to use for intra-op parallelism. If ' 'set to 0, the system will pick an appropriate number.') @@ -1755,6 +1758,7 @@ def print_info(self): log_fn('Num epochs: %.2f' % self.num_epochs) log_fn('Devices: %s' % benchmark_info['device_list']) log_fn('Data format: %s' % self.params.data_format) + log_fn('Input image: %s' % self.params.input_data_format) if self.rewriter_config: log_fn('RewriterConfig: %s' % self.rewriter_config) log_fn('Optimizer: %s' % self.params.optimizer) @@ -1813,6 +1817,7 @@ def _log_benchmark_run(self): 'num_batches': self.num_batches, 'num_epochs': self.num_epochs, 'data_format': self.params.data_format, + 'input_data_format': self.params.input_data_format, 'rewrite_config': self.rewriter_config, 'optimizer': self.params.optimizer, 'session_config': create_config_proto(self.params), @@ -3170,6 +3175,18 @@ def device_aware_reshape(tensor, shape): subset = 'validation' if self._doing_eval else 'train' input_shapes = self.model.get_input_shapes(subset) + + # Till now, existing pre-defined input is of NHWC format. + # Could extend branch here in the future for the case of feeding native NCHW images. + if self.model.input_data_format == 'NCHW': + # Temporarily format above NHWC data to NCHW if expecting NCHW data as input + images_, labels_ = input_list + images_ = tf.transpose(images_, [0, 3, 1, 2]) + input_list = (images_, labels_) + input_shape_, output_shape_ = input_shapes + input_shape_ = [input_shape_[0], input_shape_[3], input_shape_[1], input_shape_[2]] + input_shapes = [input_shape_, output_shape_] + input_list = [ device_aware_reshape(input_list[i], shape=input_shapes[i]) for i in range(len(input_list)) diff --git a/scripts/tf_cnn_benchmarks/models/model.py b/scripts/tf_cnn_benchmarks/models/model.py index a4570ca0..97713759 100644 --- a/scripts/tf_cnn_benchmarks/models/model.py +++ b/scripts/tf_cnn_benchmarks/models/model.py @@ -174,6 +174,7 @@ def __init__(self, self.depth = 3 self.params = params self.data_format = params.data_format if params else 'NCHW' + self.input_data_format = params.input_data_format def get_layer_counts(self): return self.layer_counts @@ -273,8 +274,14 @@ def build_network(self, information. """ images = inputs[0] - if self.data_format == 'NCHW': + if self.data_format == 'NCHW' and self.input_data_format == 'NHWC': images = tf.transpose(images, [0, 3, 1, 2]) + elif self.data_format == 'NHWC' and self.input_data_format == 'NCHW': + images = tf.transpose(images, [0, 2, 3, 1]) + else: + # No need to transpose since self.data_format == self.input_data_format + pass + var_type = tf.float32 if self.data_type == tf.float16 and self.fp16_vars: var_type = tf.float16