Skip to content

Commit

Permalink
async client (#50)
Browse files Browse the repository at this point in the history
* add async client

* create/close single aiohttp session

* update template

* update example

* fix bug in missing content-type exception

* session param in init, better teardown and examples

* formatting

* don't auto-create client session

* update example
  • Loading branch information
chadawagner authored Dec 12, 2023
1 parent d35e284 commit 69b39fb
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 54 deletions.
69 changes: 60 additions & 9 deletions example/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,66 @@
try:
import asyncio
import aiohttp
except ModuleNotFoundError:
pass

from twirp.context import Context
from twirp.exceptions import TwirpServerException

from generated import haberdasher_twirp, haberdasher_pb2

client = haberdasher_twirp.HaberdasherClient("http://localhost:3000")

try:
response = client.MakeHat(
ctx=Context(), request=haberdasher_pb2.Size(inches=12), server_path_prefix="/twirpy")
if not response.HasField('name'):
print("We didn't get a name!")
print(response)
except TwirpServerException as e:
print(e.code, e.message, e.meta, e.to_dict())
server_url = "http://localhost:3000"
timeout_s = 5


def main():
client = haberdasher_twirp.HaberdasherClient(server_url, timeout_s)

try:
response = client.MakeHat(
ctx=Context(),
request=haberdasher_pb2.Size(inches=12),
server_path_prefix="/twirpy",
)
if not response.HasField("name"):
print("We didn't get a name!")
print(response)
except TwirpServerException as e:
print(e.code, e.message, e.meta, e.to_dict())


async def async_main():
# The caller must provide their own ClientSession to the twirp client
# either on init or per request, and ensure it is closed properly on app shutdown.

# NOTE: ClientSession may only be created (or closed) within a coroutine.
session = aiohttp.ClientSession(
server_url, timeout=aiohttp.ClientTimeout(total=timeout_s)
)
client = haberdasher_twirp.AsyncHaberdasherClient(server_url, session=session)

try:
response = await client.MakeHat(
ctx=Context(),
request=haberdasher_pb2.Size(inches=12),
server_path_prefix="/twirpy",
# Optionally provide a session per request
# session=session,
)
if not response.HasField("name"):
print("We didn't get a name!")
print(response)
except TwirpServerException as e:
print(e.code, e.message, e.meta, e.to_dict())
finally:
# Close the session (could also use a context manager)
await session.close()


if __name__ == "__main__":
if hasattr(haberdasher_twirp, "AsyncHaberdasherClient"):
print("using async client")
asyncio.run(async_main())
else:
main()
19 changes: 19 additions & 0 deletions example/generated/haberdasher_twirp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from twirp.base import Endpoint
from twirp.server import TwirpServer
from twirp.client import TwirpClient
try:
from twirp.async_client import AsyncTwirpClient
_async_available = True
except ModuleNotFoundError:
_async_available = False

_sym_db = _symbol_database.Default()

Expand Down Expand Up @@ -35,3 +40,17 @@ def MakeHat(self, *args, ctx, request, server_path_prefix="/twirp", **kwargs):
response_obj=_sym_db.GetSymbol("twitch.twirp.example.Hat"),
**kwargs,
)


if _async_available:
class AsyncHaberdasherClient(AsyncTwirpClient):

