Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async Stub Annotations #611

Merged
merged 7 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ per-file-ignores =
*.py: E203, E301, E302, E305, E501
*.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037
*_pb2.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037, Y021
*_pb2_grpc.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037, Y021
*_pb2_grpc.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037, Y021, Y023

extend_exclude = venv*,*_pb2.py,*_pb2_grpc.py,build/
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

- Mark top-level mangled identifiers as `TypeAlias`.
- Change the top-level mangling prefix from `global___` to `Global___` to respect
[Y042](https://github.com/PyCQA/flake8-pyi/blob/main/ERRORCODES.md#list-of-warnings) naming convention.
[Y042](https://github.com/PyCQA/flake8-pyi/blob/main/ERRORCODES.md#list-of-warnings) naming convention.
- Support client stub async typing overloads

## 3.6.0

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ black .
- [@fergyfresh](https://github.com/fergyfresh)
- [@AlexWaygood](https://github.com/AlexWaygood)
- [@Avasam](https://github.com/Avasam)
- [@artificial-aidan](https://github.com/artificial-aidan)

## Licence etc.

Expand Down
107 changes: 83 additions & 24 deletions mypy_protobuf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
Iterator,
List,
Optional,
Set,
Sequence,
Set,
Tuple,
)

import google.protobuf.descriptor_pb2 as d
from google.protobuf.compiler import plugin_pb2 as plugin_pb2
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
from google.protobuf.internal.well_known_types import WKTBASES

from . import extensions_pb2

__version__ = "3.6.0"
Expand Down Expand Up @@ -85,6 +86,11 @@
}


def _build_typevar_name(service_name: str, method_name: str) -> str:
# Prefix with underscore to avoid public api error: https://stackoverflow.com/a/78871465
return f"_{service_name}{method_name}Type"


def _mangle_global_identifier(name: str) -> str:
"""
Module level identifiers are mangled and aliased so that they can be disambiguated
Expand Down Expand Up @@ -168,9 +174,7 @@ def _import(self, path: str, name: str) -> str:
eg. self._import("typing", "Literal") -> "Literal"
"""
if path == "typing_extensions":
stabilization = {
"TypeAlias": (3, 10),
}
stabilization = {"TypeAlias": (3, 10), "TypeVar": (3, 13)}
assert name in stabilization
if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]:
self.typing_extensions_min = stabilization[name]
Expand Down Expand Up @@ -732,6 +736,46 @@ def write_grpc_async_hacks(self) -> None:
wl("...")
wl("")

def write_grpc_type_vars(self, service: d.ServiceDescriptorProto) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
if not methods:
return
for _, method in methods:
wl("{} = {}(", _build_typevar_name(service.name, method.name), self._import("typing_extensions", "TypeVar"))
with self._indent():
wl("'{}',", _build_typevar_name(service.name, method.name))
wl("{}[", self._callable_type(method, is_async=False))
with self._indent():
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("],")
wl("{}[", self._callable_type(method, is_async=True))
with self._indent():
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("],")
wl("default={}[", self._callable_type(method, is_async=False))
with self._indent():
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("],")
wl(")")
wl("")

def write_self_types(self, service: d.ServiceDescriptorProto, is_async: bool) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
if not methods:
return
for _, method in methods:
with self._indent():
wl("{}[", self._callable_type(method, is_async=is_async))
with self._indent():
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("],")

def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
Expand Down Expand Up @@ -769,11 +813,7 @@ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix:
for i, method in methods:
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]

wl("{}: {}[", method.name, self._callable_type(method, is_async=is_async))
with self._indent():
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("]")
wl("{}: {}", method.name, f"{_build_typevar_name(service.name, method.name)}")
self._write_comments(scl)
wl("")

Expand All @@ -791,29 +831,48 @@ def write_grpc_services(

scl = scl_prefix + [i]

# Type vars
self.write_grpc_type_vars(service)

# The stub client
class_name = f"{service.name}Stub"
wl(
"class {}Stub:",
service.name,
"class {}({}[{}]):",
class_name,
self._import("typing", "Generic"),
", ".join(f"{_build_typevar_name(service.name, method.name)}" for method in service.method),
)
with self._indent():
if self._write_comments(scl):
wl("")
# To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]"
wl("def __init__(self, channel: {}) -> None: ...", channel)

# Write sync overload
wl("@{}", self._import("typing", "overload"))
wl("def __init__(self: {}[", class_name)
self.write_self_types(service, False)
wl(
"], channel: {}) -> None: ...",
self._import("grpc", "Channel"),
)
wl("")

# Write async overload
wl("@{}", self._import("typing", "overload"))
wl("def __init__(self: {}[", class_name)
self.write_self_types(service, True)
wl(
"], channel: {}) -> None: ...",
self._import("grpc.aio", "Channel"),
)
wl("")

self.write_grpc_stub_methods(service, scl)

# The (fake) async stub client
wl(
"class {}AsyncStub:",
service.name,
)
with self._indent():
if self._write_comments(scl):
wl("")
# No __init__ since this isn't a real class (yet), and requires manual casting to work.
self.write_grpc_stub_methods(service, scl, is_async=True)
# Write AsyncStub alias
wl("{}AsyncStub: {} = {}[", service.name, self._import("typing_extensions", "TypeAlias"), class_name)
self.write_self_types(service, True)
wl("]")
wl("")

# The service definition interface
wl(
Expand Down
3 changes: 2 additions & 1 deletion run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ for PY_VER in $PY_VER_UNIT_TESTS; do
# Write output to file. Make variant w/ omitted line numbers for easy diffing / CR
PY_VER_MYPY_TARGET=$(echo "$1" | cut -d. -f1-2)
export MYPYPATH=$MYPYPATH:test/generated
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable="venv_$1/bin/python" --python-version="$PY_VER_MYPY_TARGET" "${@: 2}" > "$MYPY_OUTPUT/mypy_output" || true
# Use --no-incremental to avoid caching issues: https://github.com/python/mypy/issues/16363
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable="venv_$1/bin/python" --no-incremental --python-version="$PY_VER_MYPY_TARGET" "${@: 2}" > "$MYPY_OUTPUT/mypy_output" || true
cut -d: -f1,3- "$MYPY_OUTPUT/mypy_output" > "$MYPY_OUTPUT/mypy_output.omit_linenos"
}

Expand Down
153 changes: 121 additions & 32 deletions test/generated/testproto/grpc/dummy_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ import abc
import collections.abc
import grpc
import grpc.aio
import sys
import testproto.grpc.dummy_pb2
import typing

if sys.version_info >= (3, 13):
import typing as typing_extensions
else:
import typing_extensions

_T = typing.TypeVar("_T")

class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
Expand All @@ -19,60 +25,143 @@ class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type:

GRPC_GENERATED_VERSION: str
GRPC_VERSION: str
class DummyServiceStub:
"""DummyService"""

def __init__(self, channel: typing.Union[grpc.Channel, grpc.aio.Channel]) -> None: ...
UnaryUnary: grpc.UnaryUnaryMultiCallable[
_DummyServiceUnaryUnaryType = typing_extensions.TypeVar(
'_DummyServiceUnaryUnaryType',
grpc.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
"""UnaryUnary"""
],
grpc.aio.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
default=grpc.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
)

UnaryStream: grpc.UnaryStreamMultiCallable[
_DummyServiceUnaryStreamType = typing_extensions.TypeVar(
'_DummyServiceUnaryStreamType',
grpc.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
"""UnaryStream"""
],
grpc.aio.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
default=grpc.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
)

StreamUnary: grpc.StreamUnaryMultiCallable[
_DummyServiceStreamUnaryType = typing_extensions.TypeVar(
'_DummyServiceStreamUnaryType',
grpc.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
"""StreamUnary"""
],
grpc.aio.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
default=grpc.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
)

StreamStream: grpc.StreamStreamMultiCallable[
_DummyServiceStreamStreamType = typing_extensions.TypeVar(
'_DummyServiceStreamStreamType',
grpc.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
"""StreamStream"""
],
grpc.aio.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
default=grpc.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
)

class DummyServiceAsyncStub:
class DummyServiceStub(typing.Generic[_DummyServiceUnaryUnaryType, _DummyServiceUnaryStreamType, _DummyServiceStreamUnaryType, _DummyServiceStreamStreamType]):
"""DummyService"""

UnaryUnary: grpc.aio.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
@typing.overload
def __init__(self: DummyServiceStub[
grpc.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
], channel: grpc.Channel) -> None: ...

@typing.overload
def __init__(self: DummyServiceStub[
grpc.aio.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.aio.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.aio.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.aio.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
], channel: grpc.aio.Channel) -> None: ...

UnaryUnary: _DummyServiceUnaryUnaryType
"""UnaryUnary"""

UnaryStream: grpc.aio.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
UnaryStream: _DummyServiceUnaryStreamType
"""UnaryStream"""

StreamUnary: grpc.aio.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
StreamUnary: _DummyServiceStreamUnaryType
"""StreamUnary"""

StreamStream: grpc.aio.StreamStreamMultiCallable[
StreamStream: _DummyServiceStreamStreamType
"""StreamStream"""

DummyServiceAsyncStub: typing_extensions.TypeAlias = DummyServiceStub[
grpc.aio.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
"""StreamStream"""
],
grpc.aio.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.aio.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.aio.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
]

class DummyServiceServicer(metaclass=abc.ABCMeta):
"""DummyService"""
Expand Down
Loading
Loading