diff --git a/example/client.py b/example/client.py index 5b8fd78..d2df794 100644 --- a/example/client.py +++ b/example/client.py @@ -1,15 +1,66 @@ +try: + import asyncio + import aiohttp +except ModuleNotFoundError: + pass + from twirp.context import Context from twirp.exceptions import TwirpServerException from generated import haberdasher_twirp, haberdasher_pb2 -client = haberdasher_twirp.HaberdasherClient("http://localhost:3000") -try: - response = client.MakeHat( - ctx=Context(), request=haberdasher_pb2.Size(inches=12), server_path_prefix="/twirpy") - if not response.HasField('name'): - print("We didn't get a name!") - print(response) -except TwirpServerException as e: - print(e.code, e.message, e.meta, e.to_dict()) +server_url = "http://localhost:3000" +timeout_s = 5 + + +def main(): + client = haberdasher_twirp.HaberdasherClient(server_url, timeout_s) + + try: + response = client.MakeHat( + ctx=Context(), + request=haberdasher_pb2.Size(inches=12), + server_path_prefix="/twirpy", + ) + if not response.HasField("name"): + print("We didn't get a name!") + print(response) + except TwirpServerException as e: + print(e.code, e.message, e.meta, e.to_dict()) + + +async def async_main(): + # The caller must provide their own ClientSession to the twirp client + # either on init or per request, and ensure it is closed properly on app shutdown. + + # NOTE: ClientSession may only be created (or closed) within a coroutine. + session = aiohttp.ClientSession( + server_url, timeout=aiohttp.ClientTimeout(total=timeout_s) + ) + client = haberdasher_twirp.AsyncHaberdasherClient(server_url, session=session) + + try: + response = await client.MakeHat( + ctx=Context(), + request=haberdasher_pb2.Size(inches=12), + server_path_prefix="/twirpy", + # Optionally provide a session per request + # session=session, + ) + if not response.HasField("name"): + print("We didn't get a name!") + print(response) + except TwirpServerException as e: + print(e.code, e.message, e.meta, e.to_dict()) + finally: + # Close the session (could also use a context manager) + await session.close() + + +if __name__ == "__main__": + if hasattr(haberdasher_twirp, "AsyncHaberdasherClient"): + print("using async client") + asyncio.run(async_main()) + else: + main() diff --git a/example/generated/haberdasher_twirp.py b/example/generated/haberdasher_twirp.py index 42fb488..019450b 100644 --- a/example/generated/haberdasher_twirp.py +++ b/example/generated/haberdasher_twirp.py @@ -7,6 +7,11 @@ from twirp.base import Endpoint from twirp.server import TwirpServer from twirp.client import TwirpClient +try: + from twirp.async_client import AsyncTwirpClient + _async_available = True +except ModuleNotFoundError: + _async_available = False _sym_db = _symbol_database.Default() @@ -35,3 +40,17 @@ def MakeHat(self, *args, ctx, request, server_path_prefix="/twirp", **kwargs): response_obj=_sym_db.GetSymbol("twitch.twirp.example.Hat"), **kwargs, ) + + +if _async_available: + class AsyncHaberdasherClient(AsyncTwirpClient): + + async def MakeHat(self, *, ctx, request, server_path_prefix="/twirp", session=None, **kwargs): + return await self._make_request( + url=F"{server_path_prefix}/twitch.twirp.example.Haberdasher/MakeHat", + ctx=ctx, + request=request, + response_obj=_sym_db.GetSymbol("twitch.twirp.example.Hat"), + session=session, + **kwargs, + ) diff --git a/example/requirements.txt b/example/requirements.txt index afa8e88..b73afa1 100644 --- a/example/requirements.txt +++ b/example/requirements.txt @@ -1,2 +1,2 @@ -twirp==0.0.2 -uvicorn==0.12.2 +twirp==0.0.8 +uvicorn==0.23.2 diff --git a/protoc-gen-twirpy/generator/template.go b/protoc-gen-twirpy/generator/template.go index 1ed0f57..e4bd968 100644 --- a/protoc-gen-twirpy/generator/template.go +++ b/protoc-gen-twirpy/generator/template.go @@ -38,6 +38,11 @@ from google.protobuf import symbol_database as _symbol_database from twirp.base import Endpoint from twirp.server import TwirpServer from twirp.client import TwirpClient +try: + from twirp.async_client import AsyncTwirpClient + _async_available = True +except ModuleNotFoundError: + _async_available = False _sym_db = _symbol_database.Default() {{range .Services}} @@ -66,4 +71,18 @@ class {{.Name}}Client(TwirpClient): response_obj=_sym_db.GetSymbol("{{.Output}}"), **kwargs, ) +{{end}} + +if _async_available: + class Async{{.Name}}Client(AsyncTwirpClient): +{{range .Methods}} + async def {{.Name}}(self, *, ctx, request, server_path_prefix="/twirp", session=None, **kwargs): + return await self._make_request( + url=F"{server_path_prefix}/{{.ServiceURL}}/{{.Name}}", + ctx=ctx, + request=request, + response_obj=_sym_db.GetSymbol("{{.Output}}"), + session=session, + **kwargs, + ) {{end}}{{end}}`)) diff --git a/setup.py b/setup.py index 1ff0252..1b4c415 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,9 @@ 'structlog', 'protobuf' ], + extras_require={ + 'async': ['aiohttp'], + }, test_requires=[ ], zip_safe=False) diff --git a/twirp/async_client.py b/twirp/async_client.py new file mode 100644 index 0000000..6710ed7 --- /dev/null +++ b/twirp/async_client.py @@ -0,0 +1,57 @@ +import asyncio +import json +from typing import Optional + +import aiohttp + +from . import exceptions +from . import errors + + +class AsyncTwirpClient: + def __init__( + self, address: str, session: Optional[aiohttp.ClientSession] = None + ) -> None: + self._address = address + self._session = session + + async def _make_request( + self, *, url, ctx, request, response_obj, session=None, **kwargs + ): + headers = ctx.get_headers() + if "headers" in kwargs: + headers.update(kwargs["headers"]) + kwargs["headers"] = headers + kwargs["headers"]["Content-Type"] = "application/protobuf" + + if session is None: + session = self._session + if not isinstance(session, aiohttp.ClientSession): + raise TypeError(f"invalid session type '{type(session).__name__}'") + + try: + async with await session.post( + url=url, data=request.SerializeToString(), **kwargs + ) as resp: + if resp.status == 200: + response = response_obj() + response.ParseFromString(await resp.read()) + return response + try: + raise exceptions.TwirpServerException.from_json(await resp.json()) + except (aiohttp.ContentTypeError, json.JSONDecodeError): + raise exceptions.twirp_error_from_intermediary( + resp.status, resp.reason, resp.headers, await resp.text() + ) from None + except asyncio.TimeoutError as e: + raise exceptions.TwirpServerException( + code=errors.Errors.DeadlineExceeded, + message=str(e) or "request timeout", + meta={"original_exception": e}, + ) + except aiohttp.ServerConnectionError as e: + raise exceptions.TwirpServerException( + code=errors.Errors.Unavailable, + message=str(e), + meta={"original_exception": e}, + ) diff --git a/twirp/base.py b/twirp/base.py index 6417206..ba0fbfa 100644 --- a/twirp/base.py +++ b/twirp/base.py @@ -101,6 +101,6 @@ def _get_encoder_decoder(self, endpoint, headers): else: raise exceptions.TwirpServerException( code=errors.Errors.BadRoute, - message="unexpected Content-Type: " + ctype + message="unexpected Content-Type: " + str(ctype) ) return encoder, decoder diff --git a/twirp/client.py b/twirp/client.py index fe34fc1..1fc9f8a 100644 --- a/twirp/client.py +++ b/twirp/client.py @@ -1,4 +1,3 @@ -import json import requests from . import exceptions @@ -25,8 +24,9 @@ def _make_request(self, *args, url, ctx, request, response_obj, **kwargs): return response try: raise exceptions.TwirpServerException.from_json(resp.json()) - except json.JSONDecodeError: - raise self._twirp_error_from_intermediary(resp) from None + except requests.JSONDecodeError: + raise exceptions.twirp_error_from_intermediary( + resp.status_code, resp.reason, resp.headers, resp.text) from None # Todo: handle error except requests.exceptions.Timeout as e: raise exceptions.TwirpServerException( @@ -40,42 +40,3 @@ def _make_request(self, *args, url, ctx, request, response_obj, **kwargs): message=str(e), meta={"original_exception": e}, ) - - @staticmethod - def _twirp_error_from_intermediary(resp): - # see https://twitchtv.github.io/twirp/docs/errors.html#http-errors-from-intermediary-proxies - meta = { - 'http_error_from_intermediary': 'true', - 'status_code': str(resp.status_code), - } - - if resp.is_redirect: - # twirp uses POST which should not redirect - code = errors.Errors.Internal - location = resp.headers.get('location') - message = 'unexpected HTTP status code %d "%s" received, Location="%s"' % ( - resp.status_code, - resp.reason, - location, - ) - meta['location'] = location - - else: - code = { - 400: errors.Errors.Internal, # JSON response should have been returned - 401: errors.Errors.Unauthenticated, - 403: errors.Errors.PermissionDenied, - 404: errors.Errors.BadRoute, - 429: errors.Errors.ResourceExhausted, - 502: errors.Errors.Unavailable, - 503: errors.Errors.Unavailable, - 504: errors.Errors.Unavailable, - }.get(resp.status_code, errors.Errors.Unknown) - - message = 'Error from intermediary with HTTP status code %d "%s"' % ( - resp.status_code, - resp.reason, - ) - meta['body'] = resp.text - - return exceptions.TwirpServerException(code=code, message=message, meta=meta) diff --git a/twirp/exceptions.py b/twirp/exceptions.py index fd5f59f..ec3d7f0 100644 --- a/twirp/exceptions.py +++ b/twirp/exceptions.py @@ -72,3 +72,42 @@ def RequiredArgument(*args, argument): argument=argument, error="is required" ) + + +def twirp_error_from_intermediary(status, reason, headers, body): + # see https://twitchtv.github.io/twirp/docs/errors.html#http-errors-from-intermediary-proxies + meta = { + 'http_error_from_intermediary': 'true', + 'status_code': str(status), + } + + if 300 <= status < 400: + # twirp uses POST which should not redirect + code = errors.Errors.Internal + location = headers.get('location') + message = 'unexpected HTTP status code %d "%s" received, Location="%s"' % ( + status, + reason, + location, + ) + meta['location'] = location + + else: + code = { + 400: errors.Errors.Internal, # JSON response should have been returned + 401: errors.Errors.Unauthenticated, + 403: errors.Errors.PermissionDenied, + 404: errors.Errors.BadRoute, + 429: errors.Errors.ResourceExhausted, + 502: errors.Errors.Unavailable, + 503: errors.Errors.Unavailable, + 504: errors.Errors.Unavailable, + }.get(status, errors.Errors.Unknown) + + message = 'Error from intermediary with HTTP status code %d "%s"' % ( + status, + reason, + ) + meta['body'] = body + + return TwirpServerException(code=code, message=message, meta=meta)