Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

async client #50

Merged
merged 9 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 80 additions & 9 deletions example/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,86 @@
try:
import asyncio
import aiohttp
except ModuleNotFoundError:
pass

from twirp.context import Context
from twirp.exceptions import TwirpServerException

from generated import haberdasher_twirp, haberdasher_pb2

client = haberdasher_twirp.HaberdasherClient("http://localhost:3000")

try:
response = client.MakeHat(
ctx=Context(), request=haberdasher_pb2.Size(inches=12), server_path_prefix="/twirpy")
if not response.HasField('name'):
print("We didn't get a name!")
print(response)
except TwirpServerException as e:
print(e.code, e.message, e.meta, e.to_dict())
server_url = "http://localhost:3000"
timeout_s = 5


def main():
client = haberdasher_twirp.HaberdasherClient(server_url, timeout_s)

try:
response = client.MakeHat(
ctx=Context(),
request=haberdasher_pb2.Size(inches=12),
server_path_prefix="/twirpy",
)
if not response.HasField("name"):
print("We didn't get a name!")
print(response)
except TwirpServerException as e:
print(e.code, e.message, e.meta, e.to_dict())


async def async_main():
client = haberdasher_twirp.AsyncHaberdasherClient(server_url, timeout_s)

try:
response = await client.MakeHat(
ctx=Context(),
request=haberdasher_pb2.Size(inches=12),
server_path_prefix="/twirpy",
)
if not response.HasField("name"):
print("We didn't get a name!")
print(response)
except TwirpServerException as e:
print(e.code, e.message, e.meta, e.to_dict())


async def async_with_session():
# It is optional but recommended to provide your own ClientSession to the twirp client
# either on init or per request, and ensure it is closed properly on app shutdown.
# Otherwise, the client will create its own session to use, which it will attempt to
# close in its __del__ method, but has no control over how or when that will get called.

# NOTE: ClientSession may only be created (or closed) within a coroutine.
session = aiohttp.ClientSession(
server_url, timeout=aiohttp.ClientTimeout(total=timeout_s)
)

# If session is provided, session controls the timeout. Timeout parameter to client init is unused
client = haberdasher_twirp.AsyncHaberdasherClient(server_url, session=session)

try:
response = await client.MakeHat(
ctx=Context(),
request=haberdasher_pb2.Size(inches=12),
server_path_prefix="/twirpy",
# Optionally provide a session per request
# session=session,
)
if not response.HasField("name"):
print("We didn't get a name!")
print(response)
except TwirpServerException as e:
print(e.code, e.message, e.meta, e.to_dict())
finally:
# Close the session (could also use a context manager)
await session.close()


if __name__ == "__main__":
if hasattr(haberdasher_twirp, "AsyncHaberdasherClient"):
print("using async client")
asyncio.run(async_main())
else:
main()
19 changes: 19 additions & 0 deletions example/generated/haberdasher_twirp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from twirp.base import Endpoint
from twirp.server import TwirpServer
from twirp.client import TwirpClient
try:
from twirp.async_client import AsyncTwirpClient
_async_available = True
except ModuleNotFoundError:
_async_available = False

_sym_db = _symbol_database.Default()

Expand Down Expand Up @@ -35,3 +40,17 @@ def MakeHat(self, *args, ctx, request, server_path_prefix="/twirp", **kwargs):
response_obj=_sym_db.GetSymbol("twitch.twirp.example.Hat"),
**kwargs,
)


if _async_available:
class AsyncHaberdasherClient(AsyncTwirpClient):

async def MakeHat(self, *, ctx, request, server_path_prefix="/twirp", session=None, **kwargs):
return await self._make_request(
url=F"{server_path_prefix}/twitch.twirp.example.Haberdasher/MakeHat",
ctx=ctx,
request=request,
response_obj=_sym_db.GetSymbol("twitch.twirp.example.Hat"),
session=session,
**kwargs,
)
4 changes: 2 additions & 2 deletions example/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
twirp==0.0.2
uvicorn==0.12.2
twirp==0.0.8
uvicorn==0.23.2
19 changes: 19 additions & 0 deletions protoc-gen-twirpy/generator/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ from google.protobuf import symbol_database as _symbol_database
from twirp.base import Endpoint
from twirp.server import TwirpServer
from twirp.client import TwirpClient
try:
from twirp.async_client import AsyncTwirpClient
_async_available = True
except ModuleNotFoundError:
_async_available = False

_sym_db = _symbol_database.Default()
{{range .Services}}
Expand Down Expand Up @@ -66,4 +71,18 @@ class {{.Name}}Client(TwirpClient):
response_obj=_sym_db.GetSymbol("{{.Output}}"),
**kwargs,
)
{{end}}

