From f708e92ba10d4cd200059922ad926604e4c1dec9 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Thu, 21 Nov 2024 05:56:05 -0800 Subject: [PATCH] [Inductor] support Conv/Linear + broadcast add fusion (#138201) Pull Request resolved: https://github.com/pytorch/pytorch/pull/138201 Approved by: https://github.com/jgong5, https://github.com/jansel --- aten/src/ATen/native/mkldnn/Conv.cpp | 14 +- aten/src/ATen/native/mkldnn/Linear.cpp | 4 +- test/inductor/test_mkldnn_pattern_matcher.py | 135 +++++++++++++++++++ torch/_inductor/fx_passes/mkldnn_fusion.py | 12 +- 4 files changed, 153 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 9bc382701cc49c..dd12f5e574c4a7 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -371,10 +371,9 @@ Tensor mkldnn_convolution_pointwise_binary( auto output_sizes = conv_output_size( input_t.sizes(), weight_t.sizes(), padding_expanded, stride_expanded, dilation_expanded); - // TODO: support broadcast binary fusion. TORCH_CHECK( - output_sizes == other_t.sizes(), - "Binary Fusion's inputs should have same shape"); + input_t.dim() == other_t.dim(), + "Binary Fusion's inputs should have same dimensions"); // Only calling fusion path for channels_last path. // TODO: OneDNN doesn't optimize well for groups > 1 case, it will be enabled // at next OneDNN release. @@ -405,18 +404,17 @@ Tensor mkldnn_convolution_pointwise_binary( auto weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format); auto other = other_t.contiguous(memory_format); - auto output = at::empty_like(other); + auto output = at::empty(output_sizes, input_t.options()).contiguous(memory_format); const ideep::tensor x = itensor_from_tensor(input); const ideep::tensor w = itensor_from_tensor(weight); const ideep::tensor z = itensor_from_tensor(other); ideep::tensor y = itensor_from_tensor(output); - auto output_size = other.sizes().vec(); ideep::tag format_tag = ideep::tag::nhwc; if (input_t.ndimension() == 5) { format_tag = ideep::tag::ndhwc; } auto other_desc = ideep::tensor::desc( - output_size, get_mkldnn_dtype(weight.scalar_type()), format_tag); + other.sizes().vec(), get_mkldnn_dtype(other.scalar_type()), format_tag); ideep::attr_t op_attr; ideep::post_ops po; @@ -433,7 +431,7 @@ Tensor mkldnn_convolution_pointwise_binary( z, w, b, - output_size, + output_sizes, y, stride_expanded, dilation_expanded, @@ -447,7 +445,7 @@ Tensor mkldnn_convolution_pointwise_binary( x, z, w, - output_size, + output_sizes, y, stride_expanded, dilation_expanded, diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index e5dc8a6e0c1da6..aed712cc63d3b0 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -299,8 +299,8 @@ Tensor mkldnn_linear_pointwise_binary( } TORCH_CHECK( - output.sizes() == other_reshaped.sizes(), - "linear_binary_run expects the size of output and other tensor to be the same"); + output.dim() == other_reshaped.dim(), + "linear_binary_run expects the dimension of output and other tensor to be the same"); c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); ideep::tensor mkldnn_output = itensor_from_tensor(output); diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index d2424a62f4778f..07ea40afc5b8d1 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -633,6 +633,92 @@ def test_conv2d_binary(self): def test_conv3d_binary(self): self._test_conv_binary_base(dim=5) + def _test_conv_binary_broadcast_shapes_base(self, dim=4): + assert dim == 4 or dim == 5 + + class M(torch.nn.Module): + def __init__( + self, + binary_fn, + has_relu, + **kwargs, + ): + super().__init__() + if dim == 4: + self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1) + else: + self.conv = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1) + self.binary_fn = binary_fn + self.has_relu = has_relu + + def forward(self, x, x2): + x1 = self.conv(x) + if has_relu: + return self.binary_fn(x1, x2).relu() + else: + return self.binary_fn(x1, x2) + + dtypes = [ + torch.float, + ] + if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + dtypes.append(torch.bfloat16) + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + dtypes.append(torch.float16) + cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d + test_memory_format = [torch.contiguous_format, cl_format] + options = itertools.product( + binary_list, + [True, False], + test_memory_format, + dtypes, + ) + + for ( + binary_fn, + has_relu, + memory_format, + dtype, + ) in options: + metrics.reset() + if dim == 4: + x_shape = (1, 3, 56, 56) + other_shape = (1, 16, 1, 1) + else: + x_shape = (1, 3, 20, 56, 56) + other_shape = (1, 16, 1, 1, 1) + mod = M(binary_fn, has_relu).eval() + x = ( + torch.randn(x_shape, dtype=torch.float32, requires_grad=True) + .add(1) + .to(memory_format=memory_format) + ) + other = ( + torch.randn(other_shape, dtype=torch.float32, requires_grad=True) + .add(1) + .to(memory_format=memory_format) + .to(dtype) + ) + match_count = binary_list[binary_fn][0] + 1 + match_nodes = binary_list[binary_fn][1] + if has_relu: + match_nodes += 1 + self._test_common( + mod, (x, other), match_count, match_nodes + 1, check_autocast=dtype + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_conv2d_binary_broadcast_shapes_cpu(self): + self._test_conv_binary_broadcast_shapes_base(dim=4) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_conv3d_binary_broadcast_shapes_cpu(self): + self._test_conv_binary_broadcast_shapes_base(dim=5) + def test_linear_binary(self): class M(torch.nn.Module): def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs): @@ -683,6 +769,55 @@ def forward(self, x, y): ) self.assertEqual(metrics.generated_kernel_count, 1) + def test_linear_binary_broadcast_shapes_cpu(self): + class M(torch.nn.Module): + def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs): + super().__init__() + self.linear = torch.nn.Linear( + in_channels, out_channels, bias=bias, **kwargs + ) + self.binary_fn = binary_fn + + def forward(self, x, y): + x = self.linear(x) + x = self.binary_fn(x, y.clone()) + return x + + dtypes = [] + if torch.ops.mkldnn._is_mkldnn_bf16_supported(): + dtypes.append(torch.bfloat16) + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + dtypes.append(torch.float16) + options = itertools.product( + binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes + ) + out_feature = 30 + + for binary_fn, input_shape, bias, dtype in options: + metrics.reset() + # addmm(mm) + (linear+add) + match_count = 2 + match_nodes = 3 + if len(input_shape) == 3: + is_inplace = binary_list[binary_fn][2] + # view + linear + view(joint_graph+freeze pass) + match_count = match_count + 5 if is_inplace else match_count + 3 + match_nodes = match_nodes + 8 if is_inplace else match_nodes + 5 + mod = M(binary_fn, input_shape[-1], out_feature, bias).eval() + v = torch.randn(input_shape) + other = torch.randn(input_shape[:-1] + [1]).to(dtype) + self._test_common( + mod, + ( + v, + other, + ), + match_count, + match_nodes, + check_autocast=dtype, + ) + self.assertEqual(metrics.generated_kernel_count, 1) + def test_multi_linear_share_same_input(self): # llama pattern. class M(torch.nn.Module): diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index eb6e383ac0c55a..e181677f0e2f4b 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -367,8 +367,14 @@ def get_meta_value(argument: torch.fx.node.Argument): for n in binary_nodes ): return False + if any( - get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size() + get_meta_value(n.args[0]).dim() != get_meta_value(n.args[1]).dim() + or not all( + get_meta_value(n.args[0]).size(i) == get_meta_value(n.args[1]).size(i) + or get_meta_value(match.kwargs["other"]).size(i) == 1 + for i in range(get_meta_value(n.args[0]).dim()) + ) or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype for n in binary_nodes @@ -538,7 +544,9 @@ def fn(match, *args, **kwargs): computation_args += [1.0, None, [], None] # Make sure the other is not an alias or mutation(fx side doesn't has such info). other.realize() - if not _can_be_inplace(other): + if not _can_be_inplace(other) or other.data.shape != list( + match.nodes[0].meta["val"].size() + ): return L[outplace_fusion_op](*computation_args) return L[inplace_fusion_op](*computation_args)