diff --git a/.pylintrc b/.pylintrc index 69ed5aab3..221068cf7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -124,6 +124,7 @@ disable=raw-checker-failed, no-else-raise, too-many-nested-blocks, broad-exception-raised, + too-many-public-methods, # Enable the message, report, category or checker with the given id(s). You can diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index 18b7fbeb1..ab0061f9a 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -536,7 +536,7 @@ async def check_auth_type_and_linking_status( if session_user_result.status == "SHOULD_AUTOMATICALLY_LINK_FALSE": if should_try_linking_with_session_user is True: raise BadInputError( - "should_do_automatic_account_linking returned false when creating primary user but shouldTryLinkingWithSessionUser is true" + "shouldDoAutomaticAccountLinking returned false when making the session user primary but shouldTryLinkingWithSessionUser is true" ) return OkFirstFactorResponse() elif ( @@ -565,7 +565,7 @@ async def check_auth_type_and_linking_status( if isinstance(should_link, ShouldNotAutomaticallyLink): if should_try_linking_with_session_user is True: raise BadInputError( - "should_do_automatic_account_linking returned false when creating primary user but shouldTryLinkingWithSessionUser is true" + "shouldDoAutomaticAccountLinking returned false when making the session user primary but shouldTryLinkingWithSessionUser is true" ) return OkFirstFactorResponse() else: diff --git a/supertokens_python/framework/django/django_response.py b/supertokens_python/framework/django/django_response.py index 1029c1a55..f839b9231 100644 --- a/supertokens_python/framework/django/django_response.py +++ b/supertokens_python/framework/django/django_response.py @@ -88,3 +88,9 @@ def set_json_content(self, content: Dict[str, Any]): separators=(",", ":"), ).encode("utf-8") self.response_sent = True + + def redirect(self, url: str) -> BaseResponse: + if not self.response_sent: + self.set_header("Location", url) + self.set_status_code(302) + return self diff --git a/supertokens_python/framework/fastapi/fastapi_response.py b/supertokens_python/framework/fastapi/fastapi_response.py index 3f6078af3..76e6f349c 100644 --- a/supertokens_python/framework/fastapi/fastapi_response.py +++ b/supertokens_python/framework/fastapi/fastapi_response.py @@ -94,3 +94,9 @@ def set_json_content(self, content: Dict[str, Any]): self.set_header("Content-Length", str(len(body))) self.response.body = body self.response_sent = True + + def redirect(self, url: str) -> BaseResponse: + if not self.response_sent: + self.set_header("Location", url) + self.set_status_code(302) + return self diff --git a/supertokens_python/framework/flask/flask_response.py b/supertokens_python/framework/flask/flask_response.py index 647f8d3df..025538ef4 100644 --- a/supertokens_python/framework/flask/flask_response.py +++ b/supertokens_python/framework/flask/flask_response.py @@ -85,3 +85,10 @@ def set_json_content(self, content: Dict[str, Any]): separators=(",", ":"), ).encode("utf-8") self.response_sent = True + + def redirect(self, url: str) -> BaseResponse: + self.set_header("Location", url) + self.set_status_code(302) + self.response.data = b"" + self.response_sent = True + return self diff --git a/supertokens_python/framework/request.py b/supertokens_python/framework/request.py index 6f6533185..f53e3d1f2 100644 --- a/supertokens_python/framework/request.py +++ b/supertokens_python/framework/request.py @@ -47,6 +47,14 @@ async def json(self) -> Union[Any, None]: async def form_data(self) -> Dict[str, Any]: pass + async def get_json_or_form_data(self) -> Union[Dict[str, Any], None]: + content_type = self.get_header("Content-Type") + if content_type is None: + return None + if content_type.startswith("application/json"): + return await self.json() + return await self.form_data() + @abstractmethod def method(self) -> str: pass diff --git a/supertokens_python/framework/response.py b/supertokens_python/framework/response.py index fb5cc3477..c28a24104 100644 --- a/supertokens_python/framework/response.py +++ b/supertokens_python/framework/response.py @@ -61,3 +61,7 @@ def set_json_content(self, content: Dict[str, Any]): @abstractmethod def set_html_content(self, content: str): pass + + @abstractmethod + def redirect(self, url: str) -> "BaseResponse": + pass diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index 69945493d..5b5b4cf38 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -402,28 +402,33 @@ async def send_put_request( self, path: NormalisedURLPath, data: Union[Dict[str, Any], None], + query_params: Union[Dict[str, Any], None], user_context: Union[Dict[str, Any], None], ) -> Dict[str, Any]: self.invalidate_core_call_cache(user_context) if data is None: data = {} + if query_params is None: + query_params = {} headers = await self.__get_headers_with_api_version(path, user_context) headers["content-type"] = "application/json; charset=utf-8" async def f(url: str, method: str) -> Response: - nonlocal headers, data + nonlocal headers, data, query_params if Querier.network_interceptor is not None: ( url, method, headers, - _, + query_params, data, ) = Querier.network_interceptor( # pylint:disable=not-callable - url, method, headers, {}, data, user_context + url, method, headers, query_params, data, user_context ) - return await self.api_request(url, method, 2, headers=headers, json=data) + return await self.api_request( + url, method, 2, headers=headers, json=data, params=query_params + ) return await self.__send_request_helper(path, "PUT", f, len(self.__hosts)) diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index 283569a5f..c17c14b5f 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -297,6 +297,7 @@ async def update_email_or_password( response = await self.querier.send_put_request( NormalisedURLPath("/recipe/user"), data, + None, user_context=user_context, ) diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 66ea38cb5..a48d4ec75 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -147,6 +147,7 @@ async def create_or_update_tenant( response = await self.querier.send_put_request( NormalisedURLPath("/recipe/multitenancy/tenant/v2"), json_body, + None, user_context=user_context, ) return CreateOrUpdateTenantOkResult( @@ -217,6 +218,7 @@ async def create_or_update_third_party_config( "config": config.to_json(), "skipValidation": skip_validation is True, }, + None, user_context=user_context, ) diff --git a/supertokens_python/recipe/oauth2provider/__init__.py b/supertokens_python/recipe/oauth2provider/__init__.py new file mode 100644 index 000000000..3397b3430 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Union + +from . import exceptions as ex +from . import recipe, utils + +exceptions = ex +InputOverrideConfig = utils.InputOverrideConfig + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo + + from ...recipe_module import RecipeModule + + +def init( + override: Union[InputOverrideConfig, None] = None, +) -> Callable[[AppInfo], RecipeModule]: + return recipe.OAuth2ProviderRecipe.init(override) diff --git a/supertokens_python/recipe/oauth2provider/api/__init__.py b/supertokens_python/recipe/oauth2provider/api/__init__.py new file mode 100644 index 000000000..fdef2a2bc --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from .auth import auth_get # type: ignore +from .end_session import end_session_get, end_session_post # type: ignore +from .introspect_token import introspect_token_post # type: ignore +from .login_info import login_info_get # type: ignore +from .login import login # type: ignore +from .logout import logout_post # type: ignore +from .revoke_token import revoke_token_post # type: ignore +from .token import token_post # type: ignore +from .user_info import user_info_get # type: ignore diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py new file mode 100644 index 000000000..57850ec0f --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from http.cookies import SimpleCookie +from typing import TYPE_CHECKING, Any, Dict +from urllib.parse import parse_qsl +from dateutil import parser + +from supertokens_python.recipe.session.asyncio import get_session +from supertokens_python.recipe.session.exceptions import TryRefreshTokenError +from supertokens_python.utils import send_200_response, send_non_200_response + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + + +async def auth_get( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + from ..interfaces import ( + RedirectResponse, + ErrorOAuth2Response, + ) + + if api_implementation.disable_auth_get is True: + return None + + original_url = api_options.request.get_original_url() + split_url = original_url.split("?", 1) + params = dict(parse_qsl(split_url[1], True)) if len(split_url) > 1 else {} + + session = None + should_try_refresh = False + try: + session = await get_session( + api_options.request, + session_required=False, + user_context=user_context, + ) + should_try_refresh = False + except Exception as error: + session = None + + # should_try_refresh = False should generally not happen, but we can handle this as if the session is not present, + # because then we redirect to the frontend, which should handle the validation error + should_try_refresh = isinstance(error, TryRefreshTokenError) + + response = await api_implementation.auth_get( + params=params, + cookie=api_options.request.get_header("cookie"), + session=session, + should_try_refresh=should_try_refresh, + options=api_options, + user_context=user_context, + ) + + if isinstance(response, RedirectResponse): + if response.cookies: + for cookie_string in response.cookies: + cookie = SimpleCookie() + cookie.load(cookie_string) + for morsel in cookie.values(): + api_options.response.set_cookie( + key=morsel.key, + value=morsel.value, + domain=morsel.get("domain"), + secure=morsel.get("secure", True), + httponly=morsel.get("httponly", True), + expires=parser.parse(morsel.get("expires", "")).timestamp() * 1000, # type: ignore + path=morsel.get("path", "/"), + samesite=morsel.get("samesite", "lax"), + ) + return api_options.response.redirect(response.redirect_to) + elif isinstance(response, ErrorOAuth2Response): + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code or 400, + api_options.response, + ) + else: + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/end_session.py b/supertokens_python/recipe/oauth2provider/api/end_session.py new file mode 100644 index 000000000..13fedab34 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/end_session.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union +import urllib.parse + +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.framework import BaseResponse +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.asyncio import get_session +from supertokens_python.recipe.session.exceptions import TryRefreshTokenError +from supertokens_python.types import GeneralErrorResponse +from supertokens_python.utils import send_200_response, send_non_200_response + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + RedirectResponse, + ErrorOAuth2Response, + ) + + EndSessionCallable = Callable[ + [Dict[str, str], APIOptions, Optional[SessionContainer], bool, Dict[str, Any]], + Awaitable[Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]], + ] + + +async def end_session_get( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_end_session_get is True: + return None + + orig_url = api_options.request.get_original_url() + split_url = orig_url.split("?", 1) + params = ( + dict(urllib.parse.parse_qsl(split_url[1], True)) if len(split_url) > 1 else {} + ) + + return await end_session_common( + params, api_implementation.end_session_get, api_options, user_context + ) + + +async def end_session_post( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_end_session_post is True: + return None + + params = await api_options.request.get_json_or_form_data() + if params is None: + raise_bad_input_exception("Please provide a JSON body or form data") + + return await end_session_common( + params, api_implementation.end_session_post, api_options, user_context + ) + + +async def end_session_common( + params: Dict[str, str], + api_implementation: Optional[EndSessionCallable], + options: APIOptions, + user_context: Dict[str, Any], +) -> Optional[BaseResponse]: + from ..interfaces import RedirectResponse, ErrorOAuth2Response + + if api_implementation is None: + return None + + session = None + should_try_refresh = False + try: + session = await get_session( + options.request, + False, + user_context=user_context, + ) + should_try_refresh = False + except Exception as error: + # We can handle this as if the session is not present, because then we redirect to the frontend, + # which should handle the validation error + session = None + should_try_refresh = isinstance(error, TryRefreshTokenError) + + response = await api_implementation( + params, + options, + session, + should_try_refresh, + user_context, + ) + + if isinstance(response, RedirectResponse): + return options.response.redirect(response.redirect_to) + elif isinstance(response, ErrorOAuth2Response): + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code or 400, + options.response, + ) + else: + if isinstance(response, dict): + return send_200_response(response, options.response) + else: + return send_200_response(response.to_json(), options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/implementation.py b/supertokens_python/recipe/oauth2provider/api/implementation.py new file mode 100644 index 000000000..cc06b7961 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/implementation.py @@ -0,0 +1,288 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, List, Optional, Union + +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import GeneralErrorResponse, User + +from .utils import ( + handle_login_internal_redirects, + handle_logout_internal_redirects, + login_get, +) +from ..interfaces import ( + APIInterface, + APIOptions, + ActiveTokenResponse, + ErrorOAuth2Response, + FrontendRedirectResponse, + InactiveTokenResponse, + LoginInfo, + RedirectResponse, + RevokeTokenUsingAuthorizationHeader, + RevokeTokenUsingClientIDAndClientSecret, + TokenInfo, +) + + +class APIImplementation(APIInterface): + async def login_get( + self, + login_challenge: str, + options: APIOptions, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], + ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + response = await login_get( + recipe_implementation=options.recipe_implementation, + login_challenge=login_challenge, + session=session, + should_try_refresh=should_try_refresh, + is_direct_call=True, + cookies=None, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return response + + resp_after_internal_redirects = await handle_login_internal_redirects( + response=response, + cookie=options.request.get_header("cookie") or "", + recipe_implementation=options.recipe_implementation, + session=session, + should_try_refresh=should_try_refresh, + user_context=user_context, + ) + + if isinstance(resp_after_internal_redirects, ErrorOAuth2Response): + return resp_after_internal_redirects + + return FrontendRedirectResponse( + frontend_redirect_to=resp_after_internal_redirects.redirect_to, + cookies=resp_after_internal_redirects.cookies, + ) + + async def auth_get( + self, + params: Any, + cookie: Optional[str], + session: Optional[SessionContainer], + should_try_refresh: bool, + options: APIOptions, + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + response = await options.recipe_implementation.authorization( + params=params, + cookies=cookie, + session=session, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return response + + return await handle_login_internal_redirects( + response=response, + recipe_implementation=options.recipe_implementation, + cookie=cookie or "", + session=session, + should_try_refresh=should_try_refresh, + user_context=user_context, + ) + + async def token_post( + self, + authorization_header: Optional[str], + body: Any, + options: APIOptions, + user_context: Dict[str, Any], + ) -> Union[TokenInfo, ErrorOAuth2Response, GeneralErrorResponse]: + return await options.recipe_implementation.token_exchange( + authorization_header=authorization_header, + body=body, + user_context=user_context, + ) + + async def login_info_get( + self, + login_challenge: str, + options: APIOptions, + user_context: Dict[str, Any], + ) -> Union[LoginInfo, ErrorOAuth2Response, GeneralErrorResponse]: + login_res = await options.recipe_implementation.get_login_request( + challenge=login_challenge, + user_context=user_context, + ) + + if isinstance(login_res, ErrorOAuth2Response): + return login_res + + client = login_res.client + + return LoginInfo( + client_id=client.client_id, + client_name=client.client_name, + tos_uri=client.tos_uri, + policy_uri=client.policy_uri, + logo_uri=client.logo_uri, + client_uri=client.client_uri, + metadata=client.metadata, + ) + + async def user_info_get( + self, + access_token_payload: Dict[str, Any], + user: User, + scopes: List[str], + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], + ) -> Union[Dict[str, Any], GeneralErrorResponse]: + return await options.recipe_implementation.build_user_info( + user=user, + access_token_payload=access_token_payload, + scopes=scopes, + tenant_id=tenant_id, + user_context=user_context, + ) + + async def revoke_token_post( + self, + options: APIOptions, + token: str, + authorization_header: Optional[str], + client_id: Optional[str], + client_secret: Optional[str], + user_context: Dict[str, Any], + ) -> Union[None, ErrorOAuth2Response, GeneralErrorResponse]: + if authorization_header is not None: + return await options.recipe_implementation.revoke_token( + params=RevokeTokenUsingAuthorizationHeader( + token=token, + authorization_header=authorization_header, + ), + user_context=user_context, + ) + elif client_id is not None: + if client_secret is None: + raise Exception("client_secret is required") + + return await options.recipe_implementation.revoke_token( + params=RevokeTokenUsingClientIDAndClientSecret( + token=token, + client_id=client_id, + client_secret=client_secret, + ), + user_context=user_context, + ) + else: + raise Exception( + "Either of 'authorization_header' or 'client_id' must be provided" + ) + + async def introspect_token_post( + self, + token: str, + scopes: Optional[List[str]], + options: APIOptions, + user_context: Dict[str, Any], + ) -> Union[ActiveTokenResponse, InactiveTokenResponse, GeneralErrorResponse]: + return await options.recipe_implementation.introspect_token( + token=token, + scopes=scopes, + user_context=user_context, + ) + + async def end_session_get( + self, + params: Dict[str, str], + options: APIOptions, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + response = await options.recipe_implementation.end_session( + params=params, + session=session, + should_try_refresh=should_try_refresh, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return response + + return await handle_logout_internal_redirects( + response=response, + session=session, + recipe_implementation=options.recipe_implementation, + user_context=user_context, + ) + + async def end_session_post( + self, + params: Dict[str, str], + options: APIOptions, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + response = await options.recipe_implementation.end_session( + params=params, + session=session, + should_try_refresh=should_try_refresh, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return response + + return await handle_logout_internal_redirects( + response=response, + session=session, + recipe_implementation=options.recipe_implementation, + user_context=user_context, + ) + + async def logout_post( + self, + logout_challenge: str, + options: APIOptions, + session: Optional[SessionContainer], + user_context: Dict[str, Any], + ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + if session is not None: + await session.revoke_session(user_context) + + response = await options.recipe_implementation.accept_logout_request( + challenge=logout_challenge, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return response + + res = await handle_logout_internal_redirects( + response=response, + recipe_implementation=options.recipe_implementation, + session=session, + user_context=user_context, + ) + + if isinstance(res, ErrorOAuth2Response): + return res + + return FrontendRedirectResponse(frontend_redirect_to=res.redirect_to) diff --git a/supertokens_python/recipe/oauth2provider/api/introspect_token.py b/supertokens_python/recipe/oauth2provider/api/introspect_token.py new file mode 100644 index 000000000..cb7364498 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/introspect_token.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List + +from supertokens_python.utils import send_200_response, send_non_200_response + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + + +async def introspect_token_post( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_introspect_token_post is True: + return None + + body = await api_options.request.get_json_or_form_data() + if body is None or "token" not in body: + return send_non_200_response( + {"message": "token is required in the request body"}, + 400, + api_options.response, + ) + + scopes: List[str] = body.get("scope", "").split(" ") if "scope" in body else [] + + response = await api_implementation.introspect_token_post( + body["token"], + scopes, + api_options, + user_context, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py new file mode 100644 index 000000000..ce5b7dbd6 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -0,0 +1,114 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from dateutil import parser + +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.framework import BaseResponse +from supertokens_python.recipe.session.asyncio import get_session +from supertokens_python.recipe.session.exceptions import TryRefreshTokenError +from supertokens_python.utils import send_200_response, send_non_200_response +from http.cookies import SimpleCookie + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + + +async def login( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Optional[BaseResponse]: + from ..interfaces import ( + FrontendRedirectResponse, + ErrorOAuth2Response, + ) + + if api_implementation.disable_login_get is True: + return None + + session = None + should_try_refresh = False + try: + session = await get_session( + api_options.request, + False, + user_context=user_context, + ) + should_try_refresh = False + except Exception as error: + # We can handle this as if the session is not present, because then we redirect to the frontend, + # which should handle the validation error + session = None + should_try_refresh = isinstance(error, TryRefreshTokenError) + + login_challenge = api_options.request.get_query_param( + "login_challenge" + ) or api_options.request.get_query_param("loginChallenge") + if login_challenge is None: + raise_bad_input_exception("Missing input param: loginChallenge") + + response = await api_implementation.login_get( + login_challenge=login_challenge, + options=api_options, + session=session, + should_try_refresh=should_try_refresh, + user_context=user_context, + ) + + if isinstance(response, FrontendRedirectResponse): + if response.cookies: + for cookie_string in response.cookies: + cookie = SimpleCookie() + cookie.load(cookie_string) + for morsel in cookie.values(): + api_options.response.set_cookie( + key=morsel.key, + value=morsel.value, + domain=morsel.get("domain"), + secure=morsel.get("secure", True), + httponly=morsel.get("httponly", True), + expires=parser.parse(morsel.get("expires", "")).timestamp() * 1000, # type: ignore + path=morsel.get("path", "/"), + samesite=morsel.get("samesite", "lax").lower(), + ) + + return send_200_response( + {"frontendRedirectTo": response.frontend_redirect_to}, + api_options.response, + ) + + elif isinstance(response, ErrorOAuth2Response): + # We want to avoid returning a 401 to the frontend, as it may trigger a refresh loop + if response.status_code == 401: + response.status_code = 400 + + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code or 400, + api_options.response, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/login_info.py b/supertokens_python/recipe/oauth2provider/api/login_info.py new file mode 100644 index 000000000..532e55f37 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/login_info.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.utils import send_200_response, send_non_200_response + +if TYPE_CHECKING: + from ..interfaces import APIOptions, APIInterface + + +async def login_info_get( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + from ..interfaces import ErrorOAuth2Response + + if api_implementation.disable_login_info_get is True: + return None + + login_challenge = api_options.request.get_query_param( + "login_challenge" + ) or api_options.request.get_query_param("loginChallenge") + + if login_challenge is None: + raise_bad_input_exception("Missing input param: loginChallenge") + + response = await api_implementation.login_info_get( + login_challenge, + api_options, + user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + # We want to avoid returning a 401 to the frontend, as it may trigger a refresh loop + if response.status_code == 401: + response.status_code = 400 + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code or 400, + api_options.response, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/logout.py b/supertokens_python/recipe/oauth2provider/api/logout.py new file mode 100644 index 000000000..8cb4f349e --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/logout.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.framework import BaseResponse +from supertokens_python.recipe.session.asyncio import get_session +from supertokens_python.utils import send_200_response, send_non_200_response + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + + +async def logout_post( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Optional[BaseResponse]: + from ..interfaces import ( + FrontendRedirectResponse, + ErrorOAuth2Response, + ) + + if api_implementation.disable_logout_post is True: + return None + + session = None + try: + session = await get_session( + api_options.request, session_required=False, user_context=user_context + ) + except Exception as _: + pass + + body = await api_options.request.json() + + if body is None or "logoutChallenge" not in body: + raise_bad_input_exception("Missing body param: logoutChallenge") + + response = await api_implementation.logout_post( + logout_challenge=body["logoutChallenge"], + options=api_options, + session=session, + user_context=user_context, + ) + + if isinstance(response, FrontendRedirectResponse): + return send_200_response(response.to_json(), api_options.response) + elif isinstance(response, ErrorOAuth2Response): + # We want to avoid returning a 401 to the frontend, as it may trigger a refresh loop + if response.status_code == 401: + response.status_code = 400 + + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code if response.status_code is not None else 400, + api_options.response, + ) + else: + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/revoke_token.py b/supertokens_python/recipe/oauth2provider/api/revoke_token.py new file mode 100644 index 000000000..331ecfaa1 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/revoke_token.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from supertokens_python.framework import BaseResponse +from supertokens_python.utils import send_200_response, send_non_200_response + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + + +async def revoke_token_post( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Optional[BaseResponse]: + from ..interfaces import ( + ErrorOAuth2Response, + GeneralErrorResponse, + ) + + if api_implementation.disable_revoke_token_post is True: + return None + + body = await api_options.request.get_json_or_form_data() + + if body is None or "token" not in body: + return send_non_200_response( + {"message": "token is required in the request body"}, + 400, + api_options.response, + ) + + authorization_header = api_options.request.get_header("authorization") + + if authorization_header is not None and ( + "client_id" in body or "client_secret" in body + ): + return send_non_200_response( + { + "message": "Only one of authorization header or client_id and client_secret can be provided" + }, + 400, + api_options.response, + ) + + response = await api_implementation.revoke_token_post( + token=body["token"], + options=api_options, + authorization_header=authorization_header, + client_id=body.get("client_id"), + client_secret=body.get("client_secret"), + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code if response.status_code is not None else 400, + api_options.response, + ) + elif isinstance(response, GeneralErrorResponse): + return send_200_response( + response.to_json(), + api_options.response, + ) + + return send_200_response({"status": "OK"}, api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/token.py b/supertokens_python/recipe/oauth2provider/api/token.py new file mode 100644 index 000000000..f88495a74 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/token.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +from supertokens_python.utils import send_200_response, send_non_200_response + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + + +async def token_post( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + from ..interfaces import ( + ErrorOAuth2Response, + ) + + if api_implementation.disable_token_post is True: + return None + + authorization_header = api_options.request.get_header("authorization") + + body = await api_options.request.get_json_or_form_data() + + response = await api_implementation.token_post( + authorization_header=authorization_header, + body=body, + options=api_options, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + # We do not need to normalize as this is not expected to be called by frontends where interception is enabled + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code or 400, + api_options.response, + ) + else: + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/user_info.py b/supertokens_python/recipe/oauth2provider/api/user_info.py new file mode 100644 index 000000000..9fbdeaf1f --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/user_info.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from supertokens_python.asyncio import get_user +from supertokens_python.utils import ( + send_200_response, + send_non_200_response_with_message, +) + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + + +async def user_info_get( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_user_info_get is True: + return None + + authorization_header = api_options.request.get_header("authorization") + + if authorization_header is None or not authorization_header.startswith("Bearer "): + api_options.response.set_header( + "WWW-Authenticate", 'Bearer error="invalid_token"' + ) + api_options.response.set_header( + "Access-Control-Expose-Headers", "WWW-Authenticate" + ) + return send_non_200_response_with_message( + "Missing or invalid Authorization header", + 401, + api_options.response, + ) + + access_token = authorization_header.replace("Bearer ", "").strip() + + payload: Optional[Dict[str, Any]] = None + + try: + payload = await api_options.recipe_implementation.validate_oauth2_access_token( + token=access_token, + requirements=None, + check_database=None, + user_context=user_context, + ) + + except Exception: + api_options.response.set_header( + "WWW-Authenticate", 'Bearer error="invalid_token"' + ) + api_options.response.set_header( + "Access-Control-Expose-Headers", "WWW-Authenticate" + ) + return send_non_200_response_with_message( + "Invalid or expired OAuth2 access token", + 401, + api_options.response, + ) + + if not isinstance(payload.get("sub"), str) or not isinstance( + payload.get("scp"), list + ): + api_options.response.set_header( + "WWW-Authenticate", 'Bearer error="invalid_token"' + ) + api_options.response.set_header( + "Access-Control-Expose-Headers", "WWW-Authenticate" + ) + return send_non_200_response_with_message( + "Malformed access token payload", + 401, + api_options.response, + ) + + user_id = payload["sub"] + + user = await get_user(user_id, user_context) + + if user is None: + api_options.response.set_header( + "WWW-Authenticate", 'Bearer error="invalid_token"' + ) + api_options.response.set_header( + "Access-Control-Expose-Headers", "WWW-Authenticate" + ) + return send_non_200_response_with_message( + "Couldn't find any user associated with the access token", + 401, + api_options.response, + ) + + response = await api_implementation.user_info_get( + access_token_payload=payload, + user=user, + tenant_id=_tenant_id, + scopes=payload["scp"], + options=api_options, + user_context=user_context, + ) + + if isinstance(response, dict): + return send_200_response(response, api_options.response) + else: + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/utils.py b/supertokens_python/recipe/oauth2provider/api/utils.py new file mode 100644 index 000000000..166ae2f3b --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/utils.py @@ -0,0 +1,342 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from urllib.parse import parse_qs, urlparse +import time + +from supertokens_python import Supertokens +from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID +from supertokens_python.recipe.session.asyncio import get_session_information +from ..constants import LOGIN_PATH, AUTH_PATH, END_SESSION_PATH + +if TYPE_CHECKING: + from ..interfaces import ( + RecipeInterface, + ErrorOAuth2Response, + RedirectResponse, + ) + from supertokens_python.recipe.session.interfaces import SessionContainer + + +async def login_get( + recipe_implementation: RecipeInterface, + login_challenge: str, + session: Optional[SessionContainer], + should_try_refresh: bool, + cookies: Optional[List[str]], + is_direct_call: bool, + user_context: Dict[str, Any], +) -> Union[RedirectResponse, ErrorOAuth2Response]: + from ..interfaces import ( + ErrorOAuth2Response, + RedirectResponse, + FrontendRedirectionURLTypeTryRefresh, + FrontendRedirectionURLTypeLogin, + ) + + login_request = await recipe_implementation.get_login_request( + challenge=login_challenge, + user_context=user_context, + ) + + if isinstance(login_request, ErrorOAuth2Response): + return login_request + + session_info = ( + await get_session_information(session.get_handle()) if session else None + ) + if not session_info: + session = None + + incoming_auth_url_query_params = parse_qs(urlparse(login_request.request_url).query) + prompt_param = ( + incoming_auth_url_query_params.get("prompt", [None])[0] + or incoming_auth_url_query_params.get("st_prompt", [None])[0] + ) + max_age_param = incoming_auth_url_query_params.get("max_age", [None])[0] + + if max_age_param is not None: + try: + max_age_parsed = int(max_age_param) + + if max_age_parsed < 0: + reject = await recipe_implementation.reject_login_request( + challenge=login_challenge, + error=ErrorOAuth2Response( + error="invalid_request", + error_description="max_age cannot be negative", + ), + user_context=user_context, + ) + return RedirectResponse( + redirect_to=reject.redirect_to, + cookies=cookies, + ) + + except ValueError: + reject = await recipe_implementation.reject_login_request( + challenge=login_challenge, + error=ErrorOAuth2Response( + error="invalid_request", + error_description="max_age must be an integer", + ), + user_context=user_context, + ) + return RedirectResponse( + redirect_to=reject.redirect_to, + cookies=cookies, + ) + + tenant_id_param = incoming_auth_url_query_params.get("tenant_id", [None])[0] + + if ( + session + and session_info + and ( + not login_request.subject or session.get_user_id() == login_request.subject + ) + and (not tenant_id_param or session.get_tenant_id() == tenant_id_param) + and (prompt_param != "login" or is_direct_call) + and ( + max_age_param is None + or (max_age_param == "0" and is_direct_call) + or int(max_age_param) * 1000 + > time.time() * 1000 - session_info.time_created + ) + ): + accept = await recipe_implementation.accept_login_request( + challenge=login_challenge, + acr=None, + amr=None, + context=None, + extend_session_lifespan=None, + subject=session.get_user_id(), + identity_provider_session_id=session.get_handle(), + user_context=user_context, + ) + return RedirectResponse( + redirect_to=accept.redirect_to, + cookies=cookies, + ) + + if should_try_refresh and prompt_param != "login": + return RedirectResponse( + redirect_to=await recipe_implementation.get_frontend_redirection_url( + params=FrontendRedirectionURLTypeTryRefresh( + login_challenge=login_challenge, + ), + user_context=user_context, + ), + cookies=cookies, + ) + + if prompt_param == "none": + reject = await recipe_implementation.reject_login_request( + challenge=login_challenge, + error=ErrorOAuth2Response( + error="login_required", + error_description="The Authorization Server requires End-User authentication. Prompt 'none' was requested, but no existing or expired login session was found.", + ), + user_context=user_context, + ) + return RedirectResponse( + redirect_to=reject.redirect_to, + cookies=cookies, + ) + + return RedirectResponse( + redirect_to=await recipe_implementation.get_frontend_redirection_url( + params=FrontendRedirectionURLTypeLogin( + login_challenge=login_challenge, + force_fresh_auth=session is not None or prompt_param == "login", + tenant_id=tenant_id_param or DEFAULT_TENANT_ID, + hint=( + login_request.oidc_context.get("login_hint") + if login_request.oidc_context + else None + ), + ), + user_context=user_context, + ), + cookies=cookies, + ) + + +def get_merged_cookies(orig_cookies: str, new_cookies: Optional[List[str]]) -> str: + if not new_cookies: + return orig_cookies + + cookie_map: Dict[str, str] = {} + for cookie in orig_cookies.split(";"): + if "=" in cookie: + name, value = cookie.split("=", 1) + cookie_map[name.strip()] = value + + # Note: This is a simplified version. In production code you'd want to use a proper + # cookie parsing library to handle all cookie attributes correctly + if new_cookies: + for cookie_str in new_cookies: + cookie = cookie_str.split(";")[0].strip() + if "=" in cookie: + name, value = cookie.split("=", 1) + cookie_map[name.strip()] = value + + return ";".join(f"{key}={value}" for key, value in cookie_map.items()) + + +def merge_set_cookie_headers( + set_cookie1: Optional[List[str]] = None, set_cookie2: Optional[List[str]] = None +) -> List[str]: + if not set_cookie1: + return set_cookie2 or [] + if not set_cookie2 or set(set_cookie1) == set(set_cookie2): + return set_cookie1 + return set_cookie1 + set_cookie2 + + +def is_login_internal_redirect(redirect_to: str) -> bool: + instance = Supertokens.get_instance() + api_domain = instance.app_info.api_domain.get_as_string_dangerous() + api_base_path = instance.app_info.api_base_path.get_as_string_dangerous() + base_path = f"{api_domain}{api_base_path}" + + return any( + redirect_to.startswith(f"{base_path}{path}") for path in [LOGIN_PATH, AUTH_PATH] + ) + + +def is_logout_internal_redirect(redirect_to: str) -> bool: + instance = Supertokens.get_instance() + api_domain = instance.app_info.api_domain.get_as_string_dangerous() + api_base_path = instance.app_info.api_base_path.get_as_string_dangerous() + base_path = f"{api_domain}{api_base_path}" + return redirect_to.startswith(f"{base_path}{END_SESSION_PATH}") + + +async def handle_login_internal_redirects( + response: RedirectResponse, + recipe_implementation: RecipeInterface, + session: Optional[SessionContainer], + should_try_refresh: bool, + cookie: str, + user_context: Dict[str, Any], +) -> Union[RedirectResponse, ErrorOAuth2Response]: + from ..interfaces import RedirectResponse, ErrorOAuth2Response + + if not is_login_internal_redirect(response.redirect_to): + return response + + max_redirects = 10 + redirect_count = 0 + + while redirect_count < max_redirects and is_login_internal_redirect( + response.redirect_to + ): + cookie = get_merged_cookies(cookie, response.cookies) + + query_string = ( + response.redirect_to.split("?", 1)[1] if "?" in response.redirect_to else "" + ) + params = parse_qs(query_string) + + if LOGIN_PATH in response.redirect_to: + login_challenge = ( + params.get("login_challenge", [None])[0] + or params.get("loginChallenge", [None])[0] + ) + if not login_challenge: + raise Exception(f"Expected loginChallenge in {response.redirect_to}") + + login_res = await login_get( + recipe_implementation=recipe_implementation, + login_challenge=login_challenge, + session=session, + should_try_refresh=should_try_refresh, + cookies=response.cookies, + is_direct_call=False, + user_context=user_context, + ) + + if isinstance(login_res, ErrorOAuth2Response): + return login_res + + response = RedirectResponse( + redirect_to=login_res.redirect_to, + cookies=merge_set_cookie_headers(login_res.cookies, response.cookies), + ) + + elif AUTH_PATH in response.redirect_to: + auth_res = await recipe_implementation.authorization( + params={k: v[0] for k, v in params.items()}, + cookies=cookie, + session=session, + user_context=user_context, + ) + + if isinstance(auth_res, ErrorOAuth2Response): + return auth_res + + response = RedirectResponse( + redirect_to=auth_res.redirect_to, + cookies=merge_set_cookie_headers(auth_res.cookies, response.cookies), + ) + + else: + raise Exception(f"Unexpected internal redirect {response.redirect_to}") + + redirect_count += 1 + + return response + + +async def handle_logout_internal_redirects( + response: RedirectResponse, + recipe_implementation: RecipeInterface, + session: Optional[SessionContainer], + user_context: Dict[str, Any], +) -> Union[RedirectResponse, ErrorOAuth2Response]: + if not is_logout_internal_redirect(response.redirect_to): + return response + + max_redirects = 10 + redirect_count = 0 + + while redirect_count < max_redirects and is_logout_internal_redirect( + response.redirect_to + ): + query_string = ( + response.redirect_to.split("?", 1)[1] if "?" in response.redirect_to else "" + ) + params = parse_qs(query_string) + + if END_SESSION_PATH in response.redirect_to: + end_session_res = await recipe_implementation.end_session( + params={k: v[0] for k, v in params.items()}, + session=session, + should_try_refresh=False, + user_context=user_context, + ) + if isinstance(end_session_res, ErrorOAuth2Response): + return end_session_res + response = end_session_res + else: + raise Exception(f"Unexpected internal redirect {response.redirect_to}") + + redirect_count += 1 + + return response diff --git a/supertokens_python/recipe/oauth2provider/asyncio/__init__.py b/supertokens_python/recipe/oauth2provider/asyncio/__init__.py new file mode 100644 index 000000000..3cdf1b523 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/asyncio/__init__.py @@ -0,0 +1,246 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import base64 +from typing import Any, Dict, Union, Optional, List + +from ..interfaces import ( + ActiveTokenResponse, + CreateOAuth2ClientInput, + CreateOAuth2ClientOkResult, + DeleteOAuth2ClientOkResult, + ErrorOAuth2Response, + GetOAuth2ClientOkResult, + GetOAuth2ClientsOkResult, + InactiveTokenResponse, + OAuth2TokenValidationRequirements, + RevokeTokenUsingAuthorizationHeader, + RevokeTokenUsingClientIDAndClientSecret, + TokenInfo, + UpdateOAuth2ClientInput, + UpdateOAuth2ClientOkResult, +) + + +async def get_oauth2_client( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.get_oauth2_client( + client_id=client_id, user_context=user_context + ) + + +async def get_oauth2_clients( + page_size: Optional[int] = None, + pagination_token: Optional[str] = None, + client_name: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.get_oauth2_clients( + page_size=page_size, + pagination_token=pagination_token, + client_name=client_name, + user_context=user_context, + ) + + +async def create_oauth2_client( + params: CreateOAuth2ClientInput, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.create_oauth2_client( + params=params, + user_context=user_context, + ) + + +async def update_oauth2_client( + params: UpdateOAuth2ClientInput, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.update_oauth2_client( + params=params, + user_context=user_context, + ) + + +async def delete_oauth2_client( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.delete_oauth2_client( + client_id=client_id, user_context=user_context + ) + + +async def validate_oauth2_access_token( + token: str, + requirements: Optional[OAuth2TokenValidationRequirements] = None, + check_database: Optional[bool] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.validate_oauth2_access_token( + token=token, + requirements=requirements, + check_database=check_database, + user_context=user_context, + ) + + +async def create_token_for_client_credentials( + client_id: str, + client_secret: Optional[str] = None, + scope: Optional[List[str]] = None, + audience: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[TokenInfo, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + body: Dict[str, Any] = { + "grant_type": "client_credentials", + "client_id": client_id, + } + if client_secret: + body["client_secret"] = client_secret + if scope: + body["scope"] = " ".join(scope) + if audience: + body["audience"] = audience + + return ( + await OAuth2ProviderRecipe.get_instance().recipe_implementation.token_exchange( + authorization_header=None, + body=body, + user_context=user_context, + ) + ) + + +async def revoke_token( + token: str, + client_id: str, + client_secret: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Optional[ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + recipe = OAuth2ProviderRecipe.get_instance() + + client_info = await recipe.recipe_implementation.get_oauth2_client( + client_id=client_id, user_context=user_context + ) + + if isinstance(client_info, ErrorOAuth2Response): + raise Exception( + f"Failed to get OAuth2 client with id {client_id}: {client_info.error}" + ) + + token_endpoint_auth_method = client_info.client.token_endpoint_auth_method + + if token_endpoint_auth_method == "none": + auth_header = "Basic " + base64.b64encode(f"{client_id}:".encode()).decode() + return await recipe.recipe_implementation.revoke_token( + RevokeTokenUsingAuthorizationHeader( + token=token, + authorization_header=auth_header, + ), + user_context=user_context, + ) + elif token_endpoint_auth_method == "client_secret_basic" and client_secret: + auth_header = ( + "Basic " + + base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() + ) + return await recipe.recipe_implementation.revoke_token( + RevokeTokenUsingAuthorizationHeader( + token=token, + authorization_header=auth_header, + ), + user_context=user_context, + ) + + return await recipe.recipe_implementation.revoke_token( + RevokeTokenUsingClientIDAndClientSecret( + token=token, + client_id=client_id, + client_secret=client_secret, + ), + user_context=user_context, + ) + + +async def revoke_tokens_by_client_id( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> None: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.revoke_tokens_by_client_id( + client_id=client_id, user_context=user_context + ) + + +async def revoke_tokens_by_session_handle( + session_handle: str, user_context: Optional[Dict[str, Any]] = None +) -> None: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.revoke_tokens_by_session_handle( + session_handle=session_handle, user_context=user_context + ) + + +async def validate_oauth2_refresh_token( + token: str, + scopes: Optional[List[str]] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ActiveTokenResponse, InactiveTokenResponse]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.introspect_token( + token=token, scopes=scopes, user_context=user_context + ) diff --git a/supertokens_python/recipe/oauth2provider/constants.py b/supertokens_python/recipe/oauth2provider/constants.py new file mode 100644 index 000000000..edcd4f04b --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/constants.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +LOGIN_PATH = "/oauth/login" +AUTH_PATH = "/oauth/auth" +TOKEN_PATH = "/oauth/token" +LOGIN_INFO_PATH = "/oauth/login/info" +USER_INFO_PATH = "/oauth/userinfo" +REVOKE_TOKEN_PATH = "/oauth/revoke" +INTROSPECT_TOKEN_PATH = "/oauth/introspect" +END_SESSION_PATH = "/oauth/end_session" +LOGOUT_PATH = "/oauth/logout" diff --git a/supertokens_python/recipe/oauth2provider/exceptions.py b/supertokens_python/recipe/oauth2provider/exceptions.py new file mode 100644 index 000000000..9c3b5dcf6 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/exceptions.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from supertokens_python.exceptions import SuperTokensError + + +class OAuth2ProviderError(SuperTokensError): + pass diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py new file mode 100644 index 000000000..6d07fa89c --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -0,0 +1,1381 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union +from typing_extensions import Literal +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import ( + APIResponse, + GeneralErrorResponse, + RecipeUserId, + User, +) + +from .oauth2_client import OAuth2Client + + +if TYPE_CHECKING: + from supertokens_python.framework import BaseRequest, BaseResponse + from .utils import OAuth2ProviderConfig + + +class ErrorOAuth2Response(APIResponse): + def __init__( + self, + error: str, # OAuth2 error format (e.g. invalid_request, login_required) + error_description: str, # Human readable error description + status_code: Optional[int] = None, # HTTP status code (e.g. 401 or 403) + ): + self.status: Literal["ERROR"] = "ERROR" + self.error = error + self.error_description = error_description + self.status_code = status_code + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = { + "status": self.status, + "error": self.error, + "errorDescription": self.error_description, + } + if self.status_code is not None: + result["statusCode"] = self.status_code + return result + + @staticmethod + def from_json(json: Dict[str, Any]): + return ErrorOAuth2Response( + error=json["error"], + error_description=json["errorDescription"], + status_code=json["statusCode"], + ) + + +class ConsentRequest: + def __init__( + self, + challenge: str, # ID/identifier of the consent authorization request + acr: Optional[str] = None, # Authentication Context Class Reference value + amr: Optional[List[str]] = None, # List of strings + client: Optional[OAuth2Client] = None, + context: Optional[Any] = None, # Any JSON serializable object + login_challenge: Optional[str] = None, # Associated login challenge + login_session_id: Optional[str] = None, + oidc_context: Optional[Any] = None, # Optional OpenID Connect request info + requested_access_token_audience: Optional[List[str]] = None, + requested_scope: Optional[List[str]] = None, + skip: Optional[bool] = None, + subject: Optional[str] = None, + ): + self.challenge = challenge + self.acr = acr + self.amr = amr + self.client = client + self.context = context + self.login_challenge = login_challenge + self.login_session_id = login_session_id + self.oidc_context = oidc_context + self.requested_access_token_audience = requested_access_token_audience + self.requested_scope = requested_scope + self.skip = skip + self.subject = subject + + @staticmethod + def from_json(json: Dict[str, Any]): + return ConsentRequest( + acr=json["acr"], + amr=json["amr"], + challenge=json["challenge"], + client=OAuth2Client.from_json(json["client"]), + context=json["context"], + login_challenge=json["loginChallenge"], + login_session_id=json["loginSessionId"], + oidc_context=json["oidcContext"], + requested_access_token_audience=json["requestedAccessTokenAudience"], + requested_scope=json["requestedScope"], + skip=json["skip"], + subject=json["subject"], + ) + + +class LoginRequest: + def __init__( + self, + challenge: str, # ID/identifier of the login request + client: OAuth2Client, + request_url: str, # Original OAuth 2.0 Authorization URL + skip: bool, + subject: str, + oidc_context: Optional[Any] = None, # Optional OpenID Connect request info + requested_access_token_audience: Optional[List[str]] = None, + requested_scope: Optional[List[str]] = None, + session_id: Optional[str] = None, + ): + self.challenge = challenge + self.client = client + self.oidc_context = oidc_context + self.request_url = request_url + self.requested_access_token_audience = requested_access_token_audience + self.requested_scope = requested_scope + self.session_id = session_id + self.skip = skip + self.subject = subject + + @staticmethod + def from_json(json: Dict[str, Any]): + return LoginRequest( + challenge=json["challenge"], + client=OAuth2Client.from_json(json["client"]), + request_url=json["requestUrl"], + skip=json["skip"], + subject=json["subject"], + oidc_context=json["oidcContext"], + requested_access_token_audience=json["requestedAccessTokenAudience"], + requested_scope=json["requestedScope"], + session_id=json["sessionId"], + ) + + +class TokenInfo: + def __init__( + self, + expires_in: int, # Lifetime in seconds of the access token + scope: str, + token_type: str, + access_token: Optional[str] = None, + id_token: Optional[str] = None, # Requires id_token scope + refresh_token: Optional[str] = None, # Requires offline scope + ): + self.access_token = access_token + self.expires_in = expires_in + self.id_token = id_token + self.refresh_token = refresh_token + self.scope = scope + self.token_type = token_type + + @staticmethod + def from_json(json: Dict[str, Any]): + return TokenInfo( + access_token=json.get("access_token"), + expires_in=json["expires_in"], + id_token=json.get("id_token"), + refresh_token=json.get("refresh_token"), + scope=json["scope"], + token_type=json["token_type"], + ) + + def to_json(self) -> Dict[str, Any]: + result = { + "status": "OK", + "expires_in": self.expires_in, + "scope": self.scope, + "token_type": self.token_type, + } + if self.access_token is not None: + result["access_token"] = self.access_token + if self.id_token is not None: + result["id_token"] = self.id_token + if self.refresh_token is not None: + result["refresh_token"] = self.refresh_token + return result + + +class LoginInfo: + def __init__( + self, + client_id: str, + client_name: str, + tos_uri: Optional[str] = None, + policy_uri: Optional[str] = None, + logo_uri: Optional[str] = None, + client_uri: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + self.client_id = client_id + self.client_name = client_name + self.tos_uri = tos_uri + self.policy_uri = policy_uri + self.logo_uri = logo_uri + self.client_uri = client_uri + self.metadata = metadata + + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "info": { + "clientId": self.client_id, + "clientName": self.client_name, + "tosUri": self.tos_uri, + "policyUri": self.policy_uri, + "logoUri": self.logo_uri, + "clientUri": self.client_uri, + "metadata": self.metadata, + }, + } + + +class RedirectResponse: + def __init__(self, redirect_to: str, cookies: Optional[List[str]] = None): + self.redirect_to = redirect_to + self.cookies = cookies + + +class FrontendRedirectResponse: + def __init__(self, frontend_redirect_to: str, cookies: Optional[List[str]] = None): + self.frontend_redirect_to = frontend_redirect_to + self.cookies = cookies + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = { + "frontendRedirectTo": self.frontend_redirect_to, + } + return result + + +class GetOAuth2ClientsOkResult: + def __init__( + self, clients: List[OAuth2Client], next_pagination_token: Optional[str] + ): + self.clients = clients + self.next_pagination_token = next_pagination_token + + @staticmethod + def from_json(json: Dict[str, Any]): + return GetOAuth2ClientsOkResult( + clients=[OAuth2Client.from_json(client) for client in json["clients"]], + next_pagination_token=json["nextPaginationToken"], + ) + + def to_json(self) -> Dict[str, Any]: + result = { + "status": "OK", + "clients": [client.to_json() for client in self.clients], + } + if self.next_pagination_token is not None: + result["nextPaginationToken"] = self.next_pagination_token + return result + + +class GetOAuth2ClientOkResult: + def __init__(self, client: OAuth2Client): + self.client = client + + @staticmethod + def from_json(json: Dict[str, Any]): + return GetOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"])) + + +class CreateOAuth2ClientOkResult: + def __init__(self, client: OAuth2Client): + self.client = client + + @staticmethod + def from_json(json: Dict[str, Any]): + return CreateOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"])) + + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "client": self.client.to_json(), + } + + +class UpdateOAuth2ClientOkResult: + def __init__(self, client: OAuth2Client): + self.client = client + + @staticmethod + def from_json(json: Dict[str, Any]): + return UpdateOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"])) + + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "client": self.client.to_json(), + } + + +class DeleteOAuth2ClientOkResult: + def __init__(self): + pass + + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + } + + +PayloadBuilderFunction = Callable[ + [User, List[str], str, Dict[str, Any]], Awaitable[Dict[str, Any]] +] + +UserInfoBuilderFunction = Callable[ + [User, Dict[str, Any], List[str], str, Dict[str, Any]], Awaitable[Dict[str, Any]] +] + + +class OAuth2TokenValidationRequirements: + def __init__( + self, + client_id: Optional[str] = None, + scopes: Optional[List[str]] = None, + audience: Optional[str] = None, + ): + self.client_id = client_id + self.scopes = scopes + self.audience = audience + + @staticmethod + def from_json(json: Dict[str, Any]): + return OAuth2TokenValidationRequirements( + client_id=json.get("clientId"), + scopes=json.get("scopes"), + audience=json.get("audience"), + ) + + +class FrontendRedirectionURLTypeLogin: + def __init__( + self, + login_challenge: str, + tenant_id: str, + force_fresh_auth: bool, + hint: Optional[str] = None, + ): + self.login_challenge = login_challenge + self.tenant_id = tenant_id + self.force_fresh_auth = force_fresh_auth + self.hint = hint + + +class FrontendRedirectionURLTypeTryRefresh: + def __init__(self, login_challenge: str): + self.login_challenge = login_challenge + + +class FrontendRedirectionURLTypeLogoutConfirmation: + def __init__(self, logout_challenge: str): + self.logout_challenge = logout_challenge + + +class FrontendRedirectionURLTypePostLogoutFallback: + pass + + +class RevokeTokenUsingAuthorizationHeader: + def __init__(self, token: str, authorization_header: str): + self.token = token + self.authorization_header = authorization_header + + +class RevokeTokenUsingClientIDAndClientSecret: + def __init__(self, token: str, client_id: str, client_secret: Optional[str]): + self.token = token + self.client_id = client_id + self.client_secret = client_secret + + +class InactiveTokenResponse: + def to_json(self): + return {"active": False} + + +class ActiveTokenResponse: + def __init__(self, payload: Dict[str, Any]): + self.payload = payload + + def to_json(self): + return {"active": True, **self.payload} + + +class OAuth2ClientOptions: + def __init__( + self, + client_id: str, + client_secret: Optional[str], + created_at: str, + updated_at: str, + client_name: str, + scope: str, + redirect_uris: Optional[List[str]], + post_logout_redirect_uris: Optional[List[str]], + authorization_code_grant_access_token_lifespan: Optional[str], + authorization_code_grant_id_token_lifespan: Optional[str], + authorization_code_grant_refresh_token_lifespan: Optional[str], + client_credentials_grant_access_token_lifespan: Optional[str], + implicit_grant_access_token_lifespan: Optional[str], + implicit_grant_id_token_lifespan: Optional[str], + refresh_token_grant_access_token_lifespan: Optional[str], + refresh_token_grant_id_token_lifespan: Optional[str], + refresh_token_grant_refresh_token_lifespan: Optional[str], + token_endpoint_auth_method: str, + audience: Optional[List[str]], + grant_types: Optional[List[str]], + response_types: Optional[List[str]], + client_uri: Optional[str], + logo_uri: Optional[str], + policy_uri: Optional[str], + tos_uri: Optional[str], + metadata: Optional[Dict[str, Any]], + enable_refresh_token_rotation: Optional[bool], + ): + self.client_id = client_id + self.client_secret = client_secret + self.created_at = created_at + self.updated_at = updated_at + self.client_name = client_name + self.scope = scope + self.redirect_uris = redirect_uris + self.post_logout_redirect_uris = post_logout_redirect_uris + self.authorization_code_grant_access_token_lifespan = ( + authorization_code_grant_access_token_lifespan + ) + self.authorization_code_grant_id_token_lifespan = ( + authorization_code_grant_id_token_lifespan + ) + self.authorization_code_grant_refresh_token_lifespan = ( + authorization_code_grant_refresh_token_lifespan + ) + self.client_credentials_grant_access_token_lifespan = ( + client_credentials_grant_access_token_lifespan + ) + self.implicit_grant_access_token_lifespan = implicit_grant_access_token_lifespan + self.implicit_grant_id_token_lifespan = implicit_grant_id_token_lifespan + self.refresh_token_grant_access_token_lifespan = ( + refresh_token_grant_access_token_lifespan + ) + self.refresh_token_grant_id_token_lifespan = ( + refresh_token_grant_id_token_lifespan + ) + self.refresh_token_grant_refresh_token_lifespan = ( + refresh_token_grant_refresh_token_lifespan + ) + self.token_endpoint_auth_method = token_endpoint_auth_method + self.audience = audience + self.grant_types = grant_types + self.response_types = response_types + self.client_uri = client_uri + self.logo_uri = logo_uri + self.policy_uri = policy_uri + self.tos_uri = tos_uri + self.metadata = metadata + self.enable_refresh_token_rotation = enable_refresh_token_rotation + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = { + "clientId": self.client_id, + "createdAt": self.created_at, + "updatedAt": self.updated_at, + "clientName": self.client_name, + "scope": self.scope, + "tokenEndpointAuthMethod": self.token_endpoint_auth_method, + } + if self.client_secret is not None: + result["clientSecret"] = self.client_secret + if self.redirect_uris is not None: + result["redirectUris"] = self.redirect_uris + if self.post_logout_redirect_uris is not None: + result["postLogoutRedirectUris"] = self.post_logout_redirect_uris + if self.authorization_code_grant_access_token_lifespan is not None: + result["authorizationCodeGrantAccessTokenLifespan"] = ( + self.authorization_code_grant_access_token_lifespan + ) + if self.authorization_code_grant_id_token_lifespan is not None: + result["authorizationCodeGrantIdTokenLifespan"] = ( + self.authorization_code_grant_id_token_lifespan + ) + if self.authorization_code_grant_refresh_token_lifespan is not None: + result["authorizationCodeGrantRefreshTokenLifespan"] = ( + self.authorization_code_grant_refresh_token_lifespan + ) + if self.client_credentials_grant_access_token_lifespan is not None: + result["clientCredentialsGrantAccessTokenLifespan"] = ( + self.client_credentials_grant_access_token_lifespan + ) + if self.implicit_grant_access_token_lifespan is not None: + result["implicitGrantAccessTokenLifespan"] = ( + self.implicit_grant_access_token_lifespan + ) + if self.implicit_grant_id_token_lifespan is not None: + result["implicitGrantIdTokenLifespan"] = ( + self.implicit_grant_id_token_lifespan + ) + if self.refresh_token_grant_access_token_lifespan is not None: + result["refreshTokenGrantAccessTokenLifespan"] = ( + self.refresh_token_grant_access_token_lifespan + ) + if self.refresh_token_grant_id_token_lifespan is not None: + result["refreshTokenGrantIdTokenLifespan"] = ( + self.refresh_token_grant_id_token_lifespan + ) + if self.refresh_token_grant_refresh_token_lifespan is not None: + result["refreshTokenGrantRefreshTokenLifespan"] = ( + self.refresh_token_grant_refresh_token_lifespan + ) + if self.audience is not None: + result["audience"] = self.audience + if self.grant_types is not None: + result["grantTypes"] = self.grant_types + if self.response_types is not None: + result["responseTypes"] = self.response_types + if self.client_uri is not None: + result["clientUri"] = self.client_uri + if self.logo_uri is not None: + result["logoUri"] = self.logo_uri + if self.policy_uri is not None: + result["policyUri"] = self.policy_uri + if self.tos_uri is not None: + result["tosUri"] = self.tos_uri + if self.metadata is not None: + result["metadata"] = self.metadata + if self.enable_refresh_token_rotation is not None: + result["enableRefreshTokenRotation"] = self.enable_refresh_token_rotation + return result + + @staticmethod + def from_json(json: Dict[str, Any]) -> "OAuth2ClientOptions": + return OAuth2ClientOptions( + client_id=json["clientId"], + client_secret=json["clientSecret"], + created_at=json["createdAt"], + updated_at=json["updatedAt"], + client_name=json["clientName"], + scope=json["scope"], + redirect_uris=json["redirectUris"], + post_logout_redirect_uris=json["postLogoutRedirectUris"], + authorization_code_grant_access_token_lifespan=json.get( + "authorizationCodeGrantAccessTokenLifespan" + ), + authorization_code_grant_id_token_lifespan=json.get( + "authorizationCodeGrantIdTokenLifespan" + ), + authorization_code_grant_refresh_token_lifespan=json.get( + "authorizationCodeGrantRefreshTokenLifespan" + ), + client_credentials_grant_access_token_lifespan=json.get( + "clientCredentialsGrantAccessTokenLifespan" + ), + implicit_grant_access_token_lifespan=json.get( + "implicitGrantAccessTokenLifespan" + ), + implicit_grant_id_token_lifespan=json.get("implicitGrantIdTokenLifespan"), + refresh_token_grant_access_token_lifespan=json.get( + "refreshTokenGrantAccessTokenLifespan" + ), + refresh_token_grant_id_token_lifespan=json.get( + "refreshTokenGrantIdTokenLifespan" + ), + refresh_token_grant_refresh_token_lifespan=json.get( + "refreshTokenGrantRefreshTokenLifespan" + ), + token_endpoint_auth_method=json["tokenEndpointAuthMethod"], + audience=json.get("audience"), + grant_types=json.get("grantTypes"), + response_types=json.get("responseTypes"), + client_uri=json.get("clientUri"), + logo_uri=json.get("logoUri"), + policy_uri=json.get("policyUri"), + tos_uri=json.get("tosUri"), + metadata=json.get("metadata"), + enable_refresh_token_rotation=json.get("enableRefreshTokenRotation"), + ) + + +class CreateOAuth2ClientInput: + def __init__( + self, + client_id: Optional[str], + client_secret: Optional[str], + client_name: Optional[str], + scope: Optional[str], + redirect_uris: Optional[List[str]], + post_logout_redirect_uris: Optional[List[str]], + authorization_code_grant_access_token_lifespan: Optional[str], + authorization_code_grant_id_token_lifespan: Optional[str], + authorization_code_grant_refresh_token_lifespan: Optional[str], + client_credentials_grant_access_token_lifespan: Optional[str], + implicit_grant_access_token_lifespan: Optional[str], + implicit_grant_id_token_lifespan: Optional[str], + refresh_token_grant_access_token_lifespan: Optional[str], + refresh_token_grant_id_token_lifespan: Optional[str], + refresh_token_grant_refresh_token_lifespan: Optional[str], + token_endpoint_auth_method: Optional[str], + audience: Optional[List[str]], + grant_types: Optional[List[str]], + response_types: Optional[List[str]], + client_uri: Optional[str], + logo_uri: Optional[str], + policy_uri: Optional[str], + tos_uri: Optional[str], + metadata: Optional[Dict[str, Any]], + enable_refresh_token_rotation: Optional[bool], + ): + self.client_id = client_id + self.client_secret = client_secret + self.client_name = client_name + self.scope = scope + self.redirect_uris = redirect_uris + self.post_logout_redirect_uris = post_logout_redirect_uris + self.authorization_code_grant_access_token_lifespan = ( + authorization_code_grant_access_token_lifespan + ) + self.authorization_code_grant_id_token_lifespan = ( + authorization_code_grant_id_token_lifespan + ) + self.authorization_code_grant_refresh_token_lifespan = ( + authorization_code_grant_refresh_token_lifespan + ) + self.client_credentials_grant_access_token_lifespan = ( + client_credentials_grant_access_token_lifespan + ) + self.implicit_grant_access_token_lifespan = implicit_grant_access_token_lifespan + self.implicit_grant_id_token_lifespan = implicit_grant_id_token_lifespan + self.refresh_token_grant_access_token_lifespan = ( + refresh_token_grant_access_token_lifespan + ) + self.refresh_token_grant_id_token_lifespan = ( + refresh_token_grant_id_token_lifespan + ) + self.refresh_token_grant_refresh_token_lifespan = ( + refresh_token_grant_refresh_token_lifespan + ) + self.token_endpoint_auth_method = token_endpoint_auth_method + self.audience = audience + self.grant_types = grant_types + self.response_types = response_types + self.client_uri = client_uri + self.logo_uri = logo_uri + self.policy_uri = policy_uri + self.tos_uri = tos_uri + self.metadata = metadata + self.enable_refresh_token_rotation = enable_refresh_token_rotation + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = {} + if self.client_id is not None: + result["clientId"] = self.client_id + if self.client_name is not None: + result["clientName"] = self.client_name + if self.scope is not None: + result["scope"] = self.scope + if self.token_endpoint_auth_method is not None: + result["tokenEndpointAuthMethod"] = self.token_endpoint_auth_method + if self.client_secret is not None: + result["clientSecret"] = self.client_secret + if self.redirect_uris is not None: + result["redirectUris"] = self.redirect_uris + if self.post_logout_redirect_uris is not None: + result["postLogoutRedirectUris"] = self.post_logout_redirect_uris + if self.authorization_code_grant_access_token_lifespan is not None: + result["authorizationCodeGrantAccessTokenLifespan"] = ( + self.authorization_code_grant_access_token_lifespan + ) + if self.authorization_code_grant_id_token_lifespan is not None: + result["authorizationCodeGrantIdTokenLifespan"] = ( + self.authorization_code_grant_id_token_lifespan + ) + if self.authorization_code_grant_refresh_token_lifespan is not None: + result["authorizationCodeGrantRefreshTokenLifespan"] = ( + self.authorization_code_grant_refresh_token_lifespan + ) + if self.client_credentials_grant_access_token_lifespan is not None: + result["clientCredentialsGrantAccessTokenLifespan"] = ( + self.client_credentials_grant_access_token_lifespan + ) + if self.implicit_grant_access_token_lifespan is not None: + result["implicitGrantAccessTokenLifespan"] = ( + self.implicit_grant_access_token_lifespan + ) + if self.implicit_grant_id_token_lifespan is not None: + result["implicitGrantIdTokenLifespan"] = ( + self.implicit_grant_id_token_lifespan + ) + if self.refresh_token_grant_access_token_lifespan is not None: + result["refreshTokenGrantAccessTokenLifespan"] = ( + self.refresh_token_grant_access_token_lifespan + ) + if self.refresh_token_grant_id_token_lifespan is not None: + result["refreshTokenGrantIdTokenLifespan"] = ( + self.refresh_token_grant_id_token_lifespan + ) + if self.refresh_token_grant_refresh_token_lifespan is not None: + result["refreshTokenGrantRefreshTokenLifespan"] = ( + self.refresh_token_grant_refresh_token_lifespan + ) + if self.audience is not None: + result["audience"] = self.audience + if self.grant_types is not None: + result["grantTypes"] = self.grant_types + if self.response_types is not None: + result["responseTypes"] = self.response_types + if self.client_uri is not None: + result["clientUri"] = self.client_uri + if self.logo_uri is not None: + result["logoUri"] = self.logo_uri + if self.policy_uri is not None: + result["policyUri"] = self.policy_uri + if self.tos_uri is not None: + result["tosUri"] = self.tos_uri + if self.metadata is not None: + result["metadata"] = self.metadata + if self.enable_refresh_token_rotation is not None: + result["enableRefreshTokenRotation"] = self.enable_refresh_token_rotation + return result + + @staticmethod + def from_json(json: Dict[str, Any]) -> "CreateOAuth2ClientInput": + return CreateOAuth2ClientInput( + client_id=json.get("clientId"), + client_secret=json.get("clientSecret"), + client_name=json.get("clientName"), + scope=json.get("scope"), + redirect_uris=json.get("redirectUris"), + post_logout_redirect_uris=json.get("postLogoutRedirectUris"), + authorization_code_grant_access_token_lifespan=json.get( + "authorizationCodeGrantAccessTokenLifespan" + ), + authorization_code_grant_id_token_lifespan=json.get( + "authorizationCodeGrantIdTokenLifespan" + ), + authorization_code_grant_refresh_token_lifespan=json.get( + "authorizationCodeGrantRefreshTokenLifespan" + ), + client_credentials_grant_access_token_lifespan=json.get( + "clientCredentialsGrantAccessTokenLifespan" + ), + implicit_grant_access_token_lifespan=json.get( + "implicitGrantAccessTokenLifespan" + ), + implicit_grant_id_token_lifespan=json.get("implicitGrantIdTokenLifespan"), + refresh_token_grant_access_token_lifespan=json.get( + "refreshTokenGrantAccessTokenLifespan" + ), + refresh_token_grant_id_token_lifespan=json.get( + "refreshTokenGrantIdTokenLifespan" + ), + refresh_token_grant_refresh_token_lifespan=json.get( + "refreshTokenGrantRefreshTokenLifespan" + ), + token_endpoint_auth_method=json.get("tokenEndpointAuthMethod"), + audience=json.get("audience"), + grant_types=json.get("grantTypes"), + response_types=json.get("responseTypes"), + client_uri=json.get("clientUri"), + logo_uri=json.get("logoUri"), + policy_uri=json.get("policyUri"), + tos_uri=json.get("tosUri"), + metadata=json.get("metadata"), + enable_refresh_token_rotation=json.get("enableRefreshTokenRotation"), + ) + + +class NotSet: + pass + + +class UpdateOAuth2ClientInput: + def __init__( + self, + client_id: str, + client_secret: Union[Optional[str], NotSet] = NotSet(), + client_name: Union[Optional[str], NotSet] = NotSet(), + scope: Union[Optional[str], NotSet] = NotSet(), + redirect_uris: Union[Optional[List[str]], NotSet] = NotSet(), + post_logout_redirect_uris: Union[Optional[List[str]], NotSet] = NotSet(), + authorization_code_grant_access_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + authorization_code_grant_id_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + authorization_code_grant_refresh_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + client_credentials_grant_access_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + implicit_grant_access_token_lifespan: Union[Optional[str], NotSet] = NotSet(), + implicit_grant_id_token_lifespan: Union[Optional[str], NotSet] = NotSet(), + refresh_token_grant_access_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + refresh_token_grant_id_token_lifespan: Union[Optional[str], NotSet] = NotSet(), + refresh_token_grant_refresh_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + token_endpoint_auth_method: Union[Optional[str], NotSet] = NotSet(), + audience: Union[Optional[List[str]], NotSet] = NotSet(), + grant_types: Union[Optional[List[str]], NotSet] = NotSet(), + response_types: Union[Optional[List[str]], NotSet] = NotSet(), + client_uri: Union[Optional[str], NotSet] = NotSet(), + logo_uri: Union[Optional[str], NotSet] = NotSet(), + policy_uri: Union[Optional[str], NotSet] = NotSet(), + tos_uri: Union[Optional[str], NotSet] = NotSet(), + metadata: Union[Optional[Dict[str, Any]], NotSet] = NotSet(), + enable_refresh_token_rotation: Union[Optional[bool], NotSet] = NotSet(), + ): + self.client_id = client_id + self.client_secret = client_secret + self.client_name = client_name + self.scope = scope + self.redirect_uris = redirect_uris + self.post_logout_redirect_uris = post_logout_redirect_uris + self.authorization_code_grant_access_token_lifespan = ( + authorization_code_grant_access_token_lifespan + ) + self.authorization_code_grant_id_token_lifespan = ( + authorization_code_grant_id_token_lifespan + ) + self.authorization_code_grant_refresh_token_lifespan = ( + authorization_code_grant_refresh_token_lifespan + ) + self.client_credentials_grant_access_token_lifespan = ( + client_credentials_grant_access_token_lifespan + ) + self.implicit_grant_access_token_lifespan = implicit_grant_access_token_lifespan + self.implicit_grant_id_token_lifespan = implicit_grant_id_token_lifespan + self.refresh_token_grant_access_token_lifespan = ( + refresh_token_grant_access_token_lifespan + ) + self.refresh_token_grant_id_token_lifespan = ( + refresh_token_grant_id_token_lifespan + ) + self.refresh_token_grant_refresh_token_lifespan = ( + refresh_token_grant_refresh_token_lifespan + ) + self.token_endpoint_auth_method = token_endpoint_auth_method + self.audience = audience + self.grant_types = grant_types + self.response_types = response_types + self.client_uri = client_uri + self.logo_uri = logo_uri + self.policy_uri = policy_uri + self.tos_uri = tos_uri + self.metadata = metadata + self.enable_refresh_token_rotation = enable_refresh_token_rotation + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = {} + result["clientId"] = self.client_id + + if not isinstance(self.client_name, NotSet): + result["clientName"] = self.client_name + if not isinstance(self.scope, NotSet): + result["scope"] = self.scope + if not isinstance(self.token_endpoint_auth_method, NotSet): + result["tokenEndpointAuthMethod"] = self.token_endpoint_auth_method + if not isinstance(self.client_secret, NotSet): + result["clientSecret"] = self.client_secret + if not isinstance(self.redirect_uris, NotSet): + result["redirectUris"] = self.redirect_uris + if not isinstance(self.post_logout_redirect_uris, NotSet): + result["postLogoutRedirectUris"] = self.post_logout_redirect_uris + if not isinstance(self.authorization_code_grant_access_token_lifespan, NotSet): + result["authorizationCodeGrantAccessTokenLifespan"] = ( + self.authorization_code_grant_access_token_lifespan + ) + if not isinstance(self.authorization_code_grant_id_token_lifespan, NotSet): + result["authorizationCodeGrantIdTokenLifespan"] = ( + self.authorization_code_grant_id_token_lifespan + ) + if not isinstance(self.authorization_code_grant_refresh_token_lifespan, NotSet): + result["authorizationCodeGrantRefreshTokenLifespan"] = ( + self.authorization_code_grant_refresh_token_lifespan + ) + if not isinstance(self.client_credentials_grant_access_token_lifespan, NotSet): + result["clientCredentialsGrantAccessTokenLifespan"] = ( + self.client_credentials_grant_access_token_lifespan + ) + if not isinstance(self.implicit_grant_access_token_lifespan, NotSet): + result["implicitGrantAccessTokenLifespan"] = ( + self.implicit_grant_access_token_lifespan + ) + if not isinstance(self.implicit_grant_id_token_lifespan, NotSet): + result["implicitGrantIdTokenLifespan"] = ( + self.implicit_grant_id_token_lifespan + ) + if not isinstance(self.refresh_token_grant_access_token_lifespan, NotSet): + result["refreshTokenGrantAccessTokenLifespan"] = ( + self.refresh_token_grant_access_token_lifespan + ) + if not isinstance(self.refresh_token_grant_id_token_lifespan, NotSet): + result["refreshTokenGrantIdTokenLifespan"] = ( + self.refresh_token_grant_id_token_lifespan + ) + if not isinstance(self.refresh_token_grant_refresh_token_lifespan, NotSet): + result["refreshTokenGrantRefreshTokenLifespan"] = ( + self.refresh_token_grant_refresh_token_lifespan + ) + if not isinstance(self.audience, NotSet): + result["audience"] = self.audience + if not isinstance(self.grant_types, NotSet): + result["grantTypes"] = self.grant_types + if not isinstance(self.response_types, NotSet): + result["responseTypes"] = self.response_types + if not isinstance(self.client_uri, NotSet): + result["clientUri"] = self.client_uri + if not isinstance(self.logo_uri, NotSet): + result["logoUri"] = self.logo_uri + if not isinstance(self.policy_uri, NotSet): + result["policyUri"] = self.policy_uri + if not isinstance(self.tos_uri, NotSet): + result["tosUri"] = self.tos_uri + if not isinstance(self.metadata, NotSet): + result["metadata"] = self.metadata + if not isinstance(self.enable_refresh_token_rotation, NotSet): + result["enableRefreshTokenRotation"] = self.enable_refresh_token_rotation + return result + + @staticmethod + def from_json(json: Dict[str, Any]) -> "UpdateOAuth2ClientInput": + return UpdateOAuth2ClientInput( + client_id=json["clientId"], + client_secret=json.get("clientSecret", NotSet()), + client_name=json.get("clientName", NotSet()), + scope=json.get("scope", NotSet()), + redirect_uris=json.get("redirectUris", NotSet()), + post_logout_redirect_uris=json.get("postLogoutRedirectUris", NotSet()), + authorization_code_grant_access_token_lifespan=json.get( + "authorizationCodeGrantAccessTokenLifespan", NotSet() + ), + authorization_code_grant_id_token_lifespan=json.get( + "authorizationCodeGrantIdTokenLifespan", NotSet() + ), + authorization_code_grant_refresh_token_lifespan=json.get( + "authorizationCodeGrantRefreshTokenLifespan", NotSet() + ), + client_credentials_grant_access_token_lifespan=json.get( + "clientCredentialsGrantAccessTokenLifespan", NotSet() + ), + implicit_grant_access_token_lifespan=json.get( + "implicitGrantAccessTokenLifespan", NotSet() + ), + implicit_grant_id_token_lifespan=json.get( + "implicitGrantIdTokenLifespan", NotSet() + ), + refresh_token_grant_access_token_lifespan=json.get( + "refreshTokenGrantAccessTokenLifespan", NotSet() + ), + refresh_token_grant_id_token_lifespan=json.get( + "refreshTokenGrantIdTokenLifespan", NotSet() + ), + refresh_token_grant_refresh_token_lifespan=json.get( + "refreshTokenGrantRefreshTokenLifespan", NotSet() + ), + token_endpoint_auth_method=json.get("tokenEndpointAuthMethod", NotSet()), + audience=json.get("audience", NotSet()), + grant_types=json.get("grantTypes", NotSet()), + response_types=json.get("responseTypes", NotSet()), + client_uri=json.get("clientUri", NotSet()), + logo_uri=json.get("logoUri", NotSet()), + policy_uri=json.get("policyUri", NotSet()), + tos_uri=json.get("tosUri", NotSet()), + metadata=json.get("metadata", NotSet()), + enable_refresh_token_rotation=json.get( + "enableRefreshTokenRotation", NotSet() + ), + ) + + +class RecipeInterface(ABC): + @abstractmethod + async def authorization( + self, + params: Dict[str, str], + cookies: Optional[str], + session: Optional[SessionContainer], + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + pass + + @abstractmethod + async def token_exchange( + self, + authorization_header: Optional[str], + body: Dict[str, Optional[str]], + user_context: Dict[str, Any], + ) -> Union[TokenInfo, ErrorOAuth2Response]: + pass + + @abstractmethod + async def get_consent_request( + self, challenge: str, user_context: Dict[str, Any] + ) -> ConsentRequest: + pass + + @abstractmethod + async def accept_consent_request( + self, + challenge: str, + context: Optional[Any], + grant_access_token_audience: Optional[List[str]], + grant_scope: Optional[List[str]], + handled_at: Optional[str], + tenant_id: str, + rsub: str, + session_handle: str, + initial_access_token_payload: Optional[Dict[str, Any]], + initial_id_token_payload: Optional[Dict[str, Any]], + user_context: Dict[str, Any], + ) -> RedirectResponse: + pass + + @abstractmethod + async def reject_consent_request( + self, challenge: str, error: ErrorOAuth2Response, user_context: Dict[str, Any] + ) -> RedirectResponse: + pass + + @abstractmethod + async def get_login_request( + self, challenge: str, user_context: Dict[str, Any] + ) -> Union[LoginRequest, ErrorOAuth2Response]: + pass + + @abstractmethod + async def accept_login_request( + self, + challenge: str, + acr: Optional[str], + amr: Optional[List[str]], + context: Optional[Any], + extend_session_lifespan: Optional[bool], + identity_provider_session_id: Optional[str], + subject: str, + user_context: Dict[str, Any], + ) -> RedirectResponse: + pass + + @abstractmethod + async def reject_login_request( + self, + challenge: str, + error: ErrorOAuth2Response, + user_context: Dict[str, Any], + ) -> RedirectResponse: + pass + + @abstractmethod + async def get_oauth2_clients( + self, + page_size: Optional[int], + pagination_token: Optional[str], + client_name: Optional[str], + user_context: Dict[str, Any], + ) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: + pass + + @abstractmethod + async def get_oauth2_client( + self, + client_id: str, + user_context: Dict[str, Any], + ) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: + pass + + @abstractmethod + async def create_oauth2_client( + self, + params: CreateOAuth2ClientInput, + user_context: Dict[str, Any], + ) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: + pass + + @abstractmethod + async def update_oauth2_client( + self, + params: UpdateOAuth2ClientInput, + user_context: Dict[str, Any], + ) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: + pass + + @abstractmethod + async def delete_oauth2_client( + self, + client_id: str, + user_context: Dict[str, Any], + ) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: + pass + + @abstractmethod + async def validate_oauth2_access_token( + self, + token: str, + requirements: Optional[OAuth2TokenValidationRequirements], + check_database: Optional[bool], + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + pass + + @abstractmethod + async def get_requested_scopes( + self, + recipe_user_id: Optional[RecipeUserId], + session_handle: Optional[str], + scope_param: List[str], + client_id: str, + user_context: Dict[str, Any], + ) -> List[str]: + pass + + @abstractmethod + async def build_access_token_payload( + self, + user: Optional[User], + client: OAuth2Client, + session_handle: Optional[str], + scopes: List[str], + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + pass + + @abstractmethod + async def build_id_token_payload( + self, + user: Optional[User], + client: OAuth2Client, + session_handle: Optional[str], + scopes: List[str], + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + pass + + @abstractmethod + async def build_user_info( + self, + user: User, + access_token_payload: Dict[str, Any], + scopes: List[str], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + pass + + @abstractmethod + async def get_frontend_redirection_url( + self, + params: Union[ + FrontendRedirectionURLTypeLogin, + FrontendRedirectionURLTypeTryRefresh, + FrontendRedirectionURLTypeLogoutConfirmation, + FrontendRedirectionURLTypePostLogoutFallback, + ], + user_context: Dict[str, Any], + ) -> str: + pass + + @abstractmethod + async def revoke_token( + self, + params: Union[ + RevokeTokenUsingAuthorizationHeader, + RevokeTokenUsingClientIDAndClientSecret, + ], + user_context: Dict[str, Any], + ) -> Optional[ErrorOAuth2Response]: + pass + + @abstractmethod + async def revoke_tokens_by_client_id( + self, + client_id: str, + user_context: Dict[str, Any], + ): + pass + + @abstractmethod + async def revoke_tokens_by_session_handle( + self, + session_handle: str, + user_context: Dict[str, Any], + ): + pass + + @abstractmethod + async def introspect_token( + self, + token: str, + scopes: Optional[List[str]], + user_context: Dict[str, Any], + ) -> Union[ActiveTokenResponse, InactiveTokenResponse]: + pass + + @abstractmethod + async def end_session( + self, + params: Dict[str, str], + should_try_refresh: bool, + session: Optional[SessionContainer], + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + pass + + @abstractmethod + async def accept_logout_request( + self, + challenge: str, + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + pass + + @abstractmethod + async def reject_logout_request( + self, + challenge: str, + user_context: Dict[str, Any], + ): + pass + + +class APIOptions: + def __init__( + self, + request: BaseRequest, + response: BaseResponse, + recipe_id: str, + config: OAuth2ProviderConfig, + recipe_implementation: RecipeInterface, + ): + self.request: BaseRequest = request + self.response: BaseResponse = response + self.recipe_id: str = recipe_id + self.config: OAuth2ProviderConfig = config + self.recipe_implementation: RecipeInterface = recipe_implementation + + +class APIInterface: + def __init__(self): + self.disable_login_get = False + self.disable_auth_get = False + self.disable_token_post = False + self.disable_login_info_get = False + self.disable_user_info_get = False + self.disable_revoke_token_post = False + self.disable_introspect_token_post = False + self.disable_end_session_get = False + self.disable_end_session_post = False + self.disable_logout_post = False + + @abstractmethod + async def login_get( + self, + login_challenge: str, + options: APIOptions, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], + ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def auth_get( + self, + params: Any, + cookie: Optional[str], + session: Optional[SessionContainer], + should_try_refresh: bool, + options: APIOptions, + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def token_post( + self, + authorization_header: Optional[str], + body: Any, + options: APIOptions, + user_context: Dict[str, Any], + ) -> Union[TokenInfo, ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def login_info_get( + self, + login_challenge: str, + options: APIOptions, + user_context: Dict[str, Any], + ) -> Union[ + LoginInfo, + ErrorOAuth2Response, + GeneralErrorResponse, + ]: + pass + + @abstractmethod + async def user_info_get( + self, + access_token_payload: Dict[str, Any], + user: User, + scopes: List[str], + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], + ) -> Union[Dict[str, Any], GeneralErrorResponse]: + pass + + @abstractmethod + async def revoke_token_post( + self, + options: APIOptions, + token: str, + authorization_header: Optional[str], + client_id: Optional[str], + client_secret: Optional[str], + user_context: Dict[str, Any], + ) -> Union[None, ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def introspect_token_post( + self, + token: str, + scopes: Optional[List[str]], + options: APIOptions, + user_context: Dict[str, Any], + ) -> Union[ActiveTokenResponse, InactiveTokenResponse, GeneralErrorResponse]: + pass + + @abstractmethod + async def end_session_get( + self, + params: Dict[str, str], + options: APIOptions, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def end_session_post( + self, + params: Dict[str, str], + options: APIOptions, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def logout_post( + self, + logout_challenge: str, + options: APIOptions, + session: Optional[SessionContainer], + user_context: Dict[str, Any], + ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + pass diff --git a/supertokens_python/recipe/oauth2provider/oauth2_client.py b/supertokens_python/recipe/oauth2provider/oauth2_client.py new file mode 100644 index 000000000..684562878 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/oauth2_client.py @@ -0,0 +1,297 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +from typing import Dict, Any, List, Optional + + +class OAuth2Client: + # OAuth 2.0 Client ID + # The ID is immutable. If no ID is provided, a UUID4 will be generated. + client_id: str + + # OAuth 2.0 Client Name + # The human-readable name of the client to be presented to the end-user during authorization. + client_name: str + + # OAuth 2.0 Client Scope + # Scope is a string containing a space-separated list of scope values that the client + # can use when requesting access tokens. + scope: str + + # OAuth 2.0 Token Endpoint Authentication Method + # Requested Client Authentication method for the Token Endpoint. + token_endpoint_auth_method: str + + # OAuth 2.0 Client Creation Date + # CreatedAt returns the timestamp of the client's creation. + created_at: str + + # OAuth 2.0 Client Last Update Date + # UpdatedAt returns the timestamp of the last update. + updated_at: str + + # OAuth 2.0 Client Secret + client_secret: Optional[str] = None + + # Array of redirect URIs + redirect_uris: Optional[List[str]] = None + + # Array of post logout redirect URIs + post_logout_redirect_uris: Optional[List[str]] = None + + # Authorization Code Grant Access Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + authorization_code_grant_access_token_lifespan: Optional[str] = None + + # Authorization Code Grant ID Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + authorization_code_grant_id_token_lifespan: Optional[str] = None + + # Authorization Code Grant Refresh Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + authorization_code_grant_refresh_token_lifespan: Optional[str] = None + + # Client Credentials Grant Access Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + client_credentials_grant_access_token_lifespan: Optional[str] = None + + # Implicit Grant Access Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + implicit_grant_access_token_lifespan: Optional[str] = None + + # Implicit Grant ID Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + implicit_grant_id_token_lifespan: Optional[str] = None + + # Refresh Token Grant Access Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + refresh_token_grant_access_token_lifespan: Optional[str] = None + + # Refresh Token Grant ID Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + refresh_token_grant_id_token_lifespan: Optional[str] = None + + # Refresh Token Grant Refresh Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + refresh_token_grant_refresh_token_lifespan: Optional[str] = None + + # OAuth 2.0 Client URI + # ClientURI is a URL string of a web page providing information about the client. + client_uri: str = "" + + # Array of audiences + audience: List[str] = [] + + # Array of grant types + grant_types: Optional[List[str]] = None + + # Array of response types + response_types: Optional[List[str]] = None + + # OAuth 2.0 Client Logo URI + # A URL string referencing the client's logo. + logo_uri: str = "" + + # OAuth 2.0 Client Policy URI + # PolicyURI is a URL string that points to a human-readable privacy policy document + # that describes how the deployment organization collects, uses, + # retains, and discloses personal data. + policy_uri: str = "" + + # OAuth 2.0 Client Terms of Service URI + # A URL string pointing to a human-readable terms of service + # document for the client that describes a contractual relationship + # between the end-user and the client that the end-user accepts when + # authorizing the client. + tos_uri: str = "" + + # Metadata - JSON object + metadata: Dict[str, Any] = {} + + # This flag is set to true if refresh tokens are updated upon use + enable_refresh_token_rotation: bool = False + + def __init__( + self, + client_id: str, + client_name: str, + scope: str, + token_endpoint_auth_method: str, + created_at: str, + updated_at: str, + client_secret: Optional[str], + redirect_uris: Optional[List[str]], + post_logout_redirect_uris: Optional[List[str]], + authorization_code_grant_access_token_lifespan: Optional[str], + authorization_code_grant_id_token_lifespan: Optional[str], + authorization_code_grant_refresh_token_lifespan: Optional[str], + client_credentials_grant_access_token_lifespan: Optional[str], + implicit_grant_access_token_lifespan: Optional[str], + implicit_grant_id_token_lifespan: Optional[str], + refresh_token_grant_access_token_lifespan: Optional[str], + refresh_token_grant_id_token_lifespan: Optional[str], + refresh_token_grant_refresh_token_lifespan: Optional[str], + client_uri: str, + audience: List[str], + grant_types: Optional[List[str]], + response_types: Optional[List[str]], + logo_uri: str, + policy_uri: str, + tos_uri: str, + metadata: Dict[str, Any], + enable_refresh_token_rotation: bool, + ): + self.client_id = client_id + self.client_name = client_name + self.scope = scope + self.token_endpoint_auth_method = token_endpoint_auth_method + self.created_at = created_at + self.updated_at = updated_at + self.client_secret = client_secret + self.redirect_uris = redirect_uris + self.post_logout_redirect_uris = post_logout_redirect_uris + self.authorization_code_grant_access_token_lifespan = ( + authorization_code_grant_access_token_lifespan + ) + self.authorization_code_grant_id_token_lifespan = ( + authorization_code_grant_id_token_lifespan + ) + self.authorization_code_grant_refresh_token_lifespan = ( + authorization_code_grant_refresh_token_lifespan + ) + self.client_credentials_grant_access_token_lifespan = ( + client_credentials_grant_access_token_lifespan + ) + self.implicit_grant_access_token_lifespan = implicit_grant_access_token_lifespan + self.implicit_grant_id_token_lifespan = implicit_grant_id_token_lifespan + self.refresh_token_grant_access_token_lifespan = ( + refresh_token_grant_access_token_lifespan + ) + self.refresh_token_grant_id_token_lifespan = ( + refresh_token_grant_id_token_lifespan + ) + self.refresh_token_grant_refresh_token_lifespan = ( + refresh_token_grant_refresh_token_lifespan + ) + self.client_uri = client_uri + self.audience = audience + self.grant_types = grant_types + self.response_types = response_types + self.logo_uri = logo_uri + self.policy_uri = policy_uri + self.tos_uri = tos_uri + self.metadata = metadata + self.enable_refresh_token_rotation = enable_refresh_token_rotation + + @staticmethod + def from_json(json: Dict[str, Any]) -> "OAuth2Client": + # Transform keys from snake_case to camelCase + return OAuth2Client( + client_id=json["clientId"], + client_secret=json.get("clientSecret"), + client_name=json["clientName"], + scope=json["scope"], + redirect_uris=json.get("redirectUris"), + post_logout_redirect_uris=json.get("postLogoutRedirectUris"), + authorization_code_grant_access_token_lifespan=json.get( + "authorizationCodeGrantAccessTokenLifespan" + ), + authorization_code_grant_id_token_lifespan=json.get( + "authorizationCodeGrantIdTokenLifespan" + ), + authorization_code_grant_refresh_token_lifespan=json.get( + "authorizationCodeGrantRefreshTokenLifespan" + ), + client_credentials_grant_access_token_lifespan=json.get( + "clientCredentialsGrantAccessTokenLifespan" + ), + implicit_grant_access_token_lifespan=json.get( + "implicitGrantAccessTokenLifespan" + ), + implicit_grant_id_token_lifespan=json.get("implicitGrantIdTokenLifespan"), + refresh_token_grant_access_token_lifespan=json.get( + "refreshTokenGrantAccessTokenLifespan" + ), + refresh_token_grant_id_token_lifespan=json.get( + "refreshTokenGrantIdTokenLifespan" + ), + refresh_token_grant_refresh_token_lifespan=json.get( + "refreshTokenGrantRefreshTokenLifespan" + ), + token_endpoint_auth_method=json["tokenEndpointAuthMethod"], + client_uri=json.get("clientUri", ""), + audience=json.get("audience", []), + grant_types=json.get("grantTypes"), + response_types=json.get("responseTypes"), + logo_uri=json.get("logoUri", ""), + policy_uri=json.get("policyUri", ""), + tos_uri=json.get("tosUri", ""), + created_at=json["createdAt"], + updated_at=json["updatedAt"], + metadata=json.get("metadata", {}), + enable_refresh_token_rotation=json.get("enableRefreshTokenRotation", False), + ) + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = { + "clientId": self.client_id, + "clientName": self.client_name, + "scope": self.scope, + "tokenEndpointAuthMethod": self.token_endpoint_auth_method, + "createdAt": self.created_at, + "updatedAt": self.updated_at, + "clientUri": self.client_uri, + "audience": self.audience, + "logoUri": self.logo_uri, + "policyUri": self.policy_uri, + "tosUri": self.tos_uri, + "metadata": self.metadata, + "enableRefreshTokenRotation": self.enable_refresh_token_rotation, + } + + if self.client_secret is not None: + result["clientSecret"] = self.client_secret + result["redirectUris"] = self.redirect_uris + if self.post_logout_redirect_uris is not None: + result["postLogoutRedirectUris"] = self.post_logout_redirect_uris + result["authorizationCodeGrantAccessTokenLifespan"] = ( + self.authorization_code_grant_access_token_lifespan + ) + result["authorizationCodeGrantIdTokenLifespan"] = ( + self.authorization_code_grant_id_token_lifespan + ) + result["authorizationCodeGrantRefreshTokenLifespan"] = ( + self.authorization_code_grant_refresh_token_lifespan + ) + result["clientCredentialsGrantAccessTokenLifespan"] = ( + self.client_credentials_grant_access_token_lifespan + ) + result["implicitGrantAccessTokenLifespan"] = ( + self.implicit_grant_access_token_lifespan + ) + result["implicitGrantIdTokenLifespan"] = self.implicit_grant_id_token_lifespan + result["refreshTokenGrantAccessTokenLifespan"] = ( + self.refresh_token_grant_access_token_lifespan + ) + result["refreshTokenGrantIdTokenLifespan"] = ( + self.refresh_token_grant_id_token_lifespan + ) + result["refreshTokenGrantRefreshTokenLifespan"] = ( + self.refresh_token_grant_refresh_token_lifespan + ) + result["grantTypes"] = self.grant_types + result["responseTypes"] = self.response_types + + return result diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py new file mode 100644 index 000000000..aa9ec036e --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -0,0 +1,449 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from os import environ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from supertokens_python.recipe.oauth2provider.api.introspect_token import ( + introspect_token_post, +) +from supertokens_python.recipe.oauth2provider.api.login import login +from supertokens_python.recipe.oauth2provider.api.login_info import login_info_get +from supertokens_python.recipe.oauth2provider.exceptions import OAuth2ProviderError +from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.types import User + +from .interfaces import ( + APIInterface, + APIOptions, + PayloadBuilderFunction, + UserInfoBuilderFunction, + RecipeInterface, +) + + +if TYPE_CHECKING: + from supertokens_python.framework.request import BaseRequest + from supertokens_python.framework.response import BaseResponse + from supertokens_python.supertokens import AppInfo + + +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.querier import Querier +from supertokens_python.recipe.oauth2provider.api.implementation import ( + APIImplementation, +) + + +from .api import ( + auth_get, + login, + end_session_get, + end_session_post, + logout_post, + revoke_token_post, + token_post, + user_info_get, +) +from .constants import ( + LOGIN_PATH, + AUTH_PATH, + TOKEN_PATH, + LOGIN_INFO_PATH, + USER_INFO_PATH, + REVOKE_TOKEN_PATH, + INTROSPECT_TOKEN_PATH, + END_SESSION_PATH, + LOGOUT_PATH, +) +from .utils import ( + InputOverrideConfig, + OAuth2ProviderConfig, + validate_and_normalise_user_input, +) + + +class OAuth2ProviderRecipe(RecipeModule): + recipe_id = "oauth2provider" + __instance = None + + def __init__( + self, + recipe_id: str, + app_info: AppInfo, + override: Union[InputOverrideConfig, None] = None, + ) -> None: + super().__init__(recipe_id, app_info) + self.config: OAuth2ProviderConfig = validate_and_normalise_user_input( + override, + ) + + from .recipe_implementation import RecipeImplementation + + recipe_implementation: RecipeInterface = RecipeImplementation( + Querier.get_instance(recipe_id), + app_info, + self.get_default_access_token_payload, + self.get_default_id_token_payload, + self.get_default_user_info_payload, + ) + self.recipe_implementation: RecipeInterface = ( + self.config.override.functions(recipe_implementation) + if self.config.override is not None + and self.config.override.functions is not None + else recipe_implementation + ) + + api_implementation = APIImplementation() + self.api_implementation: APIInterface = ( + self.config.override.apis(api_implementation) + if self.config.override is not None + and self.config.override.apis is not None + else api_implementation + ) + + self._access_token_builders: List[PayloadBuilderFunction] = [] + self._id_token_builders: List[PayloadBuilderFunction] = [] + self._user_info_builders: List[UserInfoBuilderFunction] = [] + + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: + return isinstance(err, OAuth2ProviderError) + + def get_apis_handled(self) -> List[APIHandled]: + return [ + APIHandled( + NormalisedURLPath(LOGIN_PATH), + "get", + LOGIN_PATH, + self.api_implementation.disable_login_get, + ), + APIHandled( + NormalisedURLPath(TOKEN_PATH), + "post", + TOKEN_PATH, + self.api_implementation.disable_token_post, + ), + APIHandled( + NormalisedURLPath(AUTH_PATH), + "get", + AUTH_PATH, + self.api_implementation.disable_auth_get, + ), + APIHandled( + NormalisedURLPath(LOGIN_INFO_PATH), + "get", + LOGIN_INFO_PATH, + self.api_implementation.disable_login_info_get, + ), + APIHandled( + NormalisedURLPath(USER_INFO_PATH), + "get", + USER_INFO_PATH, + self.api_implementation.disable_user_info_get, + ), + APIHandled( + NormalisedURLPath(REVOKE_TOKEN_PATH), + "post", + REVOKE_TOKEN_PATH, + self.api_implementation.disable_revoke_token_post, + ), + APIHandled( + NormalisedURLPath(INTROSPECT_TOKEN_PATH), + "post", + INTROSPECT_TOKEN_PATH, + self.api_implementation.disable_introspect_token_post, + ), + APIHandled( + NormalisedURLPath(END_SESSION_PATH), + "get", + END_SESSION_PATH, + self.api_implementation.disable_end_session_get, + ), + APIHandled( + NormalisedURLPath(END_SESSION_PATH), + "post", + END_SESSION_PATH, + self.api_implementation.disable_end_session_post, + ), + APIHandled( + NormalisedURLPath(LOGOUT_PATH), + "post", + LOGOUT_PATH, + self.api_implementation.disable_logout_post, + ), + ] + + async def handle_api_request( + self, + request_id: str, + tenant_id: str, + request: BaseRequest, + path: NormalisedURLPath, + method: str, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> Union[BaseResponse, None]: + api_options = APIOptions( + request, + response, + self.recipe_id, + self.config, + self.recipe_implementation, + ) + if request_id == LOGIN_PATH: + return await login( + tenant_id, self.api_implementation, api_options, user_context + ) + + if request_id == TOKEN_PATH: + return await token_post( + tenant_id, self.api_implementation, api_options, user_context + ) + + if request_id == AUTH_PATH: + return await auth_get( + tenant_id, self.api_implementation, api_options, user_context + ) + + if request_id == LOGIN_INFO_PATH: + return await login_info_get( + tenant_id, self.api_implementation, api_options, user_context + ) + + if request_id == USER_INFO_PATH: + return await user_info_get( + tenant_id, self.api_implementation, api_options, user_context + ) + + if request_id == REVOKE_TOKEN_PATH: + return await revoke_token_post( + tenant_id, self.api_implementation, api_options, user_context + ) + + if request_id == INTROSPECT_TOKEN_PATH: + return await introspect_token_post( + tenant_id, self.api_implementation, api_options, user_context + ) + + if request_id == END_SESSION_PATH and method == "get": + return await end_session_get( + tenant_id, self.api_implementation, api_options, user_context + ) + + if request_id == END_SESSION_PATH and method == "post": + return await end_session_post( + tenant_id, self.api_implementation, api_options, user_context + ) + + if request_id == LOGOUT_PATH and method == "post": + return await logout_post( + tenant_id, self.api_implementation, api_options, user_context + ) + + raise Exception( + "Should never come here: handle_api_request called with unknown id" + ) + + async def handle_error( + self, + request: BaseRequest, + err: SuperTokensError, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> BaseResponse: + raise err + + def get_all_cors_headers(self) -> List[str]: + return [] + + @staticmethod + def init( + override: Union[InputOverrideConfig, None] = None, + ): + def func(app_info: AppInfo): + if OAuth2ProviderRecipe.__instance is None: + OAuth2ProviderRecipe.__instance = OAuth2ProviderRecipe( + OAuth2ProviderRecipe.recipe_id, + app_info, + override, + ) + + return OAuth2ProviderRecipe.__instance + raise_general_exception( + "OAuth2Provider recipe has already been initialised. Please check your code for bugs." + ) + + return func + + @staticmethod + def get_instance() -> OAuth2ProviderRecipe: + if OAuth2ProviderRecipe.__instance is not None: + return OAuth2ProviderRecipe.__instance + raise_general_exception( + "Initialisation not done. Did you forget to call the SuperTokens.init function?" + ) + + @staticmethod + def get_instance_optional() -> Optional[OAuth2ProviderRecipe]: + return OAuth2ProviderRecipe.__instance + + @staticmethod + def reset(): + if ("SUPERTOKENS_ENV" not in environ) or ( + environ["SUPERTOKENS_ENV"] != "testing" + ): + raise_general_exception("calling testing function in non testing env") + OAuth2ProviderRecipe.__instance = None + + def add_user_info_builder_from_other_recipe( + self, user_info_builder_fn: UserInfoBuilderFunction + ) -> None: + self._user_info_builders.append(user_info_builder_fn) + + def add_access_token_builder_from_other_recipe( + self, access_token_builder: PayloadBuilderFunction + ) -> None: + self._access_token_builders.append(access_token_builder) + + def add_id_token_builder_from_other_recipe( + self, id_token_builder: PayloadBuilderFunction + ) -> None: + self._id_token_builders.append(id_token_builder) + + async def get_default_access_token_payload( + self, + user: User, + scopes: List[str], + session_handle: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + + if "email" in scopes: + payload["email"] = user.emails[0] if user.emails else None + payload["email_verified"] = ( + any( + lm.has_same_email_as(user.emails[0]) and lm.verified + for lm in user.login_methods + ) + if user.emails + else False + ) + payload["emails"] = user.emails + + if "phoneNumber" in scopes: + if user.phone_numbers: + payload["phoneNumber"] = user.phone_numbers[0] + payload["phoneNumber_verified"] = ( + any( + lm.has_same_phone_number_as(user.phone_numbers[0]) and lm.verified + for lm in user.login_methods + ) + if user.phone_numbers + else False + ) + payload["phoneNumbers"] = user.phone_numbers + + for fn in self._access_token_builders: + builder_payload = await fn(user, scopes, session_handle, user_context) + payload.update(builder_payload) + + return payload + + async def get_default_id_token_payload( + self, + user: User, + scopes: List[str], + session_handle: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + + if "email" in scopes: + payload["email"] = user.emails[0] if user.emails else None + payload["email_verified"] = ( + any( + lm.has_same_email_as(user.emails[0]) and lm.verified + for lm in user.login_methods + ) + if user.emails + else False + ) + payload["emails"] = user.emails + + if "phoneNumber" in scopes: + if user.phone_numbers: + payload["phoneNumber"] = user.phone_numbers[0] + payload["phoneNumber_verified"] = ( + any( + lm.has_same_phone_number_as(user.phone_numbers[0]) and lm.verified + for lm in user.login_methods + ) + if user.phone_numbers + else False + ) + payload["phoneNumbers"] = user.phone_numbers + + for fn in self._id_token_builders: + builder_payload = await fn(user, scopes, session_handle, user_context) + payload.update(builder_payload) + + return payload + + async def get_default_user_info_payload( + self, + user: User, + access_token_payload: Dict[str, Any], + scopes: List[str], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + payload: Dict[str, Any] = {"sub": access_token_payload["sub"]} + + if "email" in scopes: + payload["email"] = user.emails[0] if user.emails else None + payload["email_verified"] = ( + any( + lm.has_same_email_as(user.emails[0]) and lm.verified + for lm in user.login_methods + ) + if user.emails + else False + ) + payload["emails"] = user.emails + + if "phoneNumber" in scopes: + payload["phoneNumber"] = ( + user.phone_numbers[0] if user.phone_numbers else None + ) + payload["phoneNumber_verified"] = ( + any( + lm.has_same_phone_number_as(user.phone_numbers[0]) and lm.verified + for lm in user.login_methods + ) + if user.phone_numbers + else False + ) + payload["phoneNumbers"] = user.phone_numbers + + for fn in self._user_info_builders: + builder_payload = await fn( + user, access_token_payload, scopes, tenant_id, user_context + ) + payload.update(builder_payload) + + return payload diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py new file mode 100644 index 000000000..f14fb9449 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -0,0 +1,1035 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import base64 +from typing import TYPE_CHECKING, Dict, Optional, Any, Union, List +from urllib.parse import parse_qs, urlparse +import urllib.parse + +import jwt + +from supertokens_python.asyncio import get_user +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.openid.recipe import OpenIdRecipe +from supertokens_python.recipe.session.interfaces import SessionContainer +from supertokens_python.recipe.session.jwks import get_latest_keys +from supertokens_python.recipe.session.jwt import ( + parse_jwt_without_signature_verification, +) +from supertokens_python.recipe.session.recipe import SessionRecipe +from supertokens_python.types import RecipeUserId, User + +from .interfaces import ( + CreateOAuth2ClientInput, + FrontendRedirectionURLTypeLogin, + FrontendRedirectionURLTypeLogoutConfirmation, + FrontendRedirectionURLTypePostLogoutFallback, + FrontendRedirectionURLTypeTryRefresh, + OAuth2TokenValidationRequirements, + PayloadBuilderFunction, + RecipeInterface, + RedirectResponse, + ErrorOAuth2Response, + GetOAuth2ClientOkResult, + GetOAuth2ClientsOkResult, + CreateOAuth2ClientOkResult, + RevokeTokenUsingAuthorizationHeader, + RevokeTokenUsingClientIDAndClientSecret, + UpdateOAuth2ClientInput, + UpdateOAuth2ClientOkResult, + DeleteOAuth2ClientOkResult, + ConsentRequest, + LoginRequest, + OAuth2Client, + TokenInfo, + UserInfoBuilderFunction, + ActiveTokenResponse, + InactiveTokenResponse, +) + + +if TYPE_CHECKING: + from supertokens_python.querier import Querier + from supertokens_python import AppInfo + + +def get_updated_redirect_to(app_info: AppInfo, redirect_to: str) -> str: + return redirect_to.replace( + "{apiDomain}", + app_info.api_domain.get_as_string_dangerous() + + app_info.api_base_path.get_as_string_dangerous(), + ) + + +class RecipeImplementation(RecipeInterface): + def __init__( + self, + querier: Querier, + app_info: AppInfo, + get_default_access_token_payload: PayloadBuilderFunction, + get_default_id_token_payload: PayloadBuilderFunction, + get_default_user_info_payload: UserInfoBuilderFunction, + ): + super().__init__() + self.querier = querier + self.app_info = app_info + self._get_default_access_token_payload = get_default_access_token_payload + self._get_default_id_token_payload = get_default_id_token_payload + self._get_default_user_info_payload = get_default_user_info_payload + + async def get_login_request( + self, challenge: str, user_context: Dict[str, Any] + ) -> Union[LoginRequest, ErrorOAuth2Response]: + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/oauth/auth/requests/login"), + {"loginChallenge": challenge}, + user_context=user_context, + ) + if response["status"] != "OK": + return ErrorOAuth2Response( + response["error"], + response["errorDescription"], + response["statusCode"], + ) + + return LoginRequest.from_json(response) + + async def accept_login_request( + self, + challenge: str, + acr: Optional[str], + amr: Optional[List[str]], + context: Optional[Any], + extend_session_lifespan: Optional[bool], + identity_provider_session_id: Optional[str], + subject: str, + user_context: Dict[str, Any], + ) -> RedirectResponse: + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/login/accept"), + { + "acr": acr, + "amr": amr, + "context": context, + "extendSessionLifespan": extend_session_lifespan, + "identityProviderSessionId": identity_provider_session_id, + "subject": subject, + }, + { + "loginChallenge": challenge, + }, + user_context=user_context, + ) + + return RedirectResponse( + redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"]) + ) + + async def reject_login_request( + self, + challenge: str, + error: ErrorOAuth2Response, + user_context: Dict[str, Any], + ) -> RedirectResponse: + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/login/reject"), + { + "error": error.error, + "errorDescription": error.error_description, + "statusCode": error.status_code, + }, + { + "loginChallenge": challenge, + }, + user_context=user_context, + ) + return RedirectResponse( + redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"]) + ) + + async def get_consent_request( + self, challenge: str, user_context: Dict[str, Any] + ) -> ConsentRequest: + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/oauth/auth/requests/consent"), + {"consentChallenge": challenge}, + user_context=user_context, + ) + + return ConsentRequest.from_json(response) + + async def accept_consent_request( + self, + challenge: str, + context: Optional[Any], + grant_access_token_audience: Optional[List[str]], + grant_scope: Optional[List[str]], + handled_at: Optional[str], + tenant_id: str, + rsub: str, + session_handle: str, + initial_access_token_payload: Optional[Dict[str, Any]], + initial_id_token_payload: Optional[Dict[str, Any]], + user_context: Dict[str, Any], + ) -> RedirectResponse: + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/consent/accept"), + { + "context": context, + "grantAccessTokenAudience": grant_access_token_audience, + "grantScope": grant_scope, + "handledAt": handled_at, + "iss": await OpenIdRecipe.get_issuer(user_context), + "tId": tenant_id, + "rsub": rsub, + "sessionHandle": session_handle, + "initialAccessTokenPayload": initial_access_token_payload, + "initialIdTokenPayload": initial_id_token_payload, + }, + { + "consentChallenge": challenge, + }, + user_context=user_context, + ) + + return RedirectResponse( + redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"]) + ) + + async def reject_consent_request( + self, challenge: str, error: ErrorOAuth2Response, user_context: Dict[str, Any] + ) -> RedirectResponse: + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/consent/reject"), + { + "error": error.error, + "errorDescription": error.error_description, + "statusCode": error.status_code, + }, + { + "consentChallenge": challenge, + }, + user_context=user_context, + ) + + return RedirectResponse( + redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"]) + ) + + async def authorization( + self, + params: Dict[str, str], + cookies: Optional[str], + session: Optional[SessionContainer], + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + # we handle this in the backend SDK level + if params.get("prompt") == "none": + params["st_prompt"] = "none" + del params["prompt"] + + payloads = None + + if params.get("client_id") is None or not isinstance( + params.get("client_id"), str + ): + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="client_id is required and must be a string", + ) + + scopes = await self.get_requested_scopes( + scope_param=params.get("scope", "").split() if params.get("scope") else [], + client_id=params["client_id"], + recipe_user_id=( + session.get_recipe_user_id() if session is not None else None + ), + session_handle=session.get_handle() if session else None, + user_context=user_context, + ) + + response_types = ( + params.get("response_type", "").split() + if params.get("response_type") + else [] + ) + + if session is not None: + client_info = await self.get_oauth2_client( + client_id=params["client_id"], user_context=user_context + ) + + if isinstance(client_info, ErrorOAuth2Response): + return ErrorOAuth2Response( + status_code=400, + error=client_info.error, + error_description=client_info.error_description, + ) + + client = client_info.client + + user = await get_user(session.get_user_id()) + if not user: + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="User deleted", + ) + + # These default to empty dicts, because we want to keep them as required input + # but they'll not be actually used in flows where we are not building them + id_token = {} + if "openid" in scopes and ( + "id_token" in response_types or "code" in response_types + ): + id_token = await self.build_id_token_payload( + user=user, + client=client, + session_handle=session.get_handle(), + scopes=scopes, + user_context=user_context, + ) + + access_token = {} + if "token" in response_types or "code" in response_types: + access_token = await self.build_access_token_payload( + user=user, + client=client, + session_handle=session.get_handle(), + scopes=scopes, + user_context=user_context, + ) + + payloads = {"idToken": id_token, "accessToken": access_token} + + request_body = { + "params": {**params, "scope": " ".join(scopes)}, + "iss": await OpenIdRecipe.get_issuer(user_context), + "session": payloads, + } + if cookies is not None: + request_body["cookies"] = cookies + resp = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/auth"), + request_body, + user_context, + ) + + if resp["status"] == "CLIENT_NOT_FOUND_ERROR": + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="The provided client_id is not valid", + ) + + if resp["status"] != "OK": + return ErrorOAuth2Response( + status_code=resp["statusCode"], + error=resp["error"], + error_description=resp["errorDescription"], + ) + + if resp.get("redirectTo") is None: + raise Exception(resp) + redirect_to = get_updated_redirect_to(self.app_info, resp["redirectTo"]) + + redirect_to_query_params_str = urlparse(redirect_to).query + redirect_to_query_params: Dict[str, List[str]] = parse_qs( + redirect_to_query_params_str + ) + consent_challenge: Optional[str] = None + + if "consent_challenge" in redirect_to_query_params: + if len(redirect_to_query_params["consent_challenge"]) > 0: + consent_challenge = redirect_to_query_params["consent_challenge"][0] + + if consent_challenge is not None and session is not None: + consent_request = await self.get_consent_request( + challenge=consent_challenge, user_context=user_context + ) + + consent_res = await self.accept_consent_request( + challenge=consent_request.challenge, + context=None, + grant_access_token_audience=consent_request.requested_access_token_audience, + grant_scope=consent_request.requested_scope, + handled_at=None, + tenant_id=session.get_tenant_id(), + rsub=session.get_recipe_user_id().get_as_string(), + session_handle=session.get_handle(), + initial_access_token_payload=( + payloads.get("accessToken") if payloads else None + ), + initial_id_token_payload=payloads.get("idToken") if payloads else None, + user_context=user_context, + ) + + return RedirectResponse( + redirect_to=consent_res.redirect_to, cookies=resp["cookies"] + ) + + return RedirectResponse(redirect_to=redirect_to, cookies=resp["cookies"]) + + async def token_exchange( + self, + authorization_header: Optional[str], + body: Dict[str, Optional[str]], + user_context: Dict[str, Any], + ) -> Union[TokenInfo, ErrorOAuth2Response]: + request_body = { + "iss": await OpenIdRecipe.get_issuer(user_context), + "inputBody": body, + } + + if body.get("grant_type") == "password": + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="Unsupported grant type: password", + ) + + if body.get("grant_type") == "client_credentials": + client_id = None + if authorization_header: + # Extract client_id from Basic auth header + decoded = base64.b64decode( + authorization_header.replace("Basic ", "").strip() + ).decode() + client_id = decoded.split(":")[0] + else: + client_id = body.get("client_id") + + if not client_id: + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="client_id is required", + ) + + scopes = str(body.get("scope", "")).split() if body.get("scope") else [] + + client_info = await self.get_oauth2_client( + client_id=client_id, user_context=user_context + ) + + if isinstance(client_info, ErrorOAuth2Response): + return ErrorOAuth2Response( + status_code=400, + error=client_info.error, + error_description=client_info.error_description, + ) + + client = client_info.client + request_body["id_token"] = await self.build_id_token_payload( + user=None, + client=client, + session_handle=None, + scopes=scopes, + user_context=user_context, + ) + request_body["access_token"] = await self.build_access_token_payload( + user=None, + client=client, + session_handle=None, + scopes=scopes, + user_context=user_context, + ) + + if body.get("grant_type") == "refresh_token": + scopes = str(body.get("scope", "")).split() if body.get("scope") else [] + token_info = await self.introspect_token( + token=str(body["refresh_token"]), + scopes=scopes, + user_context=user_context, + ) + + if isinstance(token_info, ActiveTokenResponse): + session_handle = token_info.payload["sessionHandle"] + + client_info = await self.get_oauth2_client( + client_id=token_info.payload["client_id"], user_context=user_context + ) + + if isinstance(client_info, ErrorOAuth2Response): + return ErrorOAuth2Response( + status_code=400, + error=client_info.error, + error_description=client_info.error_description, + ) + + client = client_info.client + user = await get_user(token_info.payload["sub"]) + + if not user: + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="User not found", + ) + + request_body["id_token"] = await self.build_id_token_payload( + user=user, + client=client, + session_handle=session_handle, + scopes=scopes, + user_context=user_context, + ) + request_body["access_token"] = await self.build_access_token_payload( + user=user, + client=client, + session_handle=session_handle, + scopes=scopes, + user_context=user_context, + ) + + if authorization_header: + request_body["authorizationHeader"] = authorization_header + + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/token"), + request_body, + user_context=user_context, + ) + + if response["status"] == "CLIENT_NOT_FOUND_ERROR": + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="client_id not found", + ) + + if response["status"] != "OK": + return ErrorOAuth2Response( + status_code=response["statusCode"], + error=response["error"], + error_description=response["errorDescription"], + ) + + return TokenInfo.from_json(response) + + async def get_oauth2_clients( + self, + page_size: Optional[int], + pagination_token: Optional[str], + client_name: Optional[str], + user_context: Dict[str, Any], + ) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: + body: Dict[str, Any] = {} + if page_size is not None: + body["pageSize"] = page_size + if pagination_token is not None: + body["pageToken"] = pagination_token + if client_name is not None: + body["clientName"] = client_name + + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/oauth/clients/list"), + body, + user_context=user_context, + ) + + if response["status"] == "OK": + return GetOAuth2ClientsOkResult( + clients=[ + OAuth2Client.from_json(client) for client in response["clients"] + ], + next_pagination_token=response.get("nextPaginationToken"), + ) + + return ErrorOAuth2Response( + error=response["error"], + error_description=response["errorDescription"], + status_code=response["statusCode"], + ) + + async def get_oauth2_client( + self, client_id: str, user_context: Dict[str, Any] + ) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/oauth/clients"), + {"clientId": client_id}, + user_context=user_context, + ) + + if response["status"] == "OK": + return GetOAuth2ClientOkResult(client=OAuth2Client.from_json(response)) + elif response["status"] == "CLIENT_NOT_FOUND_ERROR": + return ErrorOAuth2Response( + error="invalid_request", + error_description="The provided client_id is not valid or unknown", + ) + else: + return ErrorOAuth2Response( + error=response["error"], error_description=response["errorDescription"] + ) + + async def create_oauth2_client( + self, + params: CreateOAuth2ClientInput, + user_context: Dict[str, Any], + ) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/clients"), + params.to_json(), + user_context=user_context, + ) + + if response["status"] == "OK": + return CreateOAuth2ClientOkResult(client=OAuth2Client.from_json(response)) + return ErrorOAuth2Response( + error=response["error"], error_description=response["errorDescription"] + ) + + async def update_oauth2_client( + self, + params: UpdateOAuth2ClientInput, + user_context: Dict[str, Any], + ) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/clients"), + params.to_json(), + None, + user_context=user_context, + ) + + if response["status"] == "OK": + return UpdateOAuth2ClientOkResult(client=OAuth2Client.from_json(response)) + return ErrorOAuth2Response( + error=response["error"], error_description=response["errorDescription"] + ) + + async def delete_oauth2_client( + self, + client_id: str, + user_context: Dict[str, Any], + ) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/clients/remove"), + {"clientId": client_id}, + user_context=user_context, + ) + + if response["status"] == "OK": + return DeleteOAuth2ClientOkResult() + return ErrorOAuth2Response( + error=response["error"], error_description=response["errorDescription"] + ) + + async def validate_oauth2_access_token( + self, + token: str, + requirements: Optional[OAuth2TokenValidationRequirements], + check_database: Optional[bool], + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + access_token_obj = parse_jwt_without_signature_verification(token) + + # Verify token signature using session recipe's JWKS + session_recipe = SessionRecipe.get_instance() + matching_keys = get_latest_keys(session_recipe.config, access_token_obj.kid) + err: Optional[Exception] = None + + payload: Dict[str, Any] = {} + + for matching_key in matching_keys: + err = None + try: + payload = jwt.decode( + token, + matching_key.key, + algorithms=["RS256"], + options={ + "verify_signature": True, + "verify_exp": True, + "verify_aud": False, + }, + ) + except Exception as e: + err = e + continue + break + + if err is not None: + raise err + + if payload.get("stt") != 1: + raise Exception("Wrong token type") + + if requirements is not None and requirements.client_id is not None: + if payload.get("client_id") != requirements.client_id: + raise Exception( + f"The token doesn't belong to the specified client ({requirements.client_id} !== {payload.get('client_id')})" + ) + + if requirements is not None and requirements.scopes is not None: + token_scopes = payload.get("scp", []) + if not isinstance(token_scopes, list): + token_scopes = [token_scopes] + + if any(scope not in token_scopes for scope in requirements.scopes): + raise Exception("The token is missing some required scopes") + + aud = payload.get("aud", []) + if not isinstance(aud, list): + aud = [aud] + + if requirements is not None and requirements.audience is not None: + if requirements.audience not in aud: + raise Exception("The token doesn't belong to the specified audience") + + if check_database: + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/introspect"), + { + "token": token, + }, + user_context=user_context, + ) + + if response.get("active") is not True: + raise Exception("The token is expired, invalid or has been revoked") + + return payload + + async def get_requested_scopes( + self, + recipe_user_id: Optional[RecipeUserId], + session_handle: Optional[str], + scope_param: List[str], + client_id: str, + user_context: Dict[str, Any], + ) -> List[str]: + _ = recipe_user_id + _ = session_handle + _ = client_id + _ = user_context + + return scope_param + + async def build_access_token_payload( + self, + user: Optional[User], + client: OAuth2Client, + session_handle: Optional[str], + scopes: List[str], + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + if user is None or session_handle is None: + return {} + + _ = client + + return await self._get_default_access_token_payload( + user, scopes, session_handle, user_context + ) + + async def build_id_token_payload( + self, + user: Optional[User], + client: OAuth2Client, + session_handle: Optional[str], + scopes: List[str], + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + if user is None or session_handle is None: + return {} + + _ = client + + return await self._get_default_id_token_payload( + user, scopes, session_handle, user_context + ) + + async def build_user_info( + self, + user: User, + access_token_payload: Dict[str, Any], + scopes: List[str], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + return await self._get_default_user_info_payload( + user, access_token_payload, scopes, tenant_id, user_context + ) + + async def get_frontend_redirection_url( + self, + params: Union[ + FrontendRedirectionURLTypeLogin, + FrontendRedirectionURLTypeTryRefresh, + FrontendRedirectionURLTypeLogoutConfirmation, + FrontendRedirectionURLTypePostLogoutFallback, + ], + user_context: Dict[str, Any], + ) -> str: + website_domain = self.app_info.get_origin( + None, user_context + ).get_as_string_dangerous() + website_base_path = self.app_info.api_base_path.get_as_string_dangerous() + + if isinstance(params, FrontendRedirectionURLTypeLogin): + query_params: Dict[str, str] = {"loginChallenge": params.login_challenge} + if params.tenant_id != "public": # DEFAULT_TENANT_ID is "public" + query_params["tenantId"] = params.tenant_id + if params.hint is not None: + query_params["hint"] = params.hint + if params.force_fresh_auth: + query_params["forceFreshAuth"] = "true" + + query_string = "&".join( + f"{k}={urllib.parse.quote(str(v))}" for k, v in query_params.items() + ) + return f"{website_domain}{website_base_path}?{query_string}" + + elif isinstance(params, FrontendRedirectionURLTypeTryRefresh): + return f"{website_domain}{website_base_path}/try-refresh?loginChallenge={params.login_challenge}" + + elif isinstance(params, FrontendRedirectionURLTypePostLogoutFallback): + return f"{website_domain}{website_base_path}" + + else: # isinstance(params, FrontendRedirectionURLTypeLogoutConfirmation) + return f"{website_domain}{website_base_path}/oauth/logout?logoutChallenge={params.logout_challenge}" + + async def revoke_token( + self, + params: Union[ + RevokeTokenUsingAuthorizationHeader, + RevokeTokenUsingClientIDAndClientSecret, + ], + user_context: Dict[str, Any], + ) -> Optional[ErrorOAuth2Response]: + request_body = {"token": params.token} + + if isinstance(params, RevokeTokenUsingAuthorizationHeader): + request_body["authorizationHeader"] = params.authorization_header + else: + request_body["client_id"] = params.client_id + if params.client_secret is not None: + request_body["client_secret"] = params.client_secret + + res = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/token/revoke"), + request_body, + user_context=user_context, + ) + + if res.get("status") != "OK": + return ErrorOAuth2Response( + status_code=res.get("statusCode"), + error=str(res.get("error")), + error_description=str(res.get("errorDescription")), + ) + + return None + + async def revoke_tokens_by_client_id( + self, + client_id: str, + user_context: Dict[str, Any], + ): + await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/session/revoke"), + {"client_id": client_id}, + user_context=user_context, + ) + + async def revoke_tokens_by_session_handle( + self, + session_handle: str, + user_context: Dict[str, Any], + ): + await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/session/revoke"), + {"sessionHandle": session_handle}, + user_context=user_context, + ) + + async def introspect_token( + self, + token: str, + scopes: Optional[List[str]], + user_context: Dict[str, Any], + ) -> Union[ActiveTokenResponse, InactiveTokenResponse]: + # Determine if the token is an access token by checking if it doesn't start with "st_rt" + is_access_token = not token.startswith("st_rt") + + # Attempt to validate the access token locally + # If it fails, the token is not active, and we return early + if is_access_token: + try: + await self.validate_oauth2_access_token( + token=token, + requirements=( + OAuth2TokenValidationRequirements(scopes=scopes) + if scopes + else None + ), + check_database=False, + user_context=user_context, + ) + + except Exception: + return InactiveTokenResponse() + + # For tokens that passed local validation or if it's a refresh token, + # validate the token with the database by calling the core introspection endpoint + request_body = {"token": token} + if scopes: + request_body["scope"] = " ".join(scopes) + + res = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/introspect"), + request_body, + user_context=user_context, + ) + + if res.get("active"): + return ActiveTokenResponse(payload=res) + else: + return InactiveTokenResponse() + + async def end_session( + self, + params: Dict[str, str], + should_try_refresh: bool, + session: Optional[SessionContainer], + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + # NOTE: The API response has 3 possible cases: + # + # CASE 1: end_session request with a valid id_token_hint + # - Redirects to /oauth/logout with a logout_challenge. + # + # CASE 2: end_session request with an already logged out id_token_hint + # - Redirects to the post_logout_redirect_uri or the default logout fallback page. + # + # CASE 3: end_session request with a logout_verifier (after accepting the logout request) + # - Redirects to the post_logout_redirect_uri or the default logout fallback page. + + request_body: Dict[str, Any] = {} + + if params.get("client_id") is not None: + request_body["clientId"] = params.get("client_id") + if params.get("id_token_hint") is not None: + request_body["idTokenHint"] = params.get("id_token_hint") + if params.get("post_logout_redirect_uri") is not None: + request_body["postLogoutRedirectUri"] = params.get( + "post_logout_redirect_uri" + ) + if params.get("state") is not None: + request_body["state"] = params.get("state") + if params.get("logout_verifier") is not None: + request_body["logoutVerifier"] = params.get("logout_verifier") + + resp = await self.querier.send_get_request( + NormalisedURLPath("/recipe/oauth/sessions/logout"), + request_body, + user_context=user_context, + ) + + if "error" in resp: + return ErrorOAuth2Response( + status_code=resp["statusCode"], + error=resp["error"], + error_description=resp["errorDescription"], + ) + + redirect_to = get_updated_redirect_to(self.app_info, resp["redirectTo"]) + + initial_redirect_url = urlparse(redirect_to) + query_params = parse_qs(initial_redirect_url.query) + logout_challenge = query_params.get("logout_challenge", [None])[0] + + # CASE 1 (See above notes) + if logout_challenge is not None: + # Redirect to the frontend to ask for logout confirmation if there is a valid or expired supertokens session + if session is not None or should_try_refresh: + return RedirectResponse( + redirect_to=await self.get_frontend_redirection_url( + FrontendRedirectionURLTypeLogoutConfirmation( + logout_challenge=logout_challenge + ), + user_context=user_context, + ) + ) + else: + # Accept the logout challenge immediately as there is no supertokens session + accept_logout_response = await self.accept_logout_request( + challenge=logout_challenge, user_context=user_context + ) + if isinstance(accept_logout_response, ErrorOAuth2Response): + return accept_logout_response + return RedirectResponse(redirect_to=accept_logout_response.redirect_to) + + # CASE 2 or 3 (See above notes) + + # NOTE: If no post_logout_redirect_uri is provided, Hydra redirects to a fallback page. + # In this case, we redirect the user to the /auth page. + if redirect_to.endswith("/fallbacks/logout/callback"): + return RedirectResponse( + redirect_to=await self.get_frontend_redirection_url( + FrontendRedirectionURLTypePostLogoutFallback(), + user_context=user_context, + ) + ) + + return RedirectResponse(redirect_to=redirect_to) + + async def accept_logout_request( + self, + challenge: str, + user_context: Dict[str, Any], + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + resp = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/logout/accept"), + {"challenge": challenge}, + None, + user_context=user_context, + ) + + if resp["status"] != "OK": + return ErrorOAuth2Response( + status_code=resp["statusCode"], + error=resp["error"], + error_description=resp["errorDescription"], + ) + + redirect_to = get_updated_redirect_to(self.app_info, resp["redirectTo"]) + + if redirect_to.endswith("/fallbacks/logout/callback"): + return RedirectResponse( + redirect_to=await self.get_frontend_redirection_url( + FrontendRedirectionURLTypePostLogoutFallback(), + user_context=user_context, + ) + ) + + return RedirectResponse(redirect_to=redirect_to) + + async def reject_logout_request( + self, + challenge: str, + user_context: Dict[str, Any], + ): + resp = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/logout/reject"), + {}, + {"challenge": challenge}, + user_context=user_context, + ) + + if resp["status"] != "OK": + raise Exception(resp["error"]) diff --git a/supertokens_python/recipe/oauth2provider/syncio/__init__.py b/supertokens_python/recipe/oauth2provider/syncio/__init__.py new file mode 100644 index 000000000..671e614f9 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/syncio/__init__.py @@ -0,0 +1,176 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, Dict, Union, Optional, List + +from supertokens_python.async_to_sync_wrapper import sync + +from ..interfaces import ( + ActiveTokenResponse, + CreateOAuth2ClientInput, + CreateOAuth2ClientOkResult, + DeleteOAuth2ClientOkResult, + ErrorOAuth2Response, + GetOAuth2ClientOkResult, + GetOAuth2ClientsOkResult, + InactiveTokenResponse, + OAuth2TokenValidationRequirements, + TokenInfo, + UpdateOAuth2ClientInput, + UpdateOAuth2ClientOkResult, +) + + +def get_oauth2_client( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + + from ..asyncio import get_oauth2_client + + return sync(get_oauth2_client(client_id, user_context)) + + +def get_oauth2_clients( + page_size: Optional[int] = None, + pagination_token: Optional[str] = None, + client_name: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + + from ..asyncio import get_oauth2_clients + + return sync( + get_oauth2_clients(page_size, pagination_token, client_name, user_context) + ) + + +def create_oauth2_client( + params: CreateOAuth2ClientInput, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..asyncio import create_oauth2_client + + return sync(create_oauth2_client(params, user_context)) + + +def update_oauth2_client( + params: UpdateOAuth2ClientInput, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + + from ..asyncio import update_oauth2_client + + return sync(update_oauth2_client(params, user_context)) + + +def delete_oauth2_client( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..asyncio import delete_oauth2_client + + return sync(delete_oauth2_client(client_id, user_context)) + + +def validate_oauth2_access_token( + token: str, + requirements: Optional[OAuth2TokenValidationRequirements] = None, + check_database: Optional[bool] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + if user_context is None: + user_context = {} + + from ..asyncio import validate_oauth2_access_token + + return sync( + validate_oauth2_access_token(token, requirements, check_database, user_context) + ) + + +def create_token_for_client_credentials( + client_id: str, + client_secret: str, + scope: Optional[List[str]] = None, + audience: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[TokenInfo, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + + from ..asyncio import create_token_for_client_credentials + + return sync( + create_token_for_client_credentials( + client_id, client_secret, scope, audience, user_context + ) + ) + + +def revoke_token( + token: str, + client_id: str, + client_secret: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Optional[ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..asyncio import revoke_token + + return sync(revoke_token(token, client_id, client_secret, user_context)) + + +def revoke_tokens_by_client_id( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> None: + if user_context is None: + user_context = {} + + from ..asyncio import revoke_tokens_by_client_id + + return sync(revoke_tokens_by_client_id(client_id, user_context)) + + +def revoke_tokens_by_session_handle( + session_handle: str, user_context: Optional[Dict[str, Any]] = None +) -> None: + if user_context is None: + user_context = {} + + from ..asyncio import revoke_tokens_by_session_handle + + return sync(revoke_tokens_by_session_handle(session_handle, user_context)) + + +def validate_oauth2_refresh_token( + token: str, + scopes: Optional[List[str]] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ActiveTokenResponse, InactiveTokenResponse]: + if user_context is None: + user_context = {} + + from ..asyncio import validate_oauth2_refresh_token + + return sync(validate_oauth2_refresh_token(token, scopes, user_context)) diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py new file mode 100644 index 000000000..c7f49c623 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/utils.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +from typing import Callable, TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Union + from .interfaces import APIInterface, RecipeInterface + + +class InputOverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + apis: Union[Callable[[APIInterface], APIInterface], None] = None, + ): + self.functions = functions + self.apis = apis + + +class OverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + apis: Union[Callable[[APIInterface], APIInterface], None] = None, + ): + self.functions = functions + self.apis = apis + + +class OAuth2ProviderConfig: + def __init__(self, override: Union[OverrideConfig, None] = None): + self.override = override + + +def validate_and_normalise_user_input( + override: Union[InputOverrideConfig, None] = None +): + if override is None: + return OAuth2ProviderConfig(OverrideConfig()) + return OAuth2ProviderConfig(OverrideConfig(override.functions, override.apis)) diff --git a/supertokens_python/recipe/openid/api/implementation.py b/supertokens_python/recipe/openid/api/implementation.py index 6d6440c5a..6214f7639 100644 --- a/supertokens_python/recipe/openid/api/implementation.py +++ b/supertokens_python/recipe/openid/api/implementation.py @@ -30,5 +30,15 @@ async def open_id_discovery_configuration_get( ) ) return OpenIdDiscoveryConfigurationGetResponse( - response.issuer, response.jwks_uri + issuer=response.issuer, + jwks_uri=response.jwks_uri, + authorization_endpoint=response.authorization_endpoint, + token_endpoint=response.token_endpoint, + userinfo_endpoint=response.userinfo_endpoint, + revocation_endpoint=response.revocation_endpoint, + token_introspection_endpoint=response.token_introspection_endpoint, + end_session_endpoint=response.end_session_endpoint, + subject_types_supported=response.subject_types_supported, + id_token_signing_alg_values_supported=response.id_token_signing_alg_values_supported, + response_types_supported=response.response_types_supported, ) diff --git a/supertokens_python/recipe/openid/interfaces.py b/supertokens_python/recipe/openid/interfaces.py index 738242d6d..963360616 100644 --- a/supertokens_python/recipe/openid/interfaces.py +++ b/supertokens_python/recipe/openid/interfaces.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Union, Optional +from typing import Any, Dict, List, Union, Optional from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.jwt.interfaces import ( @@ -26,9 +26,48 @@ class GetOpenIdDiscoveryConfigurationResult: - def __init__(self, issuer: str, jwks_uri: str): + def __init__( + self, + issuer: str, + jwks_uri: str, + authorization_endpoint: str, + token_endpoint: str, + userinfo_endpoint: str, + revocation_endpoint: str, + token_introspection_endpoint: str, + end_session_endpoint: str, + subject_types_supported: List[str], + id_token_signing_alg_values_supported: List[str], + response_types_supported: List[str], + ): self.issuer = issuer self.jwks_uri = jwks_uri + self.authorization_endpoint = authorization_endpoint + self.token_endpoint = token_endpoint + self.userinfo_endpoint = userinfo_endpoint + self.revocation_endpoint = revocation_endpoint + self.token_introspection_endpoint = token_introspection_endpoint + self.end_session_endpoint = end_session_endpoint + self.subject_types_supported = subject_types_supported + self.id_token_signing_alg_values_supported = ( + id_token_signing_alg_values_supported + ) + self.response_types_supported = response_types_supported + + def to_json(self) -> Dict[str, Any]: + return { + "issuer": self.issuer, + "jwks_uri": self.jwks_uri, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "userinfo_endpoint": self.userinfo_endpoint, + "revocation_endpoint": self.revocation_endpoint, + "token_introspection_endpoint": self.token_introspection_endpoint, + "end_session_endpoint": self.end_session_endpoint, + "subject_types_supported": self.subject_types_supported, + "id_token_signing_alg_values_supported": self.id_token_signing_alg_values_supported, + "response_types_supported": self.response_types_supported, + } class RecipeInterface(ABC): @@ -75,12 +114,49 @@ def __init__( class OpenIdDiscoveryConfigurationGetResponse(APIResponse): status: str = "OK" - def __init__(self, issuer: str, jwks_uri: str): + def __init__( + self, + issuer: str, + jwks_uri: str, + authorization_endpoint: str, + token_endpoint: str, + userinfo_endpoint: str, + revocation_endpoint: str, + token_introspection_endpoint: str, + end_session_endpoint: str, + subject_types_supported: List[str], + id_token_signing_alg_values_supported: List[str], + response_types_supported: List[str], + ): self.issuer = issuer self.jwks_uri = jwks_uri + self.authorization_endpoint = authorization_endpoint + self.token_endpoint = token_endpoint + self.userinfo_endpoint = userinfo_endpoint + self.revocation_endpoint = revocation_endpoint + self.token_introspection_endpoint = token_introspection_endpoint + self.end_session_endpoint = end_session_endpoint + self.subject_types_supported = subject_types_supported + self.id_token_signing_alg_values_supported = ( + id_token_signing_alg_values_supported + ) + self.response_types_supported = response_types_supported def to_json(self): - return {"status": self.status, "issuer": self.issuer, "jwks_uri": self.jwks_uri} + return { + "status": self.status, + "issuer": self.issuer, + "jwks_uri": self.jwks_uri, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "userinfo_endpoint": self.userinfo_endpoint, + "revocation_endpoint": self.revocation_endpoint, + "token_introspection_endpoint": self.token_introspection_endpoint, + "end_session_endpoint": self.end_session_endpoint, + "subject_types_supported": self.subject_types_supported, + "id_token_signing_alg_values_supported": self.id_token_signing_alg_values_supported, + "response_types_supported": self.response_types_supported, + } class APIInterface: diff --git a/supertokens_python/recipe/openid/recipe.py b/supertokens_python/recipe/openid/recipe.py index 0d3799ae4..cda0a300d 100644 --- a/supertokens_python/recipe/openid/recipe.py +++ b/supertokens_python/recipe/openid/recipe.py @@ -170,3 +170,10 @@ def reset(): ): raise_general_exception("calling testing function in non testing env") OpenIdRecipe.__instance = None + + @staticmethod + async def get_issuer(user_context: Dict[str, Any]) -> str: + open_id_config = await OpenIdRecipe.get_instance().recipe_implementation.get_open_id_discovery_configuration( + user_context + ) + return open_id_config.issuer diff --git a/supertokens_python/recipe/openid/recipe_implementation.py b/supertokens_python/recipe/openid/recipe_implementation.py index 32bda3926..f117c5007 100644 --- a/supertokens_python/recipe/openid/recipe_implementation.py +++ b/supertokens_python/recipe/openid/recipe_implementation.py @@ -39,19 +39,45 @@ class RecipeImplementation(RecipeInterface): async def get_open_id_discovery_configuration( self, user_context: Dict[str, Any] ) -> GetOpenIdDiscoveryConfigurationResult: + from ..oauth2provider.constants import ( + AUTH_PATH, + TOKEN_PATH, + USER_INFO_PATH, + REVOKE_TOKEN_PATH, + INTROSPECT_TOKEN_PATH, + END_SESSION_PATH, + ) + issuer = ( - self.config.issuer_domain.get_as_string_dangerous() - + self.config.issuer_path.get_as_string_dangerous() + self.app_info.api_domain.get_as_string_dangerous() + + self.app_info.api_base_path.get_as_string_dangerous() ) jwks_uri = ( - self.config.issuer_domain.get_as_string_dangerous() - + self.config.issuer_path.append( + self.app_info.api_domain.get_as_string_dangerous() + + self.app_info.api_base_path.append( NormalisedURLPath(GET_JWKS_API) ).get_as_string_dangerous() ) - return GetOpenIdDiscoveryConfigurationResult(issuer, jwks_uri) + api_base_path: str = ( + self.app_info.api_domain.get_as_string_dangerous() + + self.app_info.api_base_path.get_as_string_dangerous() + ) + + return GetOpenIdDiscoveryConfigurationResult( + issuer=issuer, + jwks_uri=jwks_uri, + authorization_endpoint=api_base_path + AUTH_PATH, + token_endpoint=api_base_path + TOKEN_PATH, + userinfo_endpoint=api_base_path + USER_INFO_PATH, + revocation_endpoint=api_base_path + REVOKE_TOKEN_PATH, + token_introspection_endpoint=api_base_path + INTROSPECT_TOKEN_PATH, + end_session_endpoint=api_base_path + END_SESSION_PATH, + subject_types_supported=["public"], + id_token_signing_alg_values_supported=["RS256"], + response_types_supported=["code", "id_token", "id_token token"], + ) def __init__( self, diff --git a/supertokens_python/recipe/passwordless/recipe_implementation.py b/supertokens_python/recipe/passwordless/recipe_implementation.py index 48904e095..070873122 100644 --- a/supertokens_python/recipe/passwordless/recipe_implementation.py +++ b/supertokens_python/recipe/passwordless/recipe_implementation.py @@ -447,6 +447,7 @@ async def delete_email_for_user( result = await self.querier.send_put_request( NormalisedURLPath("/recipe/user"), data, + None, user_context=user_context, ) if result["status"] == "OK": @@ -460,6 +461,7 @@ async def delete_phone_number_for_user( result = await self.querier.send_put_request( NormalisedURLPath("/recipe/user"), data, + None, user_context=user_context, ) if result["status"] == "OK": @@ -524,6 +526,7 @@ async def update_user( response = await self.querier.send_put_request( NormalisedURLPath("/recipe/user"), input_dict, + None, user_context=user_context, ) if response["status"] == "UNKNOWN_USER_ID_ERROR": diff --git a/supertokens_python/recipe/session/session_functions.py b/supertokens_python/recipe/session/session_functions.py index bef9a6c49..45a936442 100644 --- a/supertokens_python/recipe/session/session_functions.py +++ b/supertokens_python/recipe/session/session_functions.py @@ -518,6 +518,7 @@ async def update_session_data_in_database( response = await recipe_implementation.querier.send_put_request( NormalisedURLPath("/recipe/session/data"), {"sessionHandle": session_handle, "userDataInDatabase": new_session_data}, + None, user_context=user_context, ) if response["status"] == "UNAUTHORISED": @@ -535,6 +536,7 @@ async def update_access_token_payload( response = await recipe_implementation.querier.send_put_request( NormalisedURLPath("/recipe/jwt/data"), {"sessionHandle": session_handle, "userDataInJWT": new_access_token_payload}, + None, user_context=user_context, ) if response["status"] == "UNAUTHORISED": diff --git a/supertokens_python/recipe/totp/recipe_implementation.py b/supertokens_python/recipe/totp/recipe_implementation.py index d19a03eee..a010c0bd5 100644 --- a/supertokens_python/recipe/totp/recipe_implementation.py +++ b/supertokens_python/recipe/totp/recipe_implementation.py @@ -157,6 +157,7 @@ async def update_device( resp = await self.querier.send_put_request( NormalisedURLPath("/recipe/totp/device"), data, + None, user_context=user_context, ) diff --git a/supertokens_python/recipe/usermetadata/recipe_implementation.py b/supertokens_python/recipe/usermetadata/recipe_implementation.py index 7485ddbe5..e65134301 100644 --- a/supertokens_python/recipe/usermetadata/recipe_implementation.py +++ b/supertokens_python/recipe/usermetadata/recipe_implementation.py @@ -47,6 +47,7 @@ async def update_user_metadata( response = await self.querier.send_put_request( NormalisedURLPath("/recipe/user/metadata"), params, + None, user_context=user_context, ) return MetadataResult(metadata=response["metadata"]) diff --git a/supertokens_python/recipe/userroles/recipe.py b/supertokens_python/recipe/userroles/recipe.py index 844b41895..7a2d830d5 100644 --- a/supertokens_python/recipe/userroles/recipe.py +++ b/supertokens_python/recipe/userroles/recipe.py @@ -21,19 +21,20 @@ from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier +from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.recipe.userroles.recipe_implementation import ( RecipeImplementation, ) from supertokens_python.recipe.userroles.utils import validate_and_normalise_user_input from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.supertokens import AppInfo -from supertokens_python.types import RecipeUserId +from supertokens_python.types import RecipeUserId, User from ...post_init_callbacks import PostSTInitCallbacks from ..session import SessionRecipe from ..session.claim_base_classes.primitive_array_claim import PrimitiveArrayClaim from .exceptions import SuperTokensUserRolesError -from .interfaces import GetPermissionsForRoleOkResult +from .interfaces import GetPermissionsForRoleOkResult, UnknownRoleError from .utils import InputOverrideConfig @@ -49,6 +50,8 @@ def __init__( skip_adding_permissions_to_access_token: Optional[bool] = None, override: Union[InputOverrideConfig, None] = None, ): + from ..oauth2provider.recipe import OAuth2ProviderRecipe + super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( self, @@ -72,6 +75,109 @@ def callback(): PermissionClaim ) + async def token_payload_builder( + user: User, + scopes: List[str], + session_handle: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + payload: Dict[str, Any] = {"roles": None, "permissions": None} + + session_info = await get_session_information( + session_handle, user_context + ) + + if session_info is None: + raise Exception("should never come here") + + user_roles: List[str] = [] + + if "roles" in scopes or "permissions" in scopes: + res = await self.recipe_implementation.get_roles_for_user( + tenant_id=session_info.tenant_id, + user_id=user.id, + user_context=user_context, + ) + + user_roles = res.roles + + if "roles" in scopes: + payload["roles"] = user_roles + + if "permissions" in scopes: + user_permissions: Set[str] = set() + for role in user_roles: + role_permissions = ( + await self.recipe_implementation.get_permissions_for_role( + role=role, + user_context=user_context, + ) + ) + + if isinstance(role_permissions, UnknownRoleError): + raise Exception("Failed to fetch permissions for the role") + + for perm in role_permissions.permissions: + user_permissions.add(perm) + + payload["permissions"] = list(user_permissions) + + return payload + + OAuth2ProviderRecipe.get_instance().add_access_token_builder_from_other_recipe( + token_payload_builder + ) + OAuth2ProviderRecipe.get_instance().add_id_token_builder_from_other_recipe( + token_payload_builder + ) + + async def user_info_builder( + user: User, + _access_token_payload: Dict[str, Any], + scopes: List[str], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + user_info: Dict[str, Any] = {"roles": None, "permissions": None} + + user_roles = [] + + if "roles" in scopes or "permissions" in scopes: + res = await self.recipe_implementation.get_roles_for_user( + tenant_id=tenant_id, + user_id=user.id, + user_context=user_context, + ) + + user_roles = res.roles + + if "roles" in scopes: + user_info["roles"] = user_roles + + if "permissions" in scopes: + user_permissions: Set[str] = set() + for role in user_roles: + role_permissions = ( + await self.recipe_implementation.get_permissions_for_role( + role=role, + user_context=user_context, + ) + ) + + if isinstance(role_permissions, UnknownRoleError): + raise Exception("Failed to fetch permissions for the role") + + for perm in role_permissions.permissions: + user_permissions.add(perm) + + user_info["permissions"] = list(user_permissions) + + return user_info + + OAuth2ProviderRecipe.get_instance().add_user_info_builder_from_other_recipe( + user_info_builder + ) + PostSTInitCallbacks.add_post_init_callback(callback) def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: diff --git a/supertokens_python/recipe/userroles/recipe_implementation.py b/supertokens_python/recipe/userroles/recipe_implementation.py index c4939085b..01d85334f 100644 --- a/supertokens_python/recipe/userroles/recipe_implementation.py +++ b/supertokens_python/recipe/userroles/recipe_implementation.py @@ -50,6 +50,7 @@ async def add_role_to_user( response = await self.querier.send_put_request( NormalisedURLPath(f"{tenant_id}/recipe/user/role"), params, + None, user_context=user_context, ) if response["status"] == "OK": @@ -108,6 +109,7 @@ async def create_new_role_or_add_permissions( response = await self.querier.send_put_request( NormalisedURLPath("/recipe/role"), params, + None, user_context=user_context, ) return CreateNewRoleOrAddPermissionsOkResult( diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 53964da2e..6632850b4 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -259,9 +259,12 @@ def __init__( totp_found = False user_metadata_found = False multi_factor_auth_found = False + oauth2_found = False + openid_found = False + jwt_found = False def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: - nonlocal multitenancy_found, totp_found, user_metadata_found, multi_factor_auth_found + nonlocal multitenancy_found, totp_found, user_metadata_found, multi_factor_auth_found, oauth2_found, openid_found, jwt_found recipe_module = recipe(self.app_info) if recipe_module.get_recipe_id() == "multitenancy": multitenancy_found = True @@ -271,21 +274,46 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: multi_factor_auth_found = True elif recipe_module.get_recipe_id() == "totp": totp_found = True + elif recipe_module.get_recipe_id() == "oauth2provider": + oauth2_found = True + elif recipe_module.get_recipe_id() == "openid": + openid_found = True + elif recipe_module.get_recipe_id() == "jwt": + jwt_found = True return recipe_module self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) + if not jwt_found: + from supertokens_python.recipe.jwt.recipe import JWTRecipe + + self.recipe_modules.append(JWTRecipe.init()(self.app_info)) + + if not openid_found: + from supertokens_python.recipe.openid.recipe import OpenIdRecipe + + self.recipe_modules.append(OpenIdRecipe.init()(self.app_info)) + if not multitenancy_found: from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info)) + if totp_found and not multi_factor_auth_found: raise Exception("Please initialize the MultiFactorAuth recipe to use TOTP.") + if not user_metadata_found: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe self.recipe_modules.append(UserMetadataRecipe.init()(self.app_info)) + if not oauth2_found: + from supertokens_python.recipe.oauth2provider.recipe import ( + OAuth2ProviderRecipe, + ) + + self.recipe_modules.append(OAuth2ProviderRecipe.init()(self.app_info)) + self.telemetry = ( telemetry if telemetry is not None diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index cfafd8777..ae8e5872d 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -10,6 +10,7 @@ emailpassword, emailverification, multifactorauth, + oauth2provider, passwordless, session, thirdparty, @@ -40,6 +41,8 @@ APIOptions as EVAPIOptions, ) from supertokens_python.recipe.jwt import JWTRecipe +from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe +from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.recipe.passwordless import ( ContactEmailOnlyConfig, ContactEmailOrPhoneConfig, @@ -315,6 +318,8 @@ def custom_init(): Supertokens.reset() TOTPRecipe.reset() MultiFactorAuthRecipe.reset() + OpenIdRecipe.reset() + OAuth2ProviderRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -942,6 +947,10 @@ async def resync_session_and_fetch_mfa_info_put( ) ), }, + { + "id": "oauth2provider", + "init": oauth2provider.init(), + }, ] accountlinking_config_input: Dict[str, Any] = { @@ -1002,7 +1011,7 @@ async def should_do_automatic_account_linking( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( app_name="SuperTokens Demo", - api_domain="0.0.0.0:" + get_api_port(), + api_domain="localhost:" + get_api_port(), website_domain=get_website_domain(), ), framework="django", diff --git a/tests/auth-react/django3x/polls/urls.py b/tests/auth-react/django3x/polls/urls.py index 8b3cf999e..cff4e60f7 100644 --- a/tests/auth-react/django3x/polls/urls.py +++ b/tests/auth-react/django3x/polls/urls.py @@ -35,6 +35,7 @@ name="setEnabledRecipes", ), path("test/getTOTPCode", views.test_get_totp_code, name="getTotpCode"), # type: ignore + path("test/create-oauth2-client", views.test_create_oauth2_client, name="createOAuth2Client"), # type: ignore path("test/getDevice", views.test_get_device, name="getDevice"), # type: ignore path("test/featureFlags", views.test_feature_flags, name="featureFlags"), # type: ignore path("beforeeach", views.before_each, name="beforeeach"), # type: ignore diff --git a/tests/auth-react/django3x/polls/views.py b/tests/auth-react/django3x/polls/views.py index cd6b7c9be..08c28afcf 100644 --- a/tests/auth-react/django3x/polls/views.py +++ b/tests/auth-react/django3x/polls/views.py @@ -27,6 +27,8 @@ from supertokens_python.recipe.multifactorauth.asyncio import ( add_to_required_secondary_factors_for_user, ) +from supertokens_python.recipe.oauth2provider.syncio import create_oauth2_client +from supertokens_python.recipe.oauth2provider.interfaces import CreateOAuth2ClientInput from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.interfaces import SessionClaimValidator from supertokens_python.recipe.thirdparty import ProviderConfig @@ -457,6 +459,14 @@ def test_get_totp_code(request: HttpRequest): return JsonResponse({"totp": code}) +def test_create_oauth2_client(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Invalid request body") + client = create_oauth2_client(CreateOAuth2ClientInput.from_json(body)) + return JsonResponse(client.to_json()) + + def before_each(request: HttpRequest): import mysite.store @@ -486,6 +496,7 @@ def test_feature_flags(request: HttpRequest): "mfa", "recipeConfig", "accountlinking-fixes", + "oauth2", ] } ) diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index cfdaaec10..7a727da93 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -43,6 +43,7 @@ from supertokens_python.recipe import ( emailpassword, emailverification, + oauth2provider, passwordless, session, thirdparty, @@ -106,6 +107,10 @@ AssociateUserToTenantUnknownUserIdError, TenantConfigCreateOrUpdate, ) +from supertokens_python.recipe.oauth2provider.interfaces import CreateOAuth2ClientInput +from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe +from supertokens_python.recipe.oauth2provider.asyncio import create_oauth2_client +from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.recipe.passwordless import ( ContactEmailOnlyConfig, ContactEmailOrPhoneConfig, @@ -391,6 +396,8 @@ def custom_init(): Supertokens.reset() TOTPRecipe.reset() MultiFactorAuthRecipe.reset() + OpenIdRecipe.reset() + OAuth2ProviderRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -1021,6 +1028,10 @@ async def resync_session_and_fetch_mfa_info_put( ) ), }, + { + "id": "oauth2provider", + "init": oauth2provider.init(), + }, ] global accountlinking_config @@ -1084,7 +1095,7 @@ async def should_do_automatic_account_linking( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( app_name="SuperTokens Demo", - api_domain="0.0.0.0:" + get_api_port(), + api_domain="localhost:" + get_api_port(), website_domain=get_website_domain(), ), framework="fastapi", @@ -1374,6 +1385,15 @@ async def test_get_totp_code(request: Request): return JSONResponse({"totp": code}) +@app.post("/test/create-oauth2-client") +async def test_create_oauth2_client(request: Request): + body = await request.json() + if body is None: + raise Exception("Invalid request body") + client = await create_oauth2_client(CreateOAuth2ClientInput.from_json(body)) + return JSONResponse(client.to_json()) + + @app.get("/test/getDevice") def test_get_device(request: Request): global code_store @@ -1397,6 +1417,7 @@ def test_feature_flags(request: Request): "mfa", "recipeConfig", "accountlinking-fixes", + "oauth2", ] return JSONResponse({"available": available}) diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 6c840c9e3..6dbb3ce81 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -35,6 +35,7 @@ accountlinking, emailpassword, emailverification, + oauth2provider, passwordless, session, thirdparty, @@ -77,6 +78,10 @@ delete_tenant, disassociate_user_from_tenant, ) +from supertokens_python.recipe.oauth2provider.syncio import create_oauth2_client +from supertokens_python.recipe.oauth2provider.interfaces import CreateOAuth2ClientInput +from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe +from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.recipe.passwordless.syncio import update_user from supertokens_python.recipe.session.exceptions import ( ClaimValidationError, @@ -374,6 +379,8 @@ def custom_init(): Supertokens.reset() TOTPRecipe.reset() MultiFactorAuthRecipe.reset() + OpenIdRecipe.reset() + OAuth2ProviderRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -1004,6 +1011,10 @@ async def resync_session_and_fetch_mfa_info_put( ) ), }, + { + "id": "oauth2provider", + "init": oauth2provider.init(), + }, ] global accountlinking_config @@ -1066,7 +1077,7 @@ async def should_do_automatic_account_linking( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( app_name="SuperTokens Demo", - api_domain="0.0.0.0:" + get_api_port(), + api_domain="localhost:" + get_api_port(), website_domain=get_website_domain(), ), framework="flask", @@ -1401,6 +1412,15 @@ def test_get_totp_code(): return jsonify({"totp": code}) +@app.post("/test/create-oauth2-client") # type: ignore +def test_create_oauth2_client(): + body = request.get_json() + if body is None: + raise Exception("Invalid request body") + client = create_oauth2_client(CreateOAuth2ClientInput.from_json(body)) + return jsonify(client.to_json()) + + @app.get("/test/getDevice") # type: ignore def test_get_device(): global code_store @@ -1424,6 +1444,7 @@ def test_feature_flags(): "mfa", "recipeConfig", "accountlinking-fixes", + "oauth2", ] return jsonify({"available": available}) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 8ee45dd7e..1ab834d85 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -28,6 +28,8 @@ from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe +from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.types import RecipeUserId from test_functions_mapper import ( # pylint: disable=import-error get_func, @@ -55,6 +57,7 @@ session, thirdparty, emailverification, + oauth2provider, ) from supertokens_python.recipe.session import InputErrorHandlers, SessionContainer from supertokens_python.recipe.session.framework.flask import verify_session @@ -222,6 +225,8 @@ def st_reset(): AccountLinkingRecipe.reset() TOTPRecipe.reset() MultiFactorAuthRecipe.reset() + OAuth2ProviderRecipe.reset() + OpenIdRecipe.reset() def init_st(config: Dict[str, Any]): @@ -278,7 +283,9 @@ async def custom_unauthorised_callback( _: BaseRequest, __: str, response: BaseResponse ) -> BaseResponse: response.set_status_code(401) - response.set_json_content(content={"type": "UNAUTHORISED"}) + response.set_json_content( + content={"type": "UNAUTHORISED", "message": "unauthorised"} + ) return response recipe_config_json = json.loads(recipe_config.get("config", "{}")) @@ -591,6 +598,22 @@ async def send_sms( ) ) ) + elif recipe_id == "oauth2provider": + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + recipe_list.append( + oauth2provider.init( + override=oauth2provider.InputOverrideConfig( + apis=override_builder_with_logging( + "OAuth2Provider.override.apis", + recipe_config_json.get("override", {}).get("apis"), + ), + functions=override_builder_with_logging( + "OAuth2Provider.override.functions", + recipe_config_json.get("override", {}).get("functions"), + ), + ) + ) + ) interceptor_func = None if config.get("supertokens", {}).get("networkInterceptor") is not None: @@ -820,6 +843,10 @@ def handle_exception(e: Exception): add_multifactorauth_routes(app) +from oauth2provider import add_oauth2provider_routes + +add_oauth2provider_routes(app) + if __name__ == "__main__": default_st_init() port = int(os.environ.get("API_PORT", api_port)) diff --git a/tests/test-server/oauth2provider.py b/tests/test-server/oauth2provider.py new file mode 100644 index 000000000..06aa21761 --- /dev/null +++ b/tests/test-server/oauth2provider.py @@ -0,0 +1,106 @@ +from flask import Flask, request, jsonify +from supertokens_python.recipe.oauth2provider.interfaces import ( + CreateOAuth2ClientInput, + OAuth2TokenValidationRequirements, + UpdateOAuth2ClientInput, +) +import supertokens_python.recipe.oauth2provider.syncio as OAuth2Provider + + +def add_oauth2provider_routes(app: Flask): + @app.route("/test/oauth2provider/getoauth2clients", methods=["POST"]) # type: ignore + def get_oauth2_clients_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:getOAuth2Clients", request.json) + + data = request.json.get("input", {}) + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + response = OAuth2Provider.get_oauth2_clients( + page_size=data.get("pageSize"), + pagination_token=data.get("paginationToken"), + client_name=data.get("clientName"), + user_context=data.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/createoauth2client", methods=["POST"]) # type: ignore + def create_oauth2_client_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:createOAuth2Client", request.json) + + response = OAuth2Provider.create_oauth2_client( + params=CreateOAuth2ClientInput.from_json(request.json.get("input", {})), + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/updateoauth2client", methods=["POST"]) # type: ignore + def update_oauth2_client_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:updateOAuth2Client", request.json) + + response = OAuth2Provider.update_oauth2_client( + params=UpdateOAuth2ClientInput.from_json(request.json.get("input", {})), + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/deleteoauth2client", methods=["POST"]) # type: ignore + def delete_oauth2_client_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:deleteOAuth2Client", request.json) + + data = request.json.get("input", {}) + + response = OAuth2Provider.delete_oauth2_client( + client_id=data.get("clientId"), + user_context=data.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/validateoauth2accesstoken", methods=["POST"]) # type: ignore + def validate_oauth2_access_token_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:validateOAuth2AccessToken", request.json) + + response = OAuth2Provider.validate_oauth2_access_token( + token=request.json["token"], + requirements=( + OAuth2TokenValidationRequirements.from_json( + request.json["requirements"] + ) + if "requirements" in request.json + else None + ), + check_database=request.json.get("checkDatabase"), + user_context=request.json.get("userContext"), + ) + return jsonify({"payload": response, "status": "OK"}) + + @app.route("/test/oauth2provider/validateoauth2refreshtoken", methods=["POST"]) # type: ignore + def validate_oauth2_refresh_token_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:validateOAuth2RefreshToken", request.json) + + response = OAuth2Provider.validate_oauth2_refresh_token( + token=request.json["token"], + scopes=request.json["scopes"], + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/createtokenforclientcredentials", methods=["POST"]) # type: ignore + def create_token_for_client_credentials_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:createTokenForClientCredentials", request.json) + + response = OAuth2Provider.create_token_for_client_credentials( + client_id=request.json["clientId"], + client_secret=request.json["clientSecret"], + scope=request.json["scope"], + audience=request.json["audience"], + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) diff --git a/tests/test-server/session.py b/tests/test-server/session.py index cf1dc174e..ea3dea1df 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -195,6 +195,31 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore } ) + @app.route("/test/session/sessionobject/revokesession", methods=["POST"]) # type: ignore + def revoke_session(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + log_override_event("sessionobject.revokesession", "CALL", data) + + try: + session = convert_session_to_container(data) + if not session: + raise Exception( + "This should never happen: failed to deserialize session" + ) + ret_val = session.sync_revoke_session(data.get("userContext", {})) + response = { + "retVal": ret_val, + "updatedSession": convert_session_to_json(session), + } + log_override_event("sessionobject.revokesession", "RES", ret_val) + return jsonify(response) + except Exception as e: + log_override_event("sessionobject.revokesession", "REJ", e) + return jsonify({"status": "ERROR", "message": str(e)}), 500 + @app.route("/test/session/mergeintoaccesspayload", methods=["POST"]) # type: ignore def merge_into_access_payload(): # type: ignore data = request.json