Skip to content

Commit

Permalink
Added sync_enabled property to rest instance, updated decorator accor…
Browse files Browse the repository at this point in the history
…dingly
  • Loading branch information
sacOO7 committed Oct 1, 2023
1 parent 0f777c9 commit cbf171a
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 51 deletions.
52 changes: 41 additions & 11 deletions ably/executer/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
from ably.executer.eventloop import AppEventLoop


def run_safe(fn):
def force_sync(fn):
'''
USAGE :
If called from an eventloop or coroutine, returns a future, doesn't block external eventloop.
If called from a regular function, returns a blocking result.
Also makes async/sync workflow thread safe.
Forces async function to be used as sync function.
Blocks execution of caller till result is returned.
This decorator should only be used on async methods/coroutines.
Completely safe to use for existing async users.
'''
import asyncio

Expand All @@ -30,14 +27,47 @@ def wrapper(*args, **kwargs):
if caller_eventloop is not None and caller_eventloop == app_loop:
return app_loop.create_task(res)

# Block the caller till result is returned
future = asyncio.run_coroutine_threadsafe(res, app_loop)
return future.result()
return res

return wrapper

# Handle calls from external eventloop, post them on app eventloop
# Return awaitable back to external_eventloop
# if caller_eventloop is not None and caller_eventloop.is_running():
# return asyncio.wrap_future(future)

# If called from regular function, return blocking result
def optional_sync(fn):
'''
Executes async function as a sync function if sync_enabled property on the given instance is true.
Blocks execution of caller till result is returned.
This decorator should only be used on async methods/coroutines.
'''
import asyncio

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
if not hasattr(self, 'sync_enabled'):
raise Exception("sync_enabled property should exist on instance to enable this feature")

# Return awaitable like a normal async method if sync is not enabled
if not self.sync_enabled:
return asyncio.create_task(fn(self, *args, **kwargs))

# Handle result of the given async method, with blocking behaviour
caller_eventloop = None
try:
caller_eventloop: events = asyncio.get_running_loop()
except Exception:
pass
app_loop: events = AppEventLoop.current().loop

res = fn(self, *args, **kwargs)
if asyncio.iscoroutine(res):
# Handle calls from app eventloop on the same loop, return awaitable
if caller_eventloop is not None and caller_eventloop == app_loop:
return app_loop.create_task(res)

# Block the caller till result is returned
future = asyncio.run_coroutine_threadsafe(res, app_loop)
return future.result()
return res

Expand Down
20 changes: 12 additions & 8 deletions ably/http/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import httpx
import msgpack

from ably.executer.decorator import run_safe
from ably.executer.decorator import optional_sync
from ably.rest.auth import Auth
from ably.http.httputils import HttpUtils
from ably.transport.defaults import Defaults
Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(self, ably, options):
self.__host_expires = None
self.__client = httpx.AsyncClient(http2=True)

@run_safe
@optional_sync
async def close(self):
await self.__client.aclose()

Expand All @@ -159,7 +159,7 @@ def get_rest_hosts(self):
hosts.insert(0, host)
return hosts

@run_safe
@optional_sync
@reauth_if_expired
async def make_request(self, method, path, version=None, headers=None, body=None,
skip_auth=False, timeout=None, raise_on_error=True):
Expand Down Expand Up @@ -231,36 +231,40 @@ async def make_request(self, method, path, version=None, headers=None, body=None
if retry_count == len(hosts) - 1 or time_passed > http_max_retry_duration:
raise e

@run_safe
@optional_sync
async def delete(self, url, headers=None, skip_auth=False, timeout=None):
result = await self.make_request('DELETE', url, headers=headers,
skip_auth=skip_auth, timeout=timeout)
return result

@run_safe
@optional_sync
async def get(self, url, headers=None, skip_auth=False, timeout=None):
result = await self.make_request('GET', url, headers=headers,
skip_auth=skip_auth, timeout=timeout)
return result

@run_safe
@optional_sync
async def patch(self, url, headers=None, body=None, skip_auth=False, timeout=None):
result = await self.make_request('PATCH', url, headers=headers, body=body,
skip_auth=skip_auth, timeout=timeout)
return result

@run_safe
@optional_sync
async def post(self, url, headers=None, body=None, skip_auth=False, timeout=None):
result = await self.make_request('POST', url, headers=headers, body=body,
skip_auth=skip_auth, timeout=timeout)
return result

@run_safe
@optional_sync
async def put(self, url, headers=None, body=None, skip_auth=False, timeout=None):
result = await self.make_request('PUT', url, headers=headers, body=body,
skip_auth=skip_auth, timeout=timeout)
return result

@property
def sync_enabled(self):
return self.__ably.sync_enabled

@property
def auth(self):
return self.__auth
Expand Down
12 changes: 7 additions & 5 deletions ably/http/paginatedresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from urllib.parse import urlencode

from ably.executer.decorator import run_safe
from ably.executer.decorator import optional_sync
from ably.http.http import Request
from ably.util import case

Expand Down Expand Up @@ -53,6 +53,10 @@ def __init__(self, http, items, content_type, rel_first, rel_next,
self.__response_processor = response_processor
self.response = response

@property
def sync_enabled(self):
return self.__http.ably

@property
def items(self):
return self.__items
Expand All @@ -66,11 +70,11 @@ def has_next(self):
def is_last(self):
return not self.has_next()

@run_safe
@optional_sync
async def first(self):
return await self.__get_rel(self.__rel_first) if self.__rel_first else None

@run_safe
@optional_sync
async def next(self):
return await self.__get_rel(self.__rel_next) if self.__rel_next else None

Expand All @@ -80,7 +84,6 @@ async def __get_rel(self, rel_req):
return await self.paginated_query_with_request(self.__http, rel_req, self.__response_processor)

@classmethod
@run_safe
async def paginated_query(cls, http, method='GET', url='/', version=None, body=None,
headers=None, response_processor=None,
raise_on_error=True):
Expand All @@ -90,7 +93,6 @@ async def paginated_query(cls, http, method='GET', url='/', version=None, body=N
return await cls.paginated_query_with_request(http, req, response_processor)

@classmethod
@run_safe
async def paginated_query_with_request(cls, http, request, response_processor,
raise_on_error=True):
response = await http.make_request(
Expand Down
16 changes: 10 additions & 6 deletions ably/rest/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import uuid
import httpx

from ably.executer.decorator import run_safe
from ably.executer.decorator import optional_sync
from ably.types.options import Options

if TYPE_CHECKING:
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, ably: Union[AblyRest, AblyRealtime], options: Options):
raise ValueError("Can't authenticate via token, must provide "
"auth_callback, auth_url, key, token or a TokenDetail")

@run_safe
@optional_sync
async def get_auth_transport_param(self):
auth_credentials = {}
if self.auth_options.client_id:
Expand Down Expand Up @@ -155,11 +155,11 @@ def token_details_has_expired(self):

return expires < timestamp + token_details.TOKEN_EXPIRY_BUFFER

@run_safe
@optional_sync
async def authorize(self, token_params: Optional[dict] = None, auth_options=None):
return await self.__authorize_when_necessary(token_params, auth_options, force=True)

@run_safe
@optional_sync
async def request_token(self, token_params: Optional[dict] = None,
# auth_options
key_name: Optional[str] = None, key_secret: Optional[str] = None, auth_callback=None,
Expand Down Expand Up @@ -243,7 +243,7 @@ async def request_token(self, token_params: Optional[dict] = None,
log.debug("Token: %s" % str(response_dict.get("token")))
return TokenDetails.from_dict(response_dict)

@run_safe
@optional_sync
async def create_token_request(self, token_params: Optional[dict] = None, key_name: Optional[str] = None,
key_secret: Optional[str] = None, query_time=None):
token_params = token_params or {}
Expand Down Expand Up @@ -308,6 +308,10 @@ async def create_token_request(self, token_params: Optional[dict] = None, key_na

return token_req

@property
def sync_enabled(self):
return self.ably.sync_enabled

@property
def ably(self):
return self.__ably
Expand Down Expand Up @@ -401,7 +405,7 @@ def _timestamp(self):
def _random_nonce(self):
return uuid.uuid4().hex[:16]

@run_safe
@optional_sync
async def token_request_from_auth_url(self, method: str, url: str, token_params,
headers, auth_params):
body = None
Expand Down
12 changes: 8 additions & 4 deletions ably/rest/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from methoddispatch import SingleDispatch, singledispatch
import msgpack

from ably.executer.decorator import run_safe
from ably.executer.decorator import optional_sync
from ably.http.paginatedresult import PaginatedResult, format_params
from ably.types.channeldetails import ChannelDetails
from ably.types.message import Message, make_message_response_handler
Expand All @@ -29,7 +29,7 @@ def __init__(self, ably, name, options):
self.options = options
self.__presence = Presence(self)

@run_safe
@optional_sync
@catch_all
async def history(self, direction=None, limit: int = None, start=None, end=None):
"""Returns the history for this channel"""
Expand Down Expand Up @@ -105,7 +105,7 @@ async def publish_name_data(self, name, data, timeout=None):
messages = [Message(name, data)]
return await self.publish_messages(messages, timeout=timeout)

@run_safe
@optional_sync
async def publish(self, *args, **kwargs):
"""Publishes a message on this channel.
Expand Down Expand Up @@ -134,7 +134,7 @@ async def publish(self, *args, **kwargs):

return await self._publish(*args, **kwargs)

@run_safe
@optional_sync
async def status(self):
"""Retrieves current channel active status with no. of publishers, subscribers, presence_members etc"""

Expand All @@ -143,6 +143,10 @@ async def status(self):
obj = response.to_native()
return ChannelDetails.from_dict(obj)

@property
def sync_enabled(self):
return self.ably.sync_enabled

@property
def ably(self):
return self.__ably
Expand Down
20 changes: 14 additions & 6 deletions ably/rest/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional
from urllib.parse import urlencode

from ably.executer.decorator import run_safe, close_app_eventloop
from ably.executer.decorator import optional_sync, close_app_eventloop
from ably.http.http import Http
from ably.http.paginatedresult import PaginatedResult, HttpPaginatedResponse
from ably.http.paginatedresult import format_params
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(self, key: Optional[str] = None, token: Optional[str] = None,
self._is_realtime
except AttributeError:
self._is_realtime = False

self.__sync_enabled = False
self.__http = Http(self, options)
self.__auth = Auth(self, options)
self.__http.auth = self.__auth
Expand All @@ -82,7 +82,7 @@ async def __aenter__(self):
def __enter__(self):
return self

@run_safe
@optional_sync
@catch_all
async def stats(self, direction: Optional[str] = None, start=None, end=None, params: Optional[dict] = None,
limit: Optional[int] = None, paginated=None, unit=None, timeout=None):
Expand All @@ -92,7 +92,7 @@ async def stats(self, direction: Optional[str] = None, start=None, end=None, par
return await PaginatedResult.paginated_query(
self.http, url=url, response_processor=stats_response_processor)

@run_safe
@optional_sync
@catch_all
async def time(self, timeout: Optional[float] = None) -> float:
"""Returns the current server time in ms since the unix epoch"""
Expand All @@ -113,6 +113,14 @@ def channels(self):
def auth(self):
return self.__auth

@property
def sync_enabled(self):
return self.__sync_enabled

@sync_enabled.setter
def sync_enabled(self, enable_sync):
self.__sync_enabled = enable_sync

@property
def http(self):
return self.__http
Expand All @@ -125,9 +133,9 @@ def options(self):
def push(self):
return self.__push

@run_safe
@optional_sync
async def request(self, method: str, path: str, version: str, params:
Optional[dict] = None, body=None, headers=None):
Optional[dict] = None, body=None, headers=None):
if version is None:
raise AblyException("No version parameter", 400, 40000)

Expand Down
Loading

0 comments on commit cbf171a

Please sign in to comment.