Skip to content

Commit

Permalink
fix: sub module conflict error
Browse files Browse the repository at this point in the history
  • Loading branch information
StellarisW committed Nov 29, 2024
1 parent b740707 commit c5b2411
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 1 deletion.
1 change: 1 addition & 0 deletions tests/parser-cases/include.thrift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include "included.thrift"
include "include/included.thrift"

const included.Timestamp datetime = 1422009523
1 change: 1 addition & 0 deletions tests/parser-cases/include/included.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include "included_1.thrift"
Empty file.
10 changes: 10 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ def test_include():
assert thrift.datetime == 1422009523


def test_include_with_module_name_prefix():
thrift = load('parser-cases/include.thrift', module_name='parser_cases.include_thrift')
included1_thrift = thrift.__thrift_meta__['includes'][0]
assert included1_thrift.__name__ == 'parser_cases.include_thrift'
included2_thrift = thrift.__thrift_meta__['includes'][1]
assert included2_thrift.__name__ == 'parser_cases.included.include_thrift'
included3_thrift = included2_thrift.__thrift_meta__['includes'][0]
assert included3_thrift.__name__ == 'parser_cases.included.include_1_thrift'


def test_include_conflict():
with pytest.raises(ThriftParserError) as excinfo:
load('parser-cases/foo.bar.thrift', module_name='foo.bar_thrift')
Expand Down
7 changes: 6 additions & 1 deletion thriftpy2/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def load(path,
registered_thrift = sys.modules.get(include_thrift[1].__name__)
if registered_thrift is None:
sys.modules[include_thrift[1].__name__] = include_thrift[0]
include_thrifts.extend(include_thrift[0].__thrift_meta__["includes"])
if hasattr(include_thrift[0], "__thrift_meta__"):
include_thrifts.extend(
list(
zip(
include_thrift[0].__thrift_meta__["includes"],
include_thrift[0].__thrift_meta__["sub_modules"])))
else:
if registered_thrift.__thrift_file__ != include_thrift[0].__thrift_file__:
raise ThriftParserError(
Expand Down
6 changes: 6 additions & 0 deletions thriftpy2/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,14 @@ def p_include(p):
for include_dir in replace_include_dirs:
path = os.path.join(include_dir, p[2])
if os.path.exists(path):
thrift_file_name_module = os.path.basename(thrift.__thrift_file__)
if thrift_file_name_module.endswith(".thrift"):
thrift_file_name_module = thrift_file_name_module[:-7] + "_thrift"
module_prefix = str(thrift.__name__).rstrip(thrift_file_name_module)

child_rel_path = os.path.relpath(str(path), os.path.dirname(thrift.__thrift_file__))
child_module_name = str(child_rel_path).replace(os.sep, ".").replace(".thrift", "_thrift")
child_module_name = module_prefix + child_module_name

child = parse(path, module_name=child_module_name)
setattr(thrift, str(child.__name__).replace("_thrift", ""), child)
Expand Down

0 comments on commit c5b2411

Please sign in to comment.