Skip to content

Commit

Permalink
[export] Restore original placeholder names (part 2: higher-order-op …
Browse files Browse the repository at this point in the history
…subgraph naming) (pytorch#123587)

Summary:
note: breaking the original diff [D55225818](https://www.internalfb.com/diff/D55225818) into 3 parts (top-level renaming, higher-order-op subgraphs, constant input de/serialization) because of its size.

Stacked PR to restore original names to placeholder nodes, replacing the default names arg0_1, arg1_1, ...

This PR propagates node names to higher-order-op subgraph placeholders, retaining the top-level names and handling naming collisions by suffixing other non-placeholder nodes in the subgraph with an index. This is the same handling as in fx.Graph/fx.Node, but implemented separately as a pass.

Since the input schemas of HOO subgraphs are very different, they are enumerated in _name_hoo_subgraph_placeholders(). Currently cond, map_impl, and wrap_with_set_grad_enabled are handled, but other ops can be easily added.

Test Plan: verification checks on placeholder names for all export() calls, unit test in test/export/test_export.py

Differential Revision: D55456749

Pull Request resolved: pytorch#123587
Approved by: https://github.com/angelayi
  • Loading branch information
pianpwk authored and pytorchmergebot committed Apr 11, 2024
1 parent b9675e8 commit d0ccf59
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 20 deletions.
89 changes: 83 additions & 6 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3910,19 +3910,19 @@ def forward(self, b_pred, b_t, x, y):
self.assertExpectedInline(
str(exported_program.graph_module.true_graph_0.code.strip()),
"""\
def forward(self, arg1_1, arg0_1, arg2_1):
def forward(self, b_t, x, y):
submod_3 = self.submod_1
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, arg1_1, arg0_1, arg2_1); submod_3 = arg1_1 = arg0_1 = arg2_1 = None
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, b_t, x, y); submod_3 = b_t = x = y = None
return (add_1,)""",
)

self.assertExpectedInline(
str(exported_program.graph_module.true_graph_0.submod_1.code.strip()),
"""\
def forward(self, arg1_1, arg0_1, arg2_1):
sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None
add = torch.ops.aten.add.Tensor(sub, arg0_1); sub = arg0_1 = None
add_1 = torch.ops.aten.add.Tensor(add, arg2_1); add = arg2_1 = None
def forward(self, b_t, x, y):
sub = torch.ops.aten.sub.Tensor(b_t, 1); b_t = None
add = torch.ops.aten.add.Tensor(sub, x); sub = x = None
add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None
return add_1""",
)

Expand Down Expand Up @@ -4256,6 +4256,83 @@ def forward(self, mul, add, add_1):
real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes]
self.assertEqual(expected_names_and_ops, real_names_and_ops)

@testing.expectedFailureRetraceability
def test_placeholder_naming_collisions_hoo_subgraphs(self):
# test collisions between user inputs, top-level nodes, and HOO subgraph nodes
class Foo(torch.nn.Module):
def forward(self, x, mul, mul_1):
_mul = x * x
y = cond(
_mul.sum() > 0,
lambda x, y, z: x * y * z,
lambda x, y, z: x + y + z,
[_mul, mul, mul_1],
)
with torch.enable_grad():
y = y * y
return y

with torch.no_grad():
ep = torch.export._trace._export(
Foo(),
(torch.randn(4), torch.randn(4), torch.randn(4)),
pre_dispatch=True,
)
# test cond subgraph
expected_names_and_ops = [
("mul_2", "placeholder"),
("mul", "placeholder"),
("mul_1", "placeholder"),
("mul_3", "call_function"),
("mul_4", "call_function"),
("output", "output"),
]
real_names_and_ops = [
(node.name, node.op) for node in ep.graph_module.true_graph_0.graph.nodes
]
self.assertEqual(expected_names_and_ops, real_names_and_ops)
# test set_grad_enabled subgraph
expected_names_and_ops = [
("getitem", "placeholder"),
("mul_1", "call_function"),
("output", "output"),
]
real_names_and_ops = [
(node.name, node.op) for node in ep.graph_module.submod_1.graph.nodes
]
self.assertEqual(expected_names_and_ops, real_names_and_ops)

# test collisions between user inputs & higher order op subgraphs
# (please never do this)
class Foo(torch.nn.Module):
def forward(self, input, true_graph, body_graph):
def map_body(x, y):
return x + y

x = map(map_body, input, body_graph[0])
x = x + true_graph[0] + true_graph[1]
x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x])
x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x])
return x

