Skip to content

Commit

Permalink
[export] optimize unflattener (pytorch#115364)
Browse files Browse the repository at this point in the history
Unflattening was slow on the APS FM model (which has thousands of nn.EmbeddingBag modules).

Quick glance at the profile shows 75% of time in unflattening was spent copying this node list, which is immutable and globally shared. So just passing it around as a tuple yields a 4x speedup lol.

Differential Revision: [D51929775](https://our.internmc.facebook.com/intern/diff/D51929775/)
Pull Request resolved: pytorch#115364
Approved by: https://github.com/zhxchen17
  • Loading branch information
suo authored and pytorchmergebot committed Dec 8, 2023
1 parent 494cb28 commit 3d999d2
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions torch/_export/unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ class ModuleFrame:
def __init__(
self,
flat_graph,
nodes,
seen_nodes,
seen_modules,
parent,
Expand All @@ -252,6 +253,7 @@ def __init__(
graph_module=None,
):
self.flat_graph = flat_graph
self.nodes = nodes
self.seen_nodes = seen_nodes
self.seen_modules = seen_modules
self.parent = parent
Expand Down Expand Up @@ -286,7 +288,6 @@ def __init__(
self.cached_graph_module = None
self.seen_modules[self.module_id] = self.graph_module

self.nodes = list(self.flat_graph.nodes)
self.graph = self.graph_module.graph

# Mapping of nodes in the flat graph to nodes in this graph.
Expand Down Expand Up @@ -341,19 +342,19 @@ def __init__(
self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node

with self.parent.graph.inserting_before(self.parent_call_module):
nodes: List[Optional[torch.fx.Node]] = []
input_nodes: List[Optional[torch.fx.Node]] = []
for input in signature.inputs:
if isinstance(input, ConstantArgument) and input.value is None:
nodes.append(None)
input_nodes.append(None)
else:
assert isinstance(input, (TensorArgument, SymIntArgument))
nodes.append(
input_nodes.append(
self.parent.remap_input(self.seen_nodes[input.name])
)

inputs_node = _generate_unflatten(
self.parent.graph_module,
nodes,
input_nodes,
signature.in_spec,
)

Expand Down Expand Up @@ -540,6 +541,7 @@ def run_from(self, node_idx):
# counter. Once it is complete, continue from that point.
node_idx = ModuleFrame(
self.flat_graph,
self.nodes,
self.seen_nodes,
self.seen_modules,
self,
Expand All @@ -562,6 +564,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: torch.fx.GraphM
seen_modules: Dict[int, torch.nn.Module] = {}
ModuleFrame(
orig_graph,
tuple(orig_graph.nodes),
seen_nodes,
seen_modules,
None,
Expand Down

0 comments on commit 3d999d2

Please sign in to comment.