diff --git a/tests/parser-cases/include.thrift b/tests/parser-cases/include.thrift index a38743c0..37dbc8b4 100644 --- a/tests/parser-cases/include.thrift +++ b/tests/parser-cases/include.thrift @@ -1,4 +1,4 @@ include "included.thrift" -include "include/included.thrift" +include "include/included_1.thrift" const included.Timestamp datetime = 1422009523 diff --git a/tests/parser-cases/include/included.thrift b/tests/parser-cases/include/included.thrift deleted file mode 100644 index 15209552..00000000 --- a/tests/parser-cases/include/included.thrift +++ /dev/null @@ -1 +0,0 @@ -include "included_1.thrift" \ No newline at end of file diff --git a/tests/parser-cases/include/included_1.thrift b/tests/parser-cases/include/included_1.thrift index e69de29b..a803db8e 100644 --- a/tests/parser-cases/include/included_1.thrift +++ b/tests/parser-cases/include/included_1.thrift @@ -0,0 +1 @@ +include "included_2.thrift" \ No newline at end of file diff --git a/tests/parser-cases/include/included_2.thrift b/tests/parser-cases/include/included_2.thrift new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_parser.py b/tests/test_parser.py index 3b9609a1..5752dcfd 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +import sys import threading import pytest @@ -42,12 +42,10 @@ def test_include(): def test_include_with_module_name_prefix(): thrift = load('parser-cases/include.thrift', module_name='parser_cases.include_thrift') - included1_thrift = thrift.__thrift_meta__['includes'][0] - assert included1_thrift.__name__ == 'parser_cases.include_thrift' - included2_thrift = thrift.__thrift_meta__['includes'][1] - assert included2_thrift.__name__ == 'parser_cases.included.include_thrift' - included3_thrift = included2_thrift.__thrift_meta__['includes'][0] - assert included3_thrift.__name__ == 'parser_cases.included.include_1_thrift' + assert sys.modules['parser_cases.include_thrift'] is not None + assert sys.modules['parser_cases.included_thrift'] is not None + assert sys.modules['parser_cases.include.included_1_thrift'] is not None + assert sys.modules['parser_cases.include.included_2_thrift'] is not None def test_include_conflict(): diff --git a/thriftpy2/parser/parser.py b/thriftpy2/parser/parser.py index a960455a..aafd93c0 100644 --- a/thriftpy2/parser/parser.py +++ b/thriftpy2/parser/parser.py @@ -72,7 +72,11 @@ def p_include(p): child_module_name = module_prefix + child_module_name child = parse(path, module_name=child_module_name) - setattr(thrift, str(child.__name__).replace("_thrift", ""), child) + child_include_module_name = os.path.basename(path) + if child_include_module_name.endswith(".thrift"): + child_include_module_name = child_include_module_name[:-7] + child.__name__=child_include_module_name + setattr(thrift, child.__name__, child) _add_thrift_meta('includes', child) _add_thrift_meta('sub_modules', types.ModuleType(child_module_name)) return