inputs = (
torch.randn(10, 4),
(torch.randn(4), torch.randn(4)),
(torch.randn(4),),
)
ep = export(Foo(), inputs)
expected_getattr_names = [
"body_graph_1",
"true_graph_2",
"false_graph_0",
"true_graph_3",
"false_graph_1",
]
real_getattr_names = [
node.name for node in ep.graph.nodes if node.op == "get_attr"
]
self.assertEqual(expected_getattr_names, real_getattr_names)


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):
Expand Down
8 changes: 5 additions & 3 deletions torch/_export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from torch._subclasses.fake_tensor import FakeTensor

from torch.export import ExportedProgram
from torch.export.exported_program import _rename_without_collisions
from torch.export.exported_program import (
_name_hoo_subgraph_placeholders,
_rename_without_collisions,
)
from torch.export.graph_signature import ConstantArgument, InputKind, OutputKind
from torch.utils._pytree import (
_register_pytree_node,
Expand Down Expand Up @@ -527,9 +530,8 @@ def _extract_pytree_key(x):
elif node.name in name_map:
node.name = name_map[node.name]

# TODO(pianpwk), in immediate follow-up PR
# propagate names to higher order op subgraphs
# name_hoo_subgraph_placeholders(gm)
_name_hoo_subgraph_placeholders(gm)

# re-generate graph module code
gm.recompile()
Expand Down
25 changes: 14 additions & 11 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,23 +832,26 @@ def _verify_placeholder_names(gm: torch.fx.GraphModule, sig: ExportGraphSignatur
"""
Performs a sanity check on the placeholder node names.
- User input nodes: no restrictions, should match the original forward() signature
- Params/buffers/constants/custom_obj nodes: should start with "p", "b", "c", "obj"
- Params/buffers/constants/custom_obj/token nodes: should start with prefixes defined in <placeholder_prefixes>
"""
name_to_kind = {
spec.arg.name: spec.kind
for spec in sig.input_specs
if not isinstance(spec.arg, ConstantArgument)
}
for node in gm.graph.nodes:
if node.op == "placeholder":
if node.name not in name_to_kind:
continue
node_kind = name_to_kind[node.name]
prefix = placeholder_prefixes[node_kind]
if not node.name.startswith(prefix):
raise SpecViolationError(
f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}"
)
for mod in gm.modules():
if not isinstance(mod, torch.fx.GraphModule):
continue
for node in mod.graph.nodes:
if node.op == "placeholder":
if node.name not in name_to_kind:
continue
node_kind = name_to_kind[node.name]
prefix = placeholder_prefixes[node_kind]
if not node.name.startswith(prefix):
raise SpecViolationError(
f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}"
)


def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]:
Expand Down
45 changes: 45 additions & 0 deletions torch/export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,48 @@ def _rename_without_collisions(
return name_map[orig_name]


def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
"""
Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs,
and handle collisions with non-placeholders by count suffixing.
Different HOO subgraph types have different input schemas, so we first enumerate them
and gather the top-level named placeholder nodes.
"""
# gather all HOO subgraphs and their top-level named placeholder nodes
subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = []
for node in gm.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, torch._ops.HigherOrderOperator
):
# HOO subgraphs have varying input schemas, so we enumerate them there
if node.target._name == "cond":
_, true_graph, false_graph, cond_args = node._args
subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args))
subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args))
elif node.target._name == "wrap_with_set_grad_enabled":
subgraph, phs = node._args[1], node._args[2:]
subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs))
elif node.target._name == "map_impl":
body_graph, array, args = node._args
subgraph_ph_tuples.append(
(getattr(gm, body_graph.target), array + args)
)

# propagate names
for subgraph, hoo_phs in subgraph_ph_tuples:
name_map: Dict[str, str] = {}
for i, node in enumerate(subgraph.graph.nodes):
if i < len(hoo_phs): # placeholder, retain name
name_map[node.name] = hoo_phs[i].name
node.name = node.target = hoo_phs[i].name
else: # non-placeholder, check for collisions
node.name = _rename_without_collisions(name_map, node.name, node.name)

# recurse and recompile
_name_hoo_subgraph_placeholders(subgraph)
subgraph.recompile()


class ExportedProgram:
"""
Package of a program from :func:`export`. It contains
Expand Down Expand Up @@ -543,6 +585,9 @@ def update_arg(old_arg, new_ph):
continue
node.name = _rename_without_collisions(name_map, node.name, node.name)

# propagate names to higher order op subgraphs
_name_hoo_subgraph_placeholders(gm)

# To match the output target with correct input for input mutations
# need to find the old to new placeholder map
old_new_placeholder_map = {
Expand Down

0 comments on commit d0ccf59

Please sign in to comment.