diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 1533f75f..87ce2767 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple from flask import Flask, request, jsonify from supertokens_python import process_state -from supertokens_python.framework import BaseRequest +from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.ingredients.smsdelivery.types import SMSDeliveryConfig from supertokens_python.recipe import ( @@ -53,7 +53,7 @@ thirdparty, emailverification, ) -from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session import InputErrorHandlers, SessionContainer from supertokens_python.recipe.session.framework.flask import verify_session from supertokens_python.recipe.thirdparty.provider import UserFields, UserInfoMap from supertokens_python.recipe_module import RecipeModule @@ -267,6 +267,14 @@ def init_st(config: Dict[str, Any]): ) elif recipe_id == "session": + + async def custom_unauthorised_callback( + _: BaseRequest, __: str, response: BaseResponse + ) -> BaseResponse: + response.set_status_code(401) + response.set_json_content(content={"type": "UNAUTHORISED"}) + return response + recipe_config_json = json.loads(recipe_config.get("config", "{}")) recipe_list.append( session.init( @@ -302,6 +310,9 @@ def init_st(config: Dict[str, Any]): ), ), ), + error_handlers=InputErrorHandlers( + on_unauthorised=custom_unauthorised_callback + ), ) ) elif recipe_id == "accountlinking": @@ -437,7 +448,28 @@ def init_st(config: Dict[str, Any]): ev_config["override"] = emailverification.InputOverrideConfig( functions=override_functions ) - recipe_list.append(emailverification.init(**ev_config)) + from supertokens_python.recipe.emailverification.interfaces import ( + UnknownUserIdError, + ) + + recipe_list.append( + emailverification.init( + **ev_config, + get_email_for_recipe_user_id=callback_with_log( + "EmailVerification.getEmailForRecipeUserId", + recipe_config_json.get("getEmailForRecipeUserId"), + UnknownUserIdError(), + ), + email_delivery=EmailDeliveryConfig( + override=override_builder_with_logging( + "EmailVerification.emailDelivery.override", + recipe_config_json.get("emailDelivery", {}).get( + "override", None + ), + ) + ), + ) + ) elif recipe_id == "multifactorauth": recipe_config_json = json.loads(recipe_config.get("config", "{}")) recipe_list.append( diff --git a/tests/test-server/emailverification.py b/tests/test-server/emailverification.py index 208c4a9f..220b87a9 100644 --- a/tests/test-server/emailverification.py +++ b/tests/test-server/emailverification.py @@ -113,21 +113,16 @@ def update_session_if_required_post_email_verification(): # type: ignore recipe_user_id_whose_email_got_verified = RecipeUserId( data["recipeUserIdWhoseEmailGotVerified"]["recipeUserId"] ) - session = ( - convert_session_to_container(data["session"]) if "session" in data else None - ) - - try: - session_resp = async_to_sync_wrapper.sync( - EmailVerificationRecipe.get_instance_or_throw().update_session_if_required_post_email_verification( - recipe_user_id_whose_email_got_verified=recipe_user_id_whose_email_got_verified, - session=session, - req=FlaskRequest(request), - user_context=data.get("userContext", {}), - ) - ) - return jsonify( - None if session_resp is None else convert_session_to_json(session_resp) + session = convert_session_to_container(data) if "session" in data else None + + session_resp = async_to_sync_wrapper.sync( + EmailVerificationRecipe.get_instance_or_throw().update_session_if_required_post_email_verification( + recipe_user_id_whose_email_got_verified=recipe_user_id_whose_email_got_verified, + session=session, + req=FlaskRequest(request), + user_context=data.get("userContext", {}), ) - except Exception as e: - return jsonify({"status": "ERROR", "message": str(e)}), 500 + ) + return jsonify( + None if session_resp is None else convert_session_to_json(session_resp) + ) diff --git a/tests/test-server/supertokens.py b/tests/test-server/supertokens.py index 3cc40d50..b7ecc752 100644 --- a/tests/test-server/supertokens.py +++ b/tests/test-server/supertokens.py @@ -22,7 +22,7 @@ def delete_user_api(): # type: ignore assert request.json is not None delete_user( request.json["userId"], - request.json["removeAllLinkedAccounts"], + request.json.get("removeAllLinkedAccounts", True), request.json.get("userContext"), ) return jsonify({"status": "OK"}) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index c4f5d472..07be9b24 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -14,12 +14,17 @@ PasswordResetTokenInvalidError, SignUpPostNotAllowedResponse, SignUpPostOkResult, + UnknownUserIdError, ) from supertokens_python.recipe.emailpassword.types import ( EmailDeliveryOverrideInput, EmailTemplateVars, FormField, ) +from supertokens_python.recipe.emailverification.interfaces import ( + EmailDoesNotExistError, + GetEmailForUserIdOkResult, +) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo from supertokens_python.recipe.thirdparty.types import ( @@ -47,6 +52,39 @@ def func(*args): # type: ignore return func # type: ignore + elif eval_str.startswith("emailverification.init.emailDelivery.override"): + from supertokens_python.recipe.emailverification.types import ( + EmailDeliveryOverrideInput as EVEmailDeliveryOverrideInput, + EmailTemplateVars as EVEmailTemplateVars, + ) + + def custom_email_delivery_override( + original_implementation: EVEmailDeliveryOverrideInput, + ) -> EVEmailDeliveryOverrideInput: + original_send_email = original_implementation.send_email + + async def send_email( + template_vars: EVEmailTemplateVars, user_context: Dict[str, Any] + ) -> None: + global userInCallback # pylint: disable=global-variable-not-assigned + global token # pylint: disable=global-variable-not-assigned + + if template_vars.user: + userInCallback = template_vars.user + + if template_vars.email_verify_link: + token = template_vars.email_verify_link.split("?token=")[1].split( + "&tenantId=" + )[0] + + # Call the original implementation + await original_send_email(template_vars, user_context) + + original_implementation.send_email = send_email + return original_implementation + + return custom_email_delivery_override + elif eval_str.startswith("emailpassword.init.emailDelivery.override"): def custom_email_deliver( @@ -500,6 +538,27 @@ async def get_user_info5( return custom_provider + if eval_str.startswith("emailverification.init.getEmailForRecipeUserId"): + + async def get_email_for_recipe_user_id( + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], + ) -> Union[ + GetEmailForUserIdOkResult, EmailDoesNotExistError, UnknownUserIdError + ]: + if "random@example.com" in eval_str: + return GetEmailForUserIdOkResult(email="random@example.com") + + if ( + hasattr(recipe_user_id, "get_as_string") + and recipe_user_id.get_as_string() == "random" + ): + return GetEmailForUserIdOkResult(email="test@example.com") + + return UnknownUserIdError() + + return get_email_for_recipe_user_id + raise Exception("Unknown eval string: " + eval_str)