From 4923604f36ba79655322d6d96b0fb4d5db6a575d Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 24 Aug 2023 16:43:31 +0530 Subject: [PATCH 01/12] feat: Add 429 rate limting from SaaS --- supertokens_python/constants.py | 1 + supertokens_python/querier.py | 32 ++++++- tests/test_querier.py | 150 ++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 3 deletions(-) create mode 100644 tests/test_querier.py diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 070cfb626..b192cfda4 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -29,3 +29,4 @@ API_VERSION_HEADER = "cdi-version" DASHBOARD_VERSION = "0.6" HUNDRED_YEARS_IN_MS = 3153600000000 +RATE_LIMIT_STATUS_CODE = 429 diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index e3da29362..79f6ae3dc 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -13,9 +13,11 @@ # under the License. from __future__ import annotations +import asyncio + from json import JSONDecodeError from os import environ -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional from httpx import AsyncClient, ConnectTimeout, NetworkError, Response @@ -25,6 +27,7 @@ API_VERSION_HEADER, RID_KEY_HEADER, SUPPORTED_CDI_VERSIONS, + RATE_LIMIT_STATUS_CODE, ) from .normalised_url_path import NormalisedURLPath @@ -222,6 +225,7 @@ async def __send_request_helper( method: str, http_function: Callable[[str], Awaitable[Response]], no_of_tries: int, + retry_info_map: Optional[Dict[str, int]] = None, ) -> Any: if no_of_tries == 0: raise_general_exception("No SuperTokens core available to query") @@ -238,6 +242,14 @@ async def __send_request_helper( Querier.__last_tried_index %= len(self.__hosts) url = current_host + path.get_as_string_dangerous() + max_retries = 5 + + if retry_info_map is None: + retry_info_map = {} + + if retry_info_map.get(url) is None: + retry_info_map[url] = max_retries + ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER ) @@ -247,6 +259,20 @@ async def __send_request_helper( ): Querier.__hosts_alive_for_testing.add(current_host) + if response.status_code == RATE_LIMIT_STATUS_CODE: + retries_left = retry_info_map[url] + + if retries_left > 0: + retry_info_map[url] = retries_left - 1 + + attempts_made = max_retries - retries_left + delay = (10 + attempts_made * 250) / 1000 + + await asyncio.sleep(delay) + return await self.__send_request_helper( + path, method, http_function, no_of_tries, retry_info_map + ) + if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore raise_general_exception( "SuperTokens core threw an error for a " @@ -264,9 +290,9 @@ async def __send_request_helper( except JSONDecodeError: return response.text - except (ConnectionError, NetworkError, ConnectTimeout): + except (ConnectionError, NetworkError, ConnectTimeout) as _: return await self.__send_request_helper( - path, method, http_function, no_of_tries - 1 + path, method, http_function, no_of_tries - 1, retry_info_map ) except Exception as e: raise_general_exception(e) diff --git a/tests/test_querier.py b/tests/test_querier.py new file mode 100644 index 000000000..86d12aaf0 --- /dev/null +++ b/tests/test_querier.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from pytest import mark +from supertokens_python.recipe import ( + session, + emailpassword, + emailverification, + dashboard, +) +import asyncio +import respx +import httpx +from supertokens_python import init, SupertokensConfig +from supertokens_python.querier import Querier, NormalisedURLPath + +from tests.utils import get_st_init_args +from tests.utils import ( + setup_function, + teardown_function, + start_st, +) + +_ = setup_function +_ = teardown_function + +pytestmark = mark.asyncio +respx_mock = respx.MockRouter + + +async def test_network_call_is_retried_as_expected(): + # Test that network call is retried properly + # Test that rate limiting errors are thrown back to the user + args = get_st_init_args( + [ + session.init(), + emailpassword.init(), + emailverification.init(mode="OPTIONAL"), + dashboard.init(), + ] + ) + args["supertokens_config"] = SupertokensConfig("http://localhost:6789") + init(**args) # type: ignore + start_st() + + Querier.api_version = "3.0" + q = Querier.get_instance() + + api2_call_count = 0 + + def api2_side_effect(_: httpx.Request): + nonlocal api2_call_count + api2_call_count += 1 + + if api2_call_count == 3: + return httpx.Response(200) + + return httpx.Response(429, json={}) + + with respx_mock() as mocker: + api1 = mocker.get("http://localhost:6789/api1").mock( + httpx.Response(429, json={"status": "RATE_ERROR"}) + ) + api2 = mocker.get("http://localhost:6789/api2").mock( + side_effect=api2_side_effect + ) + api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200)) + + try: + await q.send_get_request(NormalisedURLPath("/api1"), {}) + except Exception as e: + if "with status code: 429" in str( + e + ) and 'message: {"status": "RATE_ERROR"}' in str(e): + pass + else: + raise e + + await q.send_get_request(NormalisedURLPath("/api2"), {}) + await q.send_get_request(NormalisedURLPath("/api3"), {}) + + # 1 initial request + 5 retries + assert api1.call_count == 6 + # 2 403 and 1 200 + assert api2.call_count == 3 + # 200 in the first attempt + assert api3.call_count == 1 + + +async def test_parallel_calls_have_independent_counters(): + args = get_st_init_args( + [ + session.init(), + emailpassword.init(), + emailverification.init(mode="OPTIONAL"), + dashboard.init(), + ] + ) + init(**args) # type: ignore + start_st() + + Querier.api_version = "3.0" + q = Querier.get_instance() + + call_count1 = 0 + call_count2 = 0 + + def api_side_effect(r: httpx.Request): + nonlocal call_count1, call_count2 + + id_ = int(r.url.params.get("id")) + if id_ == 1: + call_count1 += 1 + elif id_ == 2: + call_count2 += 1 + + return httpx.Response(429, json={}) + + with respx_mock() as mocker: + api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect) + + async def call_api(id_: int): + try: + await q.send_get_request(NormalisedURLPath("/api"), {"id": id_}) + except Exception as e: + if "with status code: 429" in str(e): + pass + else: + raise e + + _ = await asyncio.gather( + call_api(1), + call_api(2), + ) + + # 1 initial request + 5 retries + assert call_count1 == 6 + assert call_count2 == 6 + + assert api.call_count == 12 From 4d67c0448f04686ad990b7a925e7578fe16875e5 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 28 Aug 2023 17:25:54 +0530 Subject: [PATCH 02/12] feat: Add retry logic for 429 from SaaS instances --- CHANGELOG.md | 6 +++++- setup.py | 2 +- supertokens_python/constants.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d4e53a53..4bc719ed1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.14.9] - 2023-09-28 + +- Add logic to retry network calls if the core returns status 429 + ## [0.14.8] - 2023-07-07 ## Fixes @@ -148,7 +152,7 @@ if (accessTokenPayload.sub !== undefined) { ```python from supertokens_python.recipe.session.interfaces import SessionContainer -session: SessionContainer = ... +session: SessionContainer = ... access_token_payload = await session.get_access_token_payload() if access_token_payload.get('sub') is not None: diff --git a/setup.py b/setup.py index 7e83f47c4..0fff1a728 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="supertokens_python", - version="0.14.8", + version="0.14.9", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index b192cfda4..32e942570 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["2.21"] -VERSION = "0.14.8" +VERSION = "0.14.9" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" From 025f6eb9b0c576513ddcedd1a9404622a465b9d9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 29 Aug 2023 11:08:51 +0530 Subject: [PATCH 03/12] mods to add dev tag --- addDevTag | 6 ------ 1 file changed, 6 deletions(-) diff --git a/addDevTag b/addDevTag index 1fd4f670e..cdde45c47 100755 --- a/addDevTag +++ b/addDevTag @@ -1,11 +1,5 @@ #!/bin/bash -# check if we need to merge master into this branch------------ -if [[ $(git log origin/master ^HEAD) ]]; then - echo "You need to merge master into this branch. Exiting" - exit 1 -fi - # get version------------ version=`cat setup.py | grep -e 'version='` while IFS='"' read -ra ADDR; do From 33a5d6baa351992c4a2a9a7fe40ce37e2c2fb069 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 29 Aug 2023 11:09:49 +0530 Subject: [PATCH 04/12] adding dev-v0.14.9 tag to this commit to ensure building --- html/supertokens_python/constants.html | 5 ++- html/supertokens_python/querier.html | 59 +++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/html/supertokens_python/constants.html b/html/supertokens_python/constants.html index c022003ff..dc1308bec 100644 --- a/html/supertokens_python/constants.html +++ b/html/supertokens_python/constants.html @@ -42,7 +42,7 @@

Module supertokens_python.constants

from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["2.21"] -VERSION = "0.14.8" +VERSION = "0.14.9" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" @@ -56,7 +56,8 @@

Module supertokens_python.constants

API_VERSION = "/apiversion" API_VERSION_HEADER = "cdi-version" DASHBOARD_VERSION = "0.6" -HUNDRED_YEARS_IN_MS = 3153600000000 +HUNDRED_YEARS_IN_MS = 3153600000000 +RATE_LIMIT_STATUS_CODE = 429
diff --git a/html/supertokens_python/querier.html b/html/supertokens_python/querier.html index 8636020dc..f95a55798 100644 --- a/html/supertokens_python/querier.html +++ b/html/supertokens_python/querier.html @@ -41,9 +41,11 @@

Module supertokens_python.querier

# under the License. from __future__ import annotations +import asyncio + from json import JSONDecodeError from os import environ -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional from httpx import AsyncClient, ConnectTimeout, NetworkError, Response @@ -53,6 +55,7 @@

Module supertokens_python.querier

API_VERSION_HEADER, RID_KEY_HEADER, SUPPORTED_CDI_VERSIONS, + RATE_LIMIT_STATUS_CODE, ) from .normalised_url_path import NormalisedURLPath @@ -250,6 +253,7 @@

Module supertokens_python.querier

method: str, http_function: Callable[[str], Awaitable[Response]], no_of_tries: int, + retry_info_map: Optional[Dict[str, int]] = None, ) -> Any: if no_of_tries == 0: raise_general_exception("No SuperTokens core available to query") @@ -266,6 +270,14 @@

Module supertokens_python.querier

Querier.__last_tried_index %= len(self.__hosts) url = current_host + path.get_as_string_dangerous() + max_retries = 5 + + if retry_info_map is None: + retry_info_map = {} + + if retry_info_map.get(url) is None: + retry_info_map[url] = max_retries + ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER ) @@ -275,6 +287,20 @@

Module supertokens_python.querier

): Querier.__hosts_alive_for_testing.add(current_host) + if response.status_code == RATE_LIMIT_STATUS_CODE: + retries_left = retry_info_map[url] + + if retries_left > 0: + retry_info_map[url] = retries_left - 1 + + attempts_made = max_retries - retries_left + delay = (10 + attempts_made * 250) / 1000 + + await asyncio.sleep(delay) + return await self.__send_request_helper( + path, method, http_function, no_of_tries, retry_info_map + ) + if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore raise_general_exception( "SuperTokens core threw an error for a " @@ -292,9 +318,9 @@

Module supertokens_python.querier

except JSONDecodeError: return response.text - except (ConnectionError, NetworkError, ConnectTimeout): + except (ConnectionError, NetworkError, ConnectTimeout) as _: return await self.__send_request_helper( - path, method, http_function, no_of_tries - 1 + path, method, http_function, no_of_tries - 1, retry_info_map ) except Exception as e: raise_general_exception(e) @@ -503,6 +529,7 @@

Classes

method: str, http_function: Callable[[str], Awaitable[Response]], no_of_tries: int, + retry_info_map: Optional[Dict[str, int]] = None, ) -> Any: if no_of_tries == 0: raise_general_exception("No SuperTokens core available to query") @@ -519,6 +546,14 @@

Classes

Querier.__last_tried_index %= len(self.__hosts) url = current_host + path.get_as_string_dangerous() + max_retries = 5 + + if retry_info_map is None: + retry_info_map = {} + + if retry_info_map.get(url) is None: + retry_info_map[url] = max_retries + ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER ) @@ -528,6 +563,20 @@

