forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[2/N] Dynamo supports skip by function & removes skipfiles circular i…
…mport (pytorch#110835) Several improvements for skipfiles: * Add ```FUNC_INLINELIST``` to support function level skip/inline check. * Use ```fn.__code__``` to match function since we can't get the function object sometimes. * Use python module string name for ```FILE_INLINELIST``` and ```SUBMODULE_INLINELIST```. * Use filename to match file and python module, which can fundamentally resolved the circular import issues introduced by skipfiles. * Use ```TYPE_CHECKING``` to ensure the python module string name is correct. * Add unit tests. Pull Request resolved: pytorch#110835 Approved by: https://github.com/ezyang
- Loading branch information
1 parent
a6b452d
commit 986ad3b
Showing
6 changed files
with
299 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Owner(s): ["module: dynamo"] | ||
import importlib | ||
import types | ||
import unittest | ||
|
||
import torch | ||
import torch._dynamo.test_case | ||
from torch._dynamo.skipfiles import ( | ||
FILE_INLINELIST, | ||
FUNC_INLINELIST, | ||
SUBMODULE_INLINELIST, | ||
) | ||
from torch._dynamo.utils import istype | ||
|
||
try: | ||
from .utils import create_dummy_module_and_function | ||
except ImportError: | ||
from utils import create_dummy_module_and_function | ||
|
||
|
||
def gen_get_func_inlinelist(dummy_func_inlinelist): | ||
def get_func_inlinelist(): | ||
inlinelist = set() | ||
for f in dummy_func_inlinelist: | ||
module_name, fn_name = f.rsplit(".", 1) | ||
m = importlib.import_module(module_name) | ||
fn = getattr(m, fn_name) | ||
inlinelist.add(fn.__code__) | ||
return inlinelist | ||
|
||
return get_func_inlinelist | ||
|
||
|
||
class AllowInlineSkipTests(torch._dynamo.test_case.TestCase): | ||
# We are using python function and module string names for these inlinelist, | ||
# this unit test is to make sure the functions/modules can be correctly imported | ||
# or loaded in case there is typo in the strings. | ||
def test_skipfiles_inlinelist_correctness(self): | ||
for m in FILE_INLINELIST.union(SUBMODULE_INLINELIST): | ||
self.assertTrue(isinstance(importlib.import_module(m), types.ModuleType)) | ||
for f in FUNC_INLINELIST: | ||
module_name, fn_name = f.rsplit(".", 1) | ||
m = importlib.import_module(module_name) | ||
self.assertTrue(isinstance(getattr(m, fn_name), types.FunctionType)) | ||
|
||
def test_func_inlinelist_torch_function(self): | ||
def fn(x): | ||
if istype(x, torch.Tensor): | ||
return x + 1 | ||
else: | ||
return x - 1 | ||
|
||
func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy() | ||
func_inlinelist.add("torch._dynamo.utils.istype") | ||
|
||
self.assertTrue( | ||
"torch._dynamo.utils" not in torch._dynamo.skipfiles.FILE_INLINELIST | ||
) | ||
self.assertTrue( | ||
"torch._dynamo" not in torch._dynamo.skipfiles.SUBMODULE_INLINELIST | ||
) | ||
|
||
with unittest.mock.patch( | ||
"torch._dynamo.skipfiles.get_func_inlinelist", | ||
gen_get_func_inlinelist(func_inlinelist), | ||
): | ||
x = torch.rand(3) | ||
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) | ||
ref = fn(x) | ||
res = opt_fn(x) | ||
self.assertEqual(ref, res) | ||
|
||
def test_func_inlinelist_third_party_function(self): | ||
mod, func = create_dummy_module_and_function() | ||
|
||
def fn(x): | ||
return func(x) | ||
|
||
func_inlinelist = torch._dynamo.skipfiles.FUNC_INLINELIST.copy() | ||
func_inlinelist.add(f"{mod.__name__}.{func.__name__}") | ||
|
||
with unittest.mock.patch( | ||
"torch._dynamo.skipfiles.get_func_inlinelist", | ||
gen_get_func_inlinelist(func_inlinelist), | ||
), unittest.mock.patch( | ||
"torch._dynamo.skipfiles.SKIP_DIRS", | ||
torch._dynamo.skipfiles.SKIP_DIRS.copy(), | ||
): | ||
# First adding the module to SKIP_DIRS so that it will be skipped. | ||
torch._dynamo.skipfiles.add(mod.__name__) | ||
x = torch.rand(3) | ||
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) | ||
ref = fn(x) | ||
res = opt_fn(x) | ||
self.assertEqual(ref, res) | ||
|
||
|
||
if __name__ == "__main__": | ||
from torch._dynamo.test_case import run_tests | ||
|
||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.