Skip to content

Commit

Permalink
[torchscript] Handle prim::device and prim::dtype (pytorch#127466)
Browse files Browse the repository at this point in the history
- Support prim::device and prim::dtype during torchscript migration to export
- Add unit tests
Pull Request resolved: pytorch#127466
Approved by: https://github.com/SherlockNoMad
  • Loading branch information
BoyuanFeng authored and pytorchmergebot committed May 30, 2024
1 parent fa426b0 commit 4afc5c7
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 8 deletions.
44 changes: 44 additions & 0 deletions test/export/test_converter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Owner(s): ["oncall: export"]

import unittest

import torch

import torch.utils._pytree as pytree
Expand All @@ -9,6 +11,8 @@

from torch.testing._internal.common_utils import run_tests

requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")


class TestConverter(TestCase):
def _check_equal_ts_ep_converter(self, mod, inp):
Expand Down Expand Up @@ -64,6 +68,46 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
self._check_equal_ts_ep_converter(MOutputTuple(), inp)
self._check_equal_ts_ep_converter(MOutputDict(), inp)

def test_prim_device(self):
class Module(torch.nn.Module):
def forward(self, x):
device = x.device
return torch.ones(2, 3, device=device)

inp = (torch.rand(3, 4),)
self._check_equal_ts_ep_converter(Module(), inp)

@requires_cuda
def test_prim_device_cuda(self):
class Module(torch.nn.Module):
def forward(self, x):
device = x.device
return torch.ones(2, 3, device=device)

inp = (torch.rand((3, 4), device="cuda:0"),)
self._check_equal_ts_ep_converter(Module(), inp)

def test_prim_dtype(self):
class Module(torch.nn.Module):
def forward(self, x):
dtype = x.dtype
return torch.ones(2, 3, dtype=dtype)

for dtype in [
torch.float32,
torch.double,
]:
inp = (torch.rand((3, 4), dtype=dtype),)
self._check_equal_ts_ep_converter(Module(), inp)

for dtype in [
torch.uint8,
torch.int8,
torch.int32,
]:
inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),)
self._check_equal_ts_ep_converter(Module(), inp)


if __name__ == "__main__":
run_tests()
45 changes: 37 additions & 8 deletions torch/_export/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from torch.export.exported_program import ExportedProgram
from torch.export.graph_signature import (
ConstantArgument,
InputKind,
InputSpec,
OutputKind,
Expand Down Expand Up @@ -201,6 +202,20 @@ def convert_prim_Constant(self, node: torch._C.Node):

self.constant_map[name] = value

def convert_prim_device(self, node: torch._C.Node):
input_type = node.input().type()
if input_type.isSubtypeOf(torch._C.TensorType.get()):
device = input_type.device() # type: ignore[attr-defined]
output_name = node.output().debugName()
self.constant_map[output_name] = device
else:
raise ValueError(f"Unsupported JitType ({input_type}) when get device")

def convert_prim_dtype(self, node: torch._C.Node):
dtype = node.input().type().dtype()
output_name = node.output().debugName()
self.constant_map[output_name] = dtype

def convert_prim_GetAttr(self, node: torch._C.Node):
def get_attr(name: str):
if name in self.attribute_map:
Expand Down Expand Up @@ -350,6 +365,10 @@ def convert_node(self, node: torch._C.Node):
elif node_kind in {"prim::ListConstruct", "prim::TupleConstruct"}:
# Tuple is just a non-mutable List, so we can handle them together.
self.convert_prim_ListConstruct(node)
elif node_kind == "prim::device":
self.convert_prim_device(node)
elif node_kind == "prim::dtype":
self.convert_prim_dtype(node)
elif node_kind == "prim::DictConstruct":
self.convert_prim_DictConstruct(node)
# elif node_kind == "aten::Int":
Expand All @@ -369,17 +388,27 @@ def convert_graph_outputs(self):
output_name = graph_output.debugName()
if output_name in self.name_to_node:
args.append(self.name_to_node[output_name])
self.output_specs.append(
OutputSpec(
OutputKind.USER_OUTPUT,
arg=TensorArgument(name=output_name),
target=output_name,
)
)
elif output_name in self.constant_map:
args.append(self.constant_map[output_name])
self.output_specs.append(
OutputSpec(
OutputKind.USER_OUTPUT,
arg=ConstantArgument(
name=output_name, value=self.constant_map[output_name]
),
target=output_name,
)
)
else:
raise ValueError(f"Output {output_name} not found")

self.output_specs.append(
OutputSpec(
OutputKind.USER_OUTPUT,
arg=TensorArgument(name=output_name),
target=output_name,
)
)

self.fx_graph.output(
args[0]
) # Get rid of an extra list wrapped around final output.
Expand Down

0 comments on commit 4afc5c7

Please sign in to comment.