Skip to content

Commit

Permalink
[2/N] Dynamo supports skip by function & removes skipfiles circular i…
Browse files Browse the repository at this point in the history
…mport (pytorch#110835)

Several improvements for skipfiles:
* Add ```FUNC_INLINELIST``` to support function level skip/inline check.
  * Use ```fn.__code__``` to match function since we can't get the function object sometimes.
* Use python module string name for ```FILE_INLINELIST``` and ```SUBMODULE_INLINELIST```.
  * Use filename to match file and python module, which can fundamentally resolved the circular import issues introduced by skipfiles.
  * Use ```TYPE_CHECKING``` to ensure the python module string name is correct.
* Add unit tests.

Pull Request resolved: pytorch#110835
Approved by: https://github.com/ezyang
  • Loading branch information
yanboliang authored and pytorchmergebot committed Oct 12, 2023
1 parent a6b452d commit 986ad3b
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 82 deletions.
101 changes: 101 additions & 0 deletions test/dynamo/test_allow_inline_skip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Owner(s): ["module: dynamo"]
import importlib
import types
import unittest

import torch
import torch._dynamo.test_case
from torch._dynamo.skipfiles import (
FILE_INLINELIST,
FUNC_INLINELIST,
SUBMODULE_INLINELIST,
)
from torch._dynamo.utils import istype

try:
from .utils import create_dummy_module_and_function
except ImportError:
from utils import create_dummy_module_and_function


def gen_get_func_inlinelist(dummy_func_inlinelist):
def get_func_inlinelist():
inlinelist = set()
for f in dummy_func_inlinelist:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
fn = getattr(m, fn_name)
inlinelist.add(fn.__code__)
return inlinelist

return get_func_inlinelist


class AllowInlineSkipTests(torch._dynamo.test_case.TestCase):
# We are using python function and module string names for these inlinelist,
# this unit test is to make sure the functions/modules can be correctly imported
# or loaded in case there is typo in the strings.
def test_skipfiles_inlinelist_correctness(self):
for m in FILE_INLINELIST.union(SUBMODULE_INLINELIST):
self.assertTrue(isinstance(importlib.import_module(m), types.ModuleType))
for f in FUNC_INLINELIST:
module_name, fn_name = f.rsplit(".", 1)
m = importlib.import_module(module_name)
self.assertTrue(isinstance(getattr(m, fn_name), types.FunctionType))

def test_func_inlinelist_torch_function(self):
def fn(x):
if istype(x, torch.Tensor):
return x + 1
else:
return x - 1

func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
func_inlinelist.add("torch._dynamo.utils.istype")

self.assertTrue(
"torch._dynamo.utils" not in torch._dynamo.skipfiles.FILE_INLINELIST
)
self.assertTrue(
"torch._dynamo" not in torch._dynamo.skipfiles.SUBMODULE_INLINELIST
)

with unittest.mock.patch(
"torch._dynamo.skipfiles.get_func_inlinelist",
gen_get_func_inlinelist(func_inlinelist),
):
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)

def test_func_inlinelist_third_party_function(self):
mod, func = create_dummy_module_and_function()

def fn(x):
return func(x)

func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy()
func_inlinelist.add(f"{mod.__name__}.{func.__name__}")

with unittest.mock.patch(
"torch._dynamo.skipfiles.get_func_inlinelist",
gen_get_func_inlinelist(func_inlinelist),
), unittest.mock.patch(
"torch._dynamo.skipfiles.SKIP_DIRS",
torch._dynamo.skipfiles.SKIP_DIRS.copy(),
):
# First adding the module to SKIP_DIRS so that it will be skipped.
torch._dynamo.skipfiles.add(mod.__name__)
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
28 changes: 28 additions & 0 deletions test/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Owner(s): ["module: dynamo"]
import importlib
import os
import sys
import types

import torch
import torch._dynamo
Expand All @@ -20,3 +24,27 @@ def wrapped(*args):
return torch.sin(a + 1), inner_func()

return wrapped


# Create a dummy python module and function to test skipfiles rules.
module_code = """
def add(x):
return x + 1
"""


def add(x):
return x + 1


def create_dummy_module_and_function():
module = types.ModuleType("dummy_module")
module.__spec__ = importlib.machinery.ModuleSpec(
"dummy_module", None, origin=os.path.abspath(__file__)
)
exec(module_code, module.__dict__)
sys.modules["dummy_module"] = module
# Need to override the original function since its __code__.co_filename is not a regular python file name,
# and the skipfiles rules use filename when checking SKIP_DIRS.
module.add = add
return module, module.add
8 changes: 4 additions & 4 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx):
def _initialize(self):
# Do this stuff in constructor to lower overhead slightly
if isinstance(self._orig_mod.forward, types.MethodType) and skipfiles.check(
inspect.getsourcefile(self._orig_mod.forward)
self._orig_mod.forward
):
# This may be a torch.nn.* instance in skipfiles.py which
# won't trigger a frame evaluation workaround to add an extra
Expand Down Expand Up @@ -362,7 +362,7 @@ def get_compiler_config():
except TypeError:
filename = None
if (
(filename is None or skipfiles.check(filename))
(filename is None or skipfiles.check(fn))
and (
getattr(fn, "__name__", "") not in ["_call_impl", "_wrapped_call_impl"]
)
Expand Down Expand Up @@ -519,7 +519,7 @@ def catch_errors(frame, cache_entry, frame_state):
if (
# TODO: the first condition is not covered by any test
frame.f_lasti >= first_real_inst_idx(frame.f_code)
or skipfiles.check(frame.f_code.co_filename)
or skipfiles.check(frame.f_code)
or config.disable
):
log.debug("skipping %s %s", frame.f_code.co_name, frame.f_code.co_filename)
Expand Down Expand Up @@ -1218,7 +1218,7 @@ def result_capturing_wrapper(*graph_inputs):
if (
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
and not skipfiles.check(call_to_inspect)
):
dim_constraints.solve()
dim_constraints.remove_redundant_dynamic_results()
Expand Down
Loading

0 comments on commit 986ad3b

Please sign in to comment.