diff --git a/.vscode/launch.json b/.vscode/launch.json index 9c01eaf8..58eed6c4 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -30,6 +30,18 @@ "FLASK_DEBUG": "1" }, "jinja": true + }, + { + "name": "Python: FastAPI, supertokens-auth-react tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/auth-react/fastapi-server/app.py", + "args": [ + "--port", + "8083" + ], + "cwd": "${workspaceFolder}/tests/auth-react/fastapi-server", + "jinja": true } ] } \ No newline at end of file diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index 3ba30b4a..9d686962 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -13,7 +13,7 @@ # under the License. import os import typing -from typing import Any, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union import uvicorn # type: ignore from dotenv import load_dotenv @@ -26,12 +26,14 @@ from starlette.responses import Response from starlette.types import ASGIApp from typing_extensions import Literal -from supertokens_python.recipe import multitenancy +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe import multifactorauth, multitenancy, totp from supertokens_python import ( InputAppInfo, Supertokens, SupertokensConfig, + convert_to_recipe_user_id, get_all_cors_headers, init, ) @@ -45,10 +47,18 @@ thirdparty, userroles, ) +from supertokens_python.recipe import accountlinking +from supertokens_python.recipe.accountlinking import AccountInfoWithRecipeIdAndUserId +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.dashboard import DashboardRecipe from supertokens_python.recipe.emailpassword import EmailPasswordRecipe +from supertokens_python.recipe.emailpassword.asyncio import update_email_or_password from supertokens_python.recipe.emailpassword.interfaces import ( APIInterface as EmailPasswordAPIInterface, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + UpdateEmailOrPasswordOkResult, ) from supertokens_python.recipe.emailpassword.interfaces import ( APIOptions as EPAPIOptions, @@ -72,18 +82,51 @@ APIOptions as EVAPIOptions, ) from supertokens_python.recipe.jwt import JWTRecipe +from supertokens_python.recipe.multifactorauth.asyncio import ( + add_to_required_secondary_factors_for_user, +) +from supertokens_python.recipe.multifactorauth.interfaces import ( + ResyncSessionAndFetchMFAInfoPUTOkResult, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import MFARequirementList +from supertokens_python.recipe.multitenancy.asyncio import ( + associate_user_to_tenant, + create_or_update_tenant, + create_or_update_third_party_config, + delete_tenant, + disassociate_user_from_tenant, +) +from supertokens_python.recipe.multitenancy.interfaces import ( + AssociateUserToTenantEmailAlreadyExistsError, + AssociateUserToTenantOkResult, + AssociateUserToTenantPhoneNumberAlreadyExistsError, + AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantUnknownUserIdError, + TenantConfigCreateOrUpdate, +) from supertokens_python.recipe.passwordless import ( ContactEmailOnlyConfig, ContactEmailOrPhoneConfig, ContactPhoneOnlyConfig, PasswordlessRecipe, ) +from supertokens_python.recipe.passwordless.asyncio import update_user from supertokens_python.recipe.passwordless.interfaces import ( APIInterface as PasswordlessAPIInterface, + PhoneNumberChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, ) from supertokens_python.recipe.passwordless.interfaces import APIOptions as PAPIOptions from supertokens_python.recipe.session import SessionContainer, SessionRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session.exceptions import ( + ClaimValidationError, + InvalidClaimsError, +) from supertokens_python.recipe.session.framework.fastapi import verify_session from supertokens_python.recipe.session.interfaces import ( APIInterface as SessionAPIInterface, @@ -91,13 +134,19 @@ from supertokens_python.recipe.session.interfaces import APIOptions as SAPIOptions from supertokens_python.recipe.session.interfaces import SessionClaimValidator from supertokens_python.recipe.thirdparty import ( + ProviderConfig, ThirdPartyRecipe, ) +from supertokens_python.recipe.thirdparty.asyncio import manually_create_or_update_user from supertokens_python.recipe.thirdparty.interfaces import ( APIInterface as ThirdpartyAPIInterface, + EmailChangeNotAllowedError, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, ) from supertokens_python.recipe.thirdparty.interfaces import APIOptions as TPAPIOptions from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo +from supertokens_python.recipe.totp.recipe import TOTPRecipe from supertokens_python.recipe.userroles import ( PermissionClaim, @@ -108,8 +157,13 @@ add_role_to_user, create_new_role_or_add_permissions, ) -from supertokens_python.types import AccountInfo, GeneralErrorResponse -from supertokens_python.asyncio import list_users_by_account_info +from supertokens_python.types import ( + AccountInfo, + GeneralErrorResponse, + RecipeUserId, + User, +) +from supertokens_python.asyncio import get_user, list_users_by_account_info from supertokens_python.asyncio import delete_user load_dotenv() @@ -119,6 +173,14 @@ os.environ.setdefault("SUPERTOKENS_ENV", "testing") code_store: Dict[str, List[Dict[str, Any]]] = {} +accountlinking_config: Dict[str, Any] = {} +enabled_providers: Optional[List[Any]] = None +enabled_recipes: Optional[List[Any]] = None +mfa_info: Dict[str, Any] = {} +contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None +flow_type: Union[ + None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] +] = None class CustomPlessEmailService( @@ -145,9 +207,7 @@ async def send_email( code_store[template_vars.pre_auth_session_id] = codes -class CustomPlessSMSService( - passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars] -): +class CustomSMSService(passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars]): async def send_sms( self, template_vars: passwordless.SMSTemplateVars, user_context: Dict[str, Any] ) -> None: @@ -225,7 +285,7 @@ def get_website_domain(): return "http://localhost:" + get_website_port() -latest_url_with_token = None +latest_url_with_token = "" async def validate_age(value: Any, tenant_id: str): @@ -265,75 +325,48 @@ async def get_user_info( # pylint: disable=no-self-use return oi -def custom_init( - contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None, - flow_type: Union[ - None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] - ] = None, -): +def mock_provider_override(oi: Provider) -> Provider: + async def get_user_info( + oauth_tokens: Dict[str, Any], + user_context: Dict[str, Any], + ) -> UserInfo: + user_id = oauth_tokens.get("userId", "user") + email = oauth_tokens.get("email", "email@test.com") + is_verified = oauth_tokens.get("isVerified", "true").lower() != "false" + + return UserInfo( + user_id, UserInfoEmail(email, is_verified), raw_user_info_from_provider=None + ) + + async def exchange_auth_code_for_oauth_tokens( + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + return redirect_uri_info.redirect_uri_query_params + + oi.exchange_auth_code_for_oauth_tokens = exchange_auth_code_for_oauth_tokens + oi.get_user_info = get_user_info + return oi + + +def custom_init(): + global contact_method + global flow_type + + AccountLinkingRecipe.reset() UserRolesRecipe.reset() PasswordlessRecipe.reset() JWTRecipe.reset() EmailVerificationRecipe.reset() SessionRecipe.reset() ThirdPartyRecipe.reset() - EmailVerificationRecipe.reset() EmailPasswordRecipe.reset() + EmailVerificationRecipe.reset() DashboardRecipe.reset() MultitenancyRecipe.reset() Supertokens.reset() - - providers_list: List[thirdparty.ProviderInput] = [ - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="google", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["GOOGLE_CLIENT_ID"], - client_secret=os.environ["GOOGLE_CLIENT_SECRET"], - ), - ], - ), - ), - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="github", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["GITHUB_CLIENT_ID"], - client_secret=os.environ["GITHUB_CLIENT_SECRET"], - ), - ], - ) - ), - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="facebook", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["FACEBOOK_CLIENT_ID"], - client_secret=os.environ["FACEBOOK_CLIENT_SECRET"], - ), - ], - ) - ), - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="auth0", - name="Auth0", - authorization_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/authorize", - authorization_endpoint_query_params={"scope": "openid profile"}, - token_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/oauth/token", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["AUTH0_CLIENT_ID"], - client_secret=os.environ["AUTH0_CLIENT_SECRET"], - ) - ], - ), - override=auth0_provider_override, - ), - ] + TOTPRecipe.reset() + MultiFactorAuthRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -358,7 +391,11 @@ async def email_verify_post( if is_general_error: return GeneralErrorResponse("general error from API email verify") return await original_email_verify_post( - token, session, tenant_id, api_options, user_context + token, + session, + tenant_id, + api_options, + user_context, ) async def generate_email_verify_token_post( @@ -374,9 +411,7 @@ async def generate_email_verify_token_post( "general error from API email verification code" ) return await original_generate_email_verify_token_post( - session, - api_options, - user_context, + session, api_options, user_context ) original_implementation_email_verification.email_verify_post = email_verify_post @@ -656,12 +691,87 @@ async def resend_code_post( original_implementation.resend_code_post = resend_code_post return original_implementation + providers_list: List[thirdparty.ProviderInput] = [ + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="google", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GOOGLE_CLIENT_ID"], + client_secret=os.environ["GOOGLE_CLIENT_SECRET"], + ), + ], + ), + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="github", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GITHUB_CLIENT_ID"], + client_secret=os.environ["GITHUB_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="facebook", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["FACEBOOK_CLIENT_ID"], + client_secret=os.environ["FACEBOOK_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="auth0", + name="Auth0", + authorization_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/authorize", + authorization_endpoint_query_params={"scope": "openid profile"}, + token_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/oauth/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["AUTH0_CLIENT_ID"], + client_secret=os.environ["AUTH0_CLIENT_SECRET"], + ) + ], + ), + override=auth0_provider_override, + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="mock-provider", + name="Mock Provider", + authorization_endpoint=get_website_domain() + "/mockProvider/auth", + token_endpoint=get_website_domain() + "/mockProvider/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id="supertokens", + client_secret="", + ) + ], + ), + override=mock_provider_override, + ), + ] + + global enabled_providers + if enabled_providers is not None: + providers_list = [ + provider + for provider in providers_list + if provider.config.third_party_id in enabled_providers + ] + if contact_method is not None and flow_type is not None: if contact_method == "PHONE": passwordless_init = passwordless.init( contact_config=ContactPhoneOnlyConfig(), flow_type=flow_type, - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), @@ -684,7 +794,7 @@ async def resend_code_post( email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), @@ -694,7 +804,7 @@ async def resend_code_post( contact_config=ContactEmailOrPhoneConfig(), flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), ) @@ -703,32 +813,243 @@ async def get_allowed_domains_for_tenant_id( ) -> List[str]: return [tenant_id + ".example.com", "localhost"] - recipe_list = [ - multitenancy.init( - get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id - ), - userroles.init(), - session.init(override=session.InputOverrideConfig(apis=override_session_apis)), - emailverification.init( - mode="OPTIONAL", - email_delivery=emailverification.EmailDeliveryConfig( - CustomEVEmailService() + global mfa_info + + from supertokens_python.recipe.multifactorauth.interfaces import ( + RecipeInterface as MFARecipeInterface, + APIInterface as MFAApiInterface, + APIOptions as MFAApiOptions, + ) + + def override_mfa_functions(original_implementation: MFARecipeInterface): + og_get_factors_setup_for_user = ( + original_implementation.get_factors_setup_for_user + ) + + async def get_factors_setup_for_user( + user: User, + user_context: Dict[str, Any], + ): + res = await og_get_factors_setup_for_user(user, user_context) + if "alreadySetup" in mfa_info: + return mfa_info["alreadySetup"] + return res + + og_assert_allowed_to_setup_factor = ( + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ): + if "allowedToSetup" in mfa_info: + if factor_id not in mfa_info["allowedToSetup"]: + raise InvalidClaimsError( + msg="INVALID_CLAIMS", + payload=[ + ClaimValidationError(id_="test", reason="test override") + ], + ) + else: + await og_assert_allowed_to_setup_factor( + session, + factor_id, + mfa_requirements_for_auth, + factors_set_up_for_user, + user_context, + ) + + og_get_mfa_requirements_for_auth = ( + original_implementation.get_mfa_requirements_for_auth + ) + + async def get_mfa_requirements_for_auth( + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[User]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + res = await og_get_mfa_requirements_for_auth( + tenant_id, + access_token_payload, + completed_factors, + user, + factors_set_up_for_user, + required_secondary_factors_for_user, + required_secondary_factors_for_tenant, + user_context, + ) + if "requirements" in mfa_info: + return mfa_info["requirements"] + return res + + original_implementation.get_mfa_requirements_for_auth = ( + get_mfa_requirements_for_auth + ) + + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error = ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + original_implementation.get_factors_setup_for_user = get_factors_setup_for_user + return original_implementation + + def override_mfa_apis(original_implementation: MFAApiInterface): + og_resync_session_and_fetch_mfa_info_put = ( + original_implementation.resync_session_and_fetch_mfa_info_put + ) + + async def resync_session_and_fetch_mfa_info_put( + api_options: MFAApiOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + res = await og_resync_session_and_fetch_mfa_info_put( + api_options, session, user_context + ) + + if isinstance(res, ResyncSessionAndFetchMFAInfoPUTOkResult): + if "alreadySetup" in mfa_info: + res.factors.already_setup = mfa_info["alreadySetup"][:] + + if "noContacts" in mfa_info: + res.emails = {} + res.phone_numbers = {} + + return res + + original_implementation.resync_session_and_fetch_mfa_info_put = ( + resync_session_and_fetch_mfa_info_put + ) + return original_implementation + + recipe_list: List[Any] = [ + {"id": "userroles", "init": userroles.init()}, + { + "id": "session", + "init": session.init( + override=session.InputOverrideConfig(apis=override_session_apis) ), - override=EVInputOverrideConfig(apis=override_email_verification_apis), - ), - emailpassword.init( - sign_up_feature=emailpassword.InputSignUpFeature(form_fields), - email_delivery=emailpassword.EmailDeliveryConfig(CustomEPEmailService()), - override=emailpassword.InputOverrideConfig( - apis=override_email_password_apis, + }, + { + "id": "emailverification", + "init": emailverification.init( + mode="OPTIONAL", + email_delivery=emailverification.EmailDeliveryConfig( + CustomEVEmailService() + ), + override=EVInputOverrideConfig(apis=override_email_verification_apis), ), - ), - thirdparty.init( - sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), - ), - passwordless_init, + }, + { + "id": "emailpassword", + "init": emailpassword.init( + sign_up_feature=emailpassword.InputSignUpFeature(form_fields), + email_delivery=emailpassword.EmailDeliveryConfig( + CustomEPEmailService() + ), + override=emailpassword.InputOverrideConfig( + apis=override_email_password_apis, + ), + ), + }, + { + "id": "thirdparty", + "init": thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), + override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + ), + }, + { + "id": "passwordless", + "init": passwordless_init, + }, + { + "id": "multitenancy", + "init": multitenancy.init( + get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id + ), + }, + { + "id": "multifactorauth", + "init": multifactorauth.init( + first_factors=mfa_info.get("firstFactors", None), + override=multifactorauth.OverrideConfig( + functions=override_mfa_functions, + apis=override_mfa_apis, + ), + ), + }, + { + "id": "totp", + "init": totp.init( + config=totp.TOTPConfig( + default_period=1, + default_skew=30, + ) + ), + }, ] + + global accountlinking_config + + accountlinking_config_input = { + "enabled": False, + "shouldAutoLink": { + "shouldAutomaticallyLink": True, + "shouldRequireVerification": True, + }, + **accountlinking_config, + } + + async def should_do_automatic_account_linking( + _: AccountInfoWithRecipeIdAndUserId, + __: Optional[User], + ___: Optional[SessionContainer], + ____: str, + _____: Dict[str, Any], + ) -> Union[ + accountlinking.ShouldNotAutomaticallyLink, + accountlinking.ShouldAutomaticallyLink, + ]: + should_auto_link = accountlinking_config_input["shouldAutoLink"] + assert isinstance(should_auto_link, dict) + should_automatically_link = should_auto_link["shouldAutomaticallyLink"] + assert isinstance(should_automatically_link, bool) + if should_automatically_link: + should_require_verification = should_auto_link["shouldRequireVerification"] + assert isinstance(should_require_verification, bool) + return accountlinking.ShouldAutomaticallyLink( + should_require_verification=should_require_verification + ) + return accountlinking.ShouldNotAutomaticallyLink() + + if accountlinking_config_input["enabled"]: + recipe_list.append( + { + "id": "accountlinking", + "init": accountlinking.init( + should_do_automatic_account_linking=should_do_automatic_account_linking + ), + } + ) + + global enabled_recipes + if enabled_recipes is not None: + recipe_list = [ + item["init"] for item in recipe_list if item["id"] in enabled_recipes + ] + else: + recipe_list = [item["init"] for item in recipe_list] + init( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( @@ -757,20 +1078,272 @@ async def exception_handler(a, b): # type: ignore @app.post("/beforeeach") def before_each(): global code_store + global accountlinking_config + global enabled_providers + global enabled_recipes + global mfa_info + global latest_url_with_token + global contact_method + global flow_type + contact_method = "EMAIL_OR_PHONE" + flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" + latest_url_with_token = "" code_store = dict() + accountlinking_config = {} + enabled_providers = None + enabled_recipes = None + mfa_info = {} custom_init() return PlainTextResponse("") +@app.post("/changeEmail") +async def change_email(request: Request): + body: Union[dict[str, Any], None] = await request.json() + if body is None: + raise Exception("Should never come here") + + if body["rid"] == "emailpassword": + resp = await update_email_or_password( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body["email"], + tenant_id_for_password_policy=body["tenantId"], + ) + if isinstance(resp, UpdateEmailOrPasswordOkResult): + return JSONResponse({"status": "OK"}) + if isinstance(resp, EmailAlreadyExistsError): + return JSONResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UnknownUserIdError): + return JSONResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateEmailOrPasswordEmailChangeNotAllowedError): + return JSONResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return JSONResponse(resp.to_json()) + elif body["rid"] == "thirdparty": + user = await get_user(user_id=body["recipeUserId"]) + assert user is not None + login_method = next( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() == body["recipeUserId"] + ) + assert login_method is not None + assert login_method.third_party is not None + resp = await manually_create_or_update_user( + tenant_id=body["tenantId"], + third_party_id=login_method.third_party.id, + third_party_user_id=login_method.third_party.user_id, + email=body["email"], + is_verified=False, + ) + if isinstance(resp, ManuallyCreateOrUpdateUserOkResult): + return JSONResponse( + {"status": "OK", "createdNewRecipeUser": resp.created_new_recipe_user} + ) + if isinstance(resp, LinkingToSessionUserFailedError): + raise Exception("Should not come here") + if isinstance(resp, SignInUpNotAllowed): + return JSONResponse( + {"status": "SIGN_IN_UP_NOT_ALLOWED", "reason": resp.reason} + ) + return JSONResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + elif body["rid"] == "passwordless": + resp = await update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + ) + + if isinstance(resp, UpdateUserOkResult): + return JSONResponse({"status": "OK"}) + if isinstance(resp, UpdateUserUnknownUserIdError): + return JSONResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateUserEmailAlreadyExistsError): + return JSONResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UpdateUserPhoneNumberAlreadyExistsError): + return JSONResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, EmailChangeNotAllowedError): + return JSONResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + if isinstance(resp, PhoneNumberChangeNotAllowedError): + return JSONResponse( + { + "status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR", + "reason": resp.reason, + } + ) + + raise Exception("Should not come here") + + +@app.post("/setupTenant") +async def setup_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + login_methods = body["loginMethods"] + core_config = body.get("coreConfig", {}) + + first_factors: List[str] = [] + if login_methods.get("emailPassword", {}).get("enabled") == True: + first_factors.append("emailpassword") + if login_methods.get("thirdParty", {}).get("enabled") == True: + first_factors.append("thirdparty") + if login_methods.get("passwordless", {}).get("enabled") == True: + first_factors.extend(["otp-phone", "otp-email", "link-phone", "link-email"]) + + core_resp = await create_or_update_tenant( + tenant_id, + config=TenantConfigCreateOrUpdate( + first_factors=first_factors, + core_config=core_config, + ), + ) + + if login_methods.get("thirdParty", {}).get("providers") is not None: + for provider in login_methods["thirdParty"]["providers"]: + await create_or_update_third_party_config( + tenant_id, + config=ProviderConfig.from_json(provider), + ) + + return JSONResponse({"status": "OK", "createdNew": core_resp.created_new}) + + +@app.post("/addUserToTenant") +async def add_user_to_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await associate_user_to_tenant(tenant_id, RecipeUserId(recipe_user_id)) + + if isinstance(core_resp, AssociateUserToTenantOkResult): + return JSONResponse( + {"status": "OK", "wasAlreadyAssociated": core_resp.was_already_associated} + ) + elif isinstance(core_resp, AssociateUserToTenantUnknownUserIdError): + return JSONResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantEmailAlreadyExistsError): + return JSONResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantPhoneNumberAlreadyExistsError): + return JSONResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantThirdPartyUserAlreadyExistsError): + return JSONResponse({"status": "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR"}) + return JSONResponse( + {"status": "ASSOCIATION_NOT_ALLOWED_ERROR", "reason": core_resp.reason} + ) + + +@app.post("/removeUserFromTenant") +async def remove_user_from_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await disassociate_user_from_tenant( + tenant_id, RecipeUserId(recipe_user_id) + ) + + return JSONResponse({"status": "OK", "wasAssociated": core_resp.was_associated}) + + +@app.post("/removeTenant") +async def remove_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + + core_resp = await delete_tenant(tenant_id) + + return JSONResponse({"status": "OK", "didExist": core_resp.did_exist}) + + @app.post("/test/setFlow") async def test_set_flow(request: Request): body = await request.json() + global contact_method + global flow_type contact_method = body["contactMethod"] flow_type = body["flowType"] - custom_init(contact_method=contact_method, flow_type=flow_type) + custom_init() return PlainTextResponse("") +@app.post("/test/setAccountLinkingConfig") +async def test_set_account_linking_config(request: Request): + global accountlinking_config + body = await request.json() + if body is None: + raise Exception("Invalid request body") + accountlinking_config = body + custom_init() + return PlainTextResponse("", status_code=200) + + +@app.post("/setMFAInfo") +async def set_mfa_info(request: Request): + global mfa_info + body = await request.json() + if body is None: + return JSONResponse({"error": "Invalid request body"}, status_code=400) + mfa_info = body + return JSONResponse({"status": "OK"}) + + +@app.post("/addRequiredFactor") +async def add_required_factor( + request: Request, session: SessionContainer = Depends(verify_session()) +): + body = await request.json() + if body is None or "factorId" not in body: + return JSONResponse({"error": "Invalid request body"}, status_code=400) + + await add_to_required_secondary_factors_for_user( + session.get_user_id(), body["factorId"] + ) + + return JSONResponse({"status": "OK"}) + + +@app.post("/test/setEnabledRecipes") +async def test_set_enabled_recipes(request: Request): + global enabled_recipes + global enabled_providers + body = await request.json() + if body is None: + raise Exception("Invalid request body") + enabled_recipes = body.get("enabledRecipes") + enabled_providers = body.get("enabledProviders") + custom_init() + return PlainTextResponse("", status_code=200) + + +@app.post("/test/getTOTPCode") +async def test_get_totp_code(request: Request): + from pyotp import TOTP + + body = await request.json() + if body is None or "secret" not in body: + return JSONResponse({"error": "Invalid request body"}, status_code=400) + + secret = body["secret"] + totp = TOTP(secret, digits=6, interval=1) + code = totp.now() + + return JSONResponse({"totp": code}) + + @app.get("/test/getDevice") def test_get_device(request: Request): global code_store @@ -789,6 +1362,11 @@ def test_feature_flags(request: Request): "generalerror", "userroles", "multitenancy", + "multitenancyManagementEndpoints", + "accountlinking", + "mfa", + "recipeConfig", + "accountlinking-fixes", ] return JSONResponse({"available": available}) @@ -922,4 +1500,4 @@ def preflight_response(self, request_headers: Headers) -> Response: ) if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=get_api_port()) # type: ignore + uvicorn.run(app, host="0.0.0.0", port=int(get_api_port())) # type: ignore