From d0ccf599ccafb379346101591bf76c6653a03e7f Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 11 Apr 2024 22:40:46 +0000 Subject: [PATCH] [export] Restore original placeholder names (part 2: higher-order-op subgraph naming) (#123587) Summary: note: breaking the original diff [D55225818](https://www.internalfb.com/diff/D55225818) into 3 parts (top-level renaming, higher-order-op subgraphs, constant input de/serialization) because of its size. Stacked PR to restore original names to placeholder nodes, replacing the default names arg0_1, arg1_1, ... This PR propagates node names to higher-order-op subgraph placeholders, retaining the top-level names and handling naming collisions by suffixing other non-placeholder nodes in the subgraph with an index. This is the same handling as in fx.Graph/fx.Node, but implemented separately as a pass. Since the input schemas of HOO subgraphs are very different, they are enumerated in _name_hoo_subgraph_placeholders(). Currently cond, map_impl, and wrap_with_set_grad_enabled are handled, but other ops can be easily added. Test Plan: verification checks on placeholder names for all export() calls, unit test in test/export/test_export.py Differential Revision: D55456749 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123587 Approved by: https://github.com/angelayi --- test/export/test_export.py | 89 +++++++++++++++++++++++++++++--- torch/_export/utils.py | 8 +-- torch/export/_trace.py | 25 +++++---- torch/export/exported_program.py | 45 ++++++++++++++++ 4 files changed, 147 insertions(+), 20 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 5ec2deb050aa6..6aa4f04a42425 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3910,19 +3910,19 @@ def forward(self, b_pred, b_t, x, y): self.assertExpectedInline( str(exported_program.graph_module.true_graph_0.code.strip()), """\ -def forward(self, arg1_1, arg0_1, arg2_1): +def forward(self, b_t, x, y): submod_3 = self.submod_1 - add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, arg1_1, arg0_1, arg2_1); submod_3 = arg1_1 = arg0_1 = arg2_1 = None + add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, b_t, x, y); submod_3 = b_t = x = y = None return (add_1,)""", ) self.assertExpectedInline( str(exported_program.graph_module.true_graph_0.submod_1.code.strip()), """\ -def forward(self, arg1_1, arg0_1, arg2_1): - sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None - add = torch.ops.aten.add.Tensor(sub, arg0_1); sub = arg0_1 = None - add_1 = torch.ops.aten.add.Tensor(add, arg2_1); add = arg2_1 = None +def forward(self, b_t, x, y): + sub = torch.ops.aten.sub.Tensor(b_t, 1); b_t = None + add = torch.ops.aten.add.Tensor(sub, x); sub = x = None + add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None return add_1""", ) @@ -4256,6 +4256,83 @@ def forward(self, mul, add, add_1): real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes] self.assertEqual(expected_names_and_ops, real_names_and_ops) + @testing.expectedFailureRetraceability + def test_placeholder_naming_collisions_hoo_subgraphs(self): + # test collisions between user inputs, top-level nodes, and HOO subgraph nodes + class Foo(torch.nn.Module): + def forward(self, x, mul, mul_1): + _mul = x * x + y = cond( + _mul.sum() > 0, + lambda x, y, z: x * y * z, + lambda x, y, z: x + y + z, + [_mul, mul, mul_1], + ) + with torch.enable_grad(): + y = y * y + return y + + with torch.no_grad(): + ep = torch.export._trace._export( + Foo(), + (torch.randn(4), torch.randn(4), torch.randn(4)), + pre_dispatch=True, + ) + # test cond subgraph + expected_names_and_ops = [ + ("mul_2", "placeholder"), + ("mul", "placeholder"), + ("mul_1", "placeholder"), + ("mul_3", "call_function"), + ("mul_4", "call_function"), + ("output", "output"), + ] + real_names_and_ops = [ + (node.name, node.op) for node in ep.graph_module.true_graph_0.graph.nodes + ] + self.assertEqual(expected_names_and_ops, real_names_and_ops) + # test set_grad_enabled subgraph + expected_names_and_ops = [ + ("getitem", "placeholder"), + ("mul_1", "call_function"), + ("output", "output"), + ] + real_names_and_ops = [ + (node.name, node.op) for node in ep.graph_module.submod_1.graph.nodes + ] + self.assertEqual(expected_names_and_ops, real_names_and_ops) + + # test collisions between user inputs & higher order op subgraphs + # (please never do this) + class Foo(torch.nn.Module): + def forward(self, input, true_graph, body_graph): + def map_body(x, y): + return x + y + + x = map(map_body, input, body_graph[0]) + x = x + true_graph[0] + true_graph[1] + x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x]) + x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x]) + return x + + inputs = ( + torch.randn(10, 4), + (torch.randn(4), torch.randn(4)), + (torch.randn(4),), + ) + ep = export(Foo(), inputs) + expected_getattr_names = [ + "body_graph_1", + "true_graph_2", + "false_graph_0", + "true_graph_3", + "false_graph_1", + ] + real_getattr_names = [ + node.name for node in ep.graph.nodes if node.op == "get_attr" + ] + self.assertEqual(expected_getattr_names, real_getattr_names) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 557788558d271..844af54e77d85 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -9,7 +9,10 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.export import ExportedProgram -from torch.export.exported_program import _rename_without_collisions +from torch.export.exported_program import ( + _name_hoo_subgraph_placeholders, + _rename_without_collisions, +) from torch.export.graph_signature import ConstantArgument, InputKind, OutputKind from torch.utils._pytree import ( _register_pytree_node, @@ -527,9 +530,8 @@ def _extract_pytree_key(x): elif node.name in name_map: node.name = name_map[node.name] - # TODO(pianpwk), in immediate follow-up PR # propagate names to higher order op subgraphs - # name_hoo_subgraph_placeholders(gm) + _name_hoo_subgraph_placeholders(gm) # re-generate graph module code gm.recompile() diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 96f7e3605b12c..2c7249c3f61d5 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -832,23 +832,26 @@ def _verify_placeholder_names(gm: torch.fx.GraphModule, sig: ExportGraphSignatur """ Performs a sanity check on the placeholder node names. - User input nodes: no restrictions, should match the original forward() signature - - Params/buffers/constants/custom_obj nodes: should start with "p", "b", "c", "obj" + - Params/buffers/constants/custom_obj/token nodes: should start with prefixes defined in """ name_to_kind = { spec.arg.name: spec.kind for spec in sig.input_specs if not isinstance(spec.arg, ConstantArgument) } - for node in gm.graph.nodes: - if node.op == "placeholder": - if node.name not in name_to_kind: - continue - node_kind = name_to_kind[node.name] - prefix = placeholder_prefixes[node_kind] - if not node.name.startswith(prefix): - raise SpecViolationError( - f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}" - ) + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in mod.graph.nodes: + if node.op == "placeholder": + if node.name not in name_to_kind: + continue + node_kind = name_to_kind[node.name] + prefix = placeholder_prefixes[node_kind] + if not node.name.startswith(prefix): + raise SpecViolationError( + f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}" + ) def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]: diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 37af592d32176..5c18bbdeda94e 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -141,6 +141,48 @@ def _rename_without_collisions( return name_map[orig_name] +def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: + """ + Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, + and handle collisions with non-placeholders by count suffixing. + Different HOO subgraph types have different input schemas, so we first enumerate them + and gather the top-level named placeholder nodes. + """ + # gather all HOO subgraphs and their top-level named placeholder nodes + subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = [] + for node in gm.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.HigherOrderOperator + ): + # HOO subgraphs have varying input schemas, so we enumerate them there + if node.target._name == "cond": + _, true_graph, false_graph, cond_args = node._args + subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args)) + subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args)) + elif node.target._name == "wrap_with_set_grad_enabled": + subgraph, phs = node._args[1], node._args[2:] + subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs)) + elif node.target._name == "map_impl": + body_graph, array, args = node._args + subgraph_ph_tuples.append( + (getattr(gm, body_graph.target), array + args) + ) + + # propagate names + for subgraph, hoo_phs in subgraph_ph_tuples: + name_map: Dict[str, str] = {} + for i, node in enumerate(subgraph.graph.nodes): + if i < len(hoo_phs): # placeholder, retain name + name_map[node.name] = hoo_phs[i].name + node.name = node.target = hoo_phs[i].name + else: # non-placeholder, check for collisions + node.name = _rename_without_collisions(name_map, node.name, node.name) + + # recurse and recompile + _name_hoo_subgraph_placeholders(subgraph) + subgraph.recompile() + + class ExportedProgram: """ Package of a program from :func:`export`. It contains @@ -543,6 +585,9 @@ def update_arg(old_arg, new_ph): continue node.name = _rename_without_collisions(name_map, node.name, node.name) + # propagate names to higher order op subgraphs + _name_hoo_subgraph_placeholders(gm) + # To match the output target with correct input for input mutations # need to find the old to new placeholder map old_new_placeholder_map = {