diff --git a/tests/parser-cases/include.thrift b/tests/parser-cases/include.thrift index 14678cf..a38743c 100644 --- a/tests/parser-cases/include.thrift +++ b/tests/parser-cases/include.thrift @@ -1,3 +1,4 @@ include "included.thrift" +include "include/included.thrift" const included.Timestamp datetime = 1422009523 diff --git a/tests/parser-cases/include/included.thrift b/tests/parser-cases/include/included.thrift new file mode 100644 index 0000000..1520955 --- /dev/null +++ b/tests/parser-cases/include/included.thrift @@ -0,0 +1 @@ +include "included_1.thrift" \ No newline at end of file diff --git a/tests/parser-cases/include/included_1.thrift b/tests/parser-cases/include/included_1.thrift new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_parser.py b/tests/test_parser.py index a0388dc..3b9609a 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -40,6 +40,16 @@ def test_include(): assert thrift.datetime == 1422009523 +def test_include_with_module_name_prefix(): + thrift = load('parser-cases/include.thrift', module_name='parser_cases.include_thrift') + included1_thrift = thrift.__thrift_meta__['includes'][0] + assert included1_thrift.__name__ == 'parser_cases.include_thrift' + included2_thrift = thrift.__thrift_meta__['includes'][1] + assert included2_thrift.__name__ == 'parser_cases.included.include_thrift' + included3_thrift = included2_thrift.__thrift_meta__['includes'][0] + assert included3_thrift.__name__ == 'parser_cases.included.include_1_thrift' + + def test_include_conflict(): with pytest.raises(ThriftParserError) as excinfo: load('parser-cases/foo.bar.thrift', module_name='foo.bar_thrift') diff --git a/thriftpy2/parser/__init__.py b/thriftpy2/parser/__init__.py index 856941a..03c96ce 100644 --- a/thriftpy2/parser/__init__.py +++ b/thriftpy2/parser/__init__.py @@ -48,7 +48,12 @@ def load(path, registered_thrift = sys.modules.get(include_thrift[1].__name__) if registered_thrift is None: sys.modules[include_thrift[1].__name__] = include_thrift[0] - include_thrifts.extend(include_thrift[0].__thrift_meta__["includes"]) + if hasattr(include_thrift[0], "__thrift_meta__"): + include_thrifts.extend( + list( + zip( + include_thrift[0].__thrift_meta__["includes"], + include_thrift[0].__thrift_meta__["sub_modules"]))) else: if registered_thrift.__thrift_file__ != include_thrift[0].__thrift_file__: raise ThriftParserError( diff --git a/thriftpy2/parser/parser.py b/thriftpy2/parser/parser.py index be44975..a960455 100644 --- a/thriftpy2/parser/parser.py +++ b/thriftpy2/parser/parser.py @@ -62,8 +62,14 @@ def p_include(p): for include_dir in replace_include_dirs: path = os.path.join(include_dir, p[2]) if os.path.exists(path): + thrift_file_name_module = os.path.basename(thrift.__thrift_file__) + if thrift_file_name_module.endswith(".thrift"): + thrift_file_name_module = thrift_file_name_module[:-7] + "_thrift" + module_prefix = str(thrift.__name__).rstrip(thrift_file_name_module) + child_rel_path = os.path.relpath(str(path), os.path.dirname(thrift.__thrift_file__)) child_module_name = str(child_rel_path).replace(os.sep, ".").replace(".thrift", "_thrift") + child_module_name = module_prefix + child_module_name child = parse(path, module_name=child_module_name) setattr(thrift, str(child.__name__).replace("_thrift", ""), child)