From 986ad3bfa6c3bb904ffadb2edad00efed839311c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 12 Oct 2023 00:44:41 +0000 Subject: [PATCH] [2/N] Dynamo supports skip by function & removes skipfiles circular import (#110835) Several improvements for skipfiles: * Add ```FUNC_INLINELIST``` to support function level skip/inline check. * Use ```fn.__code__``` to match function since we can't get the function object sometimes. * Use python module string name for ```FILE_INLINELIST``` and ```SUBMODULE_INLINELIST```. * Use filename to match file and python module, which can fundamentally resolved the circular import issues introduced by skipfiles. * Use ```TYPE_CHECKING``` to ensure the python module string name is correct. * Add unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110835 Approved by: https://github.com/ezyang --- test/dynamo/test_allow_inline_skip.py | 101 +++++++++++ test/dynamo/utils.py | 28 +++ torch/_dynamo/eval_frame.py | 8 +- torch/_dynamo/skipfiles.py | 237 ++++++++++++++++++-------- torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/variables/builder.py | 5 +- 6 files changed, 299 insertions(+), 82 deletions(-) create mode 100644 test/dynamo/test_allow_inline_skip.py diff --git a/test/dynamo/test_allow_inline_skip.py b/test/dynamo/test_allow_inline_skip.py new file mode 100644 index 0000000000000..b8c1a5e357843 --- /dev/null +++ b/test/dynamo/test_allow_inline_skip.py @@ -0,0 +1,101 @@ +# Owner(s): ["module: dynamo"] +import importlib +import types +import unittest + +import torch +import torch._dynamo.test_case +from torch._dynamo.skipfiles import ( + FILE_INLINELIST, + FUNC_INLINELIST, + SUBMODULE_INLINELIST, +) +from torch._dynamo.utils import istype + +try: + from .utils import create_dummy_module_and_function +except ImportError: + from utils import create_dummy_module_and_function + + +def gen_get_func_inlinelist(dummy_func_inlinelist): + def get_func_inlinelist(): + inlinelist = set() + for f in dummy_func_inlinelist: + module_name, fn_name = f.rsplit(".", 1) + m = importlib.import_module(module_name) + fn = getattr(m, fn_name) + inlinelist.add(fn.__code__) + return inlinelist + + return get_func_inlinelist + + +class AllowInlineSkipTests(torch._dynamo.test_case.TestCase): + # We are using python function and module string names for these inlinelist, + # this unit test is to make sure the functions/modules can be correctly imported + # or loaded in case there is typo in the strings. + def test_skipfiles_inlinelist_correctness(self): + for m in FILE_INLINELIST.union(SUBMODULE_INLINELIST): + self.assertTrue(isinstance(importlib.import_module(m), types.ModuleType)) + for f in FUNC_INLINELIST: + module_name, fn_name = f.rsplit(".", 1) + m = importlib.import_module(module_name) + self.assertTrue(isinstance(getattr(m, fn_name), types.FunctionType)) + + def test_func_inlinelist_torch_function(self): + def fn(x): + if istype(x, torch.Tensor): + return x + 1 + else: + return x - 1 + + func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy() + func_inlinelist.add("torch._dynamo.utils.istype") + + self.assertTrue( + "torch._dynamo.utils" not in torch._dynamo.skipfiles.FILE_INLINELIST + ) + self.assertTrue( + "torch._dynamo" not in torch._dynamo.skipfiles.SUBMODULE_INLINELIST + ) + + with unittest.mock.patch( + "torch._dynamo.skipfiles.get_func_inlinelist", + gen_get_func_inlinelist(func_inlinelist), + ): + x = torch.rand(3) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_func_inlinelist_third_party_function(self): + mod, func = create_dummy_module_and_function() + + def fn(x): + return func(x) + + func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy() + func_inlinelist.add(f"{mod.__name__}.{func.__name__}") + + with unittest.mock.patch( + "torch._dynamo.skipfiles.get_func_inlinelist", + gen_get_func_inlinelist(func_inlinelist), + ), unittest.mock.patch( + "torch._dynamo.skipfiles.SKIP_DIRS", + torch._dynamo.skipfiles.SKIP_DIRS.copy(), + ): + # First adding the module to SKIP_DIRS so that it will be skipped. + torch._dynamo.skipfiles.add(mod.__name__) + x = torch.rand(3) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/utils.py b/test/dynamo/utils.py index 30a59e94092ff..719eec47d9dad 100644 --- a/test/dynamo/utils.py +++ b/test/dynamo/utils.py @@ -1,4 +1,8 @@ # Owner(s): ["module: dynamo"] +import importlib +import os +import sys +import types import torch import torch._dynamo @@ -20,3 +24,27 @@ def wrapped(*args): return torch.sin(a + 1), inner_func() return wrapped + + +# Create a dummy python module and function to test skipfiles rules. +module_code = """ +def add(x): + return x + 1 +""" + + +def add(x): + return x + 1 + + +def create_dummy_module_and_function(): + module = types.ModuleType("dummy_module") + module.__spec__ = importlib.machinery.ModuleSpec( + "dummy_module", None, origin=os.path.abspath(__file__) + ) + exec(module_code, module.__dict__) + sys.modules["dummy_module"] = module + # Need to override the original function since its __code__.co_filename is not a regular python file name, + # and the skipfiles rules use filename when checking SKIP_DIRS. + module.add = add + return module, module.add diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index d8cdf808e41b8..d9c65c4a08172 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -182,7 +182,7 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx): def _initialize(self): # Do this stuff in constructor to lower overhead slightly if isinstance(self._orig_mod.forward, types.MethodType) and skipfiles.check( - inspect.getsourcefile(self._orig_mod.forward) + self._orig_mod.forward ): # This may be a torch.nn.* instance in skipfiles.py which # won't trigger a frame evaluation workaround to add an extra @@ -362,7 +362,7 @@ def get_compiler_config(): except TypeError: filename = None if ( - (filename is None or skipfiles.check(filename)) + (filename is None or skipfiles.check(fn)) and ( getattr(fn, "__name__", "") not in ["_call_impl", "_wrapped_call_impl"] ) @@ -519,7 +519,7 @@ def catch_errors(frame, cache_entry, frame_state): if ( # TODO: the first condition is not covered by any test frame.f_lasti >= first_real_inst_idx(frame.f_code) - or skipfiles.check(frame.f_code.co_filename) + or skipfiles.check(frame.f_code) or config.disable ): log.debug("skipping %s %s", frame.f_code.co_name, frame.f_code.co_filename) @@ -1218,7 +1218,7 @@ def result_capturing_wrapper(*graph_inputs): if ( (shape_env := getattr(fake_mode, "shape_env", None)) is not None and (dim_constraints := shape_env.dim_constraints) is not None - and not skipfiles.check(inspect.getsourcefile(call_to_inspect)) + and not skipfiles.check(call_to_inspect) ): dim_constraints.solve() dim_constraints.remove_redundant_dynamic_results() diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py index 26e1bc4dede5b..bd5357afacff8 100644 --- a/torch/_dynamo/skipfiles.py +++ b/torch/_dynamo/skipfiles.py @@ -8,7 +8,6 @@ import dataclasses import enum import functools -import glob import importlib import inspect import linecache @@ -35,8 +34,14 @@ import torch._inductor.test_operators import torch.distributed import torch.utils._content_store +from .utils import getfile + +from .variables.functions import ( + NestedUserFunctionVariable, + UserFunctionVariable, + UserMethodVariable, +) -from . import comptime, external_utils, polyfill """ A note on skipfiles: @@ -59,10 +64,10 @@ * BUILTIN_SKIPLIST contains builtin python modules, such as abc, collections, etc. * THIRDPARTY_SKIPLIST contains common third party libraries, such as numpy, pandas, etc. * Functions in these two SKIPLISTs are always skipped, except when they are explicitly - put into the two INLINELIST: FILENAME_INLINELIST and SUBMODULE_INLINELIST. + put into the three INLINELIST: FUNC_INLINELIST, FILE_INLINELIST and SUBMODULE_INLINELIST. * PyTorch(torch) is in the BUILTIN_SKIPLIST by default, but there are many cases where we want inline the functions under torch namespace. We should add them - into FILENAME_INLINELIST or SUBMODULE_INLINELIST to make dynamo inline those functions. + into one of the three *_INLINELIST to make dynamo inline those functions. * If you call functions under skipped modules/files, Dynamo will wrap these functions as SkipFilesVariable. There are a few functions(e.g, collections.OrderedDict) that we have special handling at SkipFilesVariable.call_function. @@ -70,11 +75,18 @@ Overall: *_INLINELIST has precedence over *_SKIPLIST has precedence over DEFAULT (inline) To figure out what the behavior is, check the following list in order: -* FILENAME_INLINELIST (Inline if YES) +* FUNC_INLINELIST (Inline if YES) +* FILE_INLINELIST (Inline if YES) * SUBMODULE_INLINELIST (Inline if YES) * BUILTIN_SKIPLIST & THIRDPARTY_SKIPLIST (Skip if YES) * Inline by default +In general, if you want to force inline a function or module, please consider adding +the function's file or python module to FILE_INLINELIST first. +Use the FUNC_INLINELIST only when there are other functions under the same file that +you don't want to inline. +In the future, we will consolidate FILE_INLINELIST and SUBMODULE_INLINELIST into one list +as we use the same logic (filename.startswith) to determine if a file or module is skipped. """ @@ -102,7 +114,7 @@ tempfile, threading, tokenize, - torch, # torch/* is skipped by default unless specified in FILENAME_INLINELIST or SUBMODULE_INLINELIST + torch, # torch/* is skipped by default unless specified in FILE_INLINELIST or SUBMODULE_INLINELIST traceback, types, typing, @@ -145,74 +157,107 @@ def _module_dir(m: types.ModuleType): return _strip_init_py(m.__file__) -# TODO(ybliang): Change to user *.__file__ rather than hard code string for this list. -# Force inline functions in these files, even the files is in *_SKIPLIST. -FILENAME_INLINELIST = { - torch.nn.Sequential.__init__.__code__.co_filename, - torch.set_rng_state.__code__.co_filename, - torch._inductor.test_operators.__file__, - torch.utils._content_store.__file__, - external_utils.__file__, - comptime.__file__, - polyfill.__file__, - torch.optim._functional.__file__, - torch.utils._foreach_utils.__file__, - _module_dir(torch) + "ao/quantization/pt2e/qat_utils.py", - _module_dir(torch) + "ao/quantization/quantizer/xnnpack_quantizer.py", - _module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py", - _module_dir(torch) + "ao/quantization/pt2e/utils.py", - _module_dir(torch) + "ao/quantization/pt2e/eval_utils.py", - _module_dir(torch) + "_dynamo/_trace_wrapped_higher_order_op.py", - _module_dir(torch) + "_export/constraints.py", - _module_dir(torch) + "_higher_order_ops/cond.py", - _module_dir(torch) + "_functorch/apis.py", - _module_dir(torch) + "_functorch/deprecated.py", - _module_dir(torch) + "distributed/tensor/parallel/_utils.py", - _module_dir(torch) + "distributed/tensor/parallel/style.py", - _module_dir(torch) + "distributed/tensor/parallel/_data_parallel_utils.py", - _module_dir(torch) + "distributed/_tensor/api.py", - _module_dir(torch) + "distributed/_tensor/device_mesh.py", +# TODO: Add a decoractor for easily adding functions to FUNC_INLINELIST +# after resolving all circular import issues. +FUNC_INLINELIST = { + "torch._constrain_as_size", + "torch._constrain_as_value", } -if torch.distributed.is_available(): - # Inline the checkpoint code from distributed - import torch.distributed.algorithms._checkpoint.checkpoint_wrapper - FILENAME_INLINELIST |= { - torch.distributed.algorithms._checkpoint.checkpoint_wrapper.__file__ +# Force inline functions in these files or directories, even they are in *_SKIPLIST. +# We are using python module name instead of file or directory object to avoid circular dependency. +# Please keep this sorted alphabetically. +# TODO: Merge FILE_INLINELIST into SUBMODULE_INLINELIST. +FILE_INLINELIST = { + "torch._dynamo._trace_wrapped_higher_order_op", + "torch._dynamo.comptime", + "torch._dynamo.external_utils", + "torch._dynamo.polyfill", + "torch._export.db.examples", + "torch._export.wrappers", + "torch._functorch.apis", + "torch._functorch.deprecated", + "torch._higher_order_ops.cond", + "torch._inductor.test_operators", + "torch.ao.quantization.pt2e.eval_utils", + "torch.ao.quantization.pt2e.qat_utils", + "torch.ao.quantization.pt2e.representation.rewrite", + "torch.ao.quantization.pt2e.utils", + "torch.ao.quantization.quantizer.xnnpack_quantizer", + "torch.nn.modules.container", + "torch.optim._functional", + "torch.random", + "torch.utils._content_store", + "torch.utils._foreach_utils", +} + + +if torch.distributed.is_available(): + FILE_INLINELIST |= { + "torch.distributed._tensor.api", + "torch.distributed._tensor.device_mesh", + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper", + "torch.distributed.tensor.parallel._data_parallel_utils", + "torch.distributed.tensor.parallel._utils", + "torch.distributed.tensor.parallel.style", } # Include optimizer code for tracing -FILENAME_INLINELIST |= { - inspect.getfile(obj) - for obj in torch.optim.__dict__.values() - if inspect.isclass(obj) -} - -# TODO (zhxchen17) Make exportdb importable here. -FILENAME_INLINELIST |= set( - glob.glob(_module_dir(torch) + "_export/db/examples/*.py"), -) | { - _module_dir(torch) + "_export/wrappers.py", +FILE_INLINELIST |= { + str(obj.__module__) for obj in torch.optim.__dict__.values() if inspect.isclass(obj) } +# TODO: consolidate SUBMODULE_INLINELIST and FILE_INLINELIST into one list # Force inline functions under these modules, even the modules is in *_SKIPLIST. SUBMODULE_INLINELIST = { - torch.nn, - torch.distributions, - torch.testing, - torch.ao.nn, - torch._refs, - torch._prims, - torch._decomp, - torch.utils._contextlib, - torch.utils._pytree, - torch.fx._pytree, - torch.sparse, + "torch._refs", + "torch._prims", + "torch._decomp", + "torch.ao.nn", + "torch.distributions", + "torch.fx._pytree", + "torch.nn", + "torch.sparse", + "torch.testing", + "torch.utils._contextlib", + "torch.utils._pytree", } +if torch.distributed.is_available(): + SUBMODULE_INLINELIST.add("torch.distributed._functional_collectives") + + +# TODO: support adding bound method into this list +@functools.lru_cache(None) +def get_func_inlinelist(): + inlinelist = set() + for f in FUNC_INLINELIST: + module_name, fn_name = f.rsplit(".", 1) + m = importlib.import_module(module_name) + fn = getattr(m, fn_name) + inlinelist.add(fn.__code__) + return inlinelist + + +@functools.lru_cache(None) +def get_file_inlinelist(): + inlinelist = set() + for f in FILE_INLINELIST: + inlinelist.add(_module_dir(torch) + f[len("torch.") :].replace(".", "/")) + return inlinelist + + +@functools.lru_cache(None) +def get_submodule_inlinelist(): + inlinelist = set() + for m in SUBMODULE_INLINELIST: + inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + return inlinelist + + # skip some standard python builtin libs SKIP_DIRS = [ "