Classes

): Querier.__hosts_alive_for_testing.add(current_host) + if response.status_code == RATE_LIMIT_STATUS_CODE: + retries_left = retry_info_map[url] + + if retries_left > 0: + retry_info_map[url] = retries_left - 1 + + attempts_made = max_retries - retries_left + delay = (10 + attempts_made * 250) / 1000 + + await asyncio.sleep(delay) + return await self.__send_request_helper( + path, method, http_function, no_of_tries, retry_info_map + ) + if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore raise_general_exception( "SuperTokens core threw an error for a " @@ -545,9 +594,9 @@

Classes

except JSONDecodeError: return response.text - except (ConnectionError, NetworkError, ConnectTimeout): + except (ConnectionError, NetworkError, ConnectTimeout) as _: return await self.__send_request_helper( - path, method, http_function, no_of_tries - 1 + path, method, http_function, no_of_tries - 1, retry_info_map ) except Exception as e: raise_general_exception(e) From 54595d9fffe375eb1577e66302d750230ea12ca0 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 29 Aug 2023 13:53:32 +0530 Subject: [PATCH 05/12] test: Fix failing test for the 0.14 patch release --- tests/test_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_session.py b/tests/test_session.py index cc8441de4..469869307 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -794,7 +794,7 @@ async def test_expose_access_token_to_frontend_in_cookie_based_auth( assert response.status_code == 200 assert len(response.headers["st-access-token"]) > 0 - reset(stop_core=False) + reset() args = get_st_init_args([session.init(expose_access_token_to_frontend_in_cookie_based_auth=False, get_token_transfer_method=lambda *_: "cookie")]) # type: ignore init(**args) # type: ignore From 8fad5bcdd080d29591141888eef07716df627c33 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 29 Aug 2023 13:56:22 +0530 Subject: [PATCH 06/12] adding dev-v0.14.9 tag to this commit to ensure building From b4b0c4a8ad52838c7f39a5d2167e09aedc8580ac Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 31 Aug 2023 17:14:38 +0530 Subject: [PATCH 07/12] adds nestjs patch --- supertokens_python/async_to_sync_wrapper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/supertokens_python/async_to_sync_wrapper.py b/supertokens_python/async_to_sync_wrapper.py index 4a56ea31b..a8178f0e2 100644 --- a/supertokens_python/async_to_sync_wrapper.py +++ b/supertokens_python/async_to_sync_wrapper.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +import nest_asyncio # type: ignore import asyncio from typing import Any, Coroutine, TypeVar @@ -24,10 +25,12 @@ def check_event_loop(): except RuntimeError as ex: if "There is no current event loop in thread" in str(ex): loop = asyncio.new_event_loop() + nest_asyncio.apply(loop) # type: ignore asyncio.set_event_loop(loop) def sync(co: Coroutine[Any, Any, _T]) -> _T: check_event_loop() loop = asyncio.get_event_loop() + nest_asyncio.apply(loop) # type: ignore return loop.run_until_complete(co) From c1215c8dfc96d805f6a4aa055c64f26603faf86d Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 31 Aug 2023 17:24:18 +0530 Subject: [PATCH 08/12] bumps version --- CHANGELOG.md | 4 ++++ setup.py | 2 +- supertokens_python/constants.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bc719ed1..6ef325495 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.14.10] - 2023-09-31 + +- Uses nest_asyncio patch in event loop - sync to async + ## [0.14.9] - 2023-09-28 - Add logic to retry network calls if the core returns status 429 diff --git a/setup.py b/setup.py index 0fff1a728..bd1100a2f 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="supertokens_python", - version="0.14.9", + version="0.14.10", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 32e942570..b8886e6f2 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["2.21"] -VERSION = "0.14.9" +VERSION = "0.14.10" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" From 4d20336929ce23689f27e39992d5316522c96d15 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 31 Aug 2023 17:35:32 +0530 Subject: [PATCH 09/12] adds missing dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index bd1100a2f..333c5cf59 100644 --- a/setup.py +++ b/setup.py @@ -111,6 +111,7 @@ "phonenumbers==8.12.48", "twilio==7.9.1", "aiosmtplib==1.1.6", + "nest-asyncio==1.5.1", ], python_requires=">=3.7", include_package_data=True, From d900f065687445c058131b859b00b34143e3248b Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 31 Aug 2023 20:10:19 +0530 Subject: [PATCH 10/12] more changes --- supertokens_python/async_to_sync_wrapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/supertokens_python/async_to_sync_wrapper.py b/supertokens_python/async_to_sync_wrapper.py index a8178f0e2..0e9286ee7 100644 --- a/supertokens_python/async_to_sync_wrapper.py +++ b/supertokens_python/async_to_sync_wrapper.py @@ -25,7 +25,6 @@ def check_event_loop(): except RuntimeError as ex: if "There is no current event loop in thread" in str(ex): loop = asyncio.new_event_loop() - nest_asyncio.apply(loop) # type: ignore asyncio.set_event_loop(loop) From dbca4da04d3648143db567e6d327c6423593bdd9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 31 Aug 2023 21:04:16 +0530 Subject: [PATCH 11/12] more changes --- supertokens_python/async_to_sync_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supertokens_python/async_to_sync_wrapper.py b/supertokens_python/async_to_sync_wrapper.py index 0e9286ee7..9c623bf51 100644 --- a/supertokens_python/async_to_sync_wrapper.py +++ b/supertokens_python/async_to_sync_wrapper.py @@ -25,11 +25,11 @@ def check_event_loop(): except RuntimeError as ex: if "There is no current event loop in thread" in str(ex): loop = asyncio.new_event_loop() + nest_asyncio.apply(loop) # type: ignore asyncio.set_event_loop(loop) def sync(co: Coroutine[Any, Any, _T]) -> _T: check_event_loop() loop = asyncio.get_event_loop() - nest_asyncio.apply(loop) # type: ignore return loop.run_until_complete(co) From 077ff3b96dfe37d0e4eb9a0295a54e82cd25389e Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 18 Sep 2023 11:53:45 +0530 Subject: [PATCH 12/12] fix: Handle ec2 instances public url seperately when extracting TLDs --- CHANGELOG.md | 1 + supertokens_python/constants.py | 2 +- supertokens_python/utils.py | 5 +++ tests/test_config.py | 64 +++++++++++++++++++++++++++++++++ tests/test_utils.py | 28 ++++++++++++++- 5 files changed, 98 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ca30a82d3..92651cc18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Uses `nest_asyncio` patch in event loop - sync to async - Retry Querier request on `AsyncLibraryNotFoundError` +- Handle AWS Public URLs (ending with `.amazonaws.com`) separately while extracting TLDs for SameSite attribute. ## [0.16.0] - 2023-09-13 diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 075ff9309..292277329 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["3.0"] -VERSION = "0.16.0" +VERSION = "0.16.1" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index a79d182c1..1b8afd85b 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -299,8 +299,13 @@ def get_top_level_domain_for_same_site_resolution(url: str) -> str: if hostname.startswith("localhost") or is_an_ip_address(hostname): return "localhost" + parsed_url: Any = extract(hostname, include_psl_private_domains=True) if parsed_url.domain == "": # type: ignore + # We need to do this because of https://github.com/supertokens/supertokens-python/issues/394 + if hostname.endswith(".amazonaws.com") and parsed_url.suffix == hostname: + return hostname + raise Exception( "Please make sure that the apiDomain and websiteDomain have correct values" ) diff --git a/tests/test_config.py b/tests/test_config.py index 521827f99..2df57fe59 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -736,3 +736,67 @@ async def test_samesite_invalid_config(): ) else: assert False, "Exception not raised" + + +@mark.asyncio +async def test_cookie_samesite_with_ec2_public_url(): + start_st() + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="https://ec2-xx-yyy-zzz-0.compute-1.amazonaws.com:3001", + website_domain="https://blog.supertokens.com", + api_base_path="/", + ), + framework="fastapi", + recipe_list=[ + session.init(get_token_transfer_method=lambda _, __, ___: "cookie") + ], + ) + + # domain name isn't provided so browser decides to use the same host + # which will be ec2-xx-yyy-zzz-0.compute-1.amazonaws.com + assert SessionRecipe.get_instance().config.cookie_domain is None + assert SessionRecipe.get_instance().config.cookie_same_site == "none" + assert SessionRecipe.get_instance().config.cookie_secure is True + + reset() + + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://ec2-xx-yyy-zzz-0.compute-1.amazonaws.com:3001", + website_domain="http://ec2-aa-bbb-ccc-0.compute-1.amazonaws.com:3000", + api_base_path="/", + ), + framework="fastapi", + recipe_list=[ + session.init(get_token_transfer_method=lambda _, __, ___: "cookie") + ], + ) + + assert SessionRecipe.get_instance().config.cookie_domain is None + assert SessionRecipe.get_instance().config.cookie_same_site == "none" + assert SessionRecipe.get_instance().config.cookie_secure is False + + reset() + + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://ec2-xx-yyy-zzz-0.compute-1.amazonaws.com:3001", + website_domain="http://ec2-xx-yyy-zzz-0.compute-1.amazonaws.com:3000", + api_base_path="/", + ), + framework="fastapi", + recipe_list=[ + session.init(get_token_transfer_method=lambda _, __, ___: "cookie") + ], + ) + + assert SessionRecipe.get_instance().config.cookie_domain is None + assert SessionRecipe.get_instance().config.cookie_same_site == "lax" + assert SessionRecipe.get_instance().config.cookie_secure is False diff --git a/tests/test_utils.py b/tests/test_utils.py index 28b822539..db41552d2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,11 @@ import pytest import threading -from supertokens_python.utils import humanize_time, is_version_gte +from supertokens_python.utils import ( + humanize_time, + is_version_gte, + get_top_level_domain_for_same_site_resolution, +) from supertokens_python.utils import RWMutex from tests.utils import is_subset @@ -171,3 +175,25 @@ def balance_is_valid(): expected_balance -= 10 * 5 # 10 threads withdrawing 5 each actual_balance, _ = account.get_stats() assert actual_balance == expected_balance, "Incorrect account balance" + + +@pytest.mark.parametrize( + "url,res", + [ + ("http://localhost:3001", "localhost"), + ( + "https://ec2-xx-yyy-zzz-0.compute-1.amazonaws.com", + "ec2-xx-yyy-zzz-0.compute-1.amazonaws.com", + ), + ( + "https://foo.vercel.com", + "vercel.com", + ), + ( + "https://blog.supertokens.com", + "supertokens.com", + ), + ], +) +def test_tld_for_same_site(url: str, res: str): + assert get_top_level_domain_for_same_site_resolution(url) == res