if _async_available:
class Async{{.Name}}Client(AsyncTwirpClient):
{{range .Methods}}
async def {{.Name}}(self, *, ctx, request, server_path_prefix="/twirp", session=None, **kwargs):
return await self._make_request(
url=F"{server_path_prefix}/{{.ServiceURL}}/{{.Name}}",
ctx=ctx,
request=request,
response_obj=_sym_db.GetSymbol("{{.Output}}"),
session=session,
**kwargs,
)
{{end}}{{end}}`))
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
'structlog',
'protobuf'
],
extras_require={
'async': ['aiohttp'],
},
test_requires=[
],
zip_safe=False)
69 changes: 69 additions & 0 deletions twirp/async_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import asyncio
import json
import aiohttp

from . import exceptions
from . import errors


class AsyncTwirpClient:
def __init__(self, address, timeout=5, session=None):
self._address = address
self._timeout = timeout
self._session = session
self._should_close_session = False

def __del__(self):
if self._should_close_session:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self._session.close())
elif not loop.is_closed():
loop.run_until_complete(self._session.close())
except RuntimeError:
pass

@property
def session(self):
if self._session is None:
self._session = aiohttp.ClientSession(
self._address, timeout=aiohttp.ClientTimeout(total=self._timeout)
)
self._should_close_session = True
return self._session
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The go implementation doesn't create a client automatically.

I think it should be the same here.

Suggested change
def __init__(self, address, timeout=5, session=None):
self._address = address
self._timeout = timeout
self._session = session
self._should_close_session = False
def __del__(self):
if self._should_close_session:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self._session.close())
elif not loop.is_closed():
loop.run_until_complete(self._session.close())
except RuntimeError:
pass
@property
def session(self):
if self._session is None:
self._session = aiohttp.ClientSession(
self._address, timeout=aiohttp.ClientTimeout(total=self._timeout)
)
self._should_close_session = True
return self._session
def __init__(self, address: str, session: ClientSession) -> None:
self._address = address
self._session = session
async def aclose(self) -> None:
await self._session.close()

It can be used like this:

http_client_options = {
    "timeout":  aiohttp.ClientTimeout(total=timeout_s)
}

async with aiohttp.ClientSession(**http_client_options) as session:
    client = AsyncHaberdasherClient(address, session)
    await client.MakeHat(...)
    ...

async with contextlib.aclosing(AsyncHaberdasherClient(
    address, 
    aiohttp.ClientSession(**http_client_options)
)) as client:
    await client.MakeHat(...)
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Yes, I agree, and I didn't love the idea of creating the session automatically, but I wasn't sure if some would prefer a simpler way to use it without having to provide and manage the underlying session. (Because the existing sync implementation with requests has no such session requirement, I was hesitant to add one, and was trying to keep the usage similar for those who wanted to keep it that way.)

However:

  • aiohttp.ClientSession can only be created within a coroutine, so if you are required to provide a session when you init the twirp client then you can only do so within a coroutine. (This is not how we currently init our client dependencies within our app, so I had planned to pass the client session with each request, but that might feel cumbersome to some folks)
  • If the caller creates and manages the client session, I'd think they should probably handle the closing of it also, rather than having the twirp client close the session

Thoughts?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can solve your problem by turning the client into an async context manager.
It would be something like that:

def init():
     client = AsyncHaberdasherClient(address, session=None)
...
async def use(client: AsyncHaberdasherClient):
    async with (
        aiohttp.ClientSession() as session,
        client.with_session(session),
    ):
        await client.MakeHat(...)

with_session can be implemented like this:

@asynccontextmanager
async def with_session(session: ClientSession):
  backup, self._session = self._session, session
  yield
  backup, self._session = None, backup

Or using a factory? (need testing)

def init():
     async def session_factory():
         return aiohttp.ClientSession()
     client = AsyncHaberdasherClient(address, session_factory=session_factory)
...
async def use(client: AsyncHaberdasherClient):
    async with client:
        await client.MakeHat(...)

You would need to implement __aenter__ and __aexit__ on AsyncTwirpClient to call the functions on the session.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking a bit more about this, maybe you could just pass the http_client_options instead of a factory to the constructor when creating the twirp client.

Then you can use __aenter__ and __aexit__ to create a temporary session which can be reused for several calls.
But it could also create a session directly inside _make_request on each call for simplicity when there is no need to reuse a client session.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions @MiLk! I don't mind passing the session in each request, but do you think I should add with_session or session_factory as something that's generally useful for others? I pushed some changes removing the auto-created client, and would be fine merging as-is but would like to cover other common use cases.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see that the session could be passed directly with each request.
You can disregard my previous suggestions.
I think that's enough, and there is no need to make the code more complex.


async def _make_request(
self, *, url, ctx, request, response_obj, session=None, **kwargs
):
headers = ctx.get_headers()
if "headers" in kwargs:
headers.update(kwargs["headers"])
kwargs["headers"] = headers
kwargs["headers"]["Content-Type"] = "application/protobuf"
try:
async with await (session or self.session).post(
url=url, data=request.SerializeToString(), **kwargs
) as resp:
if resp.status == 200:
response = response_obj()
response.ParseFromString(await resp.read())
return response
try:
raise exceptions.TwirpServerException.from_json(await resp.json())
except (aiohttp.ContentTypeError, json.JSONDecodeError):
raise exceptions.twirp_error_from_intermediary(
resp.status, resp.reason, resp.headers, await resp.text()
) from None
except asyncio.TimeoutError as e:
raise exceptions.TwirpServerException(
code=errors.Errors.DeadlineExceeded,
message=str(e) or "request timeout",
meta={"original_exception": e},
)
except aiohttp.ServerConnectionError as e:
raise exceptions.TwirpServerException(
code=errors.Errors.Unavailable,
message=str(e),
meta={"original_exception": e},
)
2 changes: 1 addition & 1 deletion twirp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,6 @@ def _get_encoder_decoder(self, endpoint, headers):
else:
raise exceptions.TwirpServerException(
code=errors.Errors.BadRoute,
message="unexpected Content-Type: " + ctype
message="unexpected Content-Type: " + str(ctype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#46

)
return encoder, decoder
45 changes: 3 additions & 42 deletions twirp/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import requests

from . import exceptions
Expand All @@ -25,8 +24,9 @@ def _make_request(self, *args, url, ctx, request, response_obj, **kwargs):
return response
try:
raise exceptions.TwirpServerException.from_json(resp.json())
except json.JSONDecodeError:
raise self._twirp_error_from_intermediary(resp) from None
except requests.JSONDecodeError:
raise exceptions.twirp_error_from_intermediary(
resp.status_code, resp.reason, resp.headers, resp.text) from None
# Todo: handle error
except requests.exceptions.Timeout as e:
raise exceptions.TwirpServerException(
Expand All @@ -40,42 +40,3 @@ def _make_request(self, *args, url, ctx, request, response_obj, **kwargs):
message=str(e),
meta={"original_exception": e},
)

@staticmethod
def _twirp_error_from_intermediary(resp):
# see https://twitchtv.github.io/twirp/docs/errors.html#http-errors-from-intermediary-proxies
meta = {
'http_error_from_intermediary': 'true',
'status_code': str(resp.status_code),
}

if resp.is_redirect:
# twirp uses POST which should not redirect
code = errors.Errors.Internal
location = resp.headers.get('location')
message = 'unexpected HTTP status code %d "%s" received, Location="%s"' % (
resp.status_code,
resp.reason,
location,
)
meta['location'] = location

else:
code = {
400: errors.Errors.Internal, # JSON response should have been returned
401: errors.Errors.Unauthenticated,
403: errors.Errors.PermissionDenied,
404: errors.Errors.BadRoute,
429: errors.Errors.ResourceExhausted,
502: errors.Errors.Unavailable,
503: errors.Errors.Unavailable,
504: errors.Errors.Unavailable,
}.get(resp.status_code, errors.Errors.Unknown)

message = 'Error from intermediary with HTTP status code %d "%s"' % (
resp.status_code,
resp.reason,
)
meta['body'] = resp.text

return exceptions.TwirpServerException(code=code, message=message, meta=meta)
39 changes: 39 additions & 0 deletions twirp/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,42 @@ def RequiredArgument(*args, argument):
argument=argument,
error="is required"
)


def twirp_error_from_intermediary(status, reason, headers, body):
# see https://twitchtv.github.io/twirp/docs/errors.html#http-errors-from-intermediary-proxies
meta = {
'http_error_from_intermediary': 'true',
'status_code': str(status),
}

if 300 <= status < 400:
# twirp uses POST which should not redirect
code = errors.Errors.Internal
location = headers.get('location')
message = 'unexpected HTTP status code %d "%s" received, Location="%s"' % (
status,
reason,
location,
)
meta['location'] = location

else:
code = {
400: errors.Errors.Internal, # JSON response should have been returned
401: errors.Errors.Unauthenticated,
403: errors.Errors.PermissionDenied,
404: errors.Errors.BadRoute,
429: errors.Errors.ResourceExhausted,
502: errors.Errors.Unavailable,
503: errors.Errors.Unavailable,
504: errors.Errors.Unavailable,
}.get(status, errors.Errors.Unknown)

message = 'Error from intermediary with HTTP status code %d "%s"' % (
status,
reason,
)
meta['body'] = body

return TwirpServerException(code=code, message=message, meta=meta)