Skip to content

Commit

Permalink
[Inductor] support Conv/Linear + broadcast add fusion (pytorch#138201)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayisunx authored and pytorchmergebot committed Nov 22, 2024
1 parent 5ab5a61 commit f708e92
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 12 deletions.
14 changes: 6 additions & 8 deletions aten/src/ATen/native/mkldnn/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,9 @@ Tensor mkldnn_convolution_pointwise_binary(

auto output_sizes = conv_output_size(
input_t.sizes(), weight_t.sizes(), padding_expanded, stride_expanded, dilation_expanded);
// TODO: support broadcast binary fusion.
TORCH_CHECK(
output_sizes == other_t.sizes(),
"Binary Fusion's inputs should have same shape");
input_t.dim() == other_t.dim(),
"Binary Fusion's inputs should have same dimensions");
// Only calling fusion path for channels_last path.
// TODO: OneDNN doesn't optimize well for groups > 1 case, it will be enabled
// at next OneDNN release.
Expand Down Expand Up @@ -405,18 +404,17 @@ Tensor mkldnn_convolution_pointwise_binary(
auto weight =
weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
auto other = other_t.contiguous(memory_format);
auto output = at::empty_like(other);
auto output = at::empty(output_sizes, input_t.options()).contiguous(memory_format);
const ideep::tensor x = itensor_from_tensor(input);
const ideep::tensor w = itensor_from_tensor(weight);
const ideep::tensor z = itensor_from_tensor(other);
ideep::tensor y = itensor_from_tensor(output);
auto output_size = other.sizes().vec();
ideep::tag format_tag = ideep::tag::nhwc;
if (input_t.ndimension() == 5) {
format_tag = ideep::tag::ndhwc;
}
auto other_desc = ideep::tensor::desc(
output_size, get_mkldnn_dtype(weight.scalar_type()), format_tag);
other.sizes().vec(), get_mkldnn_dtype(other.scalar_type()), format_tag);

ideep::attr_t op_attr;
ideep::post_ops po;
Expand All @@ -433,7 +431,7 @@ Tensor mkldnn_convolution_pointwise_binary(
z,
w,
b,
output_size,
output_sizes,
y,
stride_expanded,
dilation_expanded,
Expand All @@ -447,7 +445,7 @@ Tensor mkldnn_convolution_pointwise_binary(
x,
z,
w,
output_size,
output_sizes,
y,
stride_expanded,
dilation_expanded,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/mkldnn/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ Tensor mkldnn_linear_pointwise_binary(
}

TORCH_CHECK(
output.sizes() == other_reshaped.sizes(),
"linear_binary_run expects the size of output and other tensor to be the same");
output.dim() == other_reshaped.dim(),
"linear_binary_run expects the dimension of output and other tensor to be the same");

c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
ideep::tensor mkldnn_output = itensor_from_tensor(output);
Expand Down
135 changes: 135 additions & 0 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,92 @@ def test_conv2d_binary(self):
def test_conv3d_binary(self):
self._test_conv_binary_base(dim=5)

def _test_conv_binary_broadcast_shapes_base(self, dim=4):
assert dim == 4 or dim == 5

class M(torch.nn.Module):
def __init__(
self,
binary_fn,
has_relu,
**kwargs,
):
super().__init__()
if dim == 4:
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1)
else:
self.conv = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1)
self.binary_fn = binary_fn
self.has_relu = has_relu

def forward(self, x, x2):
x1 = self.conv(x)
if has_relu:
return self.binary_fn(x1, x2).relu()
else:
return self.binary_fn(x1, x2)

dtypes = [
torch.float,
]
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d
test_memory_format = [torch.contiguous_format, cl_format]
options = itertools.product(
binary_list,
[True, False],
test_memory_format,
dtypes,
)

for (
binary_fn,
has_relu,
memory_format,
dtype,
) in options:
metrics.reset()
if dim == 4:
x_shape = (1, 3, 56, 56)
other_shape = (1, 16, 1, 1)
else:
x_shape = (1, 3, 20, 56, 56)
other_shape = (1, 16, 1, 1, 1)
mod = M(binary_fn, has_relu).eval()
x = (
torch.randn(x_shape, dtype=torch.float32, requires_grad=True)
.add(1)
.to(memory_format=memory_format)
)
other = (
torch.randn(other_shape, dtype=torch.float32, requires_grad=True)
.add(1)
.to(memory_format=memory_format)
.to(dtype)
)
match_count = binary_list[binary_fn][0] + 1
match_nodes = binary_list[binary_fn][1]
if has_relu:
match_nodes += 1
self._test_common(
mod, (x, other), match_count, match_nodes + 1, check_autocast=dtype
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv2d_binary_broadcast_shapes_cpu(self):
self._test_conv_binary_broadcast_shapes_base(dim=4)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_conv3d_binary_broadcast_shapes_cpu(self):
self._test_conv_binary_broadcast_shapes_base(dim=5)

def test_linear_binary(self):
class M(torch.nn.Module):
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
Expand Down Expand Up @@ -683,6 +769,55 @@ def forward(self, x, y):
)
self.assertEqual(metrics.generated_kernel_count, 1)

def test_linear_binary_broadcast_shapes_cpu(self):
class M(torch.nn.Module):
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
super().__init__()
self.linear = torch.nn.Linear(
in_channels, out_channels, bias=bias, **kwargs
)
self.binary_fn = binary_fn

def forward(self, x, y):
x = self.linear(x)
x = self.binary_fn(x, y.clone())
return x

dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
options = itertools.product(
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
)
out_feature = 30

for binary_fn, input_shape, bias, dtype in options:
metrics.reset()
# addmm(mm) + (linear+add)
match_count = 2
match_nodes = 3
if len(input_shape) == 3:
is_inplace = binary_list[binary_fn][2]
# view + linear + view(joint_graph+freeze pass)
match_count = match_count + 5 if is_inplace else match_count + 3
match_nodes = match_nodes + 8 if is_inplace else match_nodes + 5
mod = M(binary_fn, input_shape[-1], out_feature, bias).eval()
v = torch.randn(input_shape)
other = torch.randn(input_shape[:-1] + [1]).to(dtype)
self._test_common(
mod,
(
v,
other,
),
match_count,
match_nodes,
check_autocast=dtype,
)
self.assertEqual(metrics.generated_kernel_count, 1)

def test_multi_linear_share_same_input(self):
# llama pattern.
class M(torch.nn.Module):
Expand Down
12 changes: 10 additions & 2 deletions torch/_inductor/fx_passes/mkldnn_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,14 @@ def get_meta_value(argument: torch.fx.node.Argument):
for n in binary_nodes
):
return False

if any(
get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size()
get_meta_value(n.args[0]).dim() != get_meta_value(n.args[1]).dim()
or not all(
get_meta_value(n.args[0]).size(i) == get_meta_value(n.args[1]).size(i)
or get_meta_value(match.kwargs["other"]).size(i) == 1
for i in range(get_meta_value(n.args[0]).dim())
)
or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device
or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype
for n in binary_nodes
Expand Down Expand Up @@ -538,7 +544,9 @@ def fn(match, *args, **kwargs):
computation_args += [1.0, None, [], None]
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
other.realize()
if not _can_be_inplace(other):
if not _can_be_inplace(other) or other.data.shape != list(
match.nodes[0].meta["val"].size()
):
return L[outplace_fusion_op](*computation_args)
return L[inplace_fusion_op](*computation_args)

Expand Down

0 comments on commit f708e92

Please sign in to comment.