diff --git a/CHANGELOG.md b/CHANGELOG.md index ca30a82d3..f380d3c67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -541,7 +541,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.14.10] - 2023-09-31 -- Uses nest_asyncio patch in event loop - sync to async +- Uses `nest_asyncio` patch in event loop - sync to async ## [0.14.9] - 2023-09-28 diff --git a/supertokens_python/async_to_sync_wrapper.py b/supertokens_python/async_to_sync_wrapper.py index 9c623bf51..0e3d27486 100644 --- a/supertokens_python/async_to_sync_wrapper.py +++ b/supertokens_python/async_to_sync_wrapper.py @@ -19,17 +19,18 @@ _T = TypeVar("_T") -def check_event_loop(): +def create_or_get_event_loop() -> asyncio.AbstractEventLoop: try: - asyncio.get_event_loop() - except RuntimeError as ex: + return asyncio.get_event_loop() + except Exception as ex: if "There is no current event loop in thread" in str(ex): loop = asyncio.new_event_loop() nest_asyncio.apply(loop) # type: ignore asyncio.set_event_loop(loop) + return loop + raise ex def sync(co: Coroutine[Any, Any, _T]) -> _T: - check_event_loop() - loop = asyncio.get_event_loop() + loop = create_or_get_event_loop() return loop.run_until_complete(co) diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index 79f6ae3dc..db520f11a 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -39,6 +39,8 @@ from .exceptions import raise_general_exception from .process_state import AllowedProcessStates, ProcessState from .utils import find_max_version, is_4xx_error, is_5xx_error +from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop +from sniffio import AsyncLibraryNotFoundError class Querier: @@ -71,6 +73,35 @@ def get_hosts_alive_for_testing(): raise_general_exception("calling testing function in non testing env") return Querier.__hosts_alive_for_testing + async def api_request( + self, + url: str, + method: str, + attempts_remaining: int, + *args: Any, + **kwargs: Any, + ) -> Response: + if attempts_remaining == 0: + raise_general_exception("Retry request failed") + + try: + async with AsyncClient() as client: + if method == "GET": + return await client.get(url, *args, **kwargs) # type: ignore + if method == "POST": + return await client.post(url, *args, **kwargs) # type: ignore + if method == "PUT": + return await client.put(url, *args, **kwargs) # type: ignore + if method == "DELETE": + return await client.delete(url, *args, **kwargs) # type: ignore + raise Exception("Shouldn't come here") + except AsyncLibraryNotFoundError: + # Retry + loop = create_or_get_event_loop() + return loop.run_until_complete( + self.api_request(url, method, attempts_remaining - 1, *args, **kwargs) + ) + async def get_api_version(self): if Querier.api_version is not None: return Querier.api_version @@ -79,12 +110,11 @@ async def get_api_version(self): AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION ) - async def f(url: str) -> Response: + async def f(url: str, method: str) -> Response: headers = {} if Querier.__api_key is not None: headers = {API_KEY_HEADER: Querier.__api_key} - async with AsyncClient() as client: - return await client.get(url, headers=headers) # type:ignore + return await self.api_request(url, method, 2, headers=headers) response = await self.__send_request_helper( NormalisedURLPath(API_VERSION), "GET", f, len(self.__hosts) @@ -134,13 +164,14 @@ async def send_get_request( if params is None: params = {} - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.get( # type:ignore - url, - params=params, - headers=await self.__get_headers_with_api_version(path), - ) + async def f(url: str, method: str) -> Response: + return await self.api_request( + url, + method, + 2, + headers=await self.__get_headers_with_api_version(path), + params=params, + ) return await self.__send_request_helper(path, "GET", f, len(self.__hosts)) @@ -163,9 +194,14 @@ async def send_post_request( headers = await self.__get_headers_with_api_version(path) headers["content-type"] = "application/json; charset=utf-8" - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.post(url, json=data, headers=headers) # type: ignore + async def f(url: str, method: str) -> Response: + return await self.api_request( + url, + method, + 2, + headers=await self.__get_headers_with_api_version(path), + json=data, + ) return await self.__send_request_helper(path, "POST", f, len(self.__hosts)) @@ -175,13 +211,14 @@ async def send_delete_request( if params is None: params = {} - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.delete( # type:ignore - url, - params=params, - headers=await self.__get_headers_with_api_version(path), - ) + async def f(url: str, method: str) -> Response: + return await self.api_request( + url, + method, + 2, + headers=await self.__get_headers_with_api_version(path), + params=params, + ) return await self.__send_request_helper(path, "DELETE", f, len(self.__hosts)) @@ -194,9 +231,8 @@ async def send_put_request( headers = await self.__get_headers_with_api_version(path) headers["content-type"] = "application/json; charset=utf-8" - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.put(url, json=data, headers=headers) # type: ignore + async def f(url: str, method: str) -> Response: + return await self.api_request(url, method, 2, headers=headers, json=data) return await self.__send_request_helper(path, "PUT", f, len(self.__hosts)) @@ -223,7 +259,7 @@ async def __send_request_helper( self, path: NormalisedURLPath, method: str, - http_function: Callable[[str], Awaitable[Response]], + http_function: Callable[[str, str], Awaitable[Response]], no_of_tries: int, retry_info_map: Optional[Dict[str, int]] = None, ) -> Any: @@ -253,7 +289,7 @@ async def __send_request_helper( ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER ) - response = await http_function(url) + response = await http_function(url, method) if ("SUPERTOKENS_ENV" in environ) and ( environ["SUPERTOKENS_ENV"] == "testing" ): @@ -289,7 +325,6 @@ async def __send_request_helper( return response.json() except JSONDecodeError: return response.text - except (ConnectionError, NetworkError, ConnectTimeout) as _: return await self.__send_request_helper( path, method, http_function, no_of_tries - 1, retry_info_map diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index a79d182c1..d504b8fe3 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -39,7 +39,7 @@ from httpx import HTTPStatusError, Response from tldextract import extract # type: ignore -from supertokens_python.async_to_sync_wrapper import check_event_loop +from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop from supertokens_python.framework.django.framework import DjangoFramework from supertokens_python.framework.fastapi.framework import FastapiFramework from supertokens_python.framework.flask.framework import FlaskFramework @@ -212,8 +212,7 @@ def execute_async(mode: str, func: Callable[[], Coroutine[Any, Any, None]]): if real_mode == "wsgi": asyncio.run(func()) else: - check_event_loop() - loop = asyncio.get_event_loop() + loop = create_or_get_event_loop() loop.create_task(func())