async def MakeHat(self, *, ctx, request, server_path_prefix="/twirp", session=None, **kwargs):
return await self._make_request(
url=F"{server_path_prefix}/twitch.twirp.example.Haberdasher/MakeHat",
ctx=ctx,
request=request,
response_obj=_sym_db.GetSymbol("twitch.twirp.example.Hat"),
session=session,
**kwargs,
)
4 changes: 2 additions & 2 deletions example/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
twirp==0.0.2
uvicorn==0.12.2
twirp==0.0.8
uvicorn==0.23.2
19 changes: 19 additions & 0 deletions protoc-gen-twirpy/generator/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ from google.protobuf import symbol_database as _symbol_database
from twirp.base import Endpoint
from twirp.server import TwirpServer
from twirp.client import TwirpClient
try:
from twirp.async_client import AsyncTwirpClient
_async_available = True
except ModuleNotFoundError:
_async_available = False
_sym_db = _symbol_database.Default()
{{range .Services}}
Expand Down Expand Up @@ -66,4 +71,18 @@ class {{.Name}}Client(TwirpClient):
response_obj=_sym_db.GetSymbol("{{.Output}}"),
**kwargs,
)
{{end}}
if _async_available:
class Async{{.Name}}Client(AsyncTwirpClient):
{{range .Methods}}
async def {{.Name}}(self, *, ctx, request, server_path_prefix="/twirp", session=None, **kwargs):
return await self._make_request(
url=F"{server_path_prefix}/{{.ServiceURL}}/{{.Name}}",
ctx=ctx,
request=request,
response_obj=_sym_db.GetSymbol("{{.Output}}"),
session=session,
**kwargs,
)
{{end}}{{end}}`))
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
'structlog',
'protobuf'
],
extras_require={
'async': ['aiohttp'],
},
test_requires=[
],
zip_safe=False)
57 changes: 57 additions & 0 deletions twirp/async_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio
import json
from typing import Optional

import aiohttp

from . import exceptions
from . import errors


class AsyncTwirpClient:
def __init__(
self, address: str, session: Optional[aiohttp.ClientSession] = None
) -> None:
self._address = address
self._session = session

async def _make_request(
self, *, url, ctx, request, response_obj, session=None, **kwargs
):
headers = ctx.get_headers()
if "headers" in kwargs:
headers.update(kwargs["headers"])
kwargs["headers"] = headers
kwargs["headers"]["Content-Type"] = "application/protobuf"

if session is None:
session = self._session
if not isinstance(session, aiohttp.ClientSession):
raise TypeError(f"invalid session type '{type(session).__name__}'")

try:
async with await session.post(
url=url, data=request.SerializeToString(), **kwargs
) as resp:
if resp.status == 200:
response = response_obj()
response.ParseFromString(await resp.read())
return response
try:
raise exceptions.TwirpServerException.from_json(await resp.json())
except (aiohttp.ContentTypeError, json.JSONDecodeError):
raise exceptions.twirp_error_from_intermediary(
resp.status, resp.reason, resp.headers, await resp.text()
) from None
except asyncio.TimeoutError as e:
raise exceptions.TwirpServerException(
code=errors.Errors.DeadlineExceeded,
message=str(e) or "request timeout",
meta={"original_exception": e},
)
except aiohttp.ServerConnectionError as e:
raise exceptions.TwirpServerException(
code=errors.Errors.Unavailable,
message=str(e),
meta={"original_exception": e},
)
2 changes: 1 addition & 1 deletion twirp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,6 @@ def _get_encoder_decoder(self, endpoint, headers):
else:
raise exceptions.TwirpServerException(
code=errors.Errors.BadRoute,
message="unexpected Content-Type: " + ctype
message="unexpected Content-Type: " + str(ctype)
)
return encoder, decoder
45 changes: 3 additions & 42 deletions twirp/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import requests

from . import exceptions
Expand All @@ -25,8 +24,9 @@ def _make_request(self, *args, url, ctx, request, response_obj, **kwargs):
return response
try:
raise exceptions.TwirpServerException.from_json(resp.json())
except json.JSONDecodeError:
raise self._twirp_error_from_intermediary(resp) from None
except requests.JSONDecodeError:
raise exceptions.twirp_error_from_intermediary(
resp.status_code, resp.reason, resp.headers, resp.text) from None
# Todo: handle error
except requests.exceptions.Timeout as e:
raise exceptions.TwirpServerException(
Expand All @@ -40,42 +40,3 @@ def _make_request(self, *args, url, ctx, request, response_obj, **kwargs):
message=str(e),
meta={"original_exception": e},
)

@staticmethod
def _twirp_error_from_intermediary(resp):
# see https://twitchtv.github.io/twirp/docs/errors.html#http-errors-from-intermediary-proxies
meta = {
'http_error_from_intermediary': 'true',
'status_code': str(resp.status_code),
}

if resp.is_redirect:
# twirp uses POST which should not redirect
code = errors.Errors.Internal
location = resp.headers.get('location')
message = 'unexpected HTTP status code %d "%s" received, Location="%s"' % (
resp.status_code,
resp.reason,
location,
)
meta['location'] = location

else:
code = {
400: errors.Errors.Internal, # JSON response should have been returned
401: errors.Errors.Unauthenticated,
403: errors.Errors.PermissionDenied,
404: errors.Errors.BadRoute,
429: errors.Errors.ResourceExhausted,
502: errors.Errors.Unavailable,
503: errors.Errors.Unavailable,
504: errors.Errors.Unavailable,
}.get(resp.status_code, errors.Errors.Unknown)

message = 'Error from intermediary with HTTP status code %d "%s"' % (
resp.status_code,
resp.reason,
)
meta['body'] = resp.text

return exceptions.TwirpServerException(code=code, message=message, meta=meta)
39 changes: 39 additions & 0 deletions twirp/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,42 @@ def RequiredArgument(*args, argument):
argument=argument,
error="is required"
)


def twirp_error_from_intermediary(status, reason, headers, body):
# see https://twitchtv.github.io/twirp/docs/errors.html#http-errors-from-intermediary-proxies
meta = {
'http_error_from_intermediary': 'true',
'status_code': str(status),
}

if 300 <= status < 400:
# twirp uses POST which should not redirect
code = errors.Errors.Internal
location = headers.get('location')
message = 'unexpected HTTP status code %d "%s" received, Location="%s"' % (
status,
reason,
location,
)
meta['location'] = location

else:
code = {
400: errors.Errors.Internal, # JSON response should have been returned
401: errors.Errors.Unauthenticated,
403: errors.Errors.PermissionDenied,
404: errors.Errors.BadRoute,
429: errors.Errors.ResourceExhausted,
502: errors.Errors.Unavailable,
503: errors.Errors.Unavailable,
504: errors.Errors.Unavailable,
}.get(status, errors.Errors.Unknown)

message = 'Error from intermediary with HTTP status code %d "%s"' % (
status,
reason,
)
meta['body'] = body

return TwirpServerException(code=code, message=message, meta=meta)

0 comments on commit 69b39fb

Please sign in to comment.