From 97770e96ecbd8ff61b7749c085e4f65e03ac26c9 Mon Sep 17 00:00:00 2001 From: "Aaron Orenstein (Meta Employee)" Date: Tue, 16 Jul 2024 23:52:15 -0700 Subject: [PATCH] typing: convert_frame (#130670) Summary: X-link: https://github.com/pytorch/pytorch/pull/130670 Approved by: https://github.com/Skylion007 ghstack dependencies: #130669 Reviewed By: atalman Differential Revision: D59842185 Pulled By: aorenste fbshipit-source-id: fd76404791ed6cf3ebc9a2adffc3857dc892b3ad --- .../dynamo/dynamobench/_dynamo/utils.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index d7a40fa2b..2fec4c5cf 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -37,19 +37,25 @@ DefaultDict, Deque, Dict, + Iterable, Iterator, KeysView, List, Optional, + overload, Set, Tuple, Type, + TypeVar, Union, ValuesView, ) +from typing_extensions import TypeGuard from ..utils.hooks import RemovableHandle +T = TypeVar("T") + try: import numpy as np except ModuleNotFoundError: @@ -498,6 +504,23 @@ def clear(self): self.values.clear() +@overload +def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]: + ... + + +@overload +def istype( + obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]] +) -> TypeGuard[T]: + ... + + +@overload +def istype(obj: object, allowed_types: Iterable[type]) -> bool: + ... + + def istype(obj, allowed_types): """isinstance() without subclasses""" if isinstance(allowed_types, (tuple, list, set)): @@ -631,7 +654,7 @@ def is_numpy_ndarray(value): def istensor(obj): """Check of obj is a tensor""" - tensor_list = ( + tensor_list: Tuple[type, ...] = ( torch.Tensor, torch.nn.Parameter, *config.traceable_tensor_subclasses, @@ -1061,7 +1084,7 @@ def rot_n_helper(n): return fn -common_constant_types = { +common_constant_types: Set[type] = { int, float, complex,