Skip to content

Commit

Permalink
Revert "[core IR] Add lift_fresh, split.Tensor, and unbind decomposit…
Browse files Browse the repository at this point in the history
…ions to core ATen decomp table (pytorch#110102)"

This reverts commit 22e706f.

Reverted pytorch#110102 on behalf of https://github.com/atalman due to Breaks internal CI ([comment](pytorch#110102 (comment)))
  • Loading branch information
pytorchmergebot committed Sep 28, 2023
1 parent aaaa3c1 commit e0b035c
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 9 deletions.
3 changes: 3 additions & 0 deletions test/expect/HasDecompTest.test_aten_core_operators.expect
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ aten::lgamma.out
aten::lgamma_
aten::lift
aten::lift.out
aten::lift_fresh
aten::linalg_vector_norm
aten::linalg_vector_norm.out
aten::log
Expand Down Expand Up @@ -457,6 +458,7 @@ aten::special_zeta.other_scalar_out
aten::special_zeta.out
aten::special_zeta.self_scalar
aten::special_zeta.self_scalar_out
aten::split.Tensor
aten::split_with_sizes
aten::sqrt
aten::sqrt.out
Expand Down Expand Up @@ -487,6 +489,7 @@ aten::triu_indices.out
aten::trunc
aten::trunc.out
aten::trunc_
aten::unbind.int
aten::unfold
aten::uniform
aten::uniform.out
Expand Down
3 changes: 1 addition & 2 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def forward(self, x):

serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
# split.Tensor gets decomposed to split_with_sizes by the core ATen decomposition table
self.assertEqual(node.target, "torch.ops.aten.split_with_sizes.default")
self.assertEqual(node.target, "torch.ops.aten.split.Tensor")
self.assertEqual(len(node.outputs), 1)
# Input looks like:
# tensor([[0, 1],
Expand Down
3 changes: 0 additions & 3 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.leaky_relu_backward,
aten.lerp,
aten.lerp_,
aten.lift_fresh,
aten.linspace,
aten.logaddexp,
aten.logaddexp2,
Expand Down Expand Up @@ -404,7 +403,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.special_entr,
aten.special_log_ndtr,
aten.special_xlog1py,
aten.split.Tensor,
aten.std,
aten.std_mean,
aten.stack,
Expand All @@ -419,7 +417,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.tril_,
aten.triu,
aten.triu_,
aten.unbind,
aten.unfold_backward,
aten.unfold_copy,
aten._unsafe_index,
Expand Down
5 changes: 1 addition & 4 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@
aten._scaled_dot_product_flash_attention.default, # See comments in torch/_decomp/decompositions.py
aten.clamp_max,
aten.clamp_min,
aten.glu, # inductor lowers this directly
aten.lift_fresh, # inductor lowers this directly (to no-op)
aten.split.Tensor, # inductor lowers this directly
aten.unbind, # inductor lowers this directly
aten.glu, # has lowering in inductor
]

remove_decompositions(decompositions, decomps_to_exclude)
Expand Down

0 comments on commit e0b035c

Please sign in to comment.