diff --git a/.vscode/launch.json b/.vscode/launch.json index 58eed6c4..32acced2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -42,6 +42,21 @@ ], "cwd": "${workspaceFolder}/tests/auth-react/fastapi-server", "jinja": true + }, + { + "name": "Python: Django, supertokens-auth-react tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/auth-react/django3x/manage.py", + "args": [ + "runserver", + "0.0.0.0:8083" + ], + "env": { + "PYTHONPATH": "${workspaceFolder}" + }, + "cwd": "${workspaceFolder}/tests/auth-react/django3x", + "jinja": true } ] } \ No newline at end of file diff --git a/tests/auth-react/django3x/manage.py b/tests/auth-react/django3x/manage.py index be146f80..ef5aa498 100644 --- a/tests/auth-react/django3x/manage.py +++ b/tests/auth-react/django3x/manage.py @@ -5,6 +5,7 @@ def main(): + os.environ.setdefault("SUPERTOKENS_ENV", "testing") os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mysite.settings") try: from django.core.management import execute_from_command_line diff --git a/tests/auth-react/django3x/mysite/settings.py b/tests/auth-react/django3x/mysite/settings.py index a8f7d368..6ad791b2 100644 --- a/tests/auth-react/django3x/mysite/settings.py +++ b/tests/auth-react/django3x/mysite/settings.py @@ -30,7 +30,7 @@ # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True -custom_init(None, None) +custom_init() ALLOWED_HOSTS = ["localhost"] diff --git a/tests/auth-react/django3x/mysite/store.py b/tests/auth-react/django3x/mysite/store.py index 37f0dd2e..07da199b 100644 --- a/tests/auth-react/django3x/mysite/store.py +++ b/tests/auth-react/django3x/mysite/store.py @@ -1,18 +1,26 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Literal, Optional, Union -_LATEST_URL_WITH_TOKEN = None +latest_url_with_token = "" def save_url_with_token(url_with_token: str): - global _LATEST_URL_WITH_TOKEN - _LATEST_URL_WITH_TOKEN = url_with_token # type: ignore + global latest_url_with_token + latest_url_with_token = url_with_token def get_url_with_token() -> str: - return _LATEST_URL_WITH_TOKEN # type: ignore + return latest_url_with_token -_CODE_STORE: Dict[str, List[Dict[str, Any]]] = {} +code_store: Dict[str, List[Dict[str, Any]]] = {} +accountlinking_config: Dict[str, Any] = {} +enabled_providers: Optional[List[Any]] = None +enabled_recipes: Optional[List[Any]] = None +mfa_info: Dict[str, Any] = {} +contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None +flow_type: Union[ + None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] +] = None def save_code( @@ -20,8 +28,8 @@ def save_code( url_with_link_code: Union[str, None], user_input_code: Union[str, None], ): - global _CODE_STORE - codes = _CODE_STORE.get(pre_auth_session_id, []) + global code_store + codes = code_store.get(pre_auth_session_id, []) # replace sub string in url_with_link_code if url_with_link_code: url_with_link_code = url_with_link_code.replace( @@ -30,8 +38,8 @@ def save_code( codes.append( {"urlWithLinkCode": url_with_link_code, "userInputCode": user_input_code} ) - _CODE_STORE[pre_auth_session_id] = codes + code_store[pre_auth_session_id] = codes def get_codes(pre_auth_session_id: str) -> List[Dict[str, Any]]: - return _CODE_STORE.get(pre_auth_session_id, []) + return code_store.get(pre_auth_session_id, []) diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index 27f98817..c3d03d50 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from dotenv import load_dotenv from typing_extensions import Literal @@ -9,11 +9,14 @@ from supertokens_python.recipe import ( emailpassword, emailverification, + multifactorauth, passwordless, session, thirdparty, + totp, userroles, ) +from supertokens_python.recipe.accountlinking import AccountInfoWithRecipeIdAndUserId from supertokens_python.recipe.dashboard import DashboardRecipe from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.emailpassword.interfaces import ( @@ -51,6 +54,10 @@ from supertokens_python.recipe.passwordless.interfaces import APIOptions as PAPIOptions from supertokens_python.recipe.session import SessionContainer, SessionRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session.exceptions import ( + ClaimValidationError, + InvalidClaimsError, +) from supertokens_python.recipe.session.interfaces import ( APIInterface as SessionAPIInterface, ) @@ -61,12 +68,21 @@ ) from supertokens_python.recipe.thirdparty.interfaces import APIOptions as TPAPIOptions from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo +from supertokens_python.recipe.totp.recipe import TOTPRecipe from supertokens_python.recipe.userroles import UserRolesRecipe -from supertokens_python.types import GeneralErrorResponse +from supertokens_python.types import GeneralErrorResponse, User from .store import save_code, save_url_with_token from supertokens_python.recipe import multitenancy +from supertokens_python.recipe.multifactorauth.interfaces import ( + ResyncSessionAndFetchMFAInfoPUTOkResult, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import MFARequirementList +from supertokens_python.recipe import accountlinking +from supertokens_python.recipe.accountlinking import AccountInfoWithRecipeIdAndUserId +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe load_dotenv() @@ -112,9 +128,7 @@ async def send_email( ) -class CustomPlessSMSService( - passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars] -): +class CustomSMSService(passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars]): async def send_sms( self, template_vars: passwordless.SMSTemplateVars, user_context: Dict[str, Any] ) -> None: @@ -260,12 +274,34 @@ async def get_user_info( # pylint: disable=no-self-use ] -def custom_init( - contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None, - flow_type: Union[ - None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] - ] = None, -): +def mock_provider_override(oi: Provider) -> Provider: + async def get_user_info( + oauth_tokens: Dict[str, Any], + user_context: Dict[str, Any], + ) -> UserInfo: + user_id = oauth_tokens.get("userId", "user") + email = oauth_tokens.get("email", "email@test.com") + is_verified = oauth_tokens.get("isVerified", "true").lower() != "false" + + return UserInfo( + user_id, UserInfoEmail(email, is_verified), raw_user_info_from_provider=None + ) + + async def exchange_auth_code_for_oauth_tokens( + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + return redirect_uri_info.redirect_uri_query_params + + oi.exchange_auth_code_for_oauth_tokens = exchange_auth_code_for_oauth_tokens + oi.get_user_info = get_user_info + return oi + + +def custom_init(): + import mysite.store + + AccountLinkingRecipe.reset() UserRolesRecipe.reset() PasswordlessRecipe.reset() JWTRecipe.reset() @@ -277,6 +313,8 @@ def custom_init( DashboardRecipe.reset() MultitenancyRecipe.reset() Supertokens.reset() + TOTPRecipe.reset() + MultiFactorAuthRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -601,20 +639,94 @@ async def resend_code_post( original_implementation.resend_code_post = resend_code_post return original_implementation - if contact_method is not None and flow_type is not None: - if contact_method == "PHONE": + providers_list: List[thirdparty.ProviderInput] = [ + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="google", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GOOGLE_CLIENT_ID"], + client_secret=os.environ["GOOGLE_CLIENT_SECRET"], + ), + ], + ), + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="github", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GITHUB_CLIENT_ID"], + client_secret=os.environ["GITHUB_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="facebook", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["FACEBOOK_CLIENT_ID"], + client_secret=os.environ["FACEBOOK_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="auth0", + name="Auth0", + authorization_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/authorize", + authorization_endpoint_query_params={"scope": "openid profile"}, + token_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/oauth/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["AUTH0_CLIENT_ID"], + client_secret=os.environ["AUTH0_CLIENT_SECRET"], + ) + ], + ), + override=auth0_provider_override, + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="mock-provider", + name="Mock Provider", + authorization_endpoint=get_website_domain() + "/mockProvider/auth", + token_endpoint=get_website_domain() + "/mockProvider/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id="supertokens", + client_secret="", + ) + ], + ), + override=mock_provider_override, + ), + ] + + if mysite.store.enabled_providers is not None: + providers_list = [ + provider + for provider in providers_list + if provider.config.third_party_id in mysite.store.enabled_providers + ] + + if mysite.store.contact_method is not None and mysite.store.flow_type is not None: + if mysite.store.contact_method == "PHONE": passwordless_init = passwordless.init( contact_config=ContactPhoneOnlyConfig(), - flow_type=flow_type, - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + flow_type=mysite.store.flow_type, + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), ) - elif contact_method == "EMAIL": + elif mysite.store.contact_method == "EMAIL": passwordless_init = passwordless.init( contact_config=ContactEmailOnlyConfig(), - flow_type=flow_type, + flow_type=mysite.store.flow_type, email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), @@ -625,11 +737,11 @@ async def resend_code_post( else: passwordless_init = passwordless.init( contact_config=ContactEmailOrPhoneConfig(), - flow_type=flow_type, + flow_type=mysite.store.flow_type, email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), @@ -639,7 +751,7 @@ async def resend_code_post( contact_config=ContactEmailOrPhoneConfig(), flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), ) @@ -648,34 +760,240 @@ async def get_allowed_domains_for_tenant_id( ) -> List[str]: return [tenant_id + ".example.com", "localhost"] - recipe_list = [ - multitenancy.init( - get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id - ), - userroles.init(), - session.init(override=session.InputOverrideConfig(apis=override_session_apis)), - emailverification.init( - mode="OPTIONAL", - email_delivery=emailverification.EmailDeliveryConfig( - CustomEVEmailService() + from supertokens_python.recipe.multifactorauth.interfaces import ( + RecipeInterface as MFARecipeInterface, + APIInterface as MFAApiInterface, + APIOptions as MFAApiOptions, + ) + + def override_mfa_functions(original_implementation: MFARecipeInterface): + og_get_factors_setup_for_user = ( + original_implementation.get_factors_setup_for_user + ) + + async def get_factors_setup_for_user( + user: User, + user_context: Dict[str, Any], + ): + res = await og_get_factors_setup_for_user(user, user_context) + if "alreadySetup" in mysite.store.mfa_info: + return mysite.store.mfa_info["alreadySetup"] + return res + + og_assert_allowed_to_setup_factor = ( + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ): + if "allowedToSetup" in mysite.store.mfa_info: + if factor_id not in mysite.store.mfa_info["allowedToSetup"]: + raise InvalidClaimsError( + msg="INVALID_CLAIMS", + payload=[ + ClaimValidationError(id_="test", reason="test override") + ], + ) + else: + await og_assert_allowed_to_setup_factor( + session, + factor_id, + mfa_requirements_for_auth, + factors_set_up_for_user, + user_context, + ) + + og_get_mfa_requirements_for_auth = ( + original_implementation.get_mfa_requirements_for_auth + ) + + async def get_mfa_requirements_for_auth( + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[User]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + res = await og_get_mfa_requirements_for_auth( + tenant_id, + access_token_payload, + completed_factors, + user, + factors_set_up_for_user, + required_secondary_factors_for_user, + required_secondary_factors_for_tenant, + user_context, + ) + if "requirements" in mysite.store.mfa_info: + return mysite.store.mfa_info["requirements"] + return res + + original_implementation.get_mfa_requirements_for_auth = ( + get_mfa_requirements_for_auth + ) + + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error = ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + original_implementation.get_factors_setup_for_user = get_factors_setup_for_user + return original_implementation + + def override_mfa_apis(original_implementation: MFAApiInterface): + og_resync_session_and_fetch_mfa_info_put = ( + original_implementation.resync_session_and_fetch_mfa_info_put + ) + + async def resync_session_and_fetch_mfa_info_put( + api_options: MFAApiOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + res = await og_resync_session_and_fetch_mfa_info_put( + api_options, session, user_context + ) + + if isinstance(res, ResyncSessionAndFetchMFAInfoPUTOkResult): + if "alreadySetup" in mysite.store.mfa_info: + res.factors.already_setup = mysite.store.mfa_info["alreadySetup"][:] + + if "noContacts" in mysite.store.mfa_info: + res.emails = {} + res.phone_numbers = {} + + return res + + original_implementation.resync_session_and_fetch_mfa_info_put = ( + resync_session_and_fetch_mfa_info_put + ) + return original_implementation + + recipe_list: List[Any] = [ + {"id": "userroles", "init": userroles.init()}, + { + "id": "session", + "init": session.init( + override=session.InputOverrideConfig(apis=override_session_apis) ), - override=EVInputOverrideConfig(apis=override_email_verification_apis), - ), - emailpassword.init( - sign_up_feature=emailpassword.InputSignUpFeature(form_fields), - email_delivery=emailverification.EmailDeliveryConfig( - CustomEPEmailService() + }, + { + "id": "emailverification", + "init": emailverification.init( + mode="OPTIONAL", + email_delivery=emailverification.EmailDeliveryConfig( + CustomEVEmailService() + ), + override=EVInputOverrideConfig(apis=override_email_verification_apis), ), - override=emailpassword.InputOverrideConfig( - apis=override_email_password_apis, + }, + { + "id": "emailpassword", + "init": emailpassword.init( + sign_up_feature=emailpassword.InputSignUpFeature(form_fields), + email_delivery=emailpassword.EmailDeliveryConfig( + CustomEPEmailService() + ), + override=emailpassword.InputOverrideConfig( + apis=override_email_password_apis, + ), ), - ), - thirdparty.init( - sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), - ), - passwordless_init, + }, + { + "id": "thirdparty", + "init": thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), + override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + ), + }, + { + "id": "passwordless", + "init": passwordless_init, + }, + { + "id": "multitenancy", + "init": multitenancy.init( + get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id + ), + }, + { + "id": "multifactorauth", + "init": multifactorauth.init( + first_factors=mysite.store.mfa_info.get("firstFactors", None), + override=multifactorauth.OverrideConfig( + functions=override_mfa_functions, + apis=override_mfa_apis, + ), + ), + }, + { + "id": "totp", + "init": totp.init( + config=totp.TOTPConfig( + default_period=1, + default_skew=30, + ) + ), + }, ] + + accountlinking_config_input = { + "enabled": False, + "shouldAutoLink": { + "shouldAutomaticallyLink": True, + "shouldRequireVerification": True, + }, + **mysite.store.accountlinking_config, + } + + async def should_do_automatic_account_linking( + _: AccountInfoWithRecipeIdAndUserId, + __: Optional[User], + ___: Optional[SessionContainer], + ____: str, + _____: Dict[str, Any], + ) -> Union[ + accountlinking.ShouldNotAutomaticallyLink, + accountlinking.ShouldAutomaticallyLink, + ]: + should_auto_link = accountlinking_config_input["shouldAutoLink"] + assert isinstance(should_auto_link, dict) + should_automatically_link = should_auto_link["shouldAutomaticallyLink"] + assert isinstance(should_automatically_link, bool) + if should_automatically_link: + should_require_verification = should_auto_link["shouldRequireVerification"] + assert isinstance(should_require_verification, bool) + return accountlinking.ShouldAutomaticallyLink( + should_require_verification=should_require_verification + ) + return accountlinking.ShouldNotAutomaticallyLink() + + if accountlinking_config_input["enabled"]: + recipe_list.append( + { + "id": "accountlinking", + "init": accountlinking.init( + should_do_automatic_account_linking=should_do_automatic_account_linking + ), + } + ) + + if mysite.store.enabled_recipes is not None: + recipe_list = [ + item["init"] + for item in recipe_list + if item["id"] in mysite.store.enabled_recipes + ] + else: + recipe_list = [item["init"] for item in recipe_list] + init( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( diff --git a/tests/auth-react/django3x/polls/urls.py b/tests/auth-react/django3x/polls/urls.py index b9615191..8b3cf999 100644 --- a/tests/auth-react/django3x/polls/urls.py +++ b/tests/auth-react/django3x/polls/urls.py @@ -8,23 +8,49 @@ path("ping", views.ping, name="ping"), path("sessionInfo", views.session_info, name="sessionInfo"), path("token", views.token, name="token"), - path("test/setFlow", views.test_set_flow, name="setFlow"), - path("test/getDevice", views.test_get_device, name="getDevice"), - path("test/featureFlags", views.test_feature_flags, name="featureFlags"), - path("beforeeach", views.before_each, name="beforeeach"), + path("changeEmail", views.change_email, name="changeEmail"), # type: ignore + path("setupTenant", views.setup_tenant, name="setupTenant"), # type: ignore + path("removeTenant", views.remove_tenant, name="removeTenant"), # type: ignore + path( + "removeUserFromTenant", + views.remove_user_from_tenant, # type: ignore + name="removeUserFromTenant", + ), # type: ignore + path("addUserToTenant", views.add_user_to_tenant, name="addUserToTenant"), # type: ignore + path("test/setFlow", views.test_set_flow, name="setFlow"), # type: ignore + path( + "test/setAccountLinkingConfig", + views.test_set_account_linking_config, # type: ignore + name="setAccountLinkingConfig", + ), # type: ignore + path("setMFAInfo", views.set_mfa_info, name="setMfaInfo"), # type: ignore + path( + "addRequiredFactor", + views.add_required_factor, # type: ignore + name="addRequiredFactor", + ), # type: ignore + path( + "test/setEnabledRecipes", + views.test_set_enabled_recipes, # type: ignore + name="setEnabledRecipes", + ), + path("test/getTOTPCode", views.test_get_totp_code, name="getTotpCode"), # type: ignore + path("test/getDevice", views.test_get_device, name="getDevice"), # type: ignore + path("test/featureFlags", views.test_feature_flags, name="featureFlags"), # type: ignore + path("beforeeach", views.before_each, name="beforeeach"), # type: ignore ] mode = os.environ.get("APP_MODE", "asgi") if mode == "asgi": - urlpatterns += [ + urlpatterns += [ # type: ignore path("unverifyEmail", views.unverify_email_api, name="unverifyEmail"), # type: ignore path("setRole", views.set_role_api, name="setRole"), # type: ignore path("checkRole", views.check_role_api, name="checkRole"), # type: ignore path("deleteUser", views.delete_user, name="deleteUser"), # type: ignore ] else: - urlpatterns += [ + urlpatterns += [ # type: ignore path("unverifyEmail", views.sync_unverify_email_api, name="unverifyEmail"), path("setRole", views.sync_set_role_api, name="setRole"), path("checkRole", views.sync_check_role_api, name="checkRole"), diff --git a/tests/auth-react/django3x/polls/views.py b/tests/auth-react/django3x/polls/views.py index 550c43ce..c1da33fc 100644 --- a/tests/auth-react/django3x/polls/views.py +++ b/tests/auth-react/django3x/polls/views.py @@ -15,16 +15,57 @@ import os from typing import List, Dict, Any -from django.conf import settings from django.http import HttpRequest, HttpResponse, JsonResponse from mysite.store import get_codes, get_url_with_token from mysite.utils import custom_init +from supertokens_python import convert_to_recipe_user_id +from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe.emailpassword.asyncio import update_email_or_password from supertokens_python.recipe.emailverification import EmailVerificationClaim +from supertokens_python.recipe.multifactorauth.asyncio import ( + add_to_required_secondary_factors_for_user, +) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.interfaces import SessionClaimValidator +from supertokens_python.recipe.thirdparty import ProviderConfig +from supertokens_python.recipe.thirdparty.asyncio import manually_create_or_update_user +from supertokens_python.recipe.thirdparty.interfaces import ( + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, +) from supertokens_python.recipe.userroles import UserRoleClaim, PermissionClaim -from supertokens_python.types import AccountInfo +from supertokens_python.types import AccountInfo, RecipeUserId +from supertokens_python.recipe.multitenancy.asyncio import ( + associate_user_to_tenant, + create_or_update_tenant, + create_or_update_third_party_config, + delete_tenant, + disassociate_user_from_tenant, +) +from supertokens_python.recipe.multitenancy.interfaces import ( + AssociateUserToTenantEmailAlreadyExistsError, + AssociateUserToTenantOkResult, + AssociateUserToTenantPhoneNumberAlreadyExistsError, + AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantUnknownUserIdError, + TenantConfigCreateOrUpdate, +) +from supertokens_python.recipe.passwordless.asyncio import update_user +from supertokens_python.recipe.passwordless.interfaces import ( + EmailChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, +) +from supertokens_python.recipe.emailpassword.interfaces import ( + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + UpdateEmailOrPasswordOkResult, +) mode = os.environ.get("APP_MODE", "asgi") @@ -179,16 +220,247 @@ def test_get_device(request: HttpRequest): return JsonResponse({"preAuthSessionId": pre_auth_session_id, "codes": codes}) -def test_set_flow(request: HttpRequest): +async def change_email(request: HttpRequest): body = json.loads(request.body) - contact_method = body["contactMethod"] - flow_type = body["flowType"] - custom_init(contact_method=contact_method, flow_type=flow_type) + if body is None: + raise Exception("Should never come here") + + if body["rid"] == "emailpassword": + resp = await update_email_or_password( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body["email"], + tenant_id_for_password_policy=body["tenantId"], + ) + if isinstance(resp, UpdateEmailOrPasswordOkResult): + return JsonResponse({"status": "OK"}) + if isinstance(resp, EmailAlreadyExistsError): + return JsonResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UnknownUserIdError): + return JsonResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateEmailOrPasswordEmailChangeNotAllowedError): + return JsonResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return JsonResponse(resp.to_json()) + elif body["rid"] == "thirdparty": + user = await get_user(user_id=body["recipeUserId"]) + assert user is not None + login_method = next( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() == body["recipeUserId"] + ) + assert login_method is not None + assert login_method.third_party is not None + resp = await manually_create_or_update_user( + tenant_id=body["tenantId"], + third_party_id=login_method.third_party.id, + third_party_user_id=login_method.third_party.user_id, + email=body["email"], + is_verified=False, + ) + if isinstance(resp, ManuallyCreateOrUpdateUserOkResult): + return JsonResponse( + {"status": "OK", "createdNewRecipeUser": resp.created_new_recipe_user} + ) + if isinstance(resp, LinkingToSessionUserFailedError): + raise Exception("Should not come here") + if isinstance(resp, SignInUpNotAllowed): + return JsonResponse( + {"status": "SIGN_IN_UP_NOT_ALLOWED", "reason": resp.reason} + ) + return JsonResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + elif body["rid"] == "passwordless": + resp = await update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + ) + + if isinstance(resp, UpdateUserOkResult): + return JsonResponse({"status": "OK"}) + if isinstance(resp, UpdateUserUnknownUserIdError): + return JsonResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateUserEmailAlreadyExistsError): + return JsonResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UpdateUserPhoneNumberAlreadyExistsError): + return JsonResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, EmailChangeNotAllowedError): + return JsonResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return JsonResponse( + { + "status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR", + "reason": resp.reason, + } + ) + + raise Exception("Should not come here") + + +async def setup_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + login_methods = body["loginMethods"] + core_config = body.get("coreConfig", {}) + + first_factors: List[str] = [] + if login_methods.get("emailPassword", {}).get("enabled") == True: + first_factors.append("emailpassword") + if login_methods.get("thirdParty", {}).get("enabled") == True: + first_factors.append("thirdparty") + if login_methods.get("passwordless", {}).get("enabled") == True: + first_factors.extend(["otp-phone", "otp-email", "link-phone", "link-email"]) + + core_resp = await create_or_update_tenant( + tenant_id, + config=TenantConfigCreateOrUpdate( + first_factors=first_factors, + core_config=core_config, + ), + ) + + if login_methods.get("thirdParty", {}).get("providers") is not None: + for provider in login_methods["thirdParty"]["providers"]: + await create_or_update_third_party_config( + tenant_id, + config=ProviderConfig.from_json(provider), + ) + + return JsonResponse({"status": "OK", "createdNew": core_resp.created_new}) + + +async def add_user_to_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await associate_user_to_tenant(tenant_id, RecipeUserId(recipe_user_id)) + + if isinstance(core_resp, AssociateUserToTenantOkResult): + return JsonResponse( + {"status": "OK", "wasAlreadyAssociated": core_resp.was_already_associated} + ) + elif isinstance(core_resp, AssociateUserToTenantUnknownUserIdError): + return JsonResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantEmailAlreadyExistsError): + return JsonResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantPhoneNumberAlreadyExistsError): + return JsonResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantThirdPartyUserAlreadyExistsError): + return JsonResponse({"status": "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR"}) + return JsonResponse( + {"status": "ASSOCIATION_NOT_ALLOWED_ERROR", "reason": core_resp.reason} + ) + + +async def remove_user_from_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await disassociate_user_from_tenant( + tenant_id, RecipeUserId(recipe_user_id) + ) + + return JsonResponse({"status": "OK", "wasAssociated": core_resp.was_associated}) + + +async def remove_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + + core_resp = await delete_tenant(tenant_id) + + return JsonResponse({"status": "OK", "didExist": core_resp.did_exist}) + + +async def test_set_flow(request: HttpRequest): + body = json.loads(request.body) + import mysite.store + + mysite.store.contact_method = body["contactMethod"] + mysite.store.flow_type = body["flowType"] + custom_init() + return HttpResponse("") + + +async def test_set_account_linking_config(request: HttpRequest): + import mysite.store + + body = json.loads(request.body) + if body is None: + raise Exception("Invalid request body") + mysite.store.accountlinking_config = body + custom_init() + return HttpResponse("") + + +async def set_mfa_info(request: HttpRequest): + import mysite.store + + body = json.loads(request.body) + if body is None: + return JsonResponse({"error": "Invalid request body"}, status_code=400) + mysite.store.mfa_info = body + return JsonResponse({"status": "OK"}) + + +@verify_session() +async def add_required_factor(request: HttpRequest): + session_: SessionContainer = request.supertokens # type: ignore + body = json.loads(request.body) + if body is None or "factorId" not in body: + return JsonResponse({"error": "Invalid request body"}, status_code=400) + + await add_to_required_secondary_factors_for_user( + session_.get_user_id(), body["factorId"] + ) + + return JsonResponse({"status": "OK"}) + + +def test_set_enabled_recipes(request: HttpRequest): + import mysite.store + + body = json.loads(request.body) + if body is None: + raise Exception("Invalid request body") + mysite.store.enabled_recipes = body.get("enabledRecipes") + mysite.store.enabled_providers = body.get("enabledProviders") + custom_init() return HttpResponse("") +def test_get_totp_code(request: HttpRequest): + from pyotp import TOTP + + body = json.loads(request.body) + if body is None or "secret" not in body: + return JsonResponse({"error": "Invalid request body"}, status_code=400) + + secret = body["secret"] + totp = TOTP(secret, digits=6, interval=1) + code = totp.now() + + return JsonResponse({"totp": code}) + + def before_each(request: HttpRequest): - setattr(settings, "CODE_STORE", dict()) + import mysite.store + + mysite.store.code_store = dict() custom_init() return HttpResponse("") @@ -202,6 +474,11 @@ def test_feature_flags(request: HttpRequest): "generalerror", "userroles", "multitenancy", + "multitenancyManagementEndpoints", + "accountlinking", + "mfa", + "recipeConfig", + "accountlinking-fixes", ] } )