Skip to content

Commit

Permalink
Don't decompose functional composite ops in export inference IR (pyto…
Browse files Browse the repository at this point in the history
…rch#128077)

Recently we decided to split export IR into two different IRs (training vs inference). In the inference IR, one major change we decided to introduce was we wanted to keep the composite ops that user specified in the IR. This PR does that by overriding the CompositeImplicitAutograd decomp in export inference path.

Differential Revision: [D58701607](https://our.internmc.facebook.com/intern/diff/D58701607)
Pull Request resolved: pytorch#128077
Approved by: https://github.com/bdhirsh
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Jun 26, 2024
1 parent 64f1111 commit 90f6043
Show file tree
Hide file tree
Showing 10 changed files with 578 additions and 47 deletions.
163 changes: 163 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,38 @@ def forward(self, x, weight, bias):
actual_result.append(node.meta.get("torch_fn"))
self.assertEqual(actual_result, expected_result)

def test_export_preserve_linear_at_aot_level(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
x = self.linear(x)
return torch.ops.aten.chunk.default(x, 3, 0)

gm = (
torch.export.export(
Foo(),
(torch.randn(3, 3),),
)
.run_decompositions({}, _preserve_ops=(torch.ops.aten.linear.default,))
.graph_module
)
# linear is CompositeImplicitAutograd functional op so we should preserve it
# chunk is CompositeImplicitAutograd non-functional op we decompose.
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, p_linear_weight, p_linear_bias, x):
linear = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None
split = torch.ops.aten.split.Tensor(linear, 1); linear = None
getitem = split[0]
getitem_1 = split[1]
getitem_2 = split[2]; split = None
return (getitem, getitem_1, getitem_2)""",
)

# TODO(yidi)
# Expected failure for test cases that calls run_decomposition().
# The top-level cond node has pre-existing metadata,
Expand Down Expand Up @@ -1015,6 +1047,137 @@ def forward(self, x, y):
"dy - 6 = 6" not in exc.args[0]
) # don't suggest fix for non-root dim

def test_keep_composite_ops_invalid(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
x = self.linear(x)
return torch.ops.aten.chunk.default(x, 3, 0)

with self.assertRaisesRegex(
RuntimeError, "aten.chunk.default is a mutating/aliasing op"
):
_ = torch.export.export(
Foo(),
(torch.randn(3, 3),),
).run_decompositions({}, _preserve_ops=(torch.ops.aten.chunk.default,))

with self.assertRaisesRegex(
RuntimeError,
"aten.add.Tensor is not CompositeImplicitAutograd op, so we will preserve it as",
):
_ = torch.export.export(
Foo(),
(torch.randn(3, 3),),
).run_decompositions({}, _preserve_ops=(torch.ops.aten.add.Tensor,))

with self.assertRaisesRegex(
RuntimeError, "aten.sym_size.default is a metadata query function"
):
_ = torch.export.export(
Foo(),
(torch.randn(3, 3),),
).run_decompositions({}, _preserve_ops=(torch.ops.aten.sym_size.default,))

with self.assertRaisesRegex(
RuntimeError,
"We can't detect aten.native_batch_norm.default as a functional op statically",
):
_ = torch.export.export(
Foo(),
(torch.randn(3, 3),),
).run_decompositions(
{}, _preserve_ops=(torch.ops.aten.native_batch_norm.default,)
)

def test_keep_composite_ops_linear_convd(self):
class MyLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(20, 98)
self.bias = torch.randn(20)

def forward(self, x):
return torch.nn.functional.linear(x, self.weight, self.bias)

class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(16, 33, 3)
self.conv1d = torch.nn.Conv1d(16, 33, 3)
self.linear = MyLinear()

def forward(self, x, y):
x_conv = self.conv(x)
y_conv_1d = self.conv1d(y)
x_linear = self.linear(x_conv)
return x_linear.cos() + y_conv_1d.sum()

ep = torch.export.export(
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
)
ep_has_linear_convd = ep.run_decompositions(
decomp_table={},
_preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY,
)
self.assertExpectedInline(
str(ep_has_linear_convd.graph_module.code).strip(),
"""\
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None
linear = torch.ops.aten.linear.default(conv2d, c_linear_weight, c_linear_bias); conv2d = c_linear_weight = c_linear_bias = None
cos = torch.ops.aten.cos.default(linear); linear = None
sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None
add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
return (add,)""",
)

ep_has_convd = ep.run_decompositions(
decomp_table=None,
_preserve_ops=[
torch.ops.aten.conv2d.default,
torch.ops.aten.conv1d.default,
],
)
self.assertExpectedInline(
str(ep_has_convd.graph_module.code).strip(),
"""\
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None
view = torch.ops.aten.view.default(conv2d, [31680, 98]); conv2d = None
permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None
addmm = torch.ops.aten.addmm.default(c_linear_bias, view, permute); c_linear_bias = view = permute = None
view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None
cos = torch.ops.aten.cos.default(view_1); view_1 = None
sum_1 = torch.ops.aten.sum.dim_IntList(conv1d, []); conv1d = None
add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
return (add,)""",
)

ep_has_convd = ep_has_convd.run_decompositions(
decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default]
)
self.assertExpectedInline(
str(ep_has_convd.graph_module.code).strip(),
"""\
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
convolution = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None
view = torch.ops.aten.view.default(conv2d, [31680, 98]); conv2d = None
permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None
addmm = torch.ops.aten.addmm.default(c_linear_bias, view, permute); c_linear_bias = view = permute = None
view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None
cos = torch.ops.aten.cos.default(view_1); view_1 = None
sum_1 = torch.ops.aten.sum.dim_IntList(convolution, []); convolution = None
add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
return (add,)""",
)

def test_derived_dim_out_of_order_simplified(self):
_dimz = torch.export.Dim("_dimz", min=6, max=8)
dimy = _dimz - 1
Expand Down
191 changes: 191 additions & 0 deletions test/export/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,197 @@
import unittest
from unittest.mock import patch

import torch

aten = torch.ops.aten

# This list is not meant to be comprehensive
_COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [
aten.arctan2.default,
aten.divide.Tensor,
aten.divide.Scalar,
aten.divide.Tensor_mode,
aten.divide.Scalar_mode,
aten.multiply.Tensor,
aten.multiply.Scalar,
aten.subtract.Tensor,
aten.subtract.Scalar,
aten.true_divide.Tensor,
aten.true_divide.Scalar,
aten.greater.Tensor,
aten.greater.Scalar,
aten.greater_equal.Tensor,
aten.greater_equal.Scalar,
aten.less_equal.Tensor,
aten.less_equal.Scalar,
aten.less.Tensor,
aten.less.Scalar,
aten.not_equal.Tensor,
aten.not_equal.Scalar,
aten.cat.names,
aten.sum.dim_DimnameList,
aten.mean.names_dim,
aten.prod.dim_Dimname,
aten.all.dimname,
aten.norm.names_ScalarOpt_dim,
aten.norm.names_ScalarOpt_dim_dtype,
aten.var.default,
aten.var.dim,
aten.var.names_dim,
aten.var.correction_names,
aten.std.default,
aten.std.dim,
aten.std.names_dim,
aten.std.correction_names,
aten.absolute.default,
aten.arccos.default,
aten.arccosh.default,
aten.arcsin.default,
aten.arcsinh.default,
aten.arctan.default,
aten.arctanh.default,
aten.clip.default,
aten.clip.Tensor,
aten.fix.default,
aten.negative.default,
aten.square.default,
aten.size.int,
aten.size.Dimname,
aten.stride.int,
aten.stride.Dimname,
aten.repeat_interleave.self_Tensor,
aten.repeat_interleave.self_int,
aten.sym_size.int,
aten.sym_stride.int,
aten.atleast_1d.Sequence,
aten.atleast_2d.Sequence,
aten.atleast_3d.Sequence,
aten.linear.default,
aten.conv2d.default,
aten.conv2d.padding,
aten.mish_backward.default,
aten.silu_backward.default,
aten.index_add.dimname,
aten.pad_sequence.default,
aten.index_copy.dimname,
aten.upsample_nearest1d.vec,
aten.upsample_nearest2d.vec,
aten.upsample_nearest3d.vec,
aten._upsample_nearest_exact1d.vec,
aten._upsample_nearest_exact2d.vec,
aten._upsample_nearest_exact3d.vec,
aten.rnn_tanh.input,
aten.rnn_tanh.data,
aten.rnn_relu.input,
aten.rnn_relu.data,
aten.lstm.input,
aten.lstm.data,
aten.gru.input,
aten.gru.data,
aten._upsample_bilinear2d_aa.vec,
aten._upsample_bicubic2d_aa.vec,
aten.upsample_bilinear2d.vec,
aten.upsample_trilinear3d.vec,
aten.upsample_linear1d.vec,
aten.matmul.default,
aten.upsample_bicubic2d.vec,
aten.__and__.Scalar,
aten.__and__.Tensor,
aten.__or__.Tensor,
aten.__or__.Scalar,
aten.__xor__.Tensor,
aten.__xor__.Scalar,
aten.scatter.dimname_src,
aten.scatter.dimname_value,
aten.scatter_add.dimname,
aten.is_complex.default,
aten.logsumexp.names,
aten.where.ScalarOther,
aten.where.ScalarSelf,
aten.where.Scalar,
aten.where.default,
aten.item.default,
aten.any.dimname,
aten.std_mean.default,
aten.std_mean.dim,
aten.std_mean.names_dim,
aten.std_mean.correction_names,
aten.var_mean.default,
aten.var_mean.dim,
aten.var_mean.names_dim,
aten.var_mean.correction_names,
aten.broadcast_tensors.default,
aten.stft.default,
aten.stft.center,
aten.istft.default,
aten.index_fill.Dimname_Scalar,
aten.index_fill.Dimname_Tensor,
aten.index_select.dimname,
aten.diag.default,
aten.cumsum.dimname,
aten.cumprod.dimname,
aten.meshgrid.default,
aten.meshgrid.indexing,
aten.fft_fft.default,
aten.fft_ifft.default,
aten.fft_rfft.default,
aten.fft_irfft.default,
aten.fft_hfft.default,
aten.fft_ihfft.default,
aten.fft_fftn.default,
aten.fft_ifftn.default,
aten.fft_rfftn.default,
aten.fft_ihfftn.default,
aten.fft_irfftn.default,
aten.fft_hfftn.default,
aten.fft_fft2.default,
aten.fft_ifft2.default,
aten.fft_rfft2.default,
aten.fft_irfft2.default,
aten.fft_hfft2.default,
aten.fft_ihfft2.default,
aten.fft_fftshift.default,
aten.fft_ifftshift.default,
aten.selu.default,
aten.margin_ranking_loss.default,
aten.hinge_embedding_loss.default,
aten.nll_loss.default,
aten.prelu.default,
aten.relu6.default,
aten.pairwise_distance.default,
aten.pdist.default,
aten.special_ndtr.default,
aten.cummax.dimname,
aten.cummin.dimname,
aten.logcumsumexp.dimname,
aten.max.other,
aten.max.names_dim,
aten.min.other,
aten.min.names_dim,
aten.linalg_eigvals.default,
aten.median.names_dim,
aten.nanmedian.names_dim,
aten.mode.dimname,
aten.gather.dimname,
aten.sort.dimname,
aten.sort.dimname_stable,
aten.argsort.default,
aten.argsort.dimname,
aten.rrelu.default,
aten.conv_transpose1d.default,
aten.conv_transpose2d.input,
aten.conv_transpose3d.input,
aten.conv1d.default,
aten.conv1d.padding,
aten.conv3d.default,
aten.conv3d.padding,
aten.float_power.Tensor_Tensor,
aten.float_power.Tensor_Scalar,
aten.float_power.Scalar,
aten.ldexp.Tensor,
aten._version.default,
]


def make_test_cls_with_mocked_export(
cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None
Expand Down
Loading

0 comments on commit 90f6043

Please sign in to comment.