Skip to content

Commit

Permalink
typing compile_fx.py (pytorch#138033)
Browse files Browse the repository at this point in the history
Type annotations for compile_fx.
- Some of the stuff here is pretty complicated (functions which return functions that take functions) so I bailed on those and used `Any` just to get the rest landed.
- There are also changes to type signatures in other files which I did just to let mypy know more about the types in compile_fx.py.

Pull Request resolved: pytorch#138033
Approved by: https://github.com/Skylion007
  • Loading branch information
aorenste authored and pytorchmergebot committed Oct 21, 2024
1 parent 8173840 commit 07cc4bd
Show file tree
Hide file tree
Showing 13 changed files with 302 additions and 210 deletions.
2 changes: 1 addition & 1 deletion torch/_dynamo/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _wrapped_bw_compiler(*args, **kwargs):
raise


def aot_autograd(**kwargs):
def aot_autograd(**kwargs) -> AotAutograd:
return AotAutograd(**kwargs)


Expand Down
36 changes: 29 additions & 7 deletions torch/_dynamo/repro/after_aot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: allow-untyped-defs

import argparse
import copy
import functools
Expand All @@ -12,7 +13,8 @@
import uuid
from importlib import import_module
from tempfile import TemporaryFile
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, Sequence, TYPE_CHECKING, Union
from typing_extensions import Unpack

import torch
import torch.fx as fx
Expand Down Expand Up @@ -45,6 +47,12 @@
from .. import config


if TYPE_CHECKING:
from torch._inductor.codecache import CompiledFxGraph
from torch._inductor.compile_fx import _CompileFxCallableEx, _CompileFxKwargsEx
from torch._inductor.utils import InputType


log = logging.getLogger(__name__)


Expand All @@ -56,7 +64,10 @@
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str):
def wrap_compiler_debug(
unconfigured_compiler_fn: "_CompileFxCallableEx",
compiler_name: str,
) -> "_CompileFxCallableEx":
"""
Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both
forward and backward call separately with the backend compiler_fn - like
Expand All @@ -66,7 +77,11 @@ def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str):
"""

@functools.wraps(unconfigured_compiler_fn)
def debug_wrapper(gm, example_inputs, **kwargs):
def debug_wrapper(
gm: torch.fx.GraphModule,
example_inputs: Sequence["InputType"],
**kwargs: Unpack["_CompileFxKwargsEx"],
) -> Union["CompiledFxGraph", str]:
from torch._subclasses import FakeTensorMode

compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)
Expand Down Expand Up @@ -104,11 +119,15 @@ def debug_wrapper(gm, example_inputs, **kwargs):

# We may run regular PyTorch compute that may trigger Dynamo, do NOT
# recursively attempt to accuracy minify in that case!
def deferred_for_real_inputs(real_inputs):
def deferred_for_real_inputs(
real_inputs: Sequence["InputType"], **_kwargs: object
) -> Any:
# This is a bit obscure: if we recursively try to accuracy minify
# the SAME function, this would trigger. But most of the time
# we should never hit this branch
assert not _kwargs
if config.repro_after != "aot":
assert not isinstance(inner_compiled_fn, str)
return inner_compiled_fn(real_inputs)
with config.patch(repro_after=None):
return inner_debug_fn(real_inputs)
Expand Down Expand Up @@ -165,11 +184,11 @@ def inner_debug_fn(real_inputs):
raise AccuracyError("Bad accuracy detected")
else:
# Call the compiled function with real inputs
return inner_compiled_fn(real_inputs)
return inner_compiled_fn(real_inputs) # type: ignore[operator]
else:
try:
# Call the compiled function with real inputs
out = inner_compiled_fn(real_inputs)
out = inner_compiled_fn(real_inputs) # type: ignore[operator]
# sync cuda kernels to ensure IMA detection
for arg in example_inputs:
if isinstance(arg, torch.Tensor) and arg.is_cuda:
Expand All @@ -194,7 +213,7 @@ def inner_debug_fn(real_inputs):
if config.repro_after == "aot":
compiled_fn = deferred_for_real_inputs
compiled_fn._boxed_call = True # type: ignore[attr-defined]
return compiled_fn
return compiled_fn # type: ignore[return-value]
else:
return inner_compiled_fn

Expand Down Expand Up @@ -432,6 +451,7 @@ def sync():

try:
compile_mod = compile_fx_inner(fx_g, args)
assert not isinstance(compile_mod, str)
compile_mod(args)
sync()
except Exception as e:
Expand Down Expand Up @@ -601,6 +621,7 @@ def save_hook(name, val):
with intermediate_hook(save_hook), tqdm(
desc="Saving inductor intermediates", total=total
) as pbar:
assert not isinstance(compiled, str)
compiled(new_args)
assert not new_args

Expand Down Expand Up @@ -717,6 +738,7 @@ def repro_run(options, mod, load_args):
from torch.cuda import synchronize

compiled = compile_fx_inner(mod, args)
assert not isinstance(compiled, str)

if options.accuracy != "":
# We don't really respect --accuracy vs --strict-accuracy here, it
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def count_calls(g: fx.Graph) -> int:
return c


def identity(x):
def identity(x: T) -> T:
return x


Expand Down
13 changes: 7 additions & 6 deletions torch/_functorch/_aot_autograd/autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@


if TYPE_CHECKING:
from torch._inductor.compile_fx import _CompileFxKwargs
from torch._inductor.remote_cache import JsonDataTy, RemoteCache
from torch._inductor.utils import BoxedBool
from torch.fx.node import Node
Expand Down Expand Up @@ -205,7 +206,7 @@ def __init__(
gm: torch.fx.GraphModule,
example_inputs,
aot_config: AOTConfig,
fx_config: Dict[str, BoxedBool],
fx_config: _CompileFxKwargs,
):
# FxGraphHashDetails contains all the keys related to inductor. Also includes some system info
self.aot_config = aot_config
Expand Down Expand Up @@ -269,7 +270,7 @@ def autograd_cache_key(
gm: torch.fx.GraphModule,
example_inputs,
config: AOTConfig,
fx_config: Dict[str, BoxedBool],
fx_config: _CompileFxKwargs,
# TODO: add args and parameters
) -> Tuple[str, List[str]]:
"""
Expand All @@ -295,7 +296,7 @@ class FXGraphCacheLoadable:
def is_backward(self):
return False

def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGraph:
def load(self, example_inputs, fx_config: _CompileFxKwargs) -> CompiledFxGraph:
# [Note: AOTAutogradCache and FXGraphCache Guard interactions]
# As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments.
# FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph.
Expand Down Expand Up @@ -332,7 +333,7 @@ def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGra
payload_fn=lambda: json.dumps(cache_info),
)

FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"])
FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"]) # type: ignore[arg-type]
result._boxed_call = True
return result

Expand Down Expand Up @@ -399,7 +400,7 @@ def wrap_post_compile(
self,
args: List[torch.Tensor],
aot_config: AOTConfig,
fx_config: Dict[str, BoxedBool],
fx_config: _CompileFxKwargs,
) -> Callable:
"""
This function takes a cache entry and carefully reconstructs the original callable
Expand Down Expand Up @@ -566,7 +567,7 @@ def load(
debug_lines: List[str] = []
cache_event_time = time.time_ns()
cache_state = None
fx_config = {"cudagraphs": cudagraphs}
fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs}
try:
cache_key, debug_lines = autograd_cache_key(gm, args, aot_config, fx_config)
entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(
Expand Down
9 changes: 7 additions & 2 deletions torch/_inductor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, Tuple

from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

import torch.fx
import torch.utils._pytree as pytree


if TYPE_CHECKING:
from torch._inductor.utils import InputType


__all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"]


def compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
example_inputs: List["InputType"],
options: Optional[Dict[str, Any]] = None,
):
"""
Expand Down
52 changes: 28 additions & 24 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@
if TYPE_CHECKING:
from collections.abc import KeysView

from .compile_fx import _CompileFxKwargs
from .remote_cache import JsonDataTy, RemoteCache
from .utils import InputType


"""
Expand Down Expand Up @@ -751,23 +753,25 @@ class FxGraphHashDetails:
def __init__(
self,
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
fx_kwargs: Dict[str, Any],
example_inputs: Sequence[InputType],
fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int],
) -> None:
self.gm = gm
self.example_inputs = example_inputs

# Order kwargs so hashing is stable to changes in kwarg order.
self.fx_kwargs = {}
for k in sorted(fx_kwargs):
# Order kwargs so hashing is stable to changes in kwarg order. Although
# it's technically a _CompileFxKwargs we don't actually need it typed as
# such since we're just using it to generate a hash.
self.fx_kwargs: Dict[str, object] = {}
for k, v in sorted(fx_kwargs.items()):
if k not in self.EXCLUDED_KWARGS:
if type(fx_kwargs[k]) is set:
if type(v) is set:
# Special case to handle set params. Python sets can't be
# ordered, so sort the elements and store them in a proxy.
self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k]))
self.fx_kwargs[k] = OrderedSetHolder(sorted(v))
else:
self.fx_kwargs[k] = fx_kwargs[k]
self.fx_kwargs[k] = v

# Alignment checks
self.inputs_to_check = inputs_to_check
Expand Down Expand Up @@ -818,8 +822,8 @@ def debug_lines(self) -> List[str]:

def compiled_fx_graph_hash(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
fx_kwargs: Dict[str, Any],
example_inputs: Sequence[InputType],
fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int],
) -> Tuple[str, List[str]]:
"""
Expand All @@ -836,7 +840,7 @@ def compiled_fx_graph_hash(


def cudagraph_post_compile(
example_inputs: List[Any],
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool,
) -> None:
Expand Down Expand Up @@ -879,7 +883,7 @@ def cudagraph_post_compile(
assert current_callable is not None
compiled_graph.current_callable = cudagraphify(
current_callable,
static_input_idxs=static_input_idxs,
static_input_idxs=static_input_idxs or (),
device_index=next(iter(compiled_graph.device_idxs)),
stack_traces=stack_traces,
is_backward=is_backward,
Expand Down Expand Up @@ -1018,7 +1022,7 @@ def _get_tmp_dir_for_key(key: str) -> str:
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)

@staticmethod
def _filter_backed_symints(inputs: List[Any]) -> List[torch.SymInt]:
def _filter_backed_symints(inputs: Sequence[InputType]) -> List[torch.SymInt]:
"""
Get the backed SymInt objects from the input list. Note that we can never
have guards that depend on unbacked symint.
Expand All @@ -1038,7 +1042,7 @@ def _get_shape_env() -> Optional[ShapeEnv]:
@staticmethod
def _lookup_graph(
key: str,
example_inputs: List[torch.Tensor],
example_inputs: Sequence[InputType],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
) -> Optional[CompiledFxGraph]:
Expand Down Expand Up @@ -1177,7 +1181,7 @@ def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]:
@staticmethod
def post_compile(
compiled_graph: CompiledFxGraph,
example_inputs: List[torch.Tensor],
example_inputs: Sequence[InputType],
cudagraphs: BoxedBool,
) -> CompiledFxGraph:
"""
Expand Down Expand Up @@ -1224,7 +1228,7 @@ def post_compile(
def _save_graph(
key: str,
compiled_graph: CompiledFxGraph,
example_inputs: List[torch.Tensor],
example_inputs: Sequence[InputType],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
) -> None:
Expand Down Expand Up @@ -1333,8 +1337,8 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None:
@staticmethod
def prepare_key(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
fx_kwargs: Dict[str, Any],
example_inputs: Sequence[InputType],
fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int],
remote: bool,
) -> Tuple[Optional[Tuple[str, List[str]]], Dict[str, Any]]:
Expand Down Expand Up @@ -1384,7 +1388,7 @@ def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
def load_with_key(
key: str,
debug_lines: List[str],
example_inputs: List[torch.Tensor],
example_inputs: Sequence[InputType],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
is_backward: bool,
Expand Down Expand Up @@ -1427,8 +1431,8 @@ def load_with_key(
def load( # type: ignore[no-untyped-def]
compile_fx_fn: Callable[..., Any],
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
fx_kwargs: Dict[str, Any],
example_inputs: Sequence[InputType],
fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int],
local: bool,
remote: bool,
Expand Down Expand Up @@ -1514,7 +1518,7 @@ def load( # type: ignore[no-untyped-def]
)
# Use the passed in cudagraphs so that we mutate the BoxedBool correctly
FxGraphCache.post_compile(
compiled_graph, example_inputs, fx_kwargs["cudagraphs"]
compiled_graph, example_inputs, fx_kwargs["cudagraphs"] # type: ignore[arg-type]
)
return compiled_graph

Expand Down Expand Up @@ -1561,7 +1565,7 @@ class CompiledFxGraph:
guards_expr: Optional[str]

cudagraph_info: Optional[CudagraphCachedInfo]
fx_kwargs: Dict[str, Any]
fx_kwargs: _CompileFxKwargs
inputs_to_check: Sequence[int]
boxed_forward_device_index: Optional[BoxedDeviceIndex]

Expand Down Expand Up @@ -1601,7 +1605,7 @@ def __init__(
self.inputs_to_check = ()
self.boxed_forward_device_index = None

def __call__(self, inputs: List[Any]) -> Any:
def __call__(self, inputs: Sequence[Any]) -> Any:
assert self.current_callable is not None
try:
return self.current_callable(inputs)
Expand Down
Loading

0 comments on commit 07cc4bd

Please sign in to comment.