diff --git a/pyproject.toml b/pyproject.toml index b763aba..d64566b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pyprotostuben" -version = "0.2.0" +version = "0.2.1" description = "Generate Python MyPy stub modules from protobuf files." authors = ["zerlok "] readme = "README.md" diff --git a/src/pyprotostuben/codegen/mypy/plugin.py b/src/pyprotostuben/codegen/mypy/plugin.py index 82102b4..2ddd198 100644 --- a/src/pyprotostuben/codegen/mypy/plugin.py +++ b/src/pyprotostuben/codegen/mypy/plugin.py @@ -47,11 +47,11 @@ def run(self, request: CodeGeneratorRequest) -> CodeGeneratorResponse: def __create_generator(self, context: CodeGeneratorContext) -> ProtoFileGenerator: return ModuleASTBasedProtoFileGenerator( - context_factory=self.__create_visitor_context, + context_factory=_MultiProcessFuncs.create_visitor_context, visitor=MypyStubASTGenerator( registry=context.registry, - message_context_factory=partial(self.__create_file_message_context, context.params), - grpc_context_factory=partial(self.__create_file_grpc_context, context.params), + message_context_factory=partial(_MultiProcessFuncs.create_file_message_context, context.params), + grpc_context_factory=partial(_MultiProcessFuncs.create_file_grpc_context, context.params), ), ) @@ -71,7 +71,16 @@ def __build_file(self, item: GeneratedItem) -> CodeGeneratorResponse.File: content=item.content, ) - def __create_visitor_context(self, file: ProtoFile) -> MypyStubContext: + +class _MultiProcessFuncs: + """ + A set of picklable functions that can be passed to `MultiProcessPool`. + + For more info: https://docs.python.org/3/library/multiprocessing.html#programming-guidelines + """ + + @staticmethod + def create_visitor_context(file: ProtoFile) -> MypyStubContext: return MypyStubContext( file=file, modules={}, @@ -79,7 +88,8 @@ def __create_visitor_context(self, file: ProtoFile) -> MypyStubContext: grpcs=MutableStack(), ) - def __create_file_message_context(self, params: CodeGeneratorParameters, file: ProtoFile) -> MessageContext: + @staticmethod + def create_file_message_context(params: CodeGeneratorParameters, file: ProtoFile) -> MessageContext: deps: t.Set[ModuleInfo] = set() module = file.pb2_message @@ -94,7 +104,8 @@ def __create_file_message_context(self, params: CodeGeneratorParameters, file: P ), ) - def __create_file_grpc_context(self, params: CodeGeneratorParameters, file: ProtoFile) -> GRPCContext: + @staticmethod + def create_file_grpc_context(params: CodeGeneratorParameters, file: ProtoFile) -> GRPCContext: deps: t.Set[ModuleInfo] = set() module = file.pb2_grpc diff --git a/tests/integration/case.py b/tests/integration/cases/case.py similarity index 92% rename from tests/integration/case.py rename to tests/integration/cases/case.py index c15690a..3aa256b 100644 --- a/tests/integration/case.py +++ b/tests/integration/cases/case.py @@ -31,6 +31,7 @@ def __init__( # noqa: PLR0913 self, filename: str, plugin: ProtocPlugin, + parameter: t.Optional[str] = None, proto_source: t.Optional[str] = None, proto_paths: t.Optional[t.Sequence[str]] = None, expected_gen_source: t.Optional[str] = None, @@ -38,6 +39,7 @@ def __init__( # noqa: PLR0913 ) -> None: self.__case_dir = Path(filename).parent self.__plugin = plugin + self.__parameter = parameter self.__proto_source = self.__case_dir / (proto_source or "proto") self.__expected_gen_source = self.__case_dir / (expected_gen_source or "expected_gen") @@ -57,12 +59,14 @@ def __init__( # noqa: PLR0913 ) def provide(self, tmp_path: Path) -> Case: - gen_request = read_request(self.__proto_source, self.__proto_paths, tmp_path) - gen_request.parameter = "no-parallel" # for easier debug + request = read_request(self.__proto_source, self.__proto_paths, tmp_path) + + if self.__parameter is not None: + request.parameter = self.__parameter return Case( generator=self.__plugin, - request=gen_request, + request=request, gen_expected_files=[ load_codegen_response_file_content(self.__expected_gen_source, path) for path in self.__expected_gen_paths diff --git a/tests/integration/cases/case_000_greeting/case.py b/tests/integration/cases/case_000_greeting/case.py index be8cef8..7bf4e80 100644 --- a/tests/integration/cases/case_000_greeting/case.py +++ b/tests/integration/cases/case_000_greeting/case.py @@ -1,4 +1,5 @@ from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin -from tests.integration.case import DirCaseProvider +from tests.integration.cases.case import DirCaseProvider -mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin()) +mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel") +mypy_case_multiprocessing = DirCaseProvider(__file__, MypyStubProtocPlugin()) diff --git a/tests/integration/cases/case_001_types/case.py b/tests/integration/cases/case_001_types/case.py index be8cef8..a72e5e4 100644 --- a/tests/integration/cases/case_001_types/case.py +++ b/tests/integration/cases/case_001_types/case.py @@ -1,4 +1,4 @@ from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin -from tests.integration.case import DirCaseProvider +from tests.integration.cases.case import DirCaseProvider -mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin()) +mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel") diff --git a/tests/integration/cases/case_002_grpc_service/case.py b/tests/integration/cases/case_002_grpc_service/case.py index be8cef8..a72e5e4 100644 --- a/tests/integration/cases/case_002_grpc_service/case.py +++ b/tests/integration/cases/case_002_grpc_service/case.py @@ -1,4 +1,4 @@ from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin -from tests.integration.case import DirCaseProvider +from tests.integration.cases.case import DirCaseProvider -mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin()) +mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel") diff --git a/tests/integration/cases/case_003_message_nesting/case.py b/tests/integration/cases/case_003_message_nesting/case.py index be8cef8..a72e5e4 100644 --- a/tests/integration/cases/case_003_message_nesting/case.py +++ b/tests/integration/cases/case_003_message_nesting/case.py @@ -1,4 +1,4 @@ from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin -from tests.integration.case import DirCaseProvider +from tests.integration.cases.case import DirCaseProvider -mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin()) +mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel") diff --git a/tests/integration/cases/case_004_message_import/case.py b/tests/integration/cases/case_004_message_import/case.py index be8cef8..a72e5e4 100644 --- a/tests/integration/cases/case_004_message_import/case.py +++ b/tests/integration/cases/case_004_message_import/case.py @@ -1,4 +1,4 @@ from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin -from tests.integration.case import DirCaseProvider +from tests.integration.cases.case import DirCaseProvider -mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin()) +mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel") diff --git a/tests/integration/cases/case_005_comments/case.py b/tests/integration/cases/case_005_comments/case.py index be8cef8..a72e5e4 100644 --- a/tests/integration/cases/case_005_comments/case.py +++ b/tests/integration/cases/case_005_comments/case.py @@ -1,4 +1,4 @@ from pyprotostuben.codegen.mypy.plugin import MypyStubProtocPlugin -from tests.integration.case import DirCaseProvider +from tests.integration.cases.case import DirCaseProvider -mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin()) +mypy_case = DirCaseProvider(__file__, MypyStubProtocPlugin(), "no-parallel") diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7752235..93e199e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -5,7 +5,7 @@ import pytest from _pytest.fixtures import SubRequest -from tests.integration.case import Case, CaseProvider +from tests.integration.cases.case import Case, CaseProvider @pytest.fixture( diff --git a/tests/integration/test_generator_cases.py b/tests/integration/test_generator_cases.py index 128bdfe..ad2f3c7 100644 --- a/tests/integration/test_generator_cases.py +++ b/tests/integration/test_generator_cases.py @@ -2,7 +2,7 @@ from google.protobuf.compiler.plugin_pb2 import CodeGeneratorResponse -from tests.integration.case import Case +from tests.integration.cases.case import Case def test_file_content_matches(case: Case) -> None: