Skip to content

Commit

Permalink
Merge pull request #451 from supertokens/nest-asyncio-config
Browse files Browse the repository at this point in the history
feat: Use nest-asyncio when configured with env var
  • Loading branch information
rishabhpoddar authored Sep 21, 2023
2 parents 258790a + 45ae5c8 commit 2a5df36
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 66 deletions.
16 changes: 16 additions & 0 deletions .circleci/config_continue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ jobs:
- run: make with-django2x
- run: (cd .circleci/ && ./websiteDjango2x.sh)
- slack/status
test-website-flask-nest-asyncio:
docker:
- image: rishabhpoddar/supertokens_python_driver_testing
resource_class: large
environment:
SUPERTOKENS_NEST_ASYNCIO: "1"
steps:
- checkout
- run: update-alternatives --install "/usr/bin/java" "java" "/usr/java/jdk-15.0.1/bin/java" 2
- run: update-alternatives --install "/usr/bin/javac" "javac" "/usr/java/jdk-15.0.1/bin/javac" 2
- run: git config --global url."https://github.com/".insteadOf ssh://[email protected]/
- run: echo "127.0.0.1 localhost.org" >> /etc/hosts
- run: make with-flask
- run: python -m pip install nest-asyncio
- run: (cd .circleci/ && ./websiteFlask.sh)
- slack/status
test-authreact-fastapi:
docker:
- image: rishabhpoddar/supertokens_python_driver_testing
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

## [0.16.2] - 2023-09-20

- Allow use of [nest-asyncio](https://pypi.org/project/nest-asyncio/) when env var `SUPERTOKENS_NEST_ASYNCIO=1`.
- Retry Querier request on `AsyncLibraryNotFoundError`

## [0.16.1] - 2023-09-19
- Handle AWS Public URLs (ending with `.amazonaws.com`) separately while extracting TLDs for SameSite attribute.

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

setup(
name="supertokens_python",
version="0.16.1",
version="0.16.2",
author="SuperTokens",
license="Apache 2.0",
author_email="[email protected]",
Expand Down
13 changes: 9 additions & 4 deletions supertokens_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.

from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional

from typing_extensions import Literal

Expand All @@ -32,11 +32,16 @@ def init(
framework: Literal["fastapi", "flask", "django"],
supertokens_config: SupertokensConfig,
recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]],
mode: Union[Literal["asgi", "wsgi"], None] = None,
telemetry: Union[bool, None] = None,
mode: Optional[Literal["asgi", "wsgi"]] = None,
telemetry: Optional[bool] = None,
):
return Supertokens.init(
app_info, framework, supertokens_config, recipe_list, mode, telemetry
app_info,
framework,
supertokens_config,
recipe_list,
mode,
telemetry,
)


Expand Down
22 changes: 17 additions & 5 deletions supertokens_python/async_to_sync_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,32 @@

import asyncio
from typing import Any, Coroutine, TypeVar
from os import getenv

_T = TypeVar("_T")


def check_event_loop():
def nest_asyncio_enabled():
return getenv("SUPERTOKENS_NEST_ASYNCIO", "") == "1"


def create_or_get_event_loop() -> asyncio.AbstractEventLoop:
try:
asyncio.get_event_loop()
except RuntimeError as ex:
return asyncio.get_event_loop()
except Exception as ex:
if "There is no current event loop in thread" in str(ex):
loop = asyncio.new_event_loop()

if nest_asyncio_enabled():
import nest_asyncio # type: ignore

nest_asyncio.apply(loop) # type: ignore

asyncio.set_event_loop(loop)
return loop
raise ex


def sync(co: Coroutine[Any, Any, _T]) -> _T:
check_event_loop()
loop = asyncio.get_event_loop()
loop = create_or_get_event_loop()
return loop.run_until_complete(co)
2 changes: 1 addition & 1 deletion supertokens_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

SUPPORTED_CDI_VERSIONS = ["3.0"]
VERSION = "0.16.1"
VERSION = "0.16.2"
TELEMETRY = "/telemetry"
USER_COUNT = "/users/count"
USER_DELETE = "/user/remove"
Expand Down
86 changes: 61 additions & 25 deletions supertokens_python/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from .exceptions import raise_general_exception
from .process_state import AllowedProcessStates, ProcessState
from .utils import find_max_version, is_4xx_error, is_5xx_error
from sniffio import AsyncLibraryNotFoundError
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop


class Querier:
Expand Down Expand Up @@ -71,6 +73,35 @@ def get_hosts_alive_for_testing():
raise_general_exception("calling testing function in non testing env")
return Querier.__hosts_alive_for_testing

async def api_request(
self,
url: str,
method: str,
attempts_remaining: int,
*args: Any,
**kwargs: Any,
) -> Response:
if attempts_remaining == 0:
raise_general_exception("Retry request failed")

try:
async with AsyncClient() as client:
if method == "GET":
return await client.get(url, *args, **kwargs) # type: ignore
if method == "POST":
return await client.post(url, *args, **kwargs) # type: ignore
if method == "PUT":
return await client.put(url, *args, **kwargs) # type: ignore
if method == "DELETE":
return await client.delete(url, *args, **kwargs) # type: ignore
raise Exception("Shouldn't come here")
except AsyncLibraryNotFoundError:
# Retry
loop = create_or_get_event_loop()
return loop.run_until_complete(
self.api_request(url, method, attempts_remaining - 1, *args, **kwargs)
)

async def get_api_version(self):
if Querier.api_version is not None:
return Querier.api_version
Expand All @@ -79,12 +110,11 @@ async def get_api_version(self):
AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION
)

async def f(url: str) -> Response:
async def f(url: str, method: str) -> Response:
headers = {}
if Querier.__api_key is not None:
headers = {API_KEY_HEADER: Querier.__api_key}
async with AsyncClient() as client:
return await client.get(url, headers=headers) # type:ignore
return await self.api_request(url, method, 2, headers=headers)

response = await self.__send_request_helper(
NormalisedURLPath(API_VERSION), "GET", f, len(self.__hosts)
Expand Down Expand Up @@ -134,13 +164,14 @@ async def send_get_request(
if params is None:
params = {}

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.get( # type:ignore
url,
params=params,
headers=await self.__get_headers_with_api_version(path),
)
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
params=params,
)

return await self.__send_request_helper(path, "GET", f, len(self.__hosts))

Expand All @@ -163,9 +194,14 @@ async def send_post_request(
headers = await self.__get_headers_with_api_version(path)
headers["content-type"] = "application/json; charset=utf-8"

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.post(url, json=data, headers=headers) # type: ignore
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
json=data,
)

return await self.__send_request_helper(path, "POST", f, len(self.__hosts))

Expand All @@ -175,13 +211,14 @@ async def send_delete_request(
if params is None:
params = {}

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.delete( # type:ignore
url,
params=params,
headers=await self.__get_headers_with_api_version(path),
)
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
params=params,
)

return await self.__send_request_helper(path, "DELETE", f, len(self.__hosts))

Expand All @@ -194,9 +231,8 @@ async def send_put_request(
headers = await self.__get_headers_with_api_version(path)
headers["content-type"] = "application/json; charset=utf-8"

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.put(url, json=data, headers=headers) # type: ignore
async def f(url: str, method: str) -> Response:
return await self.api_request(url, method, 2, headers=headers, json=data)

return await self.__send_request_helper(path, "PUT", f, len(self.__hosts))

Expand All @@ -223,7 +259,7 @@ async def __send_request_helper(
self,
path: NormalisedURLPath,
method: str,
http_function: Callable[[str], Awaitable[Response]],
http_function: Callable[[str, str], Awaitable[Response]],
no_of_tries: int,
retry_info_map: Optional[Dict[str, int]] = None,
) -> Any:
Expand Down Expand Up @@ -253,7 +289,7 @@ async def __send_request_helper(
ProcessState.get_instance().add_state(
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
)
response = await http_function(url)
response = await http_function(url, method)
if ("SUPERTOKENS_ENV" in environ) and (
environ["SUPERTOKENS_ENV"] == "testing"
):
Expand Down
15 changes: 10 additions & 5 deletions supertokens_python/supertokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def __init__(
framework: Literal["fastapi", "flask", "django"],
supertokens_config: SupertokensConfig,
recipe_list: List[Callable[[AppInfo], RecipeModule]],
mode: Union[Literal["asgi", "wsgi"], None],
telemetry: Union[bool, None],
mode: Optional[Literal["asgi", "wsgi"]],
telemetry: Optional[bool],
):
if not isinstance(app_info, InputAppInfo): # type: ignore
raise ValueError("app_info must be an instance of InputAppInfo")
Expand Down Expand Up @@ -215,12 +215,17 @@ def init(
framework: Literal["fastapi", "flask", "django"],
supertokens_config: SupertokensConfig,
recipe_list: List[Callable[[AppInfo], RecipeModule]],
mode: Union[Literal["asgi", "wsgi"], None],
telemetry: Union[bool, None],
mode: Optional[Literal["asgi", "wsgi"]],
telemetry: Optional[bool],
):
if Supertokens.__instance is None:
Supertokens.__instance = Supertokens(
app_info, framework, supertokens_config, recipe_list, mode, telemetry
app_info,
framework,
supertokens_config,
recipe_list,
mode,
telemetry,
)
PostSTInitCallbacks.run_post_init_callbacks()

Expand Down
25 changes: 0 additions & 25 deletions supertokens_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

import asyncio
import json
import threading
import warnings
Expand All @@ -27,7 +26,6 @@
Any,
Awaitable,
Callable,
Coroutine,
Dict,
List,
TypeVar,
Expand All @@ -39,7 +37,6 @@
from httpx import HTTPStatusError, Response
from tldextract import extract # type: ignore

from supertokens_python.async_to_sync_wrapper import check_event_loop
from supertokens_python.framework.django.framework import DjangoFramework
from supertokens_python.framework.fastapi.framework import FastapiFramework
from supertokens_python.framework.flask.framework import FlaskFramework
Expand Down Expand Up @@ -195,28 +192,6 @@ def find_first_occurrence_in_list(
return None


def execute_async(mode: str, func: Callable[[], Coroutine[Any, Any, None]]):
real_mode = None
try:
asyncio.get_running_loop()
real_mode = "asgi"
except RuntimeError:
real_mode = "wsgi"

if mode != real_mode:
warnings.warn(
"Inconsistent mode detected, check if you are using the right asgi / wsgi mode",
category=RuntimeWarning,
)

if real_mode == "wsgi":
asyncio.run(func())
else:
check_event_loop()
loop = asyncio.get_event_loop()
loop.create_task(func())


def frontend_has_interceptor(request: BaseRequest) -> bool:
return get_rid_from_header(request) is not None

Expand Down

0 comments on commit 2a5df36

Please sign in to comment.