Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Async lib not found error #440

Merged
merged 7 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

- Retry Querier request on `AsyncLibraryNotFoundError`

## [0.14.10] - 2023-09-31

- Uses nest_asyncio patch in event loop - sync to async
- Uses `nest_asyncio` patch in event loop - sync to async

## [0.14.9] - 2023-09-28

Expand Down
11 changes: 6 additions & 5 deletions supertokens_python/async_to_sync_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
_T = TypeVar("_T")


def check_event_loop():
def create_or_get_event_loop() -> asyncio.AbstractEventLoop:
try:
asyncio.get_event_loop()
except RuntimeError as ex:
return asyncio.get_event_loop()
except Exception as ex:
if "There is no current event loop in thread" in str(ex):
loop = asyncio.new_event_loop()
nest_asyncio.apply(loop) # type: ignore
asyncio.set_event_loop(loop)
return loop
raise ex


def sync(co: Coroutine[Any, Any, _T]) -> _T:
check_event_loop()
loop = asyncio.get_event_loop()
loop = create_or_get_event_loop()
return loop.run_until_complete(co)
87 changes: 61 additions & 26 deletions supertokens_python/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from .exceptions import raise_general_exception
from .process_state import AllowedProcessStates, ProcessState
from .utils import find_max_version, is_4xx_error, is_5xx_error
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop
from sniffio import AsyncLibraryNotFoundError


class Querier:
Expand Down Expand Up @@ -71,6 +73,35 @@ def get_hosts_alive_for_testing():
raise_general_exception("calling testing function in non testing env")
return Querier.__hosts_alive_for_testing

async def api_request(
self,
url: str,
method: str,
attempts_remaining: int,
*args: Any,
**kwargs: Any,
) -> Response:
if attempts_remaining == 0:
raise_general_exception("Retry request failed")

try:
async with AsyncClient() as client:
if method == "GET":
return await client.get(url, *args, **kwargs) # type: ignore
if method == "POST":
return await client.post(url, *args, **kwargs) # type: ignore
if method == "PUT":
return await client.put(url, *args, **kwargs) # type: ignore
if method == "DELETE":
return await client.delete(url, *args, **kwargs) # type: ignore
raise Exception("Shouldn't come here")
except AsyncLibraryNotFoundError:
# Retry
loop = create_or_get_event_loop()
return loop.run_until_complete(
self.api_request(url, method, attempts_remaining - 1, *args, **kwargs)
)

async def get_api_version(self):
if Querier.api_version is not None:
return Querier.api_version
Expand All @@ -79,12 +110,11 @@ async def get_api_version(self):
AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION
)

async def f(url: str) -> Response:
async def f(url: str, method: str) -> Response:
headers = {}
if Querier.__api_key is not None:
headers = {API_KEY_HEADER: Querier.__api_key}
async with AsyncClient() as client:
return await client.get(url, headers=headers) # type:ignore
return await self.api_request(url, method, 2, headers=headers)

response = await self.__send_request_helper(
NormalisedURLPath(API_VERSION), "GET", f, len(self.__hosts)
Expand Down Expand Up @@ -134,13 +164,14 @@ async def send_get_request(
if params is None:
params = {}

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.get( # type:ignore
url,
params=params,
headers=await self.__get_headers_with_api_version(path),
)
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
params=params,
)

return await self.__send_request_helper(path, "GET", f, len(self.__hosts))

Expand All @@ -163,9 +194,14 @@ async def send_post_request(
headers = await self.__get_headers_with_api_version(path)
headers["content-type"] = "application/json; charset=utf-8"

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.post(url, json=data, headers=headers) # type: ignore
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
json=data,
)

return await self.__send_request_helper(path, "POST", f, len(self.__hosts))

Expand All @@ -175,13 +211,14 @@ async def send_delete_request(
if params is None:
params = {}

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.delete( # type:ignore
url,
params=params,
headers=await self.__get_headers_with_api_version(path),
)
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
params=params,
)

return await self.__send_request_helper(path, "DELETE", f, len(self.__hosts))

Expand All @@ -194,9 +231,8 @@ async def send_put_request(
headers = await self.__get_headers_with_api_version(path)
headers["content-type"] = "application/json; charset=utf-8"

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.put(url, json=data, headers=headers) # type: ignore
async def f(url: str, method: str) -> Response:
return await self.api_request(url, method, 2, headers=headers, json=data)

return await self.__send_request_helper(path, "PUT", f, len(self.__hosts))

Expand All @@ -223,7 +259,7 @@ async def __send_request_helper(
self,
path: NormalisedURLPath,
method: str,
http_function: Callable[[str], Awaitable[Response]],
http_function: Callable[[str, str], Awaitable[Response]],
no_of_tries: int,
retry_info_map: Optional[Dict[str, int]] = None,
) -> Any:
Expand Down Expand Up @@ -253,7 +289,7 @@ async def __send_request_helper(
ProcessState.get_instance().add_state(
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
)
response = await http_function(url)
response = await http_function(url, method)
if ("SUPERTOKENS_ENV" in environ) and (
environ["SUPERTOKENS_ENV"] == "testing"
):
Expand Down Expand Up @@ -289,7 +325,6 @@ async def __send_request_helper(
return response.json()
except JSONDecodeError:
return response.text

except (ConnectionError, NetworkError, ConnectTimeout) as _:
return await self.__send_request_helper(
path, method, http_function, no_of_tries - 1, retry_info_map
Expand Down
5 changes: 2 additions & 3 deletions supertokens_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from httpx import HTTPStatusError, Response
from tldextract import extract # type: ignore

from supertokens_python.async_to_sync_wrapper import check_event_loop
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop
from supertokens_python.framework.django.framework import DjangoFramework
from supertokens_python.framework.fastapi.framework import FastapiFramework
from supertokens_python.framework.flask.framework import FlaskFramework
Expand Down Expand Up @@ -212,8 +212,7 @@ def execute_async(mode: str, func: Callable[[], Coroutine[Any, Any, None]]):
if real_mode == "wsgi":
asyncio.run(func())
else:
check_event_loop()
loop = asyncio.get_event_loop()
loop = create_or_get_event_loop()
loop.create_task(func())


Expand Down
Loading