Skip to content

Commit

Permalink
Revert "Support fused batchnorm with any ndims and axis" (#695)
Browse files Browse the repository at this point in the history
This reverts commit 65dabe1.
  • Loading branch information
liutongxuan authored Feb 17, 2023
1 parent 08c81ad commit aec3c96
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 234 deletions.
4 changes: 2 additions & 2 deletions tensorflow/python/keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -943,15 +943,15 @@ tf_py_test(

tf_py_test(
name = "normalization_test",
size = "large",
size = "medium",
srcs = ["layers/normalization_test.py"],
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
],
shard_count = 8,
shard_count = 4,
tags = [
"no_rocm",
"notsan",
Expand Down
113 changes: 43 additions & 70 deletions tensorflow/python/keras/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from __future__ import print_function

from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
Expand Down Expand Up @@ -187,8 +186,10 @@ def __init__(self,
if self._USE_V2_BEHAVIOR:
if fused:
self._raise_if_fused_cannot_be_used()
elif fused is None:
fused = self._fused_can_be_used()
# We leave fused as None if self._fused_can_be_used()==True, since we
# still may set it to False in self.build() if the input rank is not 4.
elif fused is None and not self._fused_can_be_used():
fused = False
elif fused is None:
fused = True
self.supports_masking = True
Expand All @@ -209,16 +210,22 @@ def __init__(self,

def _raise_if_fused_cannot_be_used(self):
"""Raises a ValueError if fused implementation cannot be used.
In addition to the checks done in this function, the input tensors rank must
be 4. The input rank check can only be done once the input shape is known.
"""
# Currently fused batch norm doesn't support renorm. It also only supports a
# single axis, when no virtual batch size or adjustment is used.
# channel dimension on axis 1 or 3, when no virtual batch size or adjustment
# is used.
if self.renorm:
raise ValueError('Passing both fused=True and renorm=True is '
'unsupported')
axis = [self.axis] if isinstance(self.axis, int) else self.axis
if len(axis) > 1:
raise ValueError('Passing fused=True is only supported when operating '
'over a single axis.')
# Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
# input rank is required to be 4 (which is checked later).
if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
raise ValueError('Passing fused=True is only supported when axis is 1 '
'or 3')
if self.virtual_batch_size is not None:
raise ValueError('Passing fused=True is unsupported when '
'virtual_batch_size is specified.')
Expand Down Expand Up @@ -262,62 +269,6 @@ def _support_zero_size_input(self):
distribution_strategy_context.get_strategy().extended,
'experimental_enable_get_next_as_optional', False)

def _get_shape_and_axis_for_fused(self, nd_shape, nd_axis):
"""Compute an equivalent shape and axis that are compatible with the fused
implementation.
The input/output of the layer can be reshaped to/from the shape returned by
this function without affecting the correctness of the computation.
Arguments:
nd_shape: Tensor. The original shape of the operation.
nd_axis: Integer. The original axis of the operation.
Returns:
shape: Tensor. A 4D shape.
axis: Integer. An axis (always 1 or 3).
"""
assert(isinstance(nd_axis, int))
ndims = nd_shape.shape[0]
shape = nd_shape[:]
axis = nd_shape + nd_axis if nd_axis < 0 else nd_axis
# First check if the axis needs to be moved.
if axis not in (1, ndims - 1):
# Move axis to dim 1.
if axis == 0:
# Transform [C, ...] to [1, C, ...].
shape = array_ops.concat([constant_op.constant([1]), shape], axis=0)
ndims += 1
else:
# Merge excess pre-axis dims into first dim.
# Transform [N, ..., C, ...] to [product(N, ...), C, ...].
product = math_ops.reduce_prod(shape[:axis], keepdims=True)
shape = array_ops.concat([product, shape[axis:]], axis=0)
ndims -= (axis - 1)
axis = 1
# Now change shape to 4D.
is_channels_last = axis == ndims - 1
if ndims < 4:
# Insert new dims after existing spatial dim or before channel dim.
new_dims = constant_op.constant([1] * (4 - ndims))
if is_channels_last:
# Transform [..., C] to [..., 1..., C] (ndims=4).
shape = array_ops.concat([shape[:-1], new_dims, shape[-1:]], axis=0)
else:
# Transform [N, C, ...] to [N, C, ..., 1...] (ndims=4).
shape = array_ops.concat([shape, new_dims], axis=0)
elif ndims > 4:
# Merge excess spatial dims into the second spatial dim.
# Transform [N, C, H, W, ...] to [N, C, H, product(W, ...)].
# Or [N, H, W, ..., C] to [N, H, product(W, ...), C].
merge_dim = 2 if is_channels_last else 3
product = math_ops.reduce_prod(
shape[merge_dim:merge_dim + 1 + (ndims - 4)], keepdims=True)
shape = array_ops.concat([shape[:merge_dim], product,
shape[merge_dim + 1 + (ndims - 4):]], axis=0)
axis = 3 if is_channels_last else 1
return shape, axis

def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims:
Expand Down Expand Up @@ -352,8 +303,33 @@ def build(self, input_shape):
raise ValueError('When using virtual_batch_size, adjustment cannot '
'be specified')

if self.fused and not self._USE_V2_BEHAVIOR:
self.fused = self._fused_can_be_used()
if self.fused in (None, True):
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
# output back to its original shape accordingly.
if self._USE_V2_BEHAVIOR:
if self.fused is None:
self.fused = (ndims == 4)
elif self.fused and ndims != 4:
raise ValueError('Batch normalization layers with fused=True only '
'support 4D input tensors.')
else:
assert self.fused is not None
self.fused = (ndims == 4 and self._fused_can_be_used())
# TODO(chrisying): fused batch norm is currently not supported for
# multi-axis batch norm and by extension virtual batches. In some cases,
# it might be possible to use fused batch norm but would require reshaping
# the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
# particularly tricky. A compromise might be to just support the most
# common use case (turning 5D w/ virtual batch to NCHW)

if self.fused:
if self.axis == [1]:
self._data_format = 'NCHW'
elif self.axis == [3]:
self._data_format = 'NHWC'
else:
raise ValueError('Unsupported axis, fused batch norm only supports '
'axis == [1] or axis == [3]')

axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
for x in axis_to_dim:
Expand Down Expand Up @@ -548,7 +524,7 @@ def _fused_batch_norm_training():
gamma,
beta,
epsilon=self.epsilon,
data_format=data_format)
data_format=self._data_format)

def _fused_batch_norm_inference():
return nn.fused_batch_norm(
Expand All @@ -559,7 +535,7 @@ def _fused_batch_norm_inference():
variance=self.moving_variance,
epsilon=self.epsilon,
is_training=False,
data_format=data_format)
data_format=self._data_format)

output, mean, variance = tf_utils.smart_cond(
training, _fused_batch_norm_training, _fused_batch_norm_inference)
Expand All @@ -572,9 +548,6 @@ def _fused_batch_norm_inference():
factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
variance *= factor

if original_shape is not None:
output = array_ops.reshape(output, original_shape)

training_value = tf_utils.constant_value(training)
if training_value is None:
momentum = tf_utils.smart_cond(training,
Expand Down
Loading

0 comments on commit aec3c96

Please sign in to comment.