diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index 323ac9412a9fda..e5815ad266b3e6 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -77,7 +77,7 @@ def _wrapped_bw_compiler(*args, **kwargs): raise -def aot_autograd(**kwargs): +def aot_autograd(**kwargs) -> AotAutograd: return AotAutograd(**kwargs) diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index fa0afa8d3b7956..5d3995ee41666f 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs + import argparse import copy import functools @@ -12,7 +13,8 @@ import uuid from importlib import import_module from tempfile import TemporaryFile -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Sequence, TYPE_CHECKING, Union +from typing_extensions import Unpack import torch import torch.fx as fx @@ -45,6 +47,12 @@ from .. import config +if TYPE_CHECKING: + from torch._inductor.codecache import CompiledFxGraph + from torch._inductor.compile_fx import _CompileFxCallableEx, _CompileFxKwargsEx + from torch._inductor.utils import InputType + + log = logging.getLogger(__name__) @@ -56,7 +64,10 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str): +def wrap_compiler_debug( + unconfigured_compiler_fn: "_CompileFxCallableEx", + compiler_name: str, +) -> "_CompileFxCallableEx": """ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both forward and backward call separately with the backend compiler_fn - like @@ -66,7 +77,11 @@ def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str): """ @functools.wraps(unconfigured_compiler_fn) - def debug_wrapper(gm, example_inputs, **kwargs): + def debug_wrapper( + gm: torch.fx.GraphModule, + example_inputs: Sequence["InputType"], + **kwargs: Unpack["_CompileFxKwargsEx"], + ) -> Union["CompiledFxGraph", str]: from torch._subclasses import FakeTensorMode compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) @@ -104,11 +119,15 @@ def debug_wrapper(gm, example_inputs, **kwargs): # We may run regular PyTorch compute that may trigger Dynamo, do NOT # recursively attempt to accuracy minify in that case! - def deferred_for_real_inputs(real_inputs): + def deferred_for_real_inputs( + real_inputs: Sequence["InputType"], **_kwargs: object + ) -> Any: # This is a bit obscure: if we recursively try to accuracy minify # the SAME function, this would trigger. But most of the time # we should never hit this branch + assert not _kwargs if config.repro_after != "aot": + assert not isinstance(inner_compiled_fn, str) return inner_compiled_fn(real_inputs) with config.patch(repro_after=None): return inner_debug_fn(real_inputs) @@ -165,11 +184,11 @@ def inner_debug_fn(real_inputs): raise AccuracyError("Bad accuracy detected") else: # Call the compiled function with real inputs - return inner_compiled_fn(real_inputs) + return inner_compiled_fn(real_inputs) # type: ignore[operator] else: try: # Call the compiled function with real inputs - out = inner_compiled_fn(real_inputs) + out = inner_compiled_fn(real_inputs) # type: ignore[operator] # sync cuda kernels to ensure IMA detection for arg in example_inputs: if isinstance(arg, torch.Tensor) and arg.is_cuda: @@ -194,7 +213,7 @@ def inner_debug_fn(real_inputs): if config.repro_after == "aot": compiled_fn = deferred_for_real_inputs compiled_fn._boxed_call = True # type: ignore[attr-defined] - return compiled_fn + return compiled_fn # type: ignore[return-value] else: return inner_compiled_fn @@ -432,6 +451,7 @@ def sync(): try: compile_mod = compile_fx_inner(fx_g, args) + assert not isinstance(compile_mod, str) compile_mod(args) sync() except Exception as e: @@ -601,6 +621,7 @@ def save_hook(name, val): with intermediate_hook(save_hook), tqdm( desc="Saving inductor intermediates", total=total ) as pbar: + assert not isinstance(compiled, str) compiled(new_args) assert not new_args @@ -717,6 +738,7 @@ def repro_run(options, mod, load_args): from torch.cuda import synchronize compiled = compile_fx_inner(mod, args) + assert not isinstance(compiled, str) if options.accuracy != "": # We don't really respect --accuracy vs --strict-accuracy here, it diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index bc550cbca10edf..7e97d854835dfc 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -516,7 +516,7 @@ def count_calls(g: fx.Graph) -> int: return c -def identity(x): +def identity(x: T) -> T: return x diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 6dc73fb11911e2..833958c78cb9ed 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -47,6 +47,7 @@ if TYPE_CHECKING: + from torch._inductor.compile_fx import _CompileFxKwargs from torch._inductor.remote_cache import JsonDataTy, RemoteCache from torch._inductor.utils import BoxedBool from torch.fx.node import Node @@ -205,7 +206,7 @@ def __init__( gm: torch.fx.GraphModule, example_inputs, aot_config: AOTConfig, - fx_config: Dict[str, BoxedBool], + fx_config: _CompileFxKwargs, ): # FxGraphHashDetails contains all the keys related to inductor. Also includes some system info self.aot_config = aot_config @@ -269,7 +270,7 @@ def autograd_cache_key( gm: torch.fx.GraphModule, example_inputs, config: AOTConfig, - fx_config: Dict[str, BoxedBool], + fx_config: _CompileFxKwargs, # TODO: add args and parameters ) -> Tuple[str, List[str]]: """ @@ -295,7 +296,7 @@ class FXGraphCacheLoadable: def is_backward(self): return False - def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGraph: + def load(self, example_inputs, fx_config: _CompileFxKwargs) -> CompiledFxGraph: # [Note: AOTAutogradCache and FXGraphCache Guard interactions] # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. @@ -332,7 +333,7 @@ def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGra payload_fn=lambda: json.dumps(cache_info), ) - FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"]) + FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"]) # type: ignore[arg-type] result._boxed_call = True return result @@ -399,7 +400,7 @@ def wrap_post_compile( self, args: List[torch.Tensor], aot_config: AOTConfig, - fx_config: Dict[str, BoxedBool], + fx_config: _CompileFxKwargs, ) -> Callable: """ This function takes a cache entry and carefully reconstructs the original callable @@ -566,7 +567,7 @@ def load( debug_lines: List[str] = [] cache_event_time = time.time_ns() cache_state = None - fx_config = {"cudagraphs": cudagraphs} + fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs} try: cache_key, debug_lines = autograd_cache_key(gm, args, aot_config, fx_config) entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup( diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index a4beb7dbaaa80b..954ce6cde88d0e 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -1,16 +1,21 @@ # mypy: allow-untyped-defs -from typing import Any, Dict, List, Optional, Tuple + +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING import torch.fx import torch.utils._pytree as pytree +if TYPE_CHECKING: + from torch._inductor.utils import InputType + + __all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"] def compile( gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], + example_inputs: List["InputType"], options: Optional[Dict[str, Any]] = None, ): """ diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 147066655d7e0e..075fbfb090b1a0 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -80,7 +80,9 @@ if TYPE_CHECKING: from collections.abc import KeysView + from .compile_fx import _CompileFxKwargs from .remote_cache import JsonDataTy, RemoteCache + from .utils import InputType """ @@ -751,23 +753,25 @@ class FxGraphHashDetails: def __init__( self, gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - fx_kwargs: Dict[str, Any], + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, inputs_to_check: Sequence[int], ) -> None: self.gm = gm self.example_inputs = example_inputs - # Order kwargs so hashing is stable to changes in kwarg order. - self.fx_kwargs = {} - for k in sorted(fx_kwargs): + # Order kwargs so hashing is stable to changes in kwarg order. Although + # it's technically a _CompileFxKwargs we don't actually need it typed as + # such since we're just using it to generate a hash. + self.fx_kwargs: Dict[str, object] = {} + for k, v in sorted(fx_kwargs.items()): if k not in self.EXCLUDED_KWARGS: - if type(fx_kwargs[k]) is set: + if type(v) is set: # Special case to handle set params. Python sets can't be # ordered, so sort the elements and store them in a proxy. - self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k])) + self.fx_kwargs[k] = OrderedSetHolder(sorted(v)) else: - self.fx_kwargs[k] = fx_kwargs[k] + self.fx_kwargs[k] = v # Alignment checks self.inputs_to_check = inputs_to_check @@ -818,8 +822,8 @@ def debug_lines(self) -> List[str]: def compiled_fx_graph_hash( gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - fx_kwargs: Dict[str, Any], + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, inputs_to_check: Sequence[int], ) -> Tuple[str, List[str]]: """ @@ -836,7 +840,7 @@ def compiled_fx_graph_hash( def cudagraph_post_compile( - example_inputs: List[Any], + example_inputs: Sequence[InputType], compiled_graph: CompiledFxGraph, cudagraphs: BoxedBool, ) -> None: @@ -879,7 +883,7 @@ def cudagraph_post_compile( assert current_callable is not None compiled_graph.current_callable = cudagraphify( current_callable, - static_input_idxs=static_input_idxs, + static_input_idxs=static_input_idxs or (), device_index=next(iter(compiled_graph.device_idxs)), stack_traces=stack_traces, is_backward=is_backward, @@ -1018,7 +1022,7 @@ def _get_tmp_dir_for_key(key: str) -> str: return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) @staticmethod - def _filter_backed_symints(inputs: List[Any]) -> List[torch.SymInt]: + def _filter_backed_symints(inputs: Sequence[InputType]) -> List[torch.SymInt]: """ Get the backed SymInt objects from the input list. Note that we can never have guards that depend on unbacked symint. @@ -1038,7 +1042,7 @@ def _get_shape_env() -> Optional[ShapeEnv]: @staticmethod def _lookup_graph( key: str, - example_inputs: List[torch.Tensor], + example_inputs: Sequence[InputType], local: bool, remote_cache: Optional[RemoteCache[JsonDataTy]], ) -> Optional[CompiledFxGraph]: @@ -1177,7 +1181,7 @@ def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]: @staticmethod def post_compile( compiled_graph: CompiledFxGraph, - example_inputs: List[torch.Tensor], + example_inputs: Sequence[InputType], cudagraphs: BoxedBool, ) -> CompiledFxGraph: """ @@ -1224,7 +1228,7 @@ def post_compile( def _save_graph( key: str, compiled_graph: CompiledFxGraph, - example_inputs: List[torch.Tensor], + example_inputs: Sequence[InputType], local: bool, remote_cache: Optional[RemoteCache[JsonDataTy]], ) -> None: @@ -1333,8 +1337,8 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None: @staticmethod def prepare_key( gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - fx_kwargs: Dict[str, Any], + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, inputs_to_check: Sequence[int], remote: bool, ) -> Tuple[Optional[Tuple[str, List[str]]], Dict[str, Any]]: @@ -1384,7 +1388,7 @@ def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: def load_with_key( key: str, debug_lines: List[str], - example_inputs: List[torch.Tensor], + example_inputs: Sequence[InputType], local: bool, remote_cache: Optional[RemoteCache[JsonDataTy]], is_backward: bool, @@ -1427,8 +1431,8 @@ def load_with_key( def load( # type: ignore[no-untyped-def] compile_fx_fn: Callable[..., Any], gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - fx_kwargs: Dict[str, Any], + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, inputs_to_check: Sequence[int], local: bool, remote: bool, @@ -1514,7 +1518,7 @@ def load( # type: ignore[no-untyped-def] ) # Use the passed in cudagraphs so that we mutate the BoxedBool correctly FxGraphCache.post_compile( - compiled_graph, example_inputs, fx_kwargs["cudagraphs"] + compiled_graph, example_inputs, fx_kwargs["cudagraphs"] # type: ignore[arg-type] ) return compiled_graph @@ -1561,7 +1565,7 @@ class CompiledFxGraph: guards_expr: Optional[str] cudagraph_info: Optional[CudagraphCachedInfo] - fx_kwargs: Dict[str, Any] + fx_kwargs: _CompileFxKwargs inputs_to_check: Sequence[int] boxed_forward_device_index: Optional[BoxedDeviceIndex] @@ -1601,7 +1605,7 @@ def __init__( self.inputs_to_check = () self.boxed_forward_device_index = None - def __call__(self, inputs: List[Any]) -> Any: + def __call__(self, inputs: Sequence[Any]) -> Any: assert self.current_callable is not None try: return self.current_callable(inputs) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 26fe3abb80393a..1320cbc30a0546 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1,5 +1,5 @@ -# mypy: allow-untyped-decorators -# mypy: allow-untyped-defs +from __future__ import annotations + import contextlib import functools import io @@ -15,12 +15,16 @@ Callable, ContextManager, Dict, + Generator, List, Optional, Sequence, Tuple, + TYPE_CHECKING, + TypeVar, Union, ) +from typing_extensions import Never, ParamSpec, Protocol, TypedDict, Unpack from unittest import mock import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools @@ -72,15 +76,15 @@ tensor_is_aligned, ) from torch._logging import trace_structured -from torch._ops import OpOverload from torch._utils_internal import compile_time_strobelight_meta +from torch.fx import GraphModule from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.monitor import _WaitCounter from torch.utils._ordered_set import OrderedSet from .._dynamo.backends.common import aot_autograd -from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined] +from ..fx._lazy_graph_module import _use_lazy_graph_module from ..fx.graph import _PyTreeCodeGen from . import config, metrics from .debug import DebugContext @@ -89,7 +93,6 @@ from .fx_passes.post_grad import post_grad_passes, view_to_reshape from .fx_passes.pre_grad import pre_grad_passes from .graph import GraphLowering -from .ir import ExternKernelNode from .utils import ( align_inputs_from_check_idxs, clone_preserve_strides, @@ -104,13 +107,33 @@ from .virtualized import V -if config.is_fbcode(): - from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log -else: +if TYPE_CHECKING: + from torch._ops import OpOverload + + from .ir import ExternKernelNode + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + +if TYPE_CHECKING or not config.is_fbcode(): # no-op decorator - def time_and_log(attr: str): + def time_and_log(attr: str) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: return dynamo_utils.identity + def log_optimus_to_scuba(*args: object, **kwargs: object) -> None: + pass + +else: + from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log + +if TYPE_CHECKING: + from torch._functorch._aot_autograd.schemas import ( + FQN, + GraphInputName, + GraphSignature, + ) + log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") @@ -124,7 +147,7 @@ def time_and_log(attr: str): # for expanded dimensions (a dimension which used to have size 1 -> ?) # we can select one element from that dimension and write to it # to achieve writing to all values of that dimension of the input tensor -def get_expanded_dims(t): +def get_expanded_dims(t: torch.Tensor) -> List[int]: if not isinstance(t, torch.Tensor): return None return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1] @@ -155,7 +178,7 @@ def complex_memory_overlap(t: torch.Tensor) -> bool: return False -def get_static_input_idxs(num_fixed): +def get_static_input_idxs(num_fixed: int) -> List[int]: # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes # of cudagraphs. Rather than copying these into cudagraph-owned memory # like we do for normal inputs on each run, we will re-record a cudagraph if these @@ -169,12 +192,12 @@ def get_static_input_idxs(num_fixed): @functools.lru_cache(None) -def _step_logger(): +def _step_logger() -> Callable[..., None]: return dynamo_logging.get_step_logger(log) @functools.lru_cache(None) -def _warn_tf32_disabled(): +def _warn_tf32_disabled() -> None: if ( torch.cuda.is_available() and not torch.backends.cuda.matmul.allow_tf32 @@ -186,10 +209,12 @@ def _warn_tf32_disabled(): ) -def _unlift_graph(mod, gm, graph_signature): +def _unlift_graph( + mod: GraphModule, gm: GraphModule, graph_signature: GraphSignature +) -> GraphModule: from torch.export.unflatten import _assign_attr, _AttrKind - state_dict = {} + state_dict: Dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {} for name, param in mod.named_parameters(remove_duplicate=False): state_dict[name] = param _assign_attr( @@ -208,7 +233,7 @@ def _unlift_graph(mod, gm, graph_signature): ) placeholder_nodes = gm.graph.find_nodes(op="placeholder") - lifted_inputs = [] + lifted_inputs: List[Optional[FQN]] = [] # In AOTI, module parameters and buffers are not lifted as graph inputs. # As a result, mutation to buffers has side effect which makes their initial @@ -238,7 +263,7 @@ def _unlift_graph(mod, gm, graph_signature): user_input_mutations = graph_signature.user_inputs_to_mutate output_tokens = graph_signature.output_tokens for idx, out in enumerate(outputs): - value = None + value: Optional[Union[FQN, GraphInputName]] = None if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): if out.name in buffer_mutations: @@ -260,7 +285,7 @@ def _unlift_graph(mod, gm, graph_signature): return unlifted_gm -def _get_subgraph_names(gm): +def _get_subgraph_names(gm: GraphModule) -> Generator[str, None, None]: for node in sorted( itertools.chain( gm.graph.find_nodes(op="call_function", target=torch.ops.higher_order.cond), @@ -281,24 +306,26 @@ def _get_subgraph_names(gm): yield body_subgraph_name -def _recursive_pre_grad_passes(gm, example_inputs): +def _recursive_pre_grad_passes( + gm: GraphModule, example_inputs: Sequence[InputType] +) -> GraphModule: with dynamo_timed("_recursive_pre_grad_passes"): for subgraph_name in _get_subgraph_names(gm): subgraph = getattr(gm, subgraph_name) - # as we don't have recursive example inputs, passing None here - new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None) + # as we don't have recursive example inputs, passing empty set here + new_subgraph = _recursive_pre_grad_passes(subgraph, ()) setattr(gm, subgraph_name, new_subgraph) return pre_grad_passes(gm, example_inputs) -def _recursive_joint_graph_passes(gm): +def _recursive_joint_graph_passes(gm: GraphModule) -> None: for subgraph_name in _get_subgraph_names(gm): subgraph = getattr(gm, subgraph_name) _recursive_joint_graph_passes(subgraph) joint_graph_passes(gm) -def _recursive_post_grad_passes(gm, is_inference: bool = False): +def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) -> None: with dynamo_timed("_recursive_post_grad_passes"): for subgraph_name in _get_subgraph_names(gm): subgraph = getattr(gm, subgraph_name) @@ -307,10 +334,10 @@ def _recursive_post_grad_passes(gm, is_inference: bool = False): def split_const_gm( - gm: torch.fx.GraphModule, + gm: GraphModule, lifted_constants: Optional[Dict[str, Any]] = None, skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, -) -> Tuple[torch.fx.GraphModule, Dict[str, int]]: +) -> Tuple[GraphModule, Dict[str, int]]: """ This function takes an GraphModule input "gm". The gm will be split into 2 components, @@ -372,7 +399,7 @@ def split_const_gm( return const_gm, const_output_index -def is_tf32_warning_applicable(gm: torch.fx.GraphModule): +def is_tf32_warning_applicable(gm: GraphModule) -> bool: aten = torch.ops.aten tf32_ops = { aten.mm.default, @@ -391,7 +418,9 @@ def is_tf32_warning_applicable(gm: torch.fx.GraphModule): return False -def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): +def maybe_disable_comprehensive_padding( + example_inputs: Sequence[InputType], +) -> contextlib.AbstractContextManager[None, None]: """ For CPU backend, enable comprehensive padding causes some unit tests fail due to changing number of generated kernels. Skip for now. @@ -408,10 +437,10 @@ def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): def fake_tensor_prop( - gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], + gm: GraphModule, + example_inputs: Sequence[InputType], force_allow_non_fake_inputs: bool = False, -): +) -> torch._subclasses.FakeTensorMode: """ If we can not detect fake mode from the context of inputs, create one. @@ -438,13 +467,15 @@ def fake_tensor_prop( # pass config dict back to user -def get_patched_config_dict(config_patches=None) -> Dict[str, Any]: +def get_patched_config_dict( + config_patches: Optional[Union[str, Dict[str, Any]]] = None +) -> Dict[str, Any]: with config.patch(config_patches): return config.get_config_copy() @contextlib.contextmanager -def with_fresh_cache_if_config(): +def with_fresh_cache_if_config() -> Generator[None, None, None]: if config.force_disable_caches: # Don't delete the cache dir because it has to survive beyond the # compile_fx call. Let's put the temp dirs under the default cache @@ -455,7 +486,50 @@ def with_fresh_cache_if_config(): yield -def compile_fx_inner(*args, **kwargs): +class _CompileFxKwargs(TypedDict, total=False): + cudagraphs: Optional[BoxedBool] + static_input_idxs: Sequence[int] + is_backward: bool + graph_id: Optional[int] + cpp_wrapper: bool + aot_mode: bool + is_inference: bool + user_visible_outputs: Optional[Dict[str, None]] + layout_opt: Optional[bool] + extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] + + +class _CompileFxKwargsEx(_CompileFxKwargs, total=False): + boxed_forward_device_index: Optional[BoxedDeviceIndex] + + +class _CompileFxCallableEx(Protocol): + def __call__( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + **kwargs: Unpack[_CompileFxKwargsEx], + ) -> Union[CompiledFxGraph, str]: + ... + + +def compile_fx_inner( + gm: GraphModule, + example_inputs: Sequence[InputType], + **kwargs: Unpack[_CompileFxKwargsEx], +) -> Union[CompiledFxGraph, str]: + kwargs.setdefault("cudagraphs", None) + kwargs.setdefault("static_input_idxs", ()) + kwargs.setdefault("is_backward", False) + kwargs.setdefault("graph_id", None) + kwargs.setdefault("cpp_wrapper", False) + kwargs.setdefault("aot_mode", False) + kwargs.setdefault("is_inference", False) + kwargs.setdefault("boxed_forward_device_index", None) + kwargs.setdefault("user_visible_outputs", None) + kwargs.setdefault("layout_opt", None) + kwargs.setdefault("extern_node_serializer", None) + # Need with_fresh_cache_if_config for compile_fx_inner even if we already have one for # compile_fx. The reason is the compilation for backward graph may happen after # compile_fx return and we may want to use the _LazyGraphModule for compiling @@ -472,25 +546,18 @@ def compile_fx_inner(*args, **kwargs): stack.enter_context(DebugContext()) return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( - *args, **kwargs + gm, + example_inputs, + **kwargs, ) @time_and_log(attr="compilation time (in seconds)") def _compile_fx_inner( - gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - cudagraphs: Optional[BoxedBool] = None, - static_input_idxs: Optional[List[int]] = None, - is_backward: bool = False, - graph_id: Optional[int] = None, - cpp_wrapper: bool = False, - aot_mode: bool = False, - is_inference: bool = False, + gm: GraphModule, + example_inputs: Sequence[InputType], boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, - user_visible_outputs: Optional[Dict[str, None]] = None, - layout_opt: Optional[bool] = None, - extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None, + **graph_kwargs: Unpack[_CompileFxKwargs], ) -> Union[CompiledFxGraph, str]: """ Inductor API that compiles a single graph. @@ -498,6 +565,8 @@ def _compile_fx_inner( If you change the argument list for this function, make sure you also update the call to save_args_for_compile_fx_inner below accordingly. """ + aot_mode: bool = graph_kwargs.setdefault("aot_mode", False) + if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: # trigger the real recompilation for _LazyGraphModule before returning # the forward method. @@ -506,62 +575,35 @@ def _compile_fx_inner( _LazyGraphModule.force_recompile(gm) return make_boxed_func(gm.forward) - if static_input_idxs is None: - static_input_idxs = [] - + static_input_idxs: Sequence[int] = graph_kwargs.setdefault("static_input_idxs", ()) static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs) assert isinstance( next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list) ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" + if (cudagraphs := graph_kwargs.get("cudagraphs")) is None: + graph_kwargs["cudagraphs"] = cudagraphs = BoxedBool(config.triton.cudagraphs) if config.save_args: save_args_for_compile_fx_inner( gm, example_inputs, - cudagraphs=cudagraphs, - static_input_idxs=static_input_idxs, - is_backward=is_backward, - graph_id=graph_id, - cpp_wrapper=cpp_wrapper, - aot_mode=aot_mode, - is_inference=is_inference, boxed_forward_device_index=boxed_forward_device_index, - user_visible_outputs=user_visible_outputs, - layout_opt=layout_opt, + **graph_kwargs, ) - if cudagraphs is None: - cudagraphs = BoxedBool(config.triton.cudagraphs) - - # Inputs to fx_codegen_and_compile - # Anything that affects codegen should go here, so if the signature - # of fx_codegen_and_compile changes, the dict should be updated accordingly - graph_kwargs = { - "cudagraphs": cudagraphs, - "static_input_idxs": static_input_idxs, - "is_backward": is_backward, - "graph_id": graph_id, - "cpp_wrapper": cpp_wrapper, - "aot_mode": aot_mode, - "is_inference": is_inference, - "user_visible_outputs": user_visible_outputs, - "layout_opt": layout_opt, - "extern_node_serializer": extern_node_serializer, - } - start = time.time() fx_graph_remote_cache = should_use_remote_fx_graph_cache() - inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) # type: ignore[arg-type] + inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) def codegen_and_compile( - gm, - example_inputs, - inputs_to_check, - fx_kwargs, - ): + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + fx_kwargs: _CompileFxKwargs, + ) -> Union[CompiledFxGraph, str]: """ This function calls fx_codegen_and_compile and also adds some extra metadata to the resulting compiled fx graph. The metadata is saved to FXGraphCache. @@ -600,7 +642,7 @@ def codegen_and_compile( check_for_mutation_ignore_cuda_graph_managed_tensor( gm, compiled_graph, - static_input_idxs, # type:ignore[arg-type] + static_input_idxs, ) ) has_mutation = has_mutation_str is not None @@ -666,7 +708,7 @@ def codegen_and_compile( ) else: compiled_graph = codegen_and_compile( - gm, example_inputs, inputs_to_check, graph_kwargs # type: ignore[arg-type] + gm, example_inputs, inputs_to_check, graph_kwargs ) if aot_mode: # AOT mode is special because codegen_and_compile returns a string. @@ -682,8 +724,8 @@ def codegen_and_compile( _step_logger()( logging.INFO, "torchinductor done compiling " - f"{'BACKWARDS' if is_backward else 'FORWARDS'} " - f"graph {graph_id}", + f"{'BACKWARDS' if graph_kwargs['is_backward'] else 'FORWARDS'} " + f"graph {graph_kwargs['graph_id']}", ) # aot autograd needs to know to pass in inputs as a list compiled_graph._boxed_call = True @@ -691,10 +733,10 @@ def codegen_and_compile( def fx_codegen_and_compile( - gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], + gm: GraphModule, + example_inputs: Sequence[InputType], cudagraphs: Optional[BoxedBool] = None, - static_input_idxs: Optional[List[int]] = None, + static_input_idxs: Optional[Sequence[int]] = None, is_backward: bool = False, graph_id: Optional[int] = None, cpp_wrapper: bool = False, @@ -729,7 +771,7 @@ def fx_codegen_and_compile( f"graph {graph_id}", ) - def log_graph_runnable(): + def log_graph_runnable() -> str: fd = io.StringIO() torch._dynamo.repro.after_aot.save_graph_repro( fd, gm, example_inputs, "inductor", save_dir=None @@ -942,7 +984,7 @@ def log_graph_runnable(): def get_input_idxs_to_check( - inputs: List[InputType], + inputs: Sequence[InputType], static_input_idxs: Sequence[int], ) -> Sequence[int]: """ @@ -1009,7 +1051,7 @@ def cudagraphify( compiled_fn = None - def run(new_inputs): + def run(new_inputs: Sequence[InputType]) -> Any: nonlocal compiled_fn if compiled_fn is None: with dynamo_utils.dynamo_timed( @@ -1032,7 +1074,7 @@ def index_expanded_dims_and_copy_( dst: torch.Tensor, src: torch.Tensor, expanded_dims: List[int], -): +) -> None: "Index into expanded dimensions of both dst and src then copy_" dst = index_expanded_dims(dst, expanded_dims) src = index_expanded_dims(src, expanded_dims) @@ -1043,7 +1085,7 @@ def cudagraphify_impl( model: Callable[..., Any], inputs: List[torch.Tensor], static_input_idxs: Sequence[int] = (), -): +) -> Callable[[List[InputType]], Any]: """ Assumes inputs[static_input_idxs[i]] are always the same memory address """ @@ -1095,14 +1137,15 @@ def cudagraphify_impl( if config.size_asserts: - def run(new_inputs): + def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]: assert len(static_inputs) == len(new_inputs) for idx, (dst, src, expanded_dims) in enumerate( zip(static_inputs, new_inputs, inps_expanded_dims) ): if not isinstance(dst, torch.Tensor): - pass - elif idx in static_input_idxs: + continue + assert isinstance(src, torch.Tensor) + if idx in static_input_idxs: assert dst.data_ptr() == src.data_ptr() else: # TODO - could make one single op of multiple slices @@ -1118,12 +1161,12 @@ def run(new_inputs): idx for idx in range(len(static_inputs)) if idx not in static_input_idxs ] - def run(new_inputs): + def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]: for idx in copy_indices: expanded_dims = inps_expanded_dims[idx] - index_expanded_dims_and_copy_( - static_inputs[idx], new_inputs[idx], expanded_dims - ) + src = new_inputs[idx] + assert isinstance(src, torch.Tensor) + index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims) new_inputs.clear() graph.replay() return static_outputs @@ -1132,11 +1175,11 @@ def run(new_inputs): def compile_fx_aot( - model_: torch.fx.GraphModule, - example_inputs_: List[torch.Tensor], - inner_compile: Callable[..., Any] = compile_fx_inner, - config_patches: Optional[Dict[str, Any]] = None, -): + model_: GraphModule, + example_inputs_: List[InputType], + inner_compile: _CompileFxCallableEx = compile_fx_inner, + config_patches: Optional[Dict[str, str]] = None, +) -> str: config_patches: Dict[str, Any] = ( {"cpp_wrapper": True} if config_patches is None @@ -1164,6 +1207,7 @@ def compile_fx_aot( ), config_patches=config_patches, ) + assert isinstance(compiled_lib_path, str) assert os.path.exists( compiled_lib_path ), f"AOTInductor compiled library does not exist at {compiled_lib_path}" @@ -1174,15 +1218,15 @@ def compile_fx_aot( def fw_compiler_freezing( - aot_autograd_model: torch.fx.GraphModule, - aot_example_inputs: List[torch.Tensor], - dynamo_model: torch.fx.GraphModule, + aot_autograd_model: GraphModule, + aot_example_inputs: Sequence[InputType], + dynamo_model: GraphModule, num_example_inputs: int, inner_compile: Callable[..., Any], cudagraphs: BoxedBool, graph_id: int, forward_device: BoxedDeviceIndex, -): +) -> Callable[[List[object]], Sequence[torch.Tensor]]: from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze # partition_fn won't be called @@ -1265,7 +1309,7 @@ def fw_compiler_freezing( if V.aot_compilation is True: return optimized_function - def wrapper(args): + def wrapper(args: List[object]) -> Sequence[torch.Tensor]: args_new = [ args[i - unwrapped_args_offsets[min(i, max_offset_idx)]] for i in preserved_arg_indices @@ -1278,7 +1322,7 @@ def wrapper(args): return wrapper -def get_cpp_wrapper_config(): +def get_cpp_wrapper_config() -> Dict[str, object]: return { # Set autotune_at_compile_time to True as default if the option is not explicitly set "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time @@ -1321,12 +1365,12 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]: def compile_fx( - model_: torch.fx.GraphModule, - example_inputs_: List[torch.Tensor], + model_: GraphModule, + example_inputs_: Sequence[InputType], inner_compile: Callable[..., Any] = compile_fx_inner, config_patches: Optional[Dict[str, Any]] = None, decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None, -): +) -> Union[Callable[[List[object]], Sequence[torch.Tensor]], str]: with _use_lazy_graph_module(dynamo_config.use_lazy_graph_module): """Main entrypoint to a compile given FX graph""" if config_patches: @@ -1346,8 +1390,8 @@ def compile_fx( **get_cpp_wrapper_config(), } ), V.set_real_inputs(example_inputs_): - inputs_ = example_inputs_ - if isinstance(model_, torch.fx.GraphModule): + inputs_: Sequence[InputType] = example_inputs_ + if isinstance(model_, GraphModule): fake_inputs = [ node.meta.get("val") for node in model_.graph.nodes @@ -1363,13 +1407,15 @@ def compile_fx( if all(v is not None for v in fake_inputs): # Validate devices before switching to fake tensors. for idx, fi, i in zip(count(), fake_inputs, inputs_): - if fi is not None and fi.device != i.device: - raise ValueError( - f"Device mismatch between fake input and example input at position #{idx}: " - f"{fi.device} vs {i.device}. If the model was exported via torch.export(), " - "make sure torch.export() and torch.aot_compile() run on the same device." - ) - inputs_ = fake_inputs # type: ignore[assignment] + if fi is not None: + assert isinstance(i, torch.Tensor) + if fi.device != i.device: + raise ValueError( + f"Device mismatch between fake input and example input at position #{idx}: " + f"{fi.device} vs {i.device}. If the model was exported via torch.export(), " + "make sure torch.export() and torch.aot_compile() run on the same device." + ) + inputs_ = fake_inputs return compile_fx( model_, inputs_, @@ -1390,7 +1436,7 @@ def compile_fx( recursive_compile_fx, ) - if isinstance(model_, torch.fx.GraphModule): + if isinstance(model_, GraphModule): if isinstance(model_.graph._codegen, _PyTreeCodeGen): # this graph is the result of dynamo.export() return handle_dynamo_export_graph( @@ -1420,18 +1466,18 @@ def compile_fx( ) def fw_compiler_base( - model: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], + model: GraphModule, + example_inputs: List[InputType], is_inference: bool, - ): + ) -> CompiledFxGraph: with dynamo_utils.dynamo_timed("compile_fx..fw_compiler_base"): return _fw_compiler_base(model, example_inputs, is_inference) def _fw_compiler_base( - model: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], + model: GraphModule, + example_inputs: List[InputType], is_inference: bool, - ): + ) -> CompiledFxGraph: if is_inference: # partition_fn won't be called _recursive_joint_graph_passes(model) @@ -1456,7 +1502,7 @@ def _fw_compiler_base( else: original_output_start_index = 0 - if isinstance(model_, torch.fx.GraphModule): + if isinstance(model_, GraphModule): *_, orig_model_outputs_node = model_.graph.nodes assert orig_model_outputs_node.op == "output" orig_model_outputs, _ = pytree.tree_flatten( @@ -1510,7 +1556,7 @@ def _fw_compiler_base( fw_compiler = functools.partial(fw_compiler_base, is_inference=False) if config.freezing and not torch.is_grad_enabled(): - inference_compiler = functools.partial( + inference_compiler: Callable[..., Any] = functools.partial( fw_compiler_freezing, dynamo_model=model_, num_example_inputs=num_example_inputs, @@ -1522,7 +1568,11 @@ def _fw_compiler_base( else: inference_compiler = functools.partial(fw_compiler_base, is_inference=True) - def partition_fn(gm: torch.fx.GraphModule, joint_inputs, **kwargs): + def partition_fn( + gm: GraphModule, + joint_inputs: Sequence[object], + **kwargs: object, + ) -> Tuple[GraphModule, GraphModule]: cuda_context = get_cuda_device_context(gm) with cuda_context: _recursive_joint_graph_passes(gm) @@ -1532,8 +1582,8 @@ def partition_fn(gm: torch.fx.GraphModule, joint_inputs, **kwargs): @compile_time_strobelight_meta(phase_name="backward") def bw_compiler( - model: torch.fx.GraphModule, example_inputs: List[torch.Tensor] - ): + model: GraphModule, example_inputs: List[InputType] + ) -> Union[CompiledFxGraph, str]: with dynamo_utils.dynamo_timed("compile_fx..bw_compiler"): user_visible_outputs = {} @@ -1613,9 +1663,9 @@ def bw_compiler( )(model_, example_inputs_) -def graph_returns_tuple(gm: torch.fx.GraphModule): +def graph_returns_tuple(gm: GraphModule) -> bool: """True if a FX graph returns a tuple""" - if not isinstance(gm, torch.fx.GraphModule): + if not isinstance(gm, GraphModule): return True # can't check this, assume true (rv,) = output_node(gm).args if isinstance(rv, (list, tuple)): @@ -1632,10 +1682,10 @@ def graph_returns_tuple(gm: torch.fx.GraphModule): def make_graph_return_tuple( - gm: torch.fx.GraphModule, - inputs: List[torch.Tensor], + gm: GraphModule, + inputs: Sequence[InputType], compile_gm: Callable[..., Any], -): +) -> Callable[..., Any]: """ Mutate gm so it returns a tuple. This is only needed for graphs not created by torchdynamo that return non-tuples. @@ -1651,17 +1701,17 @@ def make_graph_return_tuple( compiled_fn = compile_gm(gm, inputs) @functools.wraps(compiled_fn) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) return wrapper def handle_dynamo_export_graph( - gm: torch.fx.GraphModule, - inputs: List[torch.Tensor], + gm: GraphModule, + inputs: Sequence[InputType], compile_gm: Callable[..., Any], -): +) -> Callable[..., Any]: """ `torch._dynamo.export` embeds pytrees in the FX graph codegen object, convert that to a normal FX graph so inductor can compile it. @@ -1673,14 +1723,14 @@ def handle_dynamo_export_graph( compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs)) @functools.wraps(compiled_fn) - def wrapper(*args): + def wrapper(*args: Any) -> Any: return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args))) return wrapper def _check_triton_bf16_support(graph: GraphLowering) -> None: - def warn_and_skip(device) -> None: + def warn_and_skip(device: torch.device) -> Never: from torch._dynamo.exc import SkipFrame device_interface = get_interface_for_device(device.type) diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 15038409175001..1ccfc6e65055f2 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -853,7 +853,7 @@ def __init__( def maybe_get_static_data_ptr( idx: int, - inputs: List[Union[torch.Tensor, int]], + inputs: List[InputType], static_input_idxs: List[int], ) -> Optional[int]: inp = inputs[idx] @@ -1576,7 +1576,7 @@ def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage: def _allocate_and_copy_recording_inputs( self, inputs: List[InputType] - ) -> List[Union[torch.Tensor, int]]: + ) -> List[InputType]: """ Allocate inputs for non static, non cudagraph managed tensors in the memory pool and copy over the tensor values. diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 4309a1a77b1e5b..16a6a74aea1463 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -2,7 +2,7 @@ import copy import itertools import logging -from typing import Dict, Optional +from typing import Dict, Optional, Sequence import torch import torch.nn as nn @@ -112,7 +112,9 @@ def lazy_init(): from . import fb # type: ignore[attr-defined] # noqa: F401 -def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None): +def pre_grad_passes( + gm: torch.fx.GraphModule, example_inputs: Sequence[object] = () +) -> torch.fx.GraphModule: """ Apply passes on the input FX graph using Torch IR. @@ -138,7 +140,7 @@ def shape_prop(mod) -> None: gm=mod, # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` fake_mode=detect_fake_mode(example_inputs), - ).propagate(*example_inputs) + ).propagate(*tuple(example_inputs)) # normalization pass pass_execution_and_save( diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index a550ccaad2f40b..1db6df67f51479 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -315,7 +315,7 @@ def static_sizes_strides( def __init__( self, gm: torch.fx.GraphModule, - example_inputs: Optional[List[torch.Tensor]] = None, + example_inputs: Optional[Sequence[object]] = None, shape_env: Optional[ShapeEnv] = None, graph_id: Optional[int] = None, cpp_wrapper: bool = False, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 9f7c199b5cd263..5ddd9b0a90cacc 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -88,7 +88,7 @@ def get_gpu_type(): _T = TypeVar("_T") VarRanges = Dict[sympy.Expr, sympy.Expr] -InputType = Union[torch.Tensor, int] +InputType = Optional[Union[torch.Tensor, int, torch.SymInt]] GPU_ALIGN_BYTES = 16 @@ -1970,7 +1970,7 @@ def run_and_get_cpp_code(fn, *args, **kwargs): return result, s -def shape_env_from_inputs(inputs: List[torch.Tensor]): +def shape_env_from_inputs(inputs: Sequence[InputType]): shape_env = None fake_mode = detect_fake_mode(inputs) @@ -2024,9 +2024,9 @@ def copy_misaligned_inputs( def remove_unaligned_input_idxs( - inputs: List[InputType], + inputs: Sequence[InputType], static_input_idxs: Sequence[int], -): +) -> Sequence[int]: """ We require all inputs to be aligned, so introduce a copy for any that aren't. diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 129d135d5ab516..8973fd013f37ec 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -4,12 +4,16 @@ import os import sys import tempfile -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing_extensions import ParamSpec import torch from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler +_T = TypeVar("_T") +_P = ParamSpec("_P") + log = logging.getLogger(__name__) if os.environ.get("TORCH_COMPILE_STROBELIGHT", False): @@ -76,12 +80,16 @@ def throw_abstract_impl_not_imported_error(opname, module, context): # NB! This treats "skip" kwarg specially!! -def compile_time_strobelight_meta(phase_name): - def compile_time_strobelight_meta_inner(function): +def compile_time_strobelight_meta( + phase_name: str, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def compile_time_strobelight_meta_inner( + function: Callable[_P, _T], + ) -> Callable[_P, _T]: @functools.wraps(function) - def wrapper_function(*args, **kwargs): - if "skip" in kwargs: - kwargs["skip"] = kwargs["skip"] + 1 + def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T: + if "skip" in kwargs and isinstance(skip := kwargs["skip"], int): + kwargs["skip"] = skip + 1 if not StrobelightCompileTimeProfiler.enabled: return function(*args, **kwargs) diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index fdb16e6568fa45..a422950fa4788e 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -2,7 +2,7 @@ import copy import warnings from itertools import chain -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import torch import torch.utils._pytree as pytree @@ -43,7 +43,7 @@ def _check_input_constraints_pre_hook(self, *args, **kwargs): def _unlift_inputs_as_getattr( gm: torch.fx.GraphModule, - lifted_inputs: List[Optional[str]], + lifted_inputs: Sequence[Optional[str]], ) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]: """ Unlift inputs referring to params/buffers/constants as getattr nodes in the @@ -72,7 +72,7 @@ def _unlift_inputs_as_getattr( def _insert_copy_for_mutations( gm: torch.fx.GraphModule, - mutated_outputs: List[Optional[str]], + mutated_outputs: Sequence[Optional[str]], unlifted_name_to_node: Dict[str, torch.fx.Node], input_name_to_node: Dict[str, torch.fx.Node], ) -> None: @@ -158,8 +158,8 @@ def _get_codegen( def _unlift( gm: torch.fx.GraphModule, - lifted_inputs: List[Optional[str]], - mutated_outputs: List[Optional[str]], + lifted_inputs: Sequence[Optional[str]], + mutated_outputs: Sequence[Optional[str]], in_spec: pytree.TreeSpec, out_spec: Optional[pytree.TreeSpec], state_dict: Dict[str, Any],