Skip to content

Commit

Permalink
fix mypy stub gen in multiprocessing case
Browse files Browse the repository at this point in the history
  • Loading branch information
zerlok committed Nov 21, 2024
1 parent 9a62c88 commit 62daa37
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 24 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pyprotostuben"
version = "0.2.0"
version = "0.2.1"
description = "Generate Python MyPy stub modules from protobuf files."
authors = ["zerlok <[email protected]>"]
readme = "README.md"
Expand Down
23 changes: 17 additions & 6 deletions src/pyprotostuben/codegen/mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def run(self, request: CodeGeneratorRequest) -> CodeGeneratorResponse:

def __create_generator(self, context: CodeGeneratorContext) -> ProtoFileGenerator:
return ModuleASTBasedProtoFileGenerator(
context_factory=self.__create_visitor_context,
context_factory=_MultiProcessFuncs.create_visitor_context,
visitor=MypyStubASTGenerator(
registry=context.registry,
message_context_factory=partial(self.__create_file_message_context, context.params),
grpc_context_factory=partial(self.__create_file_grpc_context, context.params),
message_context_factory=partial(_MultiProcessFuncs.create_file_message_context, context.params),
grpc_context_factory=partial(_MultiProcessFuncs.create_file_grpc_context, context.params),
),
)

Expand All @@ -71,15 +71,25 @@ def __build_file(self, item: GeneratedItem) -> CodeGeneratorResponse.File:
content=item.content,
)

def __create_visitor_context(self, file: ProtoFile) -> MypyStubContext:

class _MultiProcessFuncs:
"""
A set of picklable functions that can be passed to `MultiProcessPool`.
For more info: https://docs.python.org/3/library/multiprocessing.html#programming-guidelines
"""

@staticmethod
def create_visitor_context(file: ProtoFile) -> MypyStubContext:
return MypyStubContext(
file=file,
modules={},
messages=MutableStack(),
grpcs=MutableStack(),
)

def __create_file_message_context(self, params: CodeGeneratorParameters, file: ProtoFile) -> MessageContext:
@staticmethod
def create_file_message_context(params: CodeGeneratorParameters, file: ProtoFile) -> MessageContext:
deps: t.Set[ModuleInfo] = set()
module = file.pb2_message

Expand All @@ -94,7 +104,8 @@ def __create_file_message_context(self, params: CodeGeneratorParameters, file: P
),
)

def __create_file_grpc_context(self, params: CodeGeneratorParameters, file: ProtoFile) -> GRPCContext:
@staticmethod
def create_file_grpc_context(params: CodeGeneratorParameters, file: ProtoFile) -> GRPCContext:
deps: t.Set[ModuleInfo] = set()
module = file.pb2_grpc

Expand Down
10 changes: 7 additions & 3 deletions tests/integration/case.py → tests/integration/cases/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ def __init__( # noqa: PLR0913
self,
filename: str,
plugin: ProtocPlugin,
parameter: t.Optional[str] = None,
proto_source: t.Optional[str] = None,
proto_paths: t.Optional[t.Sequence[str]] = None,
expected_gen_source: t.Optional[str] = None,
expected_gen_paths: t.Optional[t.Sequence[str]] = None,
) -> None:
self.__case_dir = Path(filename).parent
self.__plugin = plugin
self.__parameter = parameter

self.__proto_source = self.__case_dir / (proto_source or "proto")
self.__expected_gen_source = self.__case_dir / (expected_gen_source or "expected_gen")
Expand All @@ -57,12 +59,14 @@ def __init__( # noqa: PLR0913
)

def provide(self, tmp_path: Path) -> Case:
gen_request = read_request(self.__proto_source, self.__proto_paths, tmp_path)
gen_request.parameter = "no-parallel" # for easier debug
request = read_request(self.__proto_source, self.__proto_paths, tmp_path)

if self.__parameter is not None:
request.parameter = self.__parameter

return Case(
generator=self.__plugin,
request=gen_request,
request=request,
gen_expected_files=[
load_codegen_response_file_content(self.__expected_gen_source, path)
for path in self.__expected_gen_paths
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/cases/case_000_greeting/case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin
from tests.integration.case import DirCaseProvider
from tests.integration.cases.case import DirCaseProvider

mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin())
mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel")
mypy_case_multiprocessing = DirCaseProvider(__file__, MypyStubProtocPlugin())
4 changes: 2 additions & 2 deletions tests/integration/cases/case_001_types/case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin
from tests.integration.case import DirCaseProvider
from tests.integration.cases.case import DirCaseProvider

mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin())
mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel")
4 changes: 2 additions & 2 deletions tests/integration/cases/case_002_grpc_service/case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin
from tests.integration.case import DirCaseProvider
from tests.integration.cases.case import DirCaseProvider

mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin())
mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel")
4 changes: 2 additions & 2 deletions tests/integration/cases/case_003_message_nesting/case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin
from tests.integration.case import DirCaseProvider
from tests.integration.cases.case import DirCaseProvider

mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin())
mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel")
4 changes: 2 additions & 2 deletions tests/integration/cases/case_004_message_import/case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin
from tests.integration.case import DirCaseProvider
from tests.integration.cases.case import DirCaseProvider

mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin())
mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel")
4 changes: 2 additions & 2 deletions tests/integration/cases/case_005_comments/case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin
from tests.integration.case import DirCaseProvider
from tests.integration.cases.case import DirCaseProvider

mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin())
mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel")
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from _pytest.fixtures import SubRequest

from tests.integration.case import Case, CaseProvider
from tests.integration.cases.case import Case, CaseProvider


@pytest.fixture(
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_generator_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from google.protobuf.compiler.plugin_pb2 import CodeGeneratorResponse

from tests.integration.case import Case
from tests.integration.cases.case import Case


def test_file_content_matches(case: Case) -> None:
Expand Down

0 comments on commit 62daa37

Please sign in to comment.