diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index 52f32e6969a8e8..6b2380e79695d2 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -288,6 +288,7 @@ aten::lgamma.out aten::lgamma_ aten::lift aten::lift.out +aten::lift_fresh aten::linalg_vector_norm aten::linalg_vector_norm.out aten::log @@ -457,6 +458,7 @@ aten::special_zeta.other_scalar_out aten::special_zeta.out aten::special_zeta.self_scalar aten::special_zeta.self_scalar_out +aten::split.Tensor aten::split_with_sizes aten::sqrt aten::sqrt.out @@ -487,6 +489,7 @@ aten::triu_indices.out aten::trunc aten::trunc.out aten::trunc_ +aten::unbind.int aten::unfold aten::uniform aten::uniform.out diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index e8899408f79a30..5e1e107cf54a52 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -105,8 +105,7 @@ def forward(self, x): serialized, _ = ExportedProgramSerializer().serialize(exported_module) node = serialized.graph_module.graph.nodes[-1] - # split.Tensor gets decomposed to split_with_sizes by the core ATen decomposition table - self.assertEqual(node.target, "torch.ops.aten.split_with_sizes.default") + self.assertEqual(node.target, "torch.ops.aten.split.Tensor") self.assertEqual(len(node.outputs), 1) # Input looks like: # tensor([[0, 1], diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index abeeaaf8e42767..093d5868a2779c 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -329,7 +329,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.leaky_relu_backward, aten.lerp, aten.lerp_, - aten.lift_fresh, aten.linspace, aten.logaddexp, aten.logaddexp2, @@ -404,7 +403,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.special_entr, aten.special_log_ndtr, aten.special_xlog1py, - aten.split.Tensor, aten.std, aten.std_mean, aten.stack, @@ -419,7 +417,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.tril_, aten.triu, aten.triu_, - aten.unbind, aten.unfold_backward, aten.unfold_copy, aten._unsafe_index, diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 73bbf75ba5d033..2ae2c6f977c0c2 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -69,10 +69,7 @@ aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py aten.clamp_max, aten.clamp_min, - aten.glu, # inductor lowers this directly - aten.lift_fresh, # inductor lowers this directly (to no-op) - aten.split.Tensor, # inductor lowers this directly - aten.unbind, # inductor lowers this directly + aten.glu, # has lowering in inductor ] remove_decompositions(decompositions, decomps_to_exclude)