From 5665509654f1a7ed9903907aeea6a34328895c30 Mon Sep 17 00:00:00 2001 From: sacOO7 Date: Sat, 30 Sep 2023 12:39:06 +0530 Subject: [PATCH] Refactored code to run async safe --- ably/decorator/sync.py | 4 +++- ably/http/http.py | 6 +++--- ably/http/paginatedresult.py | 10 +++++----- ably/rest/auth.py | 12 ++++++------ ably/rest/channel.py | 14 +++++++------- ably/rest/rest.py | 6 +++--- 6 files changed, 27 insertions(+), 25 deletions(-) diff --git a/ably/decorator/sync.py b/ably/decorator/sync.py index f3577072..34a61271 100644 --- a/ably/decorator/sync.py +++ b/ably/decorator/sync.py @@ -4,9 +4,11 @@ from ably.executer.eventloop import AppEventLoop -def optional_sync(fn): +def run_async_safe(fn): ''' Enables async function to be used as both sync and async function. + If called from a eventloop or coroutine, returns a future, doesn't block external eventloop. + If called from a regular function, returns a blocking result. Also makes async/sync workflow thread safe. This decorator should only be used on async methods/coroutines. ''' diff --git a/ably/http/http.py b/ably/http/http.py index 7d550dd5..81cbdde2 100644 --- a/ably/http/http.py +++ b/ably/http/http.py @@ -7,7 +7,7 @@ import httpx import msgpack -from ably.decorator.sync import optional_sync +from ably.decorator.sync import run_async_safe from ably.rest.auth import Auth from ably.http.httputils import HttpUtils from ably.transport.defaults import Defaults @@ -132,7 +132,7 @@ def __init__(self, ably, options): self.__host_expires = None self.__client = httpx.AsyncClient(http2=True) - @optional_sync + @run_async_safe async def close(self): await self.__client.aclose() @@ -158,7 +158,7 @@ def get_rest_hosts(self): hosts.insert(0, host) return hosts - @optional_sync + @run_async_safe @reauth_if_expired async def make_request(self, method, path, version=None, headers=None, body=None, skip_auth=False, timeout=None, raise_on_error=True): diff --git a/ably/http/paginatedresult.py b/ably/http/paginatedresult.py index 39b10cd6..23780e66 100644 --- a/ably/http/paginatedresult.py +++ b/ably/http/paginatedresult.py @@ -2,7 +2,7 @@ import logging from urllib.parse import urlencode -from ably.decorator.sync import optional_sync +from ably.decorator.sync import run_async_safe from ably.http.http import Request from ably.util import case @@ -66,11 +66,11 @@ def has_next(self): def is_last(self): return not self.has_next() - @optional_sync + @run_async_safe async def first(self): return await self.__get_rel(self.__rel_first) if self.__rel_first else None - @optional_sync + @run_async_safe async def next(self): return await self.__get_rel(self.__rel_next) if self.__rel_next else None @@ -80,7 +80,7 @@ async def __get_rel(self, rel_req): return await self.paginated_query_with_request(self.__http, rel_req, self.__response_processor) @classmethod - @optional_sync + @run_async_safe async def paginated_query(cls, http, method='GET', url='/', version=None, body=None, headers=None, response_processor=None, raise_on_error=True): @@ -90,7 +90,7 @@ async def paginated_query(cls, http, method='GET', url='/', version=None, body=N return await cls.paginated_query_with_request(http, req, response_processor) @classmethod - @optional_sync + @run_async_safe async def paginated_query_with_request(cls, http, request, response_processor, raise_on_error=True): response = await http.make_request( diff --git a/ably/rest/auth.py b/ably/rest/auth.py index e753d521..83bf26c8 100644 --- a/ably/rest/auth.py +++ b/ably/rest/auth.py @@ -9,7 +9,7 @@ import uuid import httpx -from ably.decorator.sync import optional_sync +from ably.decorator.sync import run_async_safe from ably.types.options import Options if TYPE_CHECKING: @@ -87,7 +87,7 @@ def __init__(self, ably: Union[AblyRest, AblyRealtime], options: Options): raise ValueError("Can't authenticate via token, must provide " "auth_callback, auth_url, key, token or a TokenDetail") - @optional_sync + @run_async_safe async def get_auth_transport_param(self): auth_credentials = {} if self.auth_options.client_id: @@ -155,11 +155,11 @@ def token_details_has_expired(self): return expires < timestamp + token_details.TOKEN_EXPIRY_BUFFER - @optional_sync + @run_async_safe async def authorize(self, token_params: Optional[dict] = None, auth_options=None): return await self.__authorize_when_necessary(token_params, auth_options, force=True) - @optional_sync + @run_async_safe async def request_token(self, token_params: Optional[dict] = None, # auth_options key_name: Optional[str] = None, key_secret: Optional[str] = None, auth_callback=None, @@ -238,7 +238,7 @@ async def request_token(self, token_params: Optional[dict] = None, log.debug("Token: %s" % str(response_dict.get("token"))) return TokenDetails.from_dict(response_dict) - @optional_sync + @run_async_safe async def create_token_request(self, token_params: Optional[dict] = None, key_name: Optional[str] = None, key_secret: Optional[str] = None, query_time=None): token_params = token_params or {} @@ -396,7 +396,7 @@ def _timestamp(self): def _random_nonce(self): return uuid.uuid4().hex[:16] - @optional_sync + @run_async_safe async def token_request_from_auth_url(self, method: str, url: str, token_params, headers, auth_params): body = None diff --git a/ably/rest/channel.py b/ably/rest/channel.py index 71ec21fc..f368803a 100644 --- a/ably/rest/channel.py +++ b/ably/rest/channel.py @@ -9,7 +9,7 @@ from methoddispatch import SingleDispatch, singledispatch import msgpack -from ably.decorator.sync import optional_sync +from ably.decorator.sync import run_async_safe from ably.http.paginatedresult import PaginatedResult, format_params from ably.types.channeldetails import ChannelDetails from ably.types.message import Message, make_message_response_handler @@ -30,7 +30,7 @@ def __init__(self, ably, name, options): self.__presence = Presence(self) @catch_all - @optional_sync + @run_async_safe async def history(self, direction=None, limit: int = None, start=None, end=None): """Returns the history for this channel""" params = format_params({}, direction=direction, start=start, end=end, limit=limit) @@ -83,12 +83,12 @@ def _publish(self, arg, *args, **kwargs): raise TypeError('Unexpected type %s' % type(arg)) @_publish.register(Message) - @optional_sync + @run_async_safe async def publish_message(self, message, params=None, timeout=None): return await self.publish_messages([message], params, timeout=timeout) @_publish.register(list) - @optional_sync + @run_async_safe async def publish_messages(self, messages, params=None, timeout=None): request_body = self.__publish_request_body(messages) if not self.ably.options.use_binary_protocol: @@ -103,12 +103,12 @@ async def publish_messages(self, messages, params=None, timeout=None): return await self.ably.http.post(path, body=request_body, timeout=timeout) @_publish.register(str) - @optional_sync + @run_async_safe async def publish_name_data(self, name, data, timeout=None): messages = [Message(name, data)] return await self.publish_messages(messages, timeout=timeout) - @optional_sync + @run_async_safe async def publish(self, *args, **kwargs): """Publishes a message on this channel. @@ -137,7 +137,7 @@ async def publish(self, *args, **kwargs): return await self._publish(*args, **kwargs) - @optional_sync + @run_async_safe async def status(self): """Retrieves current channel active status with no. of publishers, subscribers, presence_members etc""" diff --git a/ably/rest/rest.py b/ably/rest/rest.py index 8ceb67f2..099e0355 100644 --- a/ably/rest/rest.py +++ b/ably/rest/rest.py @@ -2,7 +2,7 @@ from typing import Optional from urllib.parse import urlencode -from ably.decorator.sync import optional_sync, close_app_eventloop +from ably.decorator.sync import run_async_safe, close_app_eventloop from ably.http.http import Http from ably.http.paginatedresult import PaginatedResult, HttpPaginatedResponse from ably.http.paginatedresult import format_params @@ -92,7 +92,7 @@ async def stats(self, direction: Optional[str] = None, start=None, end=None, par self.http, url=url, response_processor=stats_response_processor) @catch_all - @optional_sync + @run_async_safe async def time(self, timeout: Optional[float] = None) -> float: """Returns the current server time in ms since the unix epoch""" r = await self.http.get('/time', skip_auth=True, timeout=timeout) @@ -124,7 +124,7 @@ def options(self): def push(self): return self.__push - @optional_sync + @run_async_safe async def request(self, method: str, path: str, version: str, params: Optional[dict] = None, body=None, headers=None): if version is None: