Skip to content

Commit

Permalink
Refactored code to run async safe
Browse files Browse the repository at this point in the history
  • Loading branch information
sacOO7 committed Sep 30, 2023
1 parent f4608cd commit 5665509
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 25 deletions.
4 changes: 3 additions & 1 deletion ably/decorator/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from ably.executer.eventloop import AppEventLoop


def optional_sync(fn):
def run_async_safe(fn):
'''
Enables async function to be used as both sync and async function.
If called from a eventloop or coroutine, returns a future, doesn't block external eventloop.
If called from a regular function, returns a blocking result.
Also makes async/sync workflow thread safe.
This decorator should only be used on async methods/coroutines.
'''
Expand Down
6 changes: 3 additions & 3 deletions ably/http/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import httpx
import msgpack

from ably.decorator.sync import optional_sync
from ably.decorator.sync import run_async_safe
from ably.rest.auth import Auth
from ably.http.httputils import HttpUtils
from ably.transport.defaults import Defaults
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(self, ably, options):
self.__host_expires = None
self.__client = httpx.AsyncClient(http2=True)

@optional_sync
@run_async_safe
async def close(self):
await self.__client.aclose()

Expand All @@ -158,7 +158,7 @@ def get_rest_hosts(self):
hosts.insert(0, host)
return hosts

@optional_sync
@run_async_safe
@reauth_if_expired
async def make_request(self, method, path, version=None, headers=None, body=None,
skip_auth=False, timeout=None, raise_on_error=True):
Expand Down
10 changes: 5 additions & 5 deletions ably/http/paginatedresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from urllib.parse import urlencode

from ably.decorator.sync import optional_sync
from ably.decorator.sync import run_async_safe
from ably.http.http import Request
from ably.util import case

Expand Down Expand Up @@ -66,11 +66,11 @@ def has_next(self):
def is_last(self):
return not self.has_next()

@optional_sync
@run_async_safe
async def first(self):
return await self.__get_rel(self.__rel_first) if self.__rel_first else None

@optional_sync
@run_async_safe
async def next(self):
return await self.__get_rel(self.__rel_next) if self.__rel_next else None

Expand All @@ -80,7 +80,7 @@ async def __get_rel(self, rel_req):
return await self.paginated_query_with_request(self.__http, rel_req, self.__response_processor)

@classmethod
@optional_sync
@run_async_safe
async def paginated_query(cls, http, method='GET', url='/', version=None, body=None,
headers=None, response_processor=None,
raise_on_error=True):
Expand All @@ -90,7 +90,7 @@ async def paginated_query(cls, http, method='GET', url='/', version=None, body=N
return await cls.paginated_query_with_request(http, req, response_processor)

@classmethod
@optional_sync
@run_async_safe
async def paginated_query_with_request(cls, http, request, response_processor,
raise_on_error=True):
response = await http.make_request(
Expand Down
12 changes: 6 additions & 6 deletions ably/rest/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import uuid
import httpx

from ably.decorator.sync import optional_sync
from ably.decorator.sync import run_async_safe
from ably.types.options import Options

if TYPE_CHECKING:
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, ably: Union[AblyRest, AblyRealtime], options: Options):
raise ValueError("Can't authenticate via token, must provide "
"auth_callback, auth_url, key, token or a TokenDetail")

@optional_sync
@run_async_safe
async def get_auth_transport_param(self):
auth_credentials = {}
if self.auth_options.client_id:
Expand Down Expand Up @@ -155,11 +155,11 @@ def token_details_has_expired(self):

return expires < timestamp + token_details.TOKEN_EXPIRY_BUFFER

@optional_sync
@run_async_safe
async def authorize(self, token_params: Optional[dict] = None, auth_options=None):
return await self.__authorize_when_necessary(token_params, auth_options, force=True)

@optional_sync
@run_async_safe
async def request_token(self, token_params: Optional[dict] = None,
# auth_options
key_name: Optional[str] = None, key_secret: Optional[str] = None, auth_callback=None,
Expand Down Expand Up @@ -238,7 +238,7 @@ async def request_token(self, token_params: Optional[dict] = None,
log.debug("Token: %s" % str(response_dict.get("token")))
return TokenDetails.from_dict(response_dict)

@optional_sync
@run_async_safe
async def create_token_request(self, token_params: Optional[dict] = None, key_name: Optional[str] = None,
key_secret: Optional[str] = None, query_time=None):
token_params = token_params or {}
Expand Down Expand Up @@ -396,7 +396,7 @@ def _timestamp(self):
def _random_nonce(self):
return uuid.uuid4().hex[:16]

@optional_sync
@run_async_safe
async def token_request_from_auth_url(self, method: str, url: str, token_params,
headers, auth_params):
body = None
Expand Down
14 changes: 7 additions & 7 deletions ably/rest/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from methoddispatch import SingleDispatch, singledispatch
import msgpack

from ably.decorator.sync import optional_sync
from ably.decorator.sync import run_async_safe
from ably.http.paginatedresult import PaginatedResult, format_params
from ably.types.channeldetails import ChannelDetails
from ably.types.message import Message, make_message_response_handler
Expand All @@ -30,7 +30,7 @@ def __init__(self, ably, name, options):
self.__presence = Presence(self)

@catch_all
@optional_sync
@run_async_safe
async def history(self, direction=None, limit: int = None, start=None, end=None):
"""Returns the history for this channel"""
params = format_params({}, direction=direction, start=start, end=end, limit=limit)
Expand Down Expand Up @@ -83,12 +83,12 @@ def _publish(self, arg, *args, **kwargs):
raise TypeError('Unexpected type %s' % type(arg))

@_publish.register(Message)
@optional_sync
@run_async_safe
async def publish_message(self, message, params=None, timeout=None):
return await self.publish_messages([message], params, timeout=timeout)

@_publish.register(list)
@optional_sync
@run_async_safe
async def publish_messages(self, messages, params=None, timeout=None):
request_body = self.__publish_request_body(messages)
if not self.ably.options.use_binary_protocol:
Expand All @@ -103,12 +103,12 @@ async def publish_messages(self, messages, params=None, timeout=None):
return await self.ably.http.post(path, body=request_body, timeout=timeout)

@_publish.register(str)
@optional_sync
@run_async_safe
async def publish_name_data(self, name, data, timeout=None):
messages = [Message(name, data)]
return await self.publish_messages(messages, timeout=timeout)

@optional_sync
@run_async_safe
async def publish(self, *args, **kwargs):
"""Publishes a message on this channel.
Expand Down Expand Up @@ -137,7 +137,7 @@ async def publish(self, *args, **kwargs):

return await self._publish(*args, **kwargs)

@optional_sync
@run_async_safe
async def status(self):
"""Retrieves current channel active status with no. of publishers, subscribers, presence_members etc"""

Expand Down
6 changes: 3 additions & 3 deletions ably/rest/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional
from urllib.parse import urlencode

from ably.decorator.sync import optional_sync, close_app_eventloop
from ably.decorator.sync import run_async_safe, close_app_eventloop
from ably.http.http import Http
from ably.http.paginatedresult import PaginatedResult, HttpPaginatedResponse
from ably.http.paginatedresult import format_params
Expand Down Expand Up @@ -92,7 +92,7 @@ async def stats(self, direction: Optional[str] = None, start=None, end=None, par
self.http, url=url, response_processor=stats_response_processor)

@catch_all
@optional_sync
@run_async_safe
async def time(self, timeout: Optional[float] = None) -> float:
"""Returns the current server time in ms since the unix epoch"""
r = await self.http.get('/time', skip_auth=True, timeout=timeout)
Expand Down Expand Up @@ -124,7 +124,7 @@ def options(self):
def push(self):
return self.__push

@optional_sync
@run_async_safe
async def request(self, method: str, path: str, version: str, params:
Optional[dict] = None, body=None, headers=None):
if version is None:
Expand Down

0 comments on commit 5665509

Please sign in to comment.