Skip to content

Commit

Permalink
[hierarchical-compilation][inductor] Support invoke_subgraph HOP (pyt…
Browse files Browse the repository at this point in the history
…orch#138031)

Pull Request resolved: pytorch#138031
Approved by: https://github.com/eellison
ghstack dependencies: pytorch#137538, pytorch#138036, pytorch#137965
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Oct 23, 2024
1 parent 7622ede commit dd4dd85
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/higher_order_ops/test_invoke_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def fn(x, y):

x_clone = x.clone().detach().requires_grad_(True)
y_clone = y.clone().detach().requires_grad_(True)
res = torch.compile(fn, backend="eager", fullgraph=True)(x_clone, y_clone)
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone)

# Run backward
ref.sum().backward()
Expand Down Expand Up @@ -254,7 +254,7 @@ def fn(x, y):

x_clone = x.clone().detach().requires_grad_(True)
y_clone = y.clone().detach().requires_grad_(True)
res = torch.compile(fn, backend="eager", fullgraph=True)(x_clone, y_clone)
res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone)

# Run backward
ref.sum().backward()
Expand Down
5 changes: 5 additions & 0 deletions torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,11 @@ def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
self.writeline(f"{outer_output}.reset();")
self.writeline(f"{outer_output} = {src}{self.ending}")

def codegen_invoke_subgraph(self, invoke_subgraph):
raise NotImplementedError(
"codegen invoke_subgraph is not implemented for cpp wrapper"
)

def codegen_conditional(self, conditional):
name = conditional.get_name()
outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands]
Expand Down
8 changes: 8 additions & 0 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2125,6 +2125,14 @@ def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):

self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs)

def codegen_invoke_subgraph(self, invoke_subgraph):
name = invoke_subgraph.get_name()

self.writeline(f"{name} = [None] * {len(invoke_subgraph.outputs)}")
outer_inputs = [buf.codegen_reference() for buf in invoke_subgraph.inputs]
outer_outputs = [f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs))]
self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, outer_outputs)

def codegen_conditional(self, conditional):
name = conditional.get_name()

Expand Down
79 changes: 79 additions & 0 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -6601,6 +6601,85 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers)


@ir_dataclass(frozen=False)
class InvokeSubgraph(ExternKernel):
subgraph: Optional[Subgraph] = None
operands: Optional[List[TensorBox]] = None
outputs: Optional[List[MultiOutput]] = None

def __init__(
self, subgraph: Subgraph, operands: List[TensorBox], layout: MultiOutputLayout
):
super().__init__(
name=None,
layout=layout,
inputs=operands,
)
self.subgraph = subgraph
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)

@classmethod
def create(cls, subgraph: Subgraph, operands):
# TODO(anijain2305) - Support sym expr as operands in future.
fx_operands = V.graph.current_node.args[-1]
fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr]

# Realize the inputs. Also intermediates can have different strides than
# the inputs of the subgraph. So, force the intermediates to have same
# strides as that of subgraph inputs.
operands = [cls.realize_input(x) for x in operands]

def handle_sym_expr(stride):
return [s.node.expr if isinstance(s, torch.SymInt) else s for s in stride]

fake_strides = [fake_operand.stride() for fake_operand in fake_operands]
fake_strides = [handle_sym_expr(stride) for stride in fake_strides]
operands = [
cls.require_exact_strides(x, fake_strides[idx])
for idx, x in enumerate(operands)
]

if subgraph.graph is None:
# create and lower subgraphs
subgraph.graph = V.graph.make_subgraph(
gm=subgraph.graph_module,
example_inputs=fake_operands,
subgraph_name=subgraph.name,
)
with V.set_graph_handler(subgraph.graph):
subgraph.graph.run(*fake_operands)

outputs = subgraph.graph.graph_outputs # type: ignore[union-attr]
device = operands[0].get_device()
invoke_subgraph = InvokeSubgraph(
subgraph=subgraph,
operands=operands,
layout=MultiOutputLayout(device=device),
)

outputs = [
MultiOutput(
FixedLayout(
device=output.get_device(),
dtype=output.get_dtype(),
size=output.get_size(),
stride=output.get_stride(),
offset=output.get_layout().offset,
),
invoke_subgraph,
[(list, i)],
)
for i, output in enumerate(outputs)
]

invoke_subgraph.outputs = outputs
return outputs

def codegen(self, wrapper):
wrapper.codegen_invoke_subgraph(self)


@ir_dataclass(frozen=False)
class Conditional(ExternKernel):
predicate: Optional[IRNode] = None
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6397,6 +6397,12 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
return list(map(TensorBox.create, result))


@register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None)
def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, operands):
result = ir.InvokeSubgraph.create(subgraph_fn, operands)
return list(map(TensorBox.create, result))


@register_lowering(associative_scan_op, type_promotion_kind=None)
def associative_scan(combine_fn: ir.Subgraph, xs, dim: int):
from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph
Expand Down

0 comments on commit dd4dd85

Please sign in to comment.