Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

async client #50

Merged
merged 9 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 60 additions & 9 deletions example/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,66 @@
try:
import asyncio
import aiohttp
except ModuleNotFoundError:
pass

from twirp.context import Context
from twirp.exceptions import TwirpServerException

from generated import haberdasher_twirp, haberdasher_pb2

client = haberdasher_twirp.HaberdasherClient("http://localhost:3000")

try:
response = client.MakeHat(
ctx=Context(), request=haberdasher_pb2.Size(inches=12), server_path_prefix="/twirpy")
if not response.HasField('name'):
print("We didn't get a name!")
print(response)
except TwirpServerException as e:
print(e.code, e.message, e.meta, e.to_dict())
server_url = "http://localhost:3000"
timeout_s = 5


def main():
client = haberdasher_twirp.HaberdasherClient(server_url, timeout_s)

try:
response = client.MakeHat(
ctx=Context(),
request=haberdasher_pb2.Size(inches=12),
server_path_prefix="/twirpy",
)
if not response.HasField("name"):
print("We didn't get a name!")
print(response)
except TwirpServerException as e:
print(e.code, e.message, e.meta, e.to_dict())


async def async_main():
# The caller must provide their own ClientSession to the twirp client
# either on init or per request, and ensure it is closed properly on app shutdown.

# NOTE: ClientSession may only be created (or closed) within a coroutine.
session = aiohttp.ClientSession(
server_url, timeout=aiohttp.ClientTimeout(total=timeout_s)
)
client = haberdasher_twirp.AsyncHaberdasherClient(server_url, session=session)

try:
response = await client.MakeHat(
ctx=Context(),
request=haberdasher_pb2.Size(inches=12),
server_path_prefix="/twirpy",
# Optionally provide a session per request
# session=session,
)
if not response.HasField("name"):
print("We didn't get a name!")
print(response)
except TwirpServerException as e:
print(e.code, e.message, e.meta, e.to_dict())
finally:
# Close the session (could also use a context manager)
await session.close()


if __name__ == "__main__":
if hasattr(haberdasher_twirp, "AsyncHaberdasherClient"):
print("using async client")
asyncio.run(async_main())
else:
main()
19 changes: 19 additions & 0 deletions example/generated/haberdasher_twirp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from twirp.base import Endpoint
from twirp.server import TwirpServer
from twirp.client import TwirpClient
try:
from twirp.async_client import AsyncTwirpClient
_async_available = True
except ModuleNotFoundError:
_async_available = False

_sym_db = _symbol_database.Default()

Expand Down Expand Up @@ -35,3 +40,17 @@ def MakeHat(self, *args, ctx, request, server_path_prefix="/twirp", **kwargs):
response_obj=_sym_db.GetSymbol("twitch.twirp.example.Hat"),
**kwargs,
)


if _async_available:
class AsyncHaberdasherClient(AsyncTwirpClient):

async def MakeHat(self, *, ctx, request, server_path_prefix="/twirp", session=None, **kwargs):
return await self._make_request(
url=F"{server_path_prefix}/twitch.twirp.example.Haberdasher/MakeHat",
ctx=ctx,
request=request,
response_obj=_sym_db.GetSymbol("twitch.twirp.example.Hat"),
session=session,
**kwargs,
)
4 changes: 2 additions & 2 deletions example/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
twirp==0.0.2
uvicorn==0.12.2
twirp==0.0.8
uvicorn==0.23.2
19 changes: 19 additions & 0 deletions protoc-gen-twirpy/generator/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ from google.protobuf import symbol_database as _symbol_database
from twirp.base import Endpoint
from twirp.server import TwirpServer
from twirp.client import TwirpClient
try:
from twirp.async_client import AsyncTwirpClient
_async_available = True
except ModuleNotFoundError:
_async_available = False

_sym_db = _symbol_database.Default()
{{range .Services}}
Expand Down Expand Up @@ -66,4 +71,18 @@ class {{.Name}}Client(TwirpClient):
response_obj=_sym_db.GetSymbol("{{.Output}}"),
**kwargs,
)
{{end}}

