diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index d5f6837cba01e..a7eb81341eb54 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1130,6 +1130,18 @@ def replace_special_case(hint: str) -> str: ) ) ], + "xpu": [ + "def xpu({}) -> Tensor: ...".format( + ", ".join( + [ + "self", + "device: Optional[Union[_device, _int, str]] = None", + "non_blocking: _bool = False", + "memory_format: torch.memory_format = torch.preserve_format", + ] + ) + ) + ], "cpu": [ "def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Tensor: ..." ],