diff --git a/test/export/test_converter.py b/test/export/test_converter.py index b6d0e54a59e173..64cea8cf8ac9e5 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1,5 +1,7 @@ # Owner(s): ["oncall: export"] +import unittest + import torch import torch.utils._pytree as pytree @@ -9,6 +11,8 @@ from torch.testing._internal.common_utils import run_tests +requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") + class TestConverter(TestCase): def _check_equal_ts_ep_converter(self, mod, inp): @@ -64,6 +68,46 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): self._check_equal_ts_ep_converter(MOutputTuple(), inp) self._check_equal_ts_ep_converter(MOutputDict(), inp) + def test_prim_device(self): + class Module(torch.nn.Module): + def forward(self, x): + device = x.device + return torch.ones(2, 3, device=device) + + inp = (torch.rand(3, 4),) + self._check_equal_ts_ep_converter(Module(), inp) + + @requires_cuda + def test_prim_device_cuda(self): + class Module(torch.nn.Module): + def forward(self, x): + device = x.device + return torch.ones(2, 3, device=device) + + inp = (torch.rand((3, 4), device="cuda:0"),) + self._check_equal_ts_ep_converter(Module(), inp) + + def test_prim_dtype(self): + class Module(torch.nn.Module): + def forward(self, x): + dtype = x.dtype + return torch.ones(2, 3, dtype=dtype) + + for dtype in [ + torch.float32, + torch.double, + ]: + inp = (torch.rand((3, 4), dtype=dtype),) + self._check_equal_ts_ep_converter(Module(), inp) + + for dtype in [ + torch.uint8, + torch.int8, + torch.int32, + ]: + inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),) + self._check_equal_ts_ep_converter(Module(), inp) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 459f534ca63621..7e6812985badb2 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -5,6 +5,7 @@ from torch.export.exported_program import ExportedProgram from torch.export.graph_signature import ( + ConstantArgument, InputKind, InputSpec, OutputKind, @@ -201,6 +202,20 @@ def convert_prim_Constant(self, node: torch._C.Node): self.constant_map[name] = value + def convert_prim_device(self, node: torch._C.Node): + input_type = node.input().type() + if input_type.isSubtypeOf(torch._C.TensorType.get()): + device = input_type.device() # type: ignore[attr-defined] + output_name = node.output().debugName() + self.constant_map[output_name] = device + else: + raise ValueError(f"Unsupported JitType ({input_type}) when get device") + + def convert_prim_dtype(self, node: torch._C.Node): + dtype = node.input().type().dtype() + output_name = node.output().debugName() + self.constant_map[output_name] = dtype + def convert_prim_GetAttr(self, node: torch._C.Node): def get_attr(name: str): if name in self.attribute_map: @@ -350,6 +365,10 @@ def convert_node(self, node: torch._C.Node): elif node_kind in {"prim::ListConstruct", "prim::TupleConstruct"}: # Tuple is just a non-mutable List, so we can handle them together. self.convert_prim_ListConstruct(node) + elif node_kind == "prim::device": + self.convert_prim_device(node) + elif node_kind == "prim::dtype": + self.convert_prim_dtype(node) elif node_kind == "prim::DictConstruct": self.convert_prim_DictConstruct(node) # elif node_kind == "aten::Int": @@ -369,17 +388,27 @@ def convert_graph_outputs(self): output_name = graph_output.debugName() if output_name in self.name_to_node: args.append(self.name_to_node[output_name]) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=TensorArgument(name=output_name), + target=output_name, + ) + ) + elif output_name in self.constant_map: + args.append(self.constant_map[output_name]) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=ConstantArgument( + name=output_name, value=self.constant_map[output_name] + ), + target=output_name, + ) + ) else: raise ValueError(f"Output {output_name} not found") - self.output_specs.append( - OutputSpec( - OutputKind.USER_OUTPUT, - arg=TensorArgument(name=output_name), - target=output_name, - ) - ) - self.fx_graph.output( args[0] ) # Get rid of an extra list wrapped around final output.