if _async_available:
class Async{{.Name}}Client(AsyncTwirpClient):
{{range .Methods}}
async def {{.Name}}(self, *, ctx, request, server_path_prefix="/twirp", session=None, **kwargs):
return await self._make_request(
url=F"{server_path_prefix}/{{.ServiceURL}}/{{.Name}}",
ctx=ctx,
request=request,
response_obj=_sym_db.GetSymbol("{{.Output}}"),
session=session,
**kwargs,
)
{{end}}{{end}}`))
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
'structlog',
'protobuf'
],
extras_require={
'async': ['aiohttp'],
},
test_requires=[
],
zip_safe=False)
57 changes: 57 additions & 0 deletions twirp/async_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio
import json
from typing import Optional

import aiohttp

from . import exceptions
from . import errors


class AsyncTwirpClient:
def __init__(
self, address: str, session: Optional[aiohttp.ClientSession] = None
) -> None:
self._address = address
self._session = session

async def _make_request(
self, *, url, ctx, request, response_obj, session=None, **kwargs
):
headers = ctx.get_headers()
if "headers" in kwargs:
headers.update(kwargs["headers"])
kwargs["headers"] = headers
kwargs["headers"]["Content-Type"] = "application/protobuf"

if session is None:
session = self._session
if not isinstance(session, aiohttp.ClientSession):
raise TypeError(f"invalid session type '{type(session).__name__}'")

try:
async with await session.post(
url=url, data=request.SerializeToString(), **kwargs
) as resp:
if resp.status == 200:
response = response_obj()
response.ParseFromString(await resp.read())
return response
try:
raise exceptions.TwirpServerException.from_json(await resp.json())
except (aiohttp.ContentTypeError, json.JSONDecodeError):
raise exceptions.twirp_error_from_intermediary(
resp.status, resp.reason, resp.headers, await resp.text()
) from None
except asyncio.TimeoutError as e:
raise exceptions.TwirpServerException(
code=errors.Errors.DeadlineExceeded,
message=str(e) or "request timeout",
meta={"original_exception": e},
)
except aiohttp.ServerConnectionError as e:
raise exceptions.TwirpServerException(
code=errors.Errors.Unavailable,
message=str(e),
meta={"original_exception": e},
)
2 changes: 1 addition & 1 deletion twirp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,6 @@ def _get_encoder_decoder(self, endpoint, headers):
else:
raise exceptions.TwirpServerException(
code=errors.Errors.BadRoute,
message="unexpected Content-Type: " + ctype
message="unexpected Content-Type: " + str(ctype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#46

)
return encoder, decoder
45 changes: 3 additions & 42 deletions twirp/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import requests

from . import exceptions
Expand All @@ -25,8 +24,9 @@ def _make_request(self, *args, url, ctx, request, response_obj, **kwargs):
return response
try:
raise exceptions.TwirpServerException.from_json(resp.json())
except json.JSONDecodeError:
raise self._twirp_error_from_intermediary(resp) from None
except requests.JSONDecodeError:
raise exceptions.twirp_error_from_intermediary(
resp.status_code, resp.reason, resp.headers, resp.text) from None
# Todo: handle error
except requests.exceptions.Timeout as e:
raise exceptions.TwirpServerException(
Expand All @@ -40,42 +40,3 @@ def _make_request(self, *args, url, ctx, request, response_obj, **kwargs):
message=str(e),
meta={"original_exception": e},
)

@staticmethod
def _twirp_error_from_intermediary(resp):
# see https://twitchtv.github.io/twirp/docs/errors.html#http-errors-from-intermediary-proxies
meta = {
'http_error_from_intermediary': 'true',
'status_code': str(resp.status_code),
}

if resp.is_redirect:
# twirp uses POST which should not redirect
code = errors.Errors.Internal
location = resp.headers.get('location')
message = 'unexpected HTTP status code %d "%s" received, Location="%s"' % (
resp.status_code,
resp.reason,
location,
)
meta['location'] = location

else:
code = {
400: errors.Errors.Internal, # JSON response should have been returned
401: errors.Errors.Unauthenticated,
403: errors.Errors.PermissionDenied,
404: errors.Errors.BadRoute,
429: errors.Errors.ResourceExhausted,
502: errors.Errors.Unavailable,
503: errors.Errors.Unavailable,
504: errors.Errors.Unavailable,
}.get(resp.status_code, errors.Errors.Unknown)

message = 'Error from intermediary with HTTP status code %d "%s"' % (
resp.status_code,
resp.reason,
)
meta['body'] = resp.text

return exceptions.TwirpServerException(code=code, message=message, meta=meta)
39 changes: 39 additions & 0 deletions twirp/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,42 @@ def RequiredArgument(*args, argument):
argument=argument,
error="is required"
)


def twirp_error_from_intermediary(status, reason, headers, body):
# see https://twitchtv.github.io/twirp/docs/errors.html#http-errors-from-intermediary-proxies
meta = {
'http_error_from_intermediary': 'true',
'status_code': str(status),
}

if 300 <= status < 400:
# twirp uses POST which should not redirect
code = errors.Errors.Internal
location = headers.get('location')
message = 'unexpected HTTP status code %d "%s" received, Location="%s"' % (
status,
reason,
location,
)
meta['location'] = location

else:
code = {
400: errors.Errors.Internal, # JSON response should have been returned
401: errors.Errors.Unauthenticated,
403: errors.Errors.PermissionDenied,
404: errors.Errors.BadRoute,
429: errors.Errors.ResourceExhausted,
502: errors.Errors.Unavailable,
503: errors.Errors.Unavailable,
504: errors.Errors.Unavailable,
}.get(status, errors.Errors.Unknown)

message = 'Error from intermediary with HTTP status code %d "%s"' % (
status,
reason,
)
meta['body'] = body

return TwirpServerException(code=code, message=message, meta=meta)