Skip to content

Commit

Permalink
[ts_converter] Basic support for prim::If conversion (pytorch#127336)
Browse files Browse the repository at this point in the history
Script module:
```
graph(%self : __torch__.M,
      %x.1 : Tensor,
      %y.1 : Tensor):
  %11 : int = prim::Constant[value=1]()
  %5 : bool = aten::Bool(%x.1) # /data/users/angelayi/pytorch2/test/export/test_converter.py:27:19
  %21 : Tensor = prim::If(%5) # /data/users/angelayi/pytorch2/test/export/test_converter.py:27:16
    block0():
      %8 : Tensor = aten::mul(%y.1, %y.1) # /data/users/angelayi/pytorch2/test/export/test_converter.py:28:27
      -> (%8)
    block1():
      %12 : Tensor = aten::add(%y.1, %y.1, %11) # /data/users/angelayi/pytorch2/test/export/test_converter.py:30:27
      -> (%12)
  return (%21)
```
ExportedProgram:
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x_1: "b8[]", y_1: "i64[]"):
            # File: <eval_with_key>.23:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, [l_args_3_0_]);  l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional = torch.ops.higher_order.cond(x_1, true_graph_0, false_graph_0, [y_1]);  x_1 = true_graph_0 = false_graph_0 = y_1 = None
            return (conditional,)

        class <lambda>(torch.nn.Module):
            def forward(self, y_1: "i64[]"):
                # File: <eval_with_key>.20:6 in forward, code: mul_tensor = torch.ops.aten.mul.Tensor(l_args_3_0__1, l_args_3_0__1);  l_args_3_0__1 = None
                mul: "i64[]" = torch.ops.aten.mul.Tensor(y_1, y_1);  y_1 = None
                return mul

        class <lambda>(torch.nn.Module):
            def forward(self, y_1: "i64[]"):
                # File: <eval_with_key>.21:6 in forward, code: add_tensor = torch.ops.aten.add.Tensor(l_args_3_0__1, l_args_3_0__1, alpha = 1);  l_args_3_0__1 = None
                add: "i64[]" = torch.ops.aten.add.Tensor(y_1, y_1);  y_1 = None
                return add
