From c1f197511940444b3f4123e66709c1db323f916f Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 10 Oct 2024 12:27:57 +0530 Subject: [PATCH] fixes stuff --- .../recipe/passwordless/interfaces.py | 8 + .../recipe/session/asyncio/__init__.py | 4 +- .../recipe/session/interfaces.py | 5 +- .../recipe/session/recipe_implementation.py | 3 +- .../recipe/session/session_class.py | 1 + .../claims/test_primitive_array_claim.py | 57 +++--- tests/sessions/claims/test_primitive_claim.py | 45 +++-- tests/sessions/claims/utils.py | 10 +- tests/test-server/app.py | 5 +- tests/test-server/passwordless.py | 163 ++++++++++++++++++ tests/test-server/session.py | 39 +++++ tests/test-server/utils.py | 89 ++++++++-- 12 files changed, 351 insertions(+), 78 deletions(-) create mode 100644 tests/test-server/passwordless.py diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index 4a92ae6e..87753906 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -111,6 +111,14 @@ def from_json(json: Dict[str, Any]) -> ConsumedDevice: phone_number=json["phoneNumber"] if "phoneNumber" in json else None, ) + def to_json(self) -> Dict[str, Any]: + return { + "preAuthSessionId": self.pre_auth_session_id, + "failedCodeInputAttemptCount": self.failed_code_input_attempt_count, + "email": self.email, + "phoneNumber": self.phone_number, + } + class ConsumeCodeOkResult: def __init__( diff --git a/supertokens_python/recipe/session/asyncio/__init__.py b/supertokens_python/recipe/session/asyncio/__init__.py index 346d111e..da67f969 100644 --- a/supertokens_python/recipe/session/asyncio/__init__.py +++ b/supertokens_python/recipe/session/asyncio/__init__.py @@ -123,7 +123,9 @@ async def create_new_session_without_request_response( user_id = user.id for claim in claims_added_by_other_recipes: - update = await claim.build(user_id, recipe_user_id, tenant_id, user_context) + update = await claim.build( + user_id, recipe_user_id, tenant_id, final_access_token_payload, user_context + ) final_access_token_payload = {**final_access_token_payload, **update} return await SessionRecipe.get_instance().recipe_implementation.create_new_session( diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 38945584..4075f848 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -677,11 +677,8 @@ async def build( recipe_user_id: RecipeUserId, tenant_id: str, current_payload: Dict[str, Any], - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any], ) -> JSONObject: - if user_context is None: - user_context = {} - value = await resolve( self.fetch_value( user_id, recipe_user_id, tenant_id, current_payload, user_context diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index 9a5fd447..363259ef 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -139,7 +139,7 @@ async def validate_claims( log_debug_message( "update_claims_in_payload_if_needed %s refetch result %s", validator.id, - json.dumps(value), + value, ) if value is not None: access_token_payload = validator.claim.add_to_payload_( @@ -425,6 +425,7 @@ async def fetch_and_set_claim( session_info.user_id, session_info.recipe_user_id, session_info.tenant_id, + session_info.custom_claims_in_access_token_payload, user_context, ) return await self.merge_into_access_token_payload( diff --git a/supertokens_python/recipe/session/session_class.py b/supertokens_python/recipe/session/session_class.py index cf4ea2b4..6150fb0e 100644 --- a/supertokens_python/recipe/session/session_class.py +++ b/supertokens_python/recipe/session/session_class.py @@ -240,6 +240,7 @@ async def fetch_and_set_claim( self.get_user_id(user_context=user_context), self.get_recipe_user_id(user_context=user_context), self.get_tenant_id(user_context=user_context), + self.get_access_token_payload(user_context=user_context), user_context, ) return await self.merge_into_access_token_payload(update, user_context) diff --git a/tests/sessions/claims/test_primitive_array_claim.py b/tests/sessions/claims/test_primitive_array_claim.py index 35c930d9..ef2b8ef7 100644 --- a/tests/sessions/claims/test_primitive_array_claim.py +++ b/tests/sessions/claims/test_primitive_array_claim.py @@ -1,5 +1,5 @@ import math -from typing import List, Tuple +from typing import Any, Dict, List, Tuple from unittest.mock import MagicMock from pytest import fixture, mark @@ -58,30 +58,31 @@ def patch_get_timestamp_ms(pac_time_patch: Tuple[MockerFixture, int]): async def test_primitive_claim(timestamp: int): claim = PrimitiveArrayClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + ctx: Dict[str, Any] = {} + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, ctx) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_without_async_fetch_value(timestamp: int): claim = PrimitiveArrayClaim("key", async_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + ctx: Dict[str, Any] = {} + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, ctx) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_matching__add_to_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + ctx: Dict[str, Any] = {} + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, ctx) assert res == claim.add_to_payload_({}, val, {}) async def test_primitive_claim_fetch_value_params_correct(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - user_id, ctx = "user_id", {} + user_id = "user_id" + ctx: Dict[str, Any] = {} recipe_user_id = RecipeUserId(user_id) - await claim.build(user_id, recipe_user_id, DEFAULT_TENANT_ID, ctx) + await claim.build(user_id, recipe_user_id, DEFAULT_TENANT_ID, {}, ctx) assert sync_fetch_value.call_count == 1 assert ( user_id, @@ -99,8 +100,8 @@ async def test_primitive_claim_fetch_value_none(): fetch_value_none.return_value = None claim = PrimitiveArrayClaim("key", fetch_value_none) - user_id, ctx = "user_id", {} - res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) + user_id = "user_id" + res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, {}, {}) assert res == {} @@ -129,7 +130,7 @@ async def test_get_last_refetch_time_empty_payload(): async def test_should_return_none_for_empty_payload(timestamp: int): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) assert claim.get_last_refetch_time(payload) == timestamp @@ -153,7 +154,7 @@ async def test_validators_should_not_validate_empty_payload(): async def test_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.includes(excluded_item).validate(payload, {}) @@ -168,7 +169,7 @@ async def test_should_not_validate_mismatching_payload(): async def test_validator_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.includes(included_item).validate(payload, {}) @@ -178,7 +179,7 @@ async def test_validator_should_validate_matching_payload(): async def test_should_not_validate_old_values(patch_get_timestamp_ms: MagicMock): claim = claim_with_inf_max_age payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 1 week @@ -198,7 +199,7 @@ async def test_should_validate_old_values_if_max_age_is_none_and_default_is_inf( ): claim = claim_with_inf_max_age payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 1 week @@ -218,7 +219,7 @@ async def test_should_refetch_if_value_not_set(): async def test_validator_should_not_refetch_if_value_is_set(): claim = claim_with_inf_max_age payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) assert ( await resolve( @@ -233,7 +234,7 @@ async def test_validator_should_refetch_if_value_is_old( ): claim = claim_with_inf_max_age payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 1 week @@ -252,7 +253,7 @@ async def test_validator_should_not_refetch_if_max_age_is_none_and_default_is_in ): claim = claim_with_inf_max_age payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 1 week @@ -271,7 +272,7 @@ async def test_validator_should_validate_values_with_default_max_age( ): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 10 MINS: @@ -286,7 +287,7 @@ async def test_validator_should_not_refetch_if_max_age_overrides_to_inf( ): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 1 week @@ -321,7 +322,7 @@ async def test_validator_excludes_should_not_validate_empty_payload(): async def test_validator_excludes_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.excludes(included_item).validate(payload, {}) @@ -336,7 +337,7 @@ async def test_validator_excludes_should_not_validate_mismatching_payload(): async def test_validator_excludes_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.excludes(excluded_item).validate(payload, {}) @@ -361,7 +362,7 @@ async def test_validator_includes_all_should_not_validate_empty_payload(): async def test_validator_includes_all_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.includes_all(excluded_item).validate(payload, {}) @@ -376,7 +377,7 @@ async def test_validator_includes_all_should_not_validate_mismatching_payload(): async def test_validator_includes_all_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.includes_all(included_item).validate(payload, {}) @@ -401,7 +402,7 @@ async def test_validator_excludes_all_should_not_validate_empty_payload(): async def test_validator_excludes_all_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.excludes_all(included_item).validate(payload, {}) @@ -416,7 +417,7 @@ async def test_validator_excludes_all_should_not_validate_mismatching_payload(): async def test_validator_excludes_all_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.excludes_all(excluded_item).validate(payload, {}) @@ -428,7 +429,7 @@ async def test_validator_should_not_validate_older_values_with_5min_default_max_ ): claim = PrimitiveArrayClaim("key", sync_fetch_value, 300) # 5 mins payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 10 MINS: diff --git a/tests/sessions/claims/test_primitive_claim.py b/tests/sessions/claims/test_primitive_claim.py index 3fde9dd3..bbfa2d81 100644 --- a/tests/sessions/claims/test_primitive_claim.py +++ b/tests/sessions/claims/test_primitive_claim.py @@ -25,35 +25,32 @@ def teardown_function(_): async def test_primitive_claim(timestamp: int): claim = PrimitiveClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, {}) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_without_async_fetch_value(timestamp: int): claim = PrimitiveClaim("key", async_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, {}) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_matching__add_to_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, {}) assert res == claim.add_to_payload_({}, val, {}) async def test_primitive_claim_fetch_value_params_correct(): claim = PrimitiveClaim("key", sync_fetch_value) - user_id, ctx = "user_id", {} - await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) + user_id = "user_id" + await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, {}, {}) assert sync_fetch_value.call_count == 1 assert ( user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, - ctx, + {}, {}, ) == sync_fetch_value.call_args_list[0][ 0 @@ -65,8 +62,8 @@ async def test_primitive_claim_fetch_value_none(): fetch_value_none.return_value = None claim = PrimitiveClaim("key", fetch_value_none) - user_id, ctx = "user_id", {} - res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) + user_id = "user_id" + res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, {}, {}) assert res == {} @@ -97,7 +94,7 @@ async def test_get_last_refetch_time_empty_payload(): async def test_should_return_none_for_empty_payload(timestamp: int): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) assert claim.get_last_refetch_time(payload) == timestamp @@ -121,7 +118,7 @@ async def test_validators_should_not_validate_empty_payload(): async def test_should_not_validate_mismatching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.has_value(val2).validate(payload, {}) @@ -136,7 +133,7 @@ async def test_should_not_validate_mismatching_payload(): async def test_validator_should_validate_matching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.has_value(val).validate(payload, {}) @@ -146,7 +143,7 @@ async def test_validator_should_validate_matching_payload(): async def test_should_validate_old_values_as_well(patch_get_timestamp_ms: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 10 mins: @@ -166,7 +163,7 @@ async def test_should_refetch_if_value_not_set(): async def test_validator_should_not_refetch_if_value_is_set(): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) assert ( await resolve(claim.validators.has_value(val2).should_refetch(payload, {})) @@ -192,7 +189,7 @@ async def test_should_not_validate_empty_payload(): async def test_has_fresh_value_should_not_validate_mismatching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.has_value(val2, 600).validate(payload, {}) assert res.is_valid is False @@ -206,7 +203,7 @@ async def test_has_fresh_value_should_not_validate_mismatching_payload(): async def test_should_validate_matching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.has_value(val, 600).validate(payload, {}) assert res.is_valid is True @@ -218,7 +215,7 @@ async def test_should_not_validate_old_values_as_well( claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 10 mins: @@ -236,14 +233,16 @@ async def test_should_refetch_if_value_is_not_set(): async def test_should_not_refetch_if_value_is_set(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("userId", RecipeUserId("userId"), "public", {}) + payload = await claim.build("userId", RecipeUserId("userId"), "public", {}, {}) assert claim.validators.has_value(val2, 600).should_refetch(payload, {}) is False async def test_should_refetch_if_value_is_old(patch_get_timestamp_ms: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("userId", RecipeUserId("userId"), DEFAULT_TENANT_ID, {}) + payload = await claim.build( + "userId", RecipeUserId("userId"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -256,7 +255,7 @@ async def test_should_not_validate_old_values_as_well_with_default_max_age_provi ): claim = PrimitiveClaim("key", sync_fetch_value, 300) # 5 mins payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 10 mins: @@ -275,7 +274,7 @@ async def test_should_refetch_if_value_is_old_with_default_max_age_provided( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("userId", RecipeUserId("userId"), "public", {}) + payload = await claim.build("userId", RecipeUserId("userId"), "public", {}, {}) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index 92a78c96..b66fd6d6 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -31,11 +31,15 @@ async def new_create_new_session( tenant_id: str, user_context: Dict[str, Any], ): - payload_update = await claim.build( - user_id, RecipeUserId(user_id), tenant_id, user_context - ) if access_token_payload is None: access_token_payload = {} + payload_update = await claim.build( + user_id, + RecipeUserId(user_id), + tenant_id, + access_token_payload, + user_context, + ) access_token_payload = { **access_token_payload, **payload_update, diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 713debd9..8a4a3a02 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -7,7 +7,7 @@ from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe from supertokens_python.recipe.totp.recipe import TOTPRecipe -from utils import init_test_claims # pylint: disable=import-error +from passwordless import add_passwordless_routes # pylint: disable=import-error from supertokens_python.process_state import ProcessState from supertokens_python.recipe.dashboard.recipe import DashboardRecipe from supertokens_python.recipe.emailpassword.recipe import EmailPasswordRecipe @@ -590,8 +590,7 @@ def not_found(error: Any) -> Any: # pylint: disable=unused-argument add_emailverification_routes(app) add_thirdparty_routes(app) add_accountlinking_routes(app) - -init_test_claims() +add_passwordless_routes(app) if __name__ == "__main__": default_st_init() diff --git a/tests/test-server/passwordless.py b/tests/test-server/passwordless.py new file mode 100644 index 00000000..4edb510f --- /dev/null +++ b/tests/test-server/passwordless.py @@ -0,0 +1,163 @@ +from flask import Flask, request, jsonify +from supertokens_python import convert_to_recipe_user_id +from supertokens_python.recipe.passwordless.interfaces import ( + ConsumeCodeExpiredUserInputCodeError, + ConsumeCodeIncorrectUserInputCodeError, + ConsumeCodeOkResult, + ConsumeCodeRestartFlowError, + EmailChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, +) +from supertokens_python.recipe.passwordless.syncio import ( + signinup, + create_code, + update_user, + consume_code, +) +from utils import ( # pylint: disable=import-error + serialize_user, + serialize_recipe_user_id, +) # pylint: disable=import-error +from session import convert_session_to_container # pylint: disable=import-error + + +def add_passwordless_routes(app: Flask): + @app.route("/test/passwordless/signinup", methods=["POST"]) # type: ignore + def sign_in_up_api(): # type: ignore + assert request.json is not None + body = request.json + session = None + if "session" in body: + session = convert_session_to_container(body) + + response = signinup( + email=body.get("email", None), + phone_number=body.get("phoneNumber", None), + tenant_id=body.get("tenantId", "public"), + user_context=body.get("userContext"), + session=session, + ) + return jsonify( + { + "status": "OK", + "createdNewRecipeUser": response.created_new_recipe_user, + "consumedDevice": response.consumed_device.to_json(), + **serialize_user(response.user, request.headers.get("fdi-version", "")), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), + } + ) + + @app.route("/test/passwordless/createcode", methods=["POST"]) # type: ignore + def create_code_api(): # type: ignore + assert request.json is not None + body = request.json + session = None + if "session" in body: + session = convert_session_to_container(body) + + response = create_code( + email=body.get("email"), + phone_number=body.get("phoneNumber"), + tenant_id=body.get("tenantId", "public"), + user_input_code=body.get("userInputCode"), + user_context=body.get("userContext"), + session=session, + ) + return jsonify( + { + "codeId": response.code_id, + "preAuthSessionId": response.pre_auth_session_id, + "codeLifeTime": response.code_life_time, + "deviceId": response.device_id, + "linkCode": response.link_code, + "timeCreated": response.time_created, + "userInputCode": response.user_input_code, + } + ) + + @app.route("/test/passwordless/consumecode", methods=["POST"]) # type: ignore + def consume_code_api(): # type: ignore + assert request.json is not None + body = request.json + session = None + if "session" in body: + session = convert_session_to_container(body) + + response = consume_code( + device_id=body["deviceId"], + pre_auth_session_id=body["preAuthSessionId"], + user_input_code=body.get("userInputCode"), + link_code=body["linkCode"], + tenant_id=body.get("tenantId", "public"), + user_context=body.get("userContext"), + session=session, + ) + + if isinstance(response, ConsumeCodeOkResult): + return jsonify( + { + "status": "OK", + "createdNewRecipeUser": response.created_new_recipe_user, + "consumedDevice": response.consumed_device.to_json(), + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), + } + ) + elif isinstance(response, ConsumeCodeIncorrectUserInputCodeError): + return jsonify( + { + "status": "INCORRECT_USER_INPUT_CODE_ERROR", + "failedCodeInputAttemptCount": response.failed_code_input_attempt_count, + "maximumCodeInputAttempts": response.maximum_code_input_attempts, + } + ) + elif isinstance(response, ConsumeCodeExpiredUserInputCodeError): + return jsonify( + { + "status": "EXPIRED_USER_INPUT_CODE_ERROR", + "failedCodeInputAttemptCount": response.failed_code_input_attempt_count, + "maximumCodeInputAttempts": response.maximum_code_input_attempts, + } + ) + elif isinstance(response, ConsumeCodeRestartFlowError): + return jsonify({"status": "RESTART_FLOW_ERROR"}) + else: + return jsonify( + { + "status": response.status, + "reason": response.reason, + } + ) + + @app.route("/test/passwordless/updateuser", methods=["POST"]) # type: ignore + def update_user_api(): # type: ignore + assert request.json is not None + body = request.json + response = update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + user_context=body.get("userContext"), + ) + + if isinstance(response, UpdateUserOkResult): + return jsonify({"status": "OK"}) + elif isinstance(response, UpdateUserUnknownUserIdError): + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(response, UpdateUserEmailAlreadyExistsError): + return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(response, UpdateUserPhoneNumberAlreadyExistsError): + return jsonify({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(response, EmailChangeNotAllowedError): + return jsonify({"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR"}) + else: + return jsonify({"status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR"}) diff --git a/tests/test-server/session.py b/tests/test-server/session.py index be43bcb4..00c5dcda 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -1,5 +1,6 @@ from typing import Any from flask import Flask, request, jsonify +from override_logging import log_override_event # pylint: disable=import-error from supertokens_python.recipe.session.interfaces import TokenInfo from supertokens_python.recipe.session.jwt import ( parse_jwt_without_signature_verification, @@ -12,6 +13,7 @@ from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe.session.session_class import Session import supertokens_python.recipe.session.syncio as session +from utils import deserialize_claim # pylint: disable=import-error def add_session_routes(app: Flask): @@ -173,6 +175,43 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore } ) + @app.route("/test/session/sessionobject/fetchandsetclaim", methods=["POST"]) # type: ignore + def fetch_and_set_claim_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + log_override_event("sessionobject.fetchandsetclaim", "CALL", data) + session = convert_session_to_container(data) + + claim = deserialize_claim(data["claim"]) + user_context = data.get("userContext", {}) + + session.sync_fetch_and_set_claim(claim, user_context) + response = { + "updatedSession": { + "sessionHandle": session.get_handle(), + "userId": session.get_user_id(), + "recipeUserId": session.get_recipe_user_id().get_as_string(), + "tenantId": session.get_tenant_id(), + "userDataInAccessToken": session.get_access_token_payload(), + "accessToken": session.get_access_token(), + "frontToken": session.get_all_session_tokens_dangerously()[ + "frontToken" + ], + "refreshToken": session.get_all_session_tokens_dangerously()[ + "refreshToken" + ], + "antiCsrfToken": session.get_all_session_tokens_dangerously()[ + "antiCsrfToken" + ], + "accessTokenUpdated": session.get_all_session_tokens_dangerously()[ + "accessAndFrontTokenUpdated" + ], + } + } + return jsonify(response) + def convert_session_to_container(data: Any) -> Session: jwt_info = parse_jwt_without_signature_verification(data["session"]["accessToken"]) diff --git a/tests/test-server/utils.py b/tests/test-server/utils.py index 1791d24f..1d70e803 100644 --- a/tests/test-server/utils.py +++ b/tests/test-server/utils.py @@ -1,26 +1,85 @@ -from typing import Any, Dict +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, +) from supertokens_python.recipe.session.claims import SessionClaim from supertokens_python.recipe.session.interfaces import SessionClaimValidator from supertokens_python.types import RecipeUserId, User - -test_claims: Dict[str, SessionClaim] = {} # type: ignore - - -def init_test_claims(): - add_builtin_claims() - - -def add_builtin_claims(): - from supertokens_python.recipe.emailverification import EmailVerificationClaim - - test_claims[EmailVerificationClaim.key] = EmailVerificationClaim +from override_logging import log_override_event # pylint: disable=import-error +from supertokens_python.recipe.session.claims import BooleanClaim +from supertokens_python.recipe.emailverification import EmailVerificationClaim +from supertokens_python.recipe.userroles import UserRoleClaim +from supertokens_python.recipe.userroles import PermissionClaim +from typing import Any, Dict +from supertokens_python.recipe.session.claims import PrimitiveClaim + + +def mock_claim_builder(key: str, values: Any) -> PrimitiveClaim[Any]: + def fetch_value( + user_id: str, + recipe_user_id: RecipeUserId, + tenant_id: str, + current_payload: Dict[str, Any], + user_context: Dict[str, Any], + ) -> Any: + log_override_event( + f"claim-{key}.fetchValue", + "CALL", + { + "userId": user_id, + "recipeUserId": recipe_user_id.get_as_string(), + "tenantId": tenant_id, + "currentPayload": current_payload, + "userContext": user_context, + }, + ) + + ret_val: Any = user_context.get("st-stub-arr-value") or ( + values[0] + if isinstance(values, list) and isinstance(values[0], list) + else values + ) + log_override_event(f"claim-{key}.fetchValue", "RES", ret_val) + + return ret_val + + return PrimitiveClaim(key=key or "st-stub-primitive", fetch_value=fetch_value) + + +test_claim_setups: Dict[str, SessionClaim[Any]] = { + "st-true": BooleanClaim( + key="st-true", + fetch_value=lambda *_args, **_kwargs: True, # type: ignore + ), + "st-undef": BooleanClaim( + key="st-undef", + fetch_value=lambda *_args, **_kwargs: None, # type: ignore + ), +} + +# Add all built-in claims +for claim in [ + EmailVerificationClaim, + MultiFactorAuthClaim, + UserRoleClaim, + PermissionClaim, +]: + test_claim_setups[claim.key] = claim # type: ignore + + +def deserialize_claim(serialized_claim: Dict[str, Any]) -> SessionClaim[Any]: + key = serialized_claim["key"] + + if key.startswith("st-stub-"): + return mock_claim_builder(key.replace("st-stub-", "", 1), serialized_claim) + + return test_claim_setups[key] def deserialize_validator(validatorsInput: Any) -> SessionClaimValidator: # type: ignore key = validatorsInput["key"] - if key in test_claims: - claim = test_claims[key] # type: ignore + if key in test_claim_setups: + claim = test_claim_setups[key] validator_name = validatorsInput["validatorName"] if hasattr(claim.validators, toSnakeCase(validator_name)): # type: ignore validator_func = getattr(claim.validators, toSnakeCase(validator_name)) # type: ignore