```

This PR also adds support for TupleIndex and incorporates some changes from pytorch#127341
Pull Request resolved: pytorch#127336
Approved by: https://github.com/BoyuanFeng
  • Loading branch information
angelayi authored and pytorchmergebot committed May 31, 2024
1 parent 3e66052 commit b2f5fd8
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 21 deletions.
46 changes: 45 additions & 1 deletion test/export/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

from torch._dynamo.test_case import TestCase
from torch._export.converter import TS2EPConverter
from torch.export import ExportedProgram

from torch.testing._internal.common_utils import run_tests

requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")


class TestConverter(TestCase):
def _check_equal_ts_ep_converter(self, mod, inp):
def _check_equal_ts_ep_converter(self, mod, inp) -> ExportedProgram:
ts_model = torch.jit.script(mod)
ep = TS2EPConverter(ts_model, inp).convert()
ep_out, _ = pytree.tree_flatten(ep.module()(*inp))
Expand All @@ -24,6 +25,7 @@ def _check_equal_ts_ep_converter(self, mod, inp):
for ep_t, orig_t in zip(ep_out, orig_out):
self.assertEqual(ep_t.shape, orig_t.shape)
self.assertTrue(torch.allclose(ep_t, orig_t))
return ep

def test_ts2ep_converter_basic(self):
class MSingle(torch.nn.Module):
Expand Down Expand Up @@ -108,6 +110,48 @@ def forward(self, x):
inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),)
self._check_equal_ts_ep_converter(Module(), inp)

def test_convert_if_basic(self):
class M(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
if x:
return y * y
else:
return y + y

inp = (torch.tensor(True), torch.tensor(4))
ep = self._check_equal_ts_ep_converter(M(), inp)

torch.testing.assert_close(
ep.module()(torch.tensor(False), torch.tensor(4)),
M()(torch.tensor(False), torch.tensor(4)),
)

def test_convert_if_multiple_out(self):
class M(torch.nn.Module):
def true_fn(self, y, z):
return (z * z, z + z)

def false_fn(self, y, z):
return (y * y * y, y + y)

def forward(self, x: torch.Tensor, y: torch.Tensor):
z = y * y

if x:
res = self.true_fn(y, z)
else:
res = self.false_fn(y, z)

return res[0] + res[1]

inp = (torch.tensor(True), torch.tensor(4))
ep = self._check_equal_ts_ep_converter(M(), inp)

torch.testing.assert_close(
ep.module()(torch.tensor(False), torch.tensor(4)),
M()(torch.tensor(False), torch.tensor(4)),
)


if __name__ == "__main__":
run_tests()
147 changes: 127 additions & 20 deletions torch/_export/converter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import operator
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import torch
Expand Down Expand Up @@ -54,22 +55,16 @@ def get_op_overload(node: torch._C.Node):
return op_overload


class TS2EPConverter:
# TorchScript model to ExportedProgram converter
class TS2FXGraphConverter:
def __init__(
self,
ts_model,
sample_args: Tuple[Any, ...],
sample_kwargs: Optional[Dict[str, Any]] = None,
ts_graph: Union[torch._C.Graph, torch._C.Block],
param_names: Set[str],
buffer_names: Set[str],
):
self.ts_model = ts_model
self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)

self.sample_args = sample_args
self.sample_kwargs = sample_kwargs

self.param_names: Set[str] = {name for name, _ in ts_model.named_parameters()}
self.buffer_names: Set[str] = {name for name, _ in ts_model.named_buffers()}
self.ts_graph = ts_graph
self.param_names = param_names
self.buffer_names = buffer_names

self.fx_graph: torch.fx.Graph = torch.fx.Graph()
self.input_specs: List[InputSpec] = []
Expand All @@ -82,6 +77,13 @@ def __init__(
self.attribute_map: Dict[str, Any] = {}
self.tensor_constants: Dict[str, torch.Tensor] = {}

self.subgraphs: Dict[str, torch.fx.GraphModule] = {}

def add_subgraph(self, subgraph) -> str:
name = f"subgraph_{len(self.subgraphs)}"
self.subgraphs[name] = subgraph
return name

def get_args_kwargs(self, node: torch._C.Node, schema):
args = []
kwargs = {}
Expand Down Expand Up @@ -110,22 +112,21 @@ def get_fx_value(self, value: torch._C.Value):
else:
raise ValueError(f"Input {value_name} not found")

def convert(self) -> ExportedProgram:
def convert(self) -> torch.fx.GraphModule:
self.convert_graph_inputs()

for node in self.ts_graph.nodes():
self.convert_node(node)

self.convert_graph_outputs()

gm = torch.fx.GraphModule({}, self.fx_graph)
gm = torch.fx.GraphModule(self.subgraphs, self.fx_graph)

inplace_optimize_sym_size_div(gm)

gm.graph.lint()

ep = self.retrace_as_exported_program(gm)
return ep
return gm

def convert_graph_inputs(self):
for graph_input in self.ts_graph.inputs():
Expand Down Expand Up @@ -234,7 +235,10 @@ def get_attr(name: str):
)

def convert_aten_op(self, node: torch._C.Node):
target = get_op_overload(node)
try:
target = get_op_overload(node)
except Exception as e:
raise RuntimeError(f"Unsupported node {node.kind()}") from e

if target is torch.ops.aten.size.int:
target = torch.ops.aten.sym_size.int
Expand Down Expand Up @@ -280,6 +284,13 @@ def convert_prim_DictConstruct(self, node: torch._C.Node):
output_name = node.output().debugName()
self.name_to_node[output_name] = output_dict

def convert_prim_TupleIndex(self, node: torch._C.Node):
args = tuple(self.get_fx_value(input) for input in node.inputs())
getitem_node = self.fx_graph.call_function(operator.getitem, args)

output_name = node.output().debugName()
self.name_to_node[output_name] = getitem_node

def convert_aten_Int(self, node: torch._C.Node):
# converts aten::Int as aten._to_copy + aten::_local_scalar_dense
target = torch.ops.aten._to_copy.default
Expand Down Expand Up @@ -352,6 +363,70 @@ def convert_aten_div(self, node: torch._C.Node):

self.convert_aten_op(node)

def convert_prim_if(self, node: torch._C.Node):
inputs = list(node.inputs())
assert len(inputs) == 1
predicate = self.get_fx_value(inputs[0])

# Get union of inputs to blocks
arguments = set()
for block in node.blocks():
block_args = set()

# TODO: block.inputs(), not sure what theyre used for

for block_node in block.nodes():
for block_node_in in block_node.inputs():
if block_node_in.debugName() in self.name_to_node:
block_args.add(block_node_in.debugName())

arguments.update(block_args)

arguments = list(arguments)

# Convert blocks to subgraphs
subgraph_nodes = []
for block in node.blocks():
subgraph_converter = TS2FXGraphConverter(block, set(), set())
subgraph_converter.constant_map = self.constant_map

for block_arg in arguments:
normalized_block_arg_name = normalize_name(block_arg)
placeholder_node = subgraph_converter.fx_graph.placeholder(
normalized_block_arg_name
)
subgraph_converter.name_to_node[block_arg] = placeholder_node

subgraph = subgraph_converter.convert()
subgraph_name = self.add_subgraph(subgraph)
subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))

assert len(subgraph_nodes) == 2

fx_block_args = [self.name_to_node[arg_name] for arg_name in arguments]
args = (
predicate,
subgraph_nodes[0],
subgraph_nodes[1],
tuple(fx_block_args),
)

cond_node = self.fx_graph.call_function(torch.cond, args, {})

output_name = node.output().debugName()
self.name_to_node[output_name] = cond_node

def convert_as_noop(self, node: torch._C.Node):
# Converts the node as a no-op by mapping its output node as arg[0]

target = get_op_overload(node)
schema = target._schema

args, kwargs = self.get_args_kwargs(node, schema)

output_name = node.output().debugName()
self.name_to_node[output_name] = args[0]

def convert_node(self, node: torch._C.Node):
node_kind = node.kind()
if node_kind == "prim::CreateObject":
Expand All @@ -371,12 +446,18 @@ def convert_node(self, node: torch._C.Node):
self.convert_prim_dtype(node)
elif node_kind == "prim::DictConstruct":
self.convert_prim_DictConstruct(node)
elif node_kind == "prim::TupleIndex":
self.convert_prim_TupleIndex(node)
# elif node_kind == "aten::Int":
# convert_aten_Int(node)
elif node_kind == "aten::_convolution":
self.convert_aten__convolution(node)
elif node_kind == "aten::div":
self.convert_aten_div(node)
elif node_kind == "prim::If":
self.convert_prim_if(node)
elif node_kind == "aten::Bool":
self.convert_as_noop(node)
elif node_kind.startswith("aten::"):
self.convert_aten_op(node)
else:
Expand Down Expand Up @@ -413,9 +494,35 @@ def convert_graph_outputs(self):
args[0]
) # Get rid of an extra list wrapped around final output.

def retrace_as_exported_program(self, gm: torch.fx.GraphModule):

class TS2EPConverter:
# TorchScript model to ExportedProgram converter
def __init__(
self,
ts_model,
sample_args: Tuple[Any, ...],
sample_kwargs: Optional[Dict[str, Any]] = None,
):
self.ts_model = ts_model
self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)

self.sample_args = sample_args
self.sample_kwargs = sample_kwargs

self.param_names: Set[str] = {name for name, _ in ts_model.named_parameters()}
self.buffer_names: Set[str] = {name for name, _ in ts_model.named_buffers()}

def convert(self) -> ExportedProgram:
graph_converter = TS2FXGraphConverter(
self.ts_graph, self.param_names, self.buffer_names
)
gm = graph_converter.convert()
ep = self.retrace_as_exported_program(gm, graph_converter.tensor_constants)
return ep

def retrace_as_exported_program(self, gm: torch.fx.GraphModule, tensor_constants):
# TODO: adjust input orders to match GraphSignature convention
inputs = [*self.sample_args, *self.params, *self.tensor_constants.values()]
inputs = [*self.sample_args, *self.params, *tensor_constants.values()]

ep = torch.export._trace._export(
gm,
Expand Down

0 comments on commit b2f5fd8

Please sign in to comment.