From e2f316ff4639f2c3ca6746a6bb7c710ed6b64790 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Mon, 11 Nov 2024 17:00:50 +0530 Subject: [PATCH 01/38] fix: files for oauth2 providers --- .../multitenancy/get_third_party_config.py | 9 ++++--- .../recipe/oauth2provider/__init__.py | 13 ++++++++++ .../recipe/oauth2provider/api/__init__.py | 13 ++++++++++ .../recipe/oauth2provider/asyncio/__init__.py | 13 ++++++++++ .../recipe/oauth2provider/constants.py | 24 +++++++++++++++++++ .../recipe/oauth2provider/interfaces.py | 13 ++++++++++ .../recipe/oauth2provider/oauth2_client.py | 13 ++++++++++ .../recipe/oauth2provider/recipe.py | 13 ++++++++++ .../oauth2provider/recipe_implementation.py | 13 ++++++++++ .../recipe/oauth2provider/syncio/__init__.py | 13 ++++++++++ .../recipe/oauth2provider/utils.py | 0 11 files changed, 134 insertions(+), 3 deletions(-) create mode 100644 supertokens_python/recipe/oauth2provider/__init__.py create mode 100644 supertokens_python/recipe/oauth2provider/api/__init__.py create mode 100644 supertokens_python/recipe/oauth2provider/asyncio/__init__.py create mode 100644 supertokens_python/recipe/oauth2provider/constants.py create mode 100644 supertokens_python/recipe/oauth2provider/interfaces.py create mode 100644 supertokens_python/recipe/oauth2provider/oauth2_client.py create mode 100644 supertokens_python/recipe/oauth2provider/recipe.py create mode 100644 supertokens_python/recipe/oauth2provider/recipe_implementation.py create mode 100644 supertokens_python/recipe/oauth2provider/syncio/__init__.py create mode 100644 supertokens_python/recipe/oauth2provider/utils.py diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py index ceb5a9255..f2790bf4a 100644 --- a/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py +++ b/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py @@ -143,7 +143,7 @@ async def get_third_party_config( for existing_client in providers_from_core[0].clients: existing_client.additional_config = { **(existing_client.additional_config or {}), - **additional_config + **additional_config, } # filter out other providers from static @@ -197,7 +197,10 @@ async def get_third_party_config( for merged_provider in merged_providers_from_core_and_static: if merged_provider.config.third_party_id == third_party_id: - if merged_provider.config.clients is None or len(merged_provider.config.clients) == 0: + if ( + merged_provider.config.clients is None + or len(merged_provider.config.clients) == 0 + ): merged_provider.config.clients = [ ProviderClientConfig( client_id="nonguessable-temporary-client-id", @@ -217,7 +220,7 @@ async def get_third_party_config( if provider.config.third_party_id == third_party_id: found_correct_config = False - for client in (provider.config.clients or []): + for client in provider.config.clients or []: try: provider_instance = await find_and_create_provider_instance( merged_providers_from_core_and_static, diff --git a/supertokens_python/recipe/oauth2provider/__init__.py b/supertokens_python/recipe/oauth2provider/__init__.py new file mode 100644 index 000000000..a59c06f52 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/supertokens_python/recipe/oauth2provider/api/__init__.py b/supertokens_python/recipe/oauth2provider/api/__init__.py new file mode 100644 index 000000000..a59c06f52 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/supertokens_python/recipe/oauth2provider/asyncio/__init__.py b/supertokens_python/recipe/oauth2provider/asyncio/__init__.py new file mode 100644 index 000000000..a59c06f52 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/asyncio/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/supertokens_python/recipe/oauth2provider/constants.py b/supertokens_python/recipe/oauth2provider/constants.py new file mode 100644 index 000000000..edcd4f04b --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/constants.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +LOGIN_PATH = "/oauth/login" +AUTH_PATH = "/oauth/auth" +TOKEN_PATH = "/oauth/token" +LOGIN_INFO_PATH = "/oauth/login/info" +USER_INFO_PATH = "/oauth/userinfo" +REVOKE_TOKEN_PATH = "/oauth/revoke" +INTROSPECT_TOKEN_PATH = "/oauth/introspect" +END_SESSION_PATH = "/oauth/end_session" +LOGOUT_PATH = "/oauth/logout" diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py new file mode 100644 index 000000000..a59c06f52 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/supertokens_python/recipe/oauth2provider/oauth2_client.py b/supertokens_python/recipe/oauth2provider/oauth2_client.py new file mode 100644 index 000000000..a59c06f52 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/oauth2_client.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py new file mode 100644 index 000000000..a59c06f52 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py new file mode 100644 index 000000000..a59c06f52 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/supertokens_python/recipe/oauth2provider/syncio/__init__.py b/supertokens_python/recipe/oauth2provider/syncio/__init__.py new file mode 100644 index 000000000..a59c06f52 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/syncio/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py new file mode 100644 index 000000000..e69de29bb From 59109a8d7caa53c521a86490f6d1339c06ee76e2 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 12 Nov 2024 08:31:53 +0530 Subject: [PATCH 02/38] fix: interface --- .../recipe/oauth2provider/interfaces.py | 35 ++++++++++++++++ .../recipe/oauth2provider/utils.py | 42 +++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index a59c06f52..032c5a1bd 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -11,3 +11,38 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + +from abc import ABC +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from supertokens_python.framework import BaseRequest, BaseResponse + + from .utils import OAuth2ProviderConfig + + +class RecipeInterface(ABC): + def __init__(self): + pass + + +class APIOptions: + def __init__( + self, + request: BaseRequest, + response: BaseResponse, + recipe_id: str, + config: OAuth2ProviderConfig, + recipe_implementation: RecipeInterface, + ): + self.request: BaseRequest = request + self.response: BaseResponse = response + self.recipe_id: str = recipe_id + self.config: OAuth2ProviderConfig = config + self.recipe_implementation: RecipeInterface = recipe_implementation + + +class APIInterface: + def __init__(self): + pass diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py index e69de29bb..15b237e5b 100644 --- a/supertokens_python/recipe/oauth2provider/utils.py +++ b/supertokens_python/recipe/oauth2provider/utils.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Callable, Union + +from .interfaces import APIInterface, RecipeInterface + + +class InputOverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + apis: Union[Callable[[APIInterface], APIInterface], None] = None, + ): + self.functions = functions + self.apis = apis + + +class OverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + apis: Union[Callable[[APIInterface], APIInterface], None] = None, + ): + self.functions = functions + self.apis = apis + + +class OAuth2ProviderConfig: + def __init__(self): + pass From f5775231f00626f4bc25b6dd30813f86ec1c0ca7 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 12 Nov 2024 12:05:01 +0530 Subject: [PATCH 03/38] fix: oauth2 interfaces --- .../recipe/oauth2provider/interfaces.py | 479 +++++++++++++++++- .../recipe/oauth2provider/oauth2_client.py | 5 + 2 files changed, 482 insertions(+), 2 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 032c5a1bd..ce2cb00dc 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -12,8 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. -from abc import ABC -from typing import TYPE_CHECKING +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing_extensions import Literal +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import APIResponse, GeneralErrorResponse, User + +from .oauth2_client import OAuth2Client if TYPE_CHECKING: @@ -22,10 +27,372 @@ from .utils import OAuth2ProviderConfig +class ErrorOAuth2Response(APIResponse): + def __init__( + self, + error: str, # OAuth2 error format (e.g. invalid_request, login_required) + error_description: str, # Human readable error description + status_code: Optional[int] = None, # HTTP status code (e.g. 401 or 403) + ): + self.status: Literal["ERROR"] = "ERROR" + self.error = error + self.error_description = error_description + self.status_code = status_code + + def to_json(self): + return { + "status": self.status, + "error": self.error, + "errorDescription": self.error_description, + "statusCode": self.status_code, + } + + @staticmethod + def from_json(json: Dict[str, Any]): + return ErrorOAuth2Response( + error=json["error"], + error_description=json["errorDescription"], + status_code=json["statusCode"], + ) + + +class ConsentRequest: + def __init__( + self, + challenge: str, # ID/identifier of the consent authorization request + acr: Optional[str] = None, # Authentication Context Class Reference value + amr: Optional[List[str]] = None, # List of strings + client: Optional[OAuth2Client] = None, + context: Optional[Any] = None, # Any JSON serializable object + login_challenge: Optional[str] = None, # Associated login challenge + login_session_id: Optional[str] = None, + oidc_context: Optional[Any] = None, # Optional OpenID Connect request info + requested_access_token_audience: Optional[List[str]] = None, + requested_scope: Optional[List[str]] = None, + skip: Optional[bool] = None, + subject: Optional[str] = None, + ): + self.challenge = challenge + self.acr = acr + self.amr = amr + self.client = client + self.context = context + self.login_challenge = login_challenge + self.login_session_id = login_session_id + self.oidc_context = oidc_context + self.requested_access_token_audience = requested_access_token_audience + self.requested_scope = requested_scope + self.skip = skip + self.subject = subject + + +class LoginRequest: + def __init__( + self, + challenge: str, # ID/identifier of the login request + client: OAuth2Client, + request_url: str, # Original OAuth 2.0 Authorization URL + skip: bool, + subject: str, + oidc_context: Optional[Any] = None, # Optional OpenID Connect request info + requested_access_token_audience: Optional[List[str]] = None, + requested_scope: Optional[List[str]] = None, + session_id: Optional[str] = None, + ): + self.challenge = challenge + self.client = client + self.oidc_context = oidc_context + self.request_url = request_url + self.requested_access_token_audience = requested_access_token_audience + self.requested_scope = requested_scope + self.session_id = session_id + self.skip = skip + self.subject = subject + + +class TokenInfo: + def __init__( + self, + expires_in: int, # Lifetime in seconds of the access token + scope: str, + token_type: str, + access_token: Optional[str] = None, + id_token: Optional[str] = None, # Requires id_token scope + refresh_token: Optional[str] = None, # Requires offline scope + ): + self.access_token = access_token + self.expires_in = expires_in + self.id_token = id_token + self.refresh_token = refresh_token + self.scope = scope + self.token_type = token_type + + +class LoginInfo: + def __init__( + self, + client_id: str, + client_name: str, + tos_uri: Optional[str] = None, + policy_uri: Optional[str] = None, + logo_uri: Optional[str] = None, + client_uri: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + self.client_id = client_id + self.client_name = client_name + self.tos_uri = tos_uri + self.policy_uri = policy_uri + self.logo_uri = logo_uri + self.client_uri = client_uri + self.metadata = metadata + + +class RedirectResponse: + def __init__(self, redirect_to: str, cookies: Optional[str]): + self.redirect_to = redirect_to + self.cookies = cookies + + class RecipeInterface(ABC): def __init__(self): pass + @abstractmethod + async def authorization( + self, + params: Dict[str, str], + cookies: Optional[str], + session: Optional[SessionContainer], + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + pass + + @abstractmethod + async def token_exchange( + self, + authorization_header: Optional[str], + body: Dict[str, Optional[str]], + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[TokenInfo, ErrorOAuth2Response]: + pass + + @abstractmethod + async def get_consent_request( + self, challenge: str, user_context: Optional[Dict[str, Any]] = None + ) -> ConsentRequest: + pass + + @abstractmethod + async def accept_consent_request( + self, + challenge: str, + context: Optional[Any] = None, + grant_access_token_audience: Optional[List[str]] = None, + grant_scope: Optional[List[str]] = None, + handled_at: Optional[str] = None, + tenant_id: str = "", + rsub: str = "", + session_handle: str = "", + initial_access_token_payload: Optional[Dict[str, Any]] = None, + initial_id_token_payload: Optional[Dict[str, Any]] = None, + user_context: Optional[Dict[str, Any]] = None, + ) -> RedirectResponse: + pass + + @abstractmethod + async def reject_consent_request( + self, challenge: str, error: ErrorOAuth2Response, user_context: Dict[str, Any] + ) -> RedirectResponse: + pass + + @abstractmethod + async def get_login_request( + self, challenge: str, user_context: Dict[str, Any] + ) -> Union[LoginRequest, ErrorOAuth2Response]: + pass + + @abstractmethod + async def accept_login_request( + self, + challenge: str, + acr: Optional[str] = None, + amr: Optional[List[str]] = None, + context: Optional[Any] = None, + extend_session_lifespan: Optional[bool] = None, + identity_provider_session_id: Optional[str] = None, + subject: str = "", + user_context: Optional[Dict[str, Any]] = None, + ) -> RedirectResponse: + pass + + @abstractmethod + async def reject_login_request( + self, + challenge: str, + error: ErrorOAuth2Response, + user_context: Optional[Dict[str, Any]] = None, + ) -> RedirectResponse: + pass + + @abstractmethod + async def get_oauth2_client( + self, client_id: str, user_context: Optional[Dict[str, Any]] = None + ) -> Union[GetOAuth2ClientOkResult, GetOAuth2ClientErrorResult]: + pass + + @abstractmethod + async def get_oauth2_clients( + self, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[GetOAuth2ClientsOkResult, GetOAuth2ClientsErrorResult]: + pass + + @abstractmethod + async def create_oauth2_client( + self, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[CreateOAuth2ClientOkResult, CreateOAuth2ClientErrorResult]: + pass + + @abstractmethod + async def update_oauth2_client( + self, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[UpdateOAuth2ClientOkResult, UpdateOAuth2ClientErrorResult]: + pass + + @abstractmethod + async def delete_oauth2_client( + self, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[DeleteOAuth2ClientOkResult, DeleteOAuth2ClientErrorResult]: + pass + + @abstractmethod + async def validate_oauth2_access_token( + self, + token: str, + requirements: Optional[Dict[str, Any]] = None, + check_database: Optional[bool] = None, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + pass + + @abstractmethod + async def get_requested_scopes( + self, + recipe_user_id: Optional[str], + session_handle: Optional[str], + scope_param: List[str], + client_id: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> List[str]: + pass + + @abstractmethod + async def build_access_token_payload( + self, + user: Optional[Dict[str, Any]], + client: OAuth2Client, + session_handle: Optional[str], + scopes: List[str], + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + pass + + @abstractmethod + async def build_id_token_payload( + self, + user: Optional[Dict[str, Any]], + client: OAuth2Client, + session_handle: Optional[str], + scopes: List[str], + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + pass + + @abstractmethod + async def build_user_info( + self, + user: Dict[str, Any], + access_token_payload: Dict[str, Any], + scopes: List[str], + tenant_id: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + pass + + @abstractmethod + async def get_frontend_redirection_url( + self, + input_type: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> str: + pass + + @abstractmethod + async def revoke_token( + self, + token: str, + authorization_header: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[Dict[str, str], ErrorOAuth2Response]: + pass + + @abstractmethod + async def revoke_tokens_by_client_id( + self, + client_id: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, str]: + pass + + @abstractmethod + async def revoke_tokens_by_session_handle( + self, + session_handle: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, str]: + pass + + @abstractmethod + async def introspect_token( + self, + token: str, + scopes: Optional[List[str]] = None, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + pass + + @abstractmethod + async def end_session( + self, + params: Dict[str, str], + session: Optional[SessionContainer] = None, + should_try_refresh: bool = False, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + pass + + @abstractmethod + async def accept_logout_request( + self, + challenge: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + pass + + @abstractmethod + async def reject_logout_request( + self, + challenge: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, str]: + pass + class APIOptions: def __init__( @@ -46,3 +413,111 @@ def __init__( class APIInterface: def __init__(self): pass + + @abstractmethod + async def login_get( + self, + login_challenge: str, + options: APIOptions, + session: Optional[SessionContainer] = None, + should_try_refresh: bool = False, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[Dict[str, Union[str, Optional[str]]], ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def auth_get( + self, + params: Any, + cookie: Optional[str], + session: Optional[SessionContainer], + should_try_refresh: bool, + options: APIOptions, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[Dict[str, Union[str, Optional[str]]], ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def token_post( + self, + authorization_header: Optional[str], + body: Any, + options: APIOptions, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[TokenInfo, ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def login_info_get( + self, + login_challenge: str, + options: APIOptions, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[Dict[Literal["status", "info"], Union[Literal["OK"], LoginInfo]], ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def user_info_get( + self, + access_token_payload: Dict[str, Any], + user: User, + scopes: List[str], + tenant_id: str, + options: APIOptions, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[Dict[str, Any], GeneralErrorResponse]: + pass + + @abstractmethod + async def revoke_token_post( + self, + token: str, + options: APIOptions, + user_context: Optional[Dict[str, Any]] = None, + authorization_header: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + ) -> Union[Dict[Literal["status"], Literal["OK"]], ErrorOAuth2Response]: + pass + + @abstractmethod + async def introspect_token_post( + self, + token: str, + scopes: Optional[List[str]], + options: APIOptions, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[IntrospectTokenResponse, GeneralErrorResponse]: + pass + + @abstractmethod + async def end_session_get( + self, + params: Dict[str, str], + options: APIOptions, + session: Optional[SessionContainer] = None, + should_try_refresh: bool = False, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def end_session_post( + self, + params: Dict[str, str], + options: APIOptions, + session: Optional[SessionContainer] = None, + should_try_refresh: bool = False, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]: + pass + + @abstractmethod + async def logout_post( + self, + logout_challenge: str, + options: APIOptions, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[Dict[str, Union[Literal["OK"], str]], ErrorOAuth2Response, GeneralErrorResponse]: + pass \ No newline at end of file diff --git a/supertokens_python/recipe/oauth2provider/oauth2_client.py b/supertokens_python/recipe/oauth2provider/oauth2_client.py index a59c06f52..6074e4f44 100644 --- a/supertokens_python/recipe/oauth2provider/oauth2_client.py +++ b/supertokens_python/recipe/oauth2provider/oauth2_client.py @@ -11,3 +11,8 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + + +class OAuth2Client: + def __init__(self): + pass From 6479bdafcdfae3b1dd2027c9776f34d0b1413d12 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 26 Nov 2024 12:33:07 +0530 Subject: [PATCH 04/38] fix: update recipe.py --- .../oauth2provider/api/implementation.py | 2 + .../recipe/oauth2provider/exceptions.py | 20 ++ .../recipe/oauth2provider/interfaces.py | 20 +- .../recipe/oauth2provider/recipe.py | 258 +++++++++++++++++- .../oauth2provider/recipe_implementation.py | 242 ++++++++++++++++ .../recipe/oauth2provider/utils.py | 12 +- 6 files changed, 546 insertions(+), 8 deletions(-) create mode 100644 supertokens_python/recipe/oauth2provider/api/implementation.py create mode 100644 supertokens_python/recipe/oauth2provider/exceptions.py diff --git a/supertokens_python/recipe/oauth2provider/api/implementation.py b/supertokens_python/recipe/oauth2provider/api/implementation.py new file mode 100644 index 000000000..5a9b155e2 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/implementation.py @@ -0,0 +1,2 @@ +class APIImplementation: + pass diff --git a/supertokens_python/recipe/oauth2provider/exceptions.py b/supertokens_python/recipe/oauth2provider/exceptions.py new file mode 100644 index 000000000..9c3b5dcf6 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/exceptions.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from supertokens_python.exceptions import SuperTokensError + + +class OAuth2ProviderError(SuperTokensError): + pass diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index ce2cb00dc..5a09477cf 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -422,7 +422,9 @@ async def login_get( session: Optional[SessionContainer] = None, should_try_refresh: bool = False, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[Dict[str, Union[str, Optional[str]]], ErrorOAuth2Response, GeneralErrorResponse]: + ) -> Union[ + Dict[str, Union[str, Optional[str]]], ErrorOAuth2Response, GeneralErrorResponse + ]: pass @abstractmethod @@ -434,7 +436,9 @@ async def auth_get( should_try_refresh: bool, options: APIOptions, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[Dict[str, Union[str, Optional[str]]], ErrorOAuth2Response, GeneralErrorResponse]: + ) -> Union[ + Dict[str, Union[str, Optional[str]]], ErrorOAuth2Response, GeneralErrorResponse + ]: pass @abstractmethod @@ -453,7 +457,11 @@ async def login_info_get( login_challenge: str, options: APIOptions, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[Dict[Literal["status", "info"], Union[Literal["OK"], LoginInfo]], ErrorOAuth2Response, GeneralErrorResponse]: + ) -> Union[ + Dict[Literal["status", "info"], Union[Literal["OK"], LoginInfo]], + ErrorOAuth2Response, + GeneralErrorResponse, + ]: pass @abstractmethod @@ -519,5 +527,7 @@ async def logout_post( options: APIOptions, session: Optional[SessionContainer] = None, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[Dict[str, Union[Literal["OK"], str]], ErrorOAuth2Response, GeneralErrorResponse]: - pass \ No newline at end of file + ) -> Union[ + Dict[str, Union[Literal["OK"], str]], ErrorOAuth2Response, GeneralErrorResponse + ]: + pass diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py index a59c06f52..67d447d6b 100644 --- a/supertokens_python/recipe/oauth2provider/recipe.py +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. # # This software is licensed under the Apache License, Version 2.0 (the # "License") as published by the Apache Software Foundation. @@ -11,3 +11,259 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + +from __future__ import annotations + +from os import environ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from supertokens_python.recipe.oauth2provider.exceptions import OAuth2ProviderError +from supertokens_python.recipe_module import APIHandled, RecipeModule + +from .interfaces import ( + APIOptions, +) + +from .recipe_implementation import RecipeImplementation + +if TYPE_CHECKING: + from supertokens_python.framework.request import BaseRequest + from supertokens_python.framework.response import BaseResponse + from supertokens_python.supertokens import AppInfo + +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.querier import Querier +from supertokens_python.recipe.oauth2provider.api.implementation import ( + APIImplementation, +) + + +from .api import * # TODO: fix this +from .constants import ( + LOGIN_PATH, + AUTH_PATH, + TOKEN_PATH, + LOGIN_INFO_PATH, + USER_INFO_PATH, + REVOKE_TOKEN_PATH, + INTROSPECT_TOKEN_PATH, + END_SESSION_PATH, + LOGOUT_PATH, +) +from .utils import ( + InputOverrideConfig, + OAuth2ProviderConfig, + validate_and_normalise_user_input, +) + + +class OAuth2ProviderRecipe(RecipeModule): + recipe_id = "oauth2provider" + __instance = None + + def __init__( + self, + recipe_id: str, + app_info: AppInfo, + override: Union[InputOverrideConfig, None] = None, + ) -> None: + super().__init__(recipe_id, app_info) + self.config: OAuth2ProviderConfig = validate_and_normalise_user_input( + override, + ) + + recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) + self.recipe_implementation: RecipeImplementation = ( + recipe_implementation + if self.config.override.functions is None + else self.config.override.functions(recipe_implementation) + ) + + api_implementation = APIImplementation() + self.api_implementation: APIImplementation = ( + api_implementation + if self.config.override.apis is None + else self.config.override.apis(api_implementation) + ) + + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: + return isinstance(err, OAuth2ProviderError) + + def get_apis_handled(self) -> List[APIHandled]: + return [ + APIHandled( + NormalisedURLPath(LOGIN_PATH), + "get", + LOGIN_PATH, + self.api_implementation.login_get, + ), + APIHandled( + NormalisedURLPath(TOKEN_PATH), + "post", + TOKEN_PATH, + self.api_implementation.token_post, + ), + APIHandled( + NormalisedURLPath(AUTH_PATH), + "get", + AUTH_PATH, + self.api_implementation.auth_get, + ), + APIHandled( + NormalisedURLPath(LOGIN_INFO_PATH), + "get", + LOGIN_INFO_PATH, + self.api_implementation.login_info_get, + ), + APIHandled( + NormalisedURLPath(USER_INFO_PATH), + "get", + USER_INFO_PATH, + self.api_implementation.user_info_get, + ), + APIHandled( + NormalisedURLPath(REVOKE_TOKEN_PATH), + "post", + REVOKE_TOKEN_PATH, + self.api_implementation.revoke_token_post, + ), + APIHandled( + NormalisedURLPath(INTROSPECT_TOKEN_PATH), + "post", + INTROSPECT_TOKEN_PATH, + self.api_implementation.introspect_token_post, + ), + APIHandled( + NormalisedURLPath(END_SESSION_PATH), + "get", + END_SESSION_PATH, + self.api_implementation.end_session_get, + ), + APIHandled( + NormalisedURLPath(END_SESSION_PATH), + "post", + END_SESSION_PATH, + self.api_implementation.end_session_post, + ), + APIHandled( + NormalisedURLPath(LOGOUT_PATH), + "post", + LOGOUT_PATH, + self.api_implementation.logout_post, + ), + ] + + async def handle_api_request( + self, + request_id: str, + tenant_id: str, + request: BaseRequest, + path: NormalisedURLPath, + method: str, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> Union[BaseResponse, None]: + api_options = APIOptions( + request, + response, + self.recipe_id, + self.config, + self.recipe_implementation, + ) + if request_id == LOGIN_PATH: + return await login_api(self.api_implementation, api_options, user_context) + + if request_id == TOKEN_PATH: + return await token_post(self.api_implementation, api_options, user_context) + + if request_id == AUTH_PATH: + return await auth_get(self.api_implementation, api_options, user_context) + + if request_id == LOGIN_INFO_PATH: + return await login_info_get( + self.api_implementation, api_options, user_context + ) + + if request_id == USER_INFO_PATH: + return await user_info_get( + self.api_implementation, tenant_id, api_options, user_context + ) + + if request_id == REVOKE_TOKEN_PATH: + return await revoke_token_post( + self.api_implementation, api_options, user_context + ) + + if request_id == INTROSPECT_TOKEN_PATH: + return await introspect_token_post( + self.api_implementation, api_options, user_context + ) + + if request_id == END_SESSION_PATH and method == "get": + return await end_session_get( + self.api_implementation, api_options, user_context + ) + + if request_id == END_SESSION_PATH and method == "post": + return await end_session_post( + self.api_implementation, api_options, user_context + ) + + if request_id == LOGOUT_PATH and method == "post": + return await logout_post(self.api_implementation, api_options, user_context) + + raise Exception( + "Should never come here: handle_api_request called with unknown id" + ) + + async def handle_error( + self, + request: BaseRequest, + err: SuperTokensError, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> BaseResponse: + raise err + + def get_all_cors_headers(self) -> List[str]: + return [] + + @staticmethod + def init( + override: Union[InputOverrideConfig, None] = None, + ): + def func(app_info: AppInfo): + if OAuth2ProviderRecipe.__instance is None: + OAuth2ProviderRecipe.__instance = OAuth2ProviderRecipe( + OAuth2ProviderRecipe.recipe_id, + app_info, + override, + ) + + return OAuth2ProviderRecipe.__instance + raise_general_exception( + "OAuth2Provider recipe has already been initialised. Please check your code for bugs." + ) + + return func + + @staticmethod + def get_instance() -> OAuth2ProviderRecipe: + if OAuth2ProviderRecipe.__instance is not None: + return OAuth2ProviderRecipe.__instance + raise_general_exception( + "Initialisation not done. Did you forget to call the SuperTokens.init function?" + ) + + @staticmethod + def get_instance_optional() -> Optional[OAuth2ProviderRecipe]: + return OAuth2ProviderRecipe.__instance + + @staticmethod + def reset(): + if ("SUPERTOKENS_ENV" not in environ) or ( + environ["SUPERTOKENS_ENV"] != "testing" + ): + raise_general_exception("calling testing function in non testing env") + OAuth2ProviderRecipe.__instance = None diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index a59c06f52..ad57216fc 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -11,3 +11,245 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + +from typing import TYPE_CHECKING, Dict, Optional, Any, Union, List + +from .interfaces import ( + RecipeInterface, + RedirectResponse, + ErrorOAuth2Response, + SessionContainer, + GetOAuth2ClientOkResult, + GetOAuth2ClientErrorResult, + GetOAuth2ClientsOkResult, + GetOAuth2ClientsErrorResult, + CreateOAuth2ClientOkResult, + CreateOAuth2ClientErrorResult, + UpdateOAuth2ClientOkResult, + UpdateOAuth2ClientErrorResult, + DeleteOAuth2ClientOkResult, + DeleteOAuth2ClientErrorResult, + ConsentRequest, + LoginRequest, + OAuth2Client, + TokenInfo, +) + + +if TYPE_CHECKING: + from supertokens_python.querier import Querier + + +class RecipeImplementation(RecipeInterface): + def __init__(self, querier: Querier): + super().__init__() + self.querier = querier + + async def authorization( + self, + params: Dict[str, str], + cookies: Optional[str], + session: Optional[SessionContainer], + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + pass + + async def token_exchange( + self, + authorization_header: Optional[str], + body: Dict[str, Optional[str]], + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[TokenInfo, ErrorOAuth2Response]: + pass + + async def get_consent_request( + self, challenge: str, user_context: Optional[Dict[str, Any]] = None + ) -> ConsentRequest: + pass + + async def accept_consent_request( + self, + challenge: str, + context: Optional[Any] = None, + grant_access_token_audience: Optional[List[str]] = None, + grant_scope: Optional[List[str]] = None, + handled_at: Optional[str] = None, + tenant_id: str = "", + rsub: str = "", + session_handle: str = "", + initial_access_token_payload: Optional[Dict[str, Any]] = None, + initial_id_token_payload: Optional[Dict[str, Any]] = None, + user_context: Optional[Dict[str, Any]] = None, + ) -> RedirectResponse: + pass + + async def reject_consent_request( + self, challenge: str, error: ErrorOAuth2Response, user_context: Dict[str, Any] + ) -> RedirectResponse: + pass + + async def get_login_request( + self, challenge: str, user_context: Dict[str, Any] + ) -> Union[LoginRequest, ErrorOAuth2Response]: + pass + + async def accept_login_request( + self, + challenge: str, + acr: Optional[str] = None, + amr: Optional[List[str]] = None, + context: Optional[Any] = None, + extend_session_lifespan: Optional[bool] = None, + identity_provider_session_id: Optional[str] = None, + subject: str = "", + user_context: Optional[Dict[str, Any]] = None, + ) -> RedirectResponse: + pass + + async def reject_login_request( + self, + challenge: str, + error: ErrorOAuth2Response, + user_context: Optional[Dict[str, Any]] = None, + ) -> RedirectResponse: + pass + + async def get_oauth2_client( + self, client_id: str, user_context: Optional[Dict[str, Any]] = None + ) -> Union[GetOAuth2ClientOkResult, GetOAuth2ClientErrorResult]: + pass + + async def get_oauth2_clients( + self, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[GetOAuth2ClientsOkResult, GetOAuth2ClientsErrorResult]: + pass + + async def create_oauth2_client( + self, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[CreateOAuth2ClientOkResult, CreateOAuth2ClientErrorResult]: + pass + + async def update_oauth2_client( + self, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[UpdateOAuth2ClientOkResult, UpdateOAuth2ClientErrorResult]: + pass + + async def delete_oauth2_client( + self, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[DeleteOAuth2ClientOkResult, DeleteOAuth2ClientErrorResult]: + pass + + async def validate_oauth2_access_token( + self, + token: str, + requirements: Optional[Dict[str, Any]] = None, + check_database: Optional[bool] = None, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + pass + + async def get_requested_scopes( + self, + recipe_user_id: Optional[str], + session_handle: Optional[str], + scope_param: List[str], + client_id: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> List[str]: + pass + + async def build_access_token_payload( + self, + user: Optional[Dict[str, Any]], + client: OAuth2Client, + session_handle: Optional[str], + scopes: List[str], + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + pass + + async def build_id_token_payload( + self, + user: Optional[Dict[str, Any]], + client: OAuth2Client, + session_handle: Optional[str], + scopes: List[str], + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + pass + + async def build_user_info( + self, + user: Dict[str, Any], + access_token_payload: Dict[str, Any], + scopes: List[str], + tenant_id: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + pass + + async def get_frontend_redirection_url( + self, + input_type: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> str: + pass + + async def revoke_token( + self, + token: str, + authorization_header: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[Dict[str, str], ErrorOAuth2Response]: + pass + + async def revoke_tokens_by_client_id( + self, + client_id: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, str]: + pass + + async def revoke_tokens_by_session_handle( + self, + session_handle: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, str]: + pass + + async def introspect_token( + self, + token: str, + scopes: Optional[List[str]] = None, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + pass + + async def end_session( + self, + params: Dict[str, str], + session: Optional[SessionContainer] = None, + should_try_refresh: bool = False, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + pass + + async def accept_logout_request( + self, + challenge: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + pass + + async def reject_logout_request( + self, + challenge: str, + user_context: Optional[Dict[str, Any]] = None, + ) -> Dict[str, str]: + pass diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py index 15b237e5b..cb7553dd1 100644 --- a/supertokens_python/recipe/oauth2provider/utils.py +++ b/supertokens_python/recipe/oauth2provider/utils.py @@ -38,5 +38,13 @@ def __init__( class OAuth2ProviderConfig: - def __init__(self): - pass + def __init__(self, override: Union[OverrideConfig, None] = None): + self.override = override + + +def validate_and_normalise_user_input( + override: Union[InputOverrideConfig, None] = None +): + if override is None: + return OAuth2ProviderConfig(OverrideConfig()) + return OAuth2ProviderConfig(OverrideConfig(override.functions, override.apis)) From 94fc82d1e4e8ae8170191c2d846df5e06459daba Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 26 Nov 2024 13:10:27 +0530 Subject: [PATCH 05/38] fix: login request impl --- .../recipe/oauth2provider/interfaces.py | 16 +- .../recipe/oauth2provider/oauth2_client.py | 231 +++++++++++++++++- .../oauth2provider/recipe_implementation.py | 97 ++++++-- 3 files changed, 314 insertions(+), 30 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 5a09477cf..78bfab706 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -109,6 +109,20 @@ def __init__( self.skip = skip self.subject = subject + @staticmethod + def from_json(json: Dict[str, Any]): + return LoginRequest( + challenge=json["challenge"], + client=OAuth2Client.from_json(json["client"]), + request_url=json["requestUrl"], + skip=json["skip"], + subject=json["subject"], + oidc_context=json["oidcContext"], + requested_access_token_audience=json["requestedAccessTokenAudience"], + requested_scope=json["requestedScope"], + session_id=json["sessionId"], + ) + class TokenInfo: def __init__( @@ -149,7 +163,7 @@ def __init__( class RedirectResponse: - def __init__(self, redirect_to: str, cookies: Optional[str]): + def __init__(self, redirect_to: str, cookies: Optional[str] = None): self.redirect_to = redirect_to self.cookies = cookies diff --git a/supertokens_python/recipe/oauth2provider/oauth2_client.py b/supertokens_python/recipe/oauth2provider/oauth2_client.py index 6074e4f44..9b17718fe 100644 --- a/supertokens_python/recipe/oauth2provider/oauth2_client.py +++ b/supertokens_python/recipe/oauth2provider/oauth2_client.py @@ -13,6 +13,233 @@ # under the License. +from typing import Dict, Any, List, Optional + + class OAuth2Client: - def __init__(self): - pass + # OAuth 2.0 Client ID + # The ID is immutable. If no ID is provided, a UUID4 will be generated. + client_id: str + + # OAuth 2.0 Client Name + # The human-readable name of the client to be presented to the end-user during authorization. + client_name: str + + # OAuth 2.0 Client Scope + # Scope is a string containing a space-separated list of scope values that the client + # can use when requesting access tokens. + scope: str + + # OAuth 2.0 Token Endpoint Authentication Method + # Requested Client Authentication method for the Token Endpoint. + token_endpoint_auth_method: str + + # OAuth 2.0 Client Creation Date + # CreatedAt returns the timestamp of the client's creation. + created_at: str + + # OAuth 2.0 Client Last Update Date + # UpdatedAt returns the timestamp of the last update. + updated_at: str + + # OAuth 2.0 Client Secret + client_secret: Optional[str] = None + + # Array of redirect URIs + redirect_uris: Optional[List[str]] = None + + # Array of post logout redirect URIs + post_logout_redirect_uris: Optional[List[str]] = None + + # Authorization Code Grant Access Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + authorization_code_grant_access_token_lifespan: Optional[str] = None + + # Authorization Code Grant ID Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + authorization_code_grant_id_token_lifespan: Optional[str] = None + + # Authorization Code Grant Refresh Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + authorization_code_grant_refresh_token_lifespan: Optional[str] = None + + # Client Credentials Grant Access Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + client_credentials_grant_access_token_lifespan: Optional[str] = None + + # Implicit Grant Access Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + implicit_grant_access_token_lifespan: Optional[str] = None + + # Implicit Grant ID Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + implicit_grant_id_token_lifespan: Optional[str] = None + + # Refresh Token Grant Access Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + refresh_token_grant_access_token_lifespan: Optional[str] = None + + # Refresh Token Grant ID Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + refresh_token_grant_id_token_lifespan: Optional[str] = None + + # Refresh Token Grant Refresh Token Lifespan + # NullDuration - ^[0-9]+(ns|us|ms|s|m|h)$ + refresh_token_grant_refresh_token_lifespan: Optional[str] = None + + # OAuth 2.0 Client URI + # ClientURI is a URL string of a web page providing information about the client. + client_uri: str = "" + + # Array of audiences + audience: List[str] = [] + + # Array of grant types + grant_types: Optional[List[str]] = None + + # Array of response types + response_types: Optional[List[str]] = None + + # OAuth 2.0 Client Logo URI + # A URL string referencing the client's logo. + logo_uri: str = "" + + # OAuth 2.0 Client Policy URI + # PolicyURI is a URL string that points to a human-readable privacy policy document + # that describes how the deployment organization collects, uses, + # retains, and discloses personal data. + policy_uri: str = "" + + # OAuth 2.0 Client Terms of Service URI + # A URL string pointing to a human-readable terms of service + # document for the client that describes a contractual relationship + # between the end-user and the client that the end-user accepts when + # authorizing the client. + tos_uri: str = "" + + # Metadata - JSON object + metadata: Dict[str, Any] = {} + + # This flag is set to true if refresh tokens are updated upon use + enable_refresh_token_rotation: bool = False + + def __init__( + self, + client_id: str, + client_name: str, + scope: str, + token_endpoint_auth_method: str, + created_at: str, + updated_at: str, + client_secret: Optional[str], + redirect_uris: Optional[List[str]], + post_logout_redirect_uris: Optional[List[str]], + authorization_code_grant_access_token_lifespan: Optional[str], + authorization_code_grant_id_token_lifespan: Optional[str], + authorization_code_grant_refresh_token_lifespan: Optional[str], + client_credentials_grant_access_token_lifespan: Optional[str], + implicit_grant_access_token_lifespan: Optional[str], + implicit_grant_id_token_lifespan: Optional[str], + refresh_token_grant_access_token_lifespan: Optional[str], + refresh_token_grant_id_token_lifespan: Optional[str], + refresh_token_grant_refresh_token_lifespan: Optional[str], + client_uri: str, + audience: List[str], + grant_types: Optional[List[str]], + response_types: Optional[List[str]], + logo_uri: str, + policy_uri: str, + tos_uri: str, + metadata: Dict[str, Any], + enable_refresh_token_rotation: bool, + ): + self.client_id = client_id + self.client_name = client_name + self.scope = scope + self.token_endpoint_auth_method = token_endpoint_auth_method + self.created_at = created_at + self.updated_at = updated_at + self.client_secret = client_secret + self.redirect_uris = redirect_uris + self.post_logout_redirect_uris = post_logout_redirect_uris + self.authorization_code_grant_access_token_lifespan = ( + authorization_code_grant_access_token_lifespan + ) + self.authorization_code_grant_id_token_lifespan = ( + authorization_code_grant_id_token_lifespan + ) + self.authorization_code_grant_refresh_token_lifespan = ( + authorization_code_grant_refresh_token_lifespan + ) + self.client_credentials_grant_access_token_lifespan = ( + client_credentials_grant_access_token_lifespan + ) + self.implicit_grant_access_token_lifespan = implicit_grant_access_token_lifespan + self.implicit_grant_id_token_lifespan = implicit_grant_id_token_lifespan + self.refresh_token_grant_access_token_lifespan = ( + refresh_token_grant_access_token_lifespan + ) + self.refresh_token_grant_id_token_lifespan = ( + refresh_token_grant_id_token_lifespan + ) + self.refresh_token_grant_refresh_token_lifespan = ( + refresh_token_grant_refresh_token_lifespan + ) + self.client_uri = client_uri + self.audience = audience + self.grant_types = grant_types + self.response_types = response_types + self.logo_uri = logo_uri + self.policy_uri = policy_uri + self.tos_uri = tos_uri + self.metadata = metadata + self.enable_refresh_token_rotation = enable_refresh_token_rotation + + @staticmethod + def from_json(json: Dict[str, Any]) -> "OAuth2Client": + # Transform keys from snake_case to camelCase + return OAuth2Client( + client_id=json["clientId"], + client_secret=json.get("clientSecret"), + client_name=json["clientName"], + scope=json["scope"], + redirect_uris=json.get("redirectUris"), + post_logout_redirect_uris=json.get("postLogoutRedirectUris"), + authorization_code_grant_access_token_lifespan=json.get( + "authorizationCodeGrantAccessTokenLifespan" + ), + authorization_code_grant_id_token_lifespan=json.get( + "authorizationCodeGrantIdTokenLifespan" + ), + authorization_code_grant_refresh_token_lifespan=json.get( + "authorizationCodeGrantRefreshTokenLifespan" + ), + client_credentials_grant_access_token_lifespan=json.get( + "clientCredentialsGrantAccessTokenLifespan" + ), + implicit_grant_access_token_lifespan=json.get( + "implicitGrantAccessTokenLifespan" + ), + implicit_grant_id_token_lifespan=json.get("implicitGrantIdTokenLifespan"), + refresh_token_grant_access_token_lifespan=json.get( + "refreshTokenGrantAccessTokenLifespan" + ), + refresh_token_grant_id_token_lifespan=json.get( + "refreshTokenGrantIdTokenLifespan" + ), + refresh_token_grant_refresh_token_lifespan=json.get( + "refreshTokenGrantRefreshTokenLifespan" + ), + token_endpoint_auth_method=json["tokenEndpointAuthMethod"], + client_uri=json.get("clientUri", ""), + audience=json.get("audience", []), + grant_types=json.get("grantTypes"), + response_types=json.get("responseTypes"), + logo_uri=json.get("logoUri", ""), + policy_uri=json.get("policyUri", ""), + tos_uri=json.get("tosUri", ""), + created_at=json["createdAt"], + updated_at=json["updatedAt"], + metadata=json.get("metadata", {}), + enable_refresh_token_rotation=json.get("enableRefreshTokenRotation", False), + ) diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index ad57216fc..1b188b52d 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -14,6 +14,9 @@ from typing import TYPE_CHECKING, Dict, Optional, Any, Union, List +from supertokens_python import AppInfo +from supertokens_python.normalised_url_path import NormalisedURLPath + from .interfaces import ( RecipeInterface, RedirectResponse, @@ -40,10 +43,76 @@ from supertokens_python.querier import Querier +def get_updated_redirect_to(app_info: AppInfo, redirect_to: str) -> str: + return redirect_to.replace( + "{apiDomain}", + app_info.api_domain.get_as_string_dangerous() + app_info.api_base_path.get_as_string_dangerous() + ) + + class RecipeImplementation(RecipeInterface): - def __init__(self, querier: Querier): + def __init__(self, querier: Querier, app_info: AppInfo): super().__init__() self.querier = querier + self.app_info = app_info + + async def get_login_request( + self, challenge: str, user_context: Dict[str, Any] + ) -> Union[LoginRequest, ErrorOAuth2Response]: + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/login"), + {"challenge": challenge}, + user_context=user_context, + ) + if response["status"] != "OK": + return ErrorOAuth2Response( + response["error"], + response["errorDescription"], + response["statusCode"], + ) + + return LoginRequest.from_json(response) + + async def accept_login_request( + self, + challenge: str, + acr: Optional[str] = None, + amr: Optional[List[str]] = None, + context: Optional[Any] = None, + extend_session_lifespan: Optional[bool] = None, + identity_provider_session_id: Optional[str] = None, + subject: str = "", + user_context: Optional[Dict[str, Any]] = None, + ) -> RedirectResponse: + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/login/accept"), + { + "acr": acr, + "amr": amr, + "context": context, + "extendSessionLifespan": extend_session_lifespan, + "identityProviderSessionId": identity_provider_session_id, + "subject": subject, + "loginChallenge": challenge, + }, + user_context=user_context, + ) + + return RedirectResponse(redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"])) + + + async def reject_login_request( + self, + challenge: str, + error: ErrorOAuth2Response, + user_context: Optional[Dict[str, Any]] = None, + ) -> RedirectResponse: + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/login/reject"), + {"error": error.error, "errorDescription": error.error_description, "statusCode": error.status_code}, + user_context=user_context, + ) + return RedirectResponse(redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"])) async def authorization( self, @@ -88,32 +157,6 @@ async def reject_consent_request( ) -> RedirectResponse: pass - async def get_login_request( - self, challenge: str, user_context: Dict[str, Any] - ) -> Union[LoginRequest, ErrorOAuth2Response]: - pass - - async def accept_login_request( - self, - challenge: str, - acr: Optional[str] = None, - amr: Optional[List[str]] = None, - context: Optional[Any] = None, - extend_session_lifespan: Optional[bool] = None, - identity_provider_session_id: Optional[str] = None, - subject: str = "", - user_context: Optional[Dict[str, Any]] = None, - ) -> RedirectResponse: - pass - - async def reject_login_request( - self, - challenge: str, - error: ErrorOAuth2Response, - user_context: Optional[Dict[str, Any]] = None, - ) -> RedirectResponse: - pass - async def get_oauth2_client( self, client_id: str, user_context: Optional[Dict[str, Any]] = None ) -> Union[GetOAuth2ClientOkResult, GetOAuth2ClientErrorResult]: From 1fe7e515f2e90e4188005569f48c27fdbb970342 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 26 Nov 2024 13:23:05 +0530 Subject: [PATCH 06/38] fix: query params for put request --- supertokens_python/querier.py | 13 ++++++--- .../emailpassword/recipe_implementation.py | 1 + .../multitenancy/recipe_implementation.py | 2 ++ .../oauth2provider/recipe_implementation.py | 27 ++++++++++++++----- .../passwordless/recipe_implementation.py | 2 ++ .../recipe/session/session_functions.py | 2 ++ .../recipe/totp/recipe_implementation.py | 1 + .../usermetadata/recipe_implementation.py | 1 + .../recipe/userroles/recipe_implementation.py | 2 ++ 9 files changed, 40 insertions(+), 11 deletions(-) diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index 69945493d..5b5b4cf38 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -402,28 +402,33 @@ async def send_put_request( self, path: NormalisedURLPath, data: Union[Dict[str, Any], None], + query_params: Union[Dict[str, Any], None], user_context: Union[Dict[str, Any], None], ) -> Dict[str, Any]: self.invalidate_core_call_cache(user_context) if data is None: data = {} + if query_params is None: + query_params = {} headers = await self.__get_headers_with_api_version(path, user_context) headers["content-type"] = "application/json; charset=utf-8" async def f(url: str, method: str) -> Response: - nonlocal headers, data + nonlocal headers, data, query_params if Querier.network_interceptor is not None: ( url, method, headers, - _, + query_params, data, ) = Querier.network_interceptor( # pylint:disable=not-callable - url, method, headers, {}, data, user_context + url, method, headers, query_params, data, user_context ) - return await self.api_request(url, method, 2, headers=headers, json=data) + return await self.api_request( + url, method, 2, headers=headers, json=data, params=query_params + ) return await self.__send_request_helper(path, "PUT", f, len(self.__hosts)) diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index 283569a5f..c17c14b5f 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -297,6 +297,7 @@ async def update_email_or_password( response = await self.querier.send_put_request( NormalisedURLPath("/recipe/user"), data, + None, user_context=user_context, ) diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 90578bacb..efabfb1c8 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -147,6 +147,7 @@ async def create_or_update_tenant( response = await self.querier.send_put_request( NormalisedURLPath("/recipe/multitenancy/tenant/v2"), json_body, + None, user_context=user_context, ) return CreateOrUpdateTenantOkResult( @@ -217,6 +218,7 @@ async def create_or_update_third_party_config( "config": config.to_json(), "skipValidation": skip_validation is True, }, + None, user_context=user_context, ) diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 1b188b52d..5a206fb55 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -46,7 +46,8 @@ def get_updated_redirect_to(app_info: AppInfo, redirect_to: str) -> str: return redirect_to.replace( "{apiDomain}", - app_info.api_domain.get_as_string_dangerous() + app_info.api_base_path.get_as_string_dangerous() + app_info.api_domain.get_as_string_dangerous() + + app_info.api_base_path.get_as_string_dangerous(), ) @@ -59,9 +60,9 @@ def __init__(self, querier: Querier, app_info: AppInfo): async def get_login_request( self, challenge: str, user_context: Dict[str, Any] ) -> Union[LoginRequest, ErrorOAuth2Response]: - response = await self.querier.send_put_request( + response = await self.querier.send_get_request( NormalisedURLPath("/recipe/oauth/auth/requests/login"), - {"challenge": challenge}, + {"loginChallenge": challenge}, user_context=user_context, ) if response["status"] != "OK": @@ -93,13 +94,16 @@ async def accept_login_request( "extendSessionLifespan": extend_session_lifespan, "identityProviderSessionId": identity_provider_session_id, "subject": subject, + }, + { "loginChallenge": challenge, }, user_context=user_context, ) - return RedirectResponse(redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"])) - + return RedirectResponse( + redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"]) + ) async def reject_login_request( self, @@ -109,10 +113,19 @@ async def reject_login_request( ) -> RedirectResponse: response = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/auth/requests/login/reject"), - {"error": error.error, "errorDescription": error.error_description, "statusCode": error.status_code}, + { + "error": error.error, + "errorDescription": error.error_description, + "statusCode": error.status_code, + }, + { + "loginChallenge": challenge, + }, user_context=user_context, ) - return RedirectResponse(redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"])) + return RedirectResponse( + redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"]) + ) async def authorization( self, diff --git a/supertokens_python/recipe/passwordless/recipe_implementation.py b/supertokens_python/recipe/passwordless/recipe_implementation.py index 371c32a14..d58285a61 100644 --- a/supertokens_python/recipe/passwordless/recipe_implementation.py +++ b/supertokens_python/recipe/passwordless/recipe_implementation.py @@ -447,6 +447,7 @@ async def delete_email_for_user( result = await self.querier.send_put_request( NormalisedURLPath("/recipe/user"), data, + None, user_context=user_context, ) if result["status"] == "OK": @@ -460,6 +461,7 @@ async def delete_phone_number_for_user( result = await self.querier.send_put_request( NormalisedURLPath("/recipe/user"), data, + None, user_context=user_context, ) if result["status"] == "OK": diff --git a/supertokens_python/recipe/session/session_functions.py b/supertokens_python/recipe/session/session_functions.py index a09200633..651337fc9 100644 --- a/supertokens_python/recipe/session/session_functions.py +++ b/supertokens_python/recipe/session/session_functions.py @@ -517,6 +517,7 @@ async def update_session_data_in_database( response = await recipe_implementation.querier.send_put_request( NormalisedURLPath("/recipe/session/data"), {"sessionHandle": session_handle, "userDataInDatabase": new_session_data}, + None, user_context=user_context, ) if response["status"] == "UNAUTHORISED": @@ -534,6 +535,7 @@ async def update_access_token_payload( response = await recipe_implementation.querier.send_put_request( NormalisedURLPath("/recipe/jwt/data"), {"sessionHandle": session_handle, "userDataInJWT": new_access_token_payload}, + None, user_context=user_context, ) if response["status"] == "UNAUTHORISED": diff --git a/supertokens_python/recipe/totp/recipe_implementation.py b/supertokens_python/recipe/totp/recipe_implementation.py index ce213dd88..e501ebb10 100644 --- a/supertokens_python/recipe/totp/recipe_implementation.py +++ b/supertokens_python/recipe/totp/recipe_implementation.py @@ -149,6 +149,7 @@ async def update_device( resp = await self.querier.send_put_request( NormalisedURLPath("/recipe/totp/device"), data, + None, user_context=user_context, ) diff --git a/supertokens_python/recipe/usermetadata/recipe_implementation.py b/supertokens_python/recipe/usermetadata/recipe_implementation.py index 7485ddbe5..e65134301 100644 --- a/supertokens_python/recipe/usermetadata/recipe_implementation.py +++ b/supertokens_python/recipe/usermetadata/recipe_implementation.py @@ -47,6 +47,7 @@ async def update_user_metadata( response = await self.querier.send_put_request( NormalisedURLPath("/recipe/user/metadata"), params, + None, user_context=user_context, ) return MetadataResult(metadata=response["metadata"]) diff --git a/supertokens_python/recipe/userroles/recipe_implementation.py b/supertokens_python/recipe/userroles/recipe_implementation.py index 955a6ae9d..dfbc6db9a 100644 --- a/supertokens_python/recipe/userroles/recipe_implementation.py +++ b/supertokens_python/recipe/userroles/recipe_implementation.py @@ -50,6 +50,7 @@ async def add_role_to_user( response = await self.querier.send_put_request( NormalisedURLPath(f"{tenant_id}/recipe/user/role"), params, + None, user_context=user_context, ) if response["status"] == "OK": @@ -108,6 +109,7 @@ async def create_new_role_or_add_permissions( response = await self.querier.send_put_request( NormalisedURLPath("/recipe/role"), params, + None, user_context=user_context, ) return CreateNewRoleOrAddPermissionsOkResult( From 8f96467f784fac81bc5d8e8d0ff672d5aafe9c4f Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 26 Nov 2024 15:27:45 +0530 Subject: [PATCH 07/38] fix: consent request --- .../recipe/oauth2provider/interfaces.py | 17 ++++ .../oauth2provider/recipe_implementation.py | 82 ++++++++++++++----- 2 files changed, 80 insertions(+), 19 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 78bfab706..fb1da5229 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -85,6 +85,23 @@ def __init__( self.skip = skip self.subject = subject + @staticmethod + def from_json(json: Dict[str, Any]): + return ConsentRequest( + acr=json["acr"], + amr=json["amr"], + challenge=json["challenge"], + client=OAuth2Client.from_json(json["client"]), + context=json["context"], + login_challenge=json["loginChallenge"], + login_session_id=json["loginSessionId"], + oidc_context=json["oidcContext"], + requested_access_token_audience=json["requestedAccessTokenAudience"], + requested_scope=json["requestedScope"], + skip=json["skip"], + subject=json["subject"], + ) + class LoginRequest: def __init__( diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 5a206fb55..96a9e27c7 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -16,6 +16,7 @@ from supertokens_python import AppInfo from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.openid.recipe import OpenIdRecipe from .interfaces import ( RecipeInterface, @@ -127,27 +128,16 @@ async def reject_login_request( redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"]) ) - async def authorization( - self, - params: Dict[str, str], - cookies: Optional[str], - session: Optional[SessionContainer], - user_context: Optional[Dict[str, Any]] = None, - ) -> Union[RedirectResponse, ErrorOAuth2Response]: - pass - - async def token_exchange( - self, - authorization_header: Optional[str], - body: Dict[str, Optional[str]], - user_context: Optional[Dict[str, Any]] = None, - ) -> Union[TokenInfo, ErrorOAuth2Response]: - pass - async def get_consent_request( self, challenge: str, user_context: Optional[Dict[str, Any]] = None ) -> ConsentRequest: - pass + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/oauth/auth/requests/consent"), + {"consentChallenge": challenge}, + user_context=user_context, + ) + + return ConsentRequest.from_json(response) async def accept_consent_request( self, @@ -163,11 +153,65 @@ async def accept_consent_request( initial_id_token_payload: Optional[Dict[str, Any]] = None, user_context: Optional[Dict[str, Any]] = None, ) -> RedirectResponse: - pass + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/consent/accept"), + { + "context": context, + "grantAccessTokenAudience": grant_access_token_audience, + "grantScope": grant_scope, + "handledAt": handled_at, + "iss": await OpenIdRecipe.get_issuer(user_context), + "tId": tenant_id, + "rsub": rsub, + "sessionHandle": session_handle, + "initialAccessTokenPayload": initial_access_token_payload, + "initialIdTokenPayload": initial_id_token_payload, + }, + { + "consentChallenge": challenge, + }, + user_context=user_context, + ) + + return RedirectResponse( + redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"]) + ) async def reject_consent_request( self, challenge: str, error: ErrorOAuth2Response, user_context: Dict[str, Any] ) -> RedirectResponse: + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/consent/reject"), + { + "error": error.error, + "errorDescription": error.error_description, + "statusCode": error.status_code, + }, + { + "consentChallenge": challenge, + }, + user_context=user_context, + ) + + return RedirectResponse( + redirect_to=get_updated_redirect_to(self.app_info, response["redirectTo"]) + ) + + async def authorization( + self, + params: Dict[str, str], + cookies: Optional[str], + session: Optional[SessionContainer], + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[RedirectResponse, ErrorOAuth2Response]: + pass + + async def token_exchange( + self, + authorization_header: Optional[str], + body: Dict[str, Optional[str]], + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[TokenInfo, ErrorOAuth2Response]: pass async def get_oauth2_client( From 804121de318d663c1ca8b8f8898612b1bcd12d4a Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 26 Nov 2024 16:13:48 +0530 Subject: [PATCH 08/38] fix: more impl --- .../recipe/oauth2provider/interfaces.py | 67 ++++++++- .../oauth2provider/recipe_implementation.py | 134 +++++++++++++++--- 2 files changed, 172 insertions(+), 29 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index fb1da5229..989c13c8b 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -185,6 +185,53 @@ def __init__(self, redirect_to: str, cookies: Optional[str] = None): self.cookies = cookies +class GetOAuth2ClientsOkResult: + def __init__( + self, clients: List[OAuth2Client], next_pagination_token: Optional[str] + ): + self.clients = clients + self.next_pagination_token = next_pagination_token + + @staticmethod + def from_json(json: Dict[str, Any]): + return GetOAuth2ClientsOkResult( + clients=[OAuth2Client.from_json(client) for client in json["clients"]], + next_pagination_token=json["nextPaginationToken"], + ) + + +class GetOAuth2ClientOkResult: + def __init__(self, client: OAuth2Client): + self.client = client + + @staticmethod + def from_json(json: Dict[str, Any]): + return GetOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"])) + + +class CreateOAuth2ClientOkResult: + def __init__(self, client: OAuth2Client): + self.client = client + + @staticmethod + def from_json(json: Dict[str, Any]): + return CreateOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"])) + + +class UpdateOAuth2ClientOkResult: + def __init__(self, client: OAuth2Client): + self.client = client + + @staticmethod + def from_json(json: Dict[str, Any]): + return UpdateOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"])) + + +class DeleteOAuth2ClientOkResult: + def __init__(self): + pass + + class RecipeInterface(ABC): def __init__(self): pass @@ -267,16 +314,21 @@ async def reject_login_request( pass @abstractmethod - async def get_oauth2_client( - self, client_id: str, user_context: Optional[Dict[str, Any]] = None - ) -> Union[GetOAuth2ClientOkResult, GetOAuth2ClientErrorResult]: + async def get_oauth2_clients( + self, + page_size: Optional[int] = None, + pagination_token: Optional[str] = None, + client_name: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, + ) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: pass @abstractmethod - async def get_oauth2_clients( + async def get_oauth2_client( self, + client_id: str, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[GetOAuth2ClientsOkResult, GetOAuth2ClientsErrorResult]: + ) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: pass @abstractmethod @@ -290,14 +342,15 @@ async def create_oauth2_client( async def update_oauth2_client( self, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[UpdateOAuth2ClientOkResult, UpdateOAuth2ClientErrorResult]: + ) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: pass @abstractmethod async def delete_oauth2_client( self, + client_id: str, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[DeleteOAuth2ClientOkResult, DeleteOAuth2ClientErrorResult]: + ) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response\]: pass @abstractmethod diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 96a9e27c7..750a10687 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -214,34 +214,111 @@ async def token_exchange( ) -> Union[TokenInfo, ErrorOAuth2Response]: pass - async def get_oauth2_client( - self, client_id: str, user_context: Optional[Dict[str, Any]] = None - ) -> Union[GetOAuth2ClientOkResult, GetOAuth2ClientErrorResult]: - pass - async def get_oauth2_clients( self, + page_size: Optional[int] = None, + pagination_token: Optional[str] = None, + client_name: Optional[str] = None, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[GetOAuth2ClientsOkResult, GetOAuth2ClientsErrorResult]: - pass + ) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: + body: Dict[str, Any] = {} + if page_size is not None: + body["pageSize"] = page_size + if pagination_token is not None: + body["pageToken"] = pagination_token + if client_name is not None: + body["clientName"] = client_name + + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/oauth/clients/list"), + body, + user_context=user_context, + ) + + if response["status"] == "OK": + return GetOAuth2ClientsOkResult( + clients=[ + OAuth2Client.from_json(client) for client in response["clients"] + ], + next_pagination_token=response["nextPaginationToken"], + ) + + return ErrorOAuth2Response( + error=response["error"], + error_description=response["errorDescription"], + status_code=response["statusCode"], + ) + + async def get_oauth2_client( + self, client_id: str, user_context: Optional[Dict[str, Any]] = None + ) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/oauth/clients"), + {"clientId": client_id}, + user_context=user_context, + ) + + if response["status"] == "OK": + return GetOAuth2ClientOkResult(client=OAuth2Client.from_json(response)) + elif response["status"] == "CLIENT_NOT_FOUND_ERROR": + return ErrorOAuth2Response( + error="invalid_request", + error_description="The provided client_id is not valid or unknown", + ) + else: + return ErrorOAuth2Response( + error=response["error"], error_description=response["errorDescription"] + ) async def create_oauth2_client( self, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[CreateOAuth2ClientOkResult, CreateOAuth2ClientErrorResult]: - pass + ) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/clients"), + {}, # Empty dict since no input params in function signature + user_context=user_context, + ) + + if response["status"] == "OK": + return CreateOAuth2ClientOkResult(client=OAuth2Client.from_json(response)) + return ErrorOAuth2Response( + error=response["error"], error_description=response["errorDescription"] + ) async def update_oauth2_client( self, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[UpdateOAuth2ClientOkResult, UpdateOAuth2ClientErrorResult]: - pass + ) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/clients"), + {}, # TODO update params + None, + user_context=user_context, + ) + + if response["status"] == "OK": + return UpdateOAuth2ClientOkResult(client=OAuth2Client.from_json(response)) + return ErrorOAuth2Response( + error=response["error"], error_description=response["errorDescription"] + ) async def delete_oauth2_client( self, + client_id: str, user_context: Optional[Dict[str, Any]] = None, - ) -> Union[DeleteOAuth2ClientOkResult, DeleteOAuth2ClientErrorResult]: - pass + ) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/clients/remove"), + {"clientId": client_id}, + user_context=user_context, + ) + + if response["status"] == "OK": + return DeleteOAuth2ClientOkResult() + return ErrorOAuth2Response( + error=response["error"], error_description=response["errorDescription"] + ) async def validate_oauth2_access_token( self, @@ -260,27 +337,35 @@ async def get_requested_scopes( client_id: str, user_context: Optional[Dict[str, Any]] = None, ) -> List[str]: - pass + return scope_param async def build_access_token_payload( self, - user: Optional[Dict[str, Any]], client: OAuth2Client, - session_handle: Optional[str], scopes: List[str], + user: Optional[Dict[str, Any]] = None, + session_handle: Optional[str] = None, user_context: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - pass + if user is None or session_handle is None: + return {} + + return get_default_access_token_payload( + user, scopes, session_handle, user_context + ) async def build_id_token_payload( self, - user: Optional[Dict[str, Any]], client: OAuth2Client, - session_handle: Optional[str], scopes: List[str], + user: Optional[Dict[str, Any]] = None, + session_handle: Optional[str] = None, user_context: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - pass + if user is None or session_handle is None: + return {} + + return get_default_id_token_payload(user, scopes, session_handle, user_context) async def build_user_info( self, @@ -290,14 +375,19 @@ async def build_user_info( tenant_id: str, user_context: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - pass + return get_default_user_info( + user, access_token_payload, scopes, tenant_id, user_context + ) async def get_frontend_redirection_url( self, input_type: str, user_context: Optional[Dict[str, Any]] = None, ) -> str: - pass + website_domain = self.app_info.get_origin( + None, user_context + ).get_as_string_dangerous() + website_base_path = self.app_info.api_base_path.get_as_string_dangerous() async def revoke_token( self, From 22ab47bd7c6f7e53742ee118fc245b40fefdc9e3 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Wed, 27 Nov 2024 16:28:00 +0530 Subject: [PATCH 09/38] fix: more impl --- .../recipe/oauth2provider/interfaces.py | 73 ++++++------ .../recipe/oauth2provider/recipe.py | 110 +++++++++++++++++- .../oauth2provider/recipe_implementation.py | 55 ++++----- supertokens_python/recipe/openid/recipe.py | 7 ++ 4 files changed, 174 insertions(+), 71 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 989c13c8b..577867a58 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -23,7 +23,6 @@ if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse - from .utils import OAuth2ProviderConfig @@ -242,7 +241,7 @@ async def authorization( params: Dict[str, str], cookies: Optional[str], session: Optional[SessionContainer], - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[RedirectResponse, ErrorOAuth2Response]: pass @@ -251,13 +250,13 @@ async def token_exchange( self, authorization_header: Optional[str], body: Dict[str, Optional[str]], - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[TokenInfo, ErrorOAuth2Response]: pass @abstractmethod async def get_consent_request( - self, challenge: str, user_context: Optional[Dict[str, Any]] = None + self, challenge: str, user_context: Dict[str, Any] = {} ) -> ConsentRequest: pass @@ -274,7 +273,7 @@ async def accept_consent_request( session_handle: str = "", initial_access_token_payload: Optional[Dict[str, Any]] = None, initial_id_token_payload: Optional[Dict[str, Any]] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> RedirectResponse: pass @@ -300,7 +299,7 @@ async def accept_login_request( extend_session_lifespan: Optional[bool] = None, identity_provider_session_id: Optional[str] = None, subject: str = "", - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> RedirectResponse: pass @@ -309,7 +308,7 @@ async def reject_login_request( self, challenge: str, error: ErrorOAuth2Response, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> RedirectResponse: pass @@ -319,7 +318,7 @@ async def get_oauth2_clients( page_size: Optional[int] = None, pagination_token: Optional[str] = None, client_name: Optional[str] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: pass @@ -327,21 +326,21 @@ async def get_oauth2_clients( async def get_oauth2_client( self, client_id: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: pass @abstractmethod async def create_oauth2_client( self, - user_context: Optional[Dict[str, Any]] = None, - ) -> Union[CreateOAuth2ClientOkResult, CreateOAuth2ClientErrorResult]: + user_context: Dict[str, Any] = {}, + ) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: pass @abstractmethod async def update_oauth2_client( self, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: pass @@ -349,8 +348,8 @@ async def update_oauth2_client( async def delete_oauth2_client( self, client_id: str, - user_context: Optional[Dict[str, Any]] = None, - ) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response\]: + user_context: Dict[str, Any] = {}, + ) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: pass @abstractmethod @@ -359,7 +358,7 @@ async def validate_oauth2_access_token( token: str, requirements: Optional[Dict[str, Any]] = None, check_database: Optional[bool] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: pass @@ -370,7 +369,7 @@ async def get_requested_scopes( session_handle: Optional[str], scope_param: List[str], client_id: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> List[str]: pass @@ -381,7 +380,7 @@ async def build_access_token_payload( client: OAuth2Client, session_handle: Optional[str], scopes: List[str], - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: pass @@ -392,7 +391,7 @@ async def build_id_token_payload( client: OAuth2Client, session_handle: Optional[str], scopes: List[str], - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: pass @@ -403,7 +402,7 @@ async def build_user_info( access_token_payload: Dict[str, Any], scopes: List[str], tenant_id: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: pass @@ -411,7 +410,7 @@ async def build_user_info( async def get_frontend_redirection_url( self, input_type: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> str: pass @@ -422,7 +421,7 @@ async def revoke_token( authorization_header: Optional[str] = None, client_id: Optional[str] = None, client_secret: Optional[str] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[Dict[str, str], ErrorOAuth2Response]: pass @@ -430,7 +429,7 @@ async def revoke_token( async def revoke_tokens_by_client_id( self, client_id: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, str]: pass @@ -438,7 +437,7 @@ async def revoke_tokens_by_client_id( async def revoke_tokens_by_session_handle( self, session_handle: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, str]: pass @@ -447,7 +446,7 @@ async def introspect_token( self, token: str, scopes: Optional[List[str]] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: pass @@ -457,7 +456,7 @@ async def end_session( params: Dict[str, str], session: Optional[SessionContainer] = None, should_try_refresh: bool = False, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[RedirectResponse, ErrorOAuth2Response]: pass @@ -465,7 +464,7 @@ async def end_session( async def accept_logout_request( self, challenge: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[RedirectResponse, ErrorOAuth2Response]: pass @@ -473,7 +472,7 @@ async def accept_logout_request( async def reject_logout_request( self, challenge: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, str]: pass @@ -505,7 +504,7 @@ async def login_get( options: APIOptions, session: Optional[SessionContainer] = None, should_try_refresh: bool = False, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[ Dict[str, Union[str, Optional[str]]], ErrorOAuth2Response, GeneralErrorResponse ]: @@ -519,7 +518,7 @@ async def auth_get( session: Optional[SessionContainer], should_try_refresh: bool, options: APIOptions, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[ Dict[str, Union[str, Optional[str]]], ErrorOAuth2Response, GeneralErrorResponse ]: @@ -531,7 +530,7 @@ async def token_post( authorization_header: Optional[str], body: Any, options: APIOptions, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[TokenInfo, ErrorOAuth2Response, GeneralErrorResponse]: pass @@ -540,7 +539,7 @@ async def login_info_get( self, login_challenge: str, options: APIOptions, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[ Dict[Literal["status", "info"], Union[Literal["OK"], LoginInfo]], ErrorOAuth2Response, @@ -556,7 +555,7 @@ async def user_info_get( scopes: List[str], tenant_id: str, options: APIOptions, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[Dict[str, Any], GeneralErrorResponse]: pass @@ -565,7 +564,7 @@ async def revoke_token_post( self, token: str, options: APIOptions, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, authorization_header: Optional[str] = None, client_id: Optional[str] = None, client_secret: Optional[str] = None, @@ -578,7 +577,7 @@ async def introspect_token_post( token: str, scopes: Optional[List[str]], options: APIOptions, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[IntrospectTokenResponse, GeneralErrorResponse]: pass @@ -589,7 +588,7 @@ async def end_session_get( options: APIOptions, session: Optional[SessionContainer] = None, should_try_refresh: bool = False, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]: pass @@ -600,7 +599,7 @@ async def end_session_post( options: APIOptions, session: Optional[SessionContainer] = None, should_try_refresh: bool = False, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]: pass @@ -610,7 +609,7 @@ async def logout_post( logout_challenge: str, options: APIOptions, session: Optional[SessionContainer] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[ Dict[str, Union[Literal["OK"], str]], ErrorOAuth2Response, GeneralErrorResponse ]: diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py index 67d447d6b..78b2a06d1 100644 --- a/supertokens_python/recipe/oauth2provider/recipe.py +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -20,18 +20,19 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.recipe.oauth2provider.exceptions import OAuth2ProviderError from supertokens_python.recipe_module import APIHandled, RecipeModule - -from .interfaces import ( - APIOptions, -) +from supertokens_python.types import User from .recipe_implementation import RecipeImplementation +from .interfaces import APIOptions + + if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest from supertokens_python.framework.response import BaseResponse from supertokens_python.supertokens import AppInfo + from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier from supertokens_python.recipe.oauth2provider.api.implementation import ( @@ -58,6 +59,76 @@ ) +async def get_default_id_token_payload( + user: Dict[str, Any], + scopes: List[str], + session_handle: str, + user_context: Dict[str, Any] = {}, +) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + + if "email" in scopes: + payload["email"] = user.get("emails", [None])[0] + payload["email_verified"] = any( + lm.get("hasSameEmailAs")(user.get("emails", [None])[0]) + and lm.get("verified") + for lm in user.get("loginMethods", []) + ) + payload["emails"] = user.get("emails", []) + + if "phoneNumber" in scopes: + payload["phoneNumber"] = user.get("phoneNumbers", [None])[0] + payload["phoneNumber_verified"] = any( + lm.get("hasSamePhoneNumberAs")(user.get("phoneNumbers", [None])[0]) + and lm.get("verified") + for lm in user.get("loginMethods", []) + ) + payload["phoneNumbers"] = user.get("phoneNumbers", []) + + for fn in self.id_token_builders: + fn_payload = await fn(user, scopes, session_handle, user_context) + payload.update(fn_payload) + + return payload + + +async def get_default_user_info( + user: Dict[str, Any], + access_token_payload: Dict[str, Any], + scopes: List[str], + tenant_id: str, + user_context: Dict[str, Any] = {}, +) -> Dict[str, Any]: + payload: Dict[str, Any] = {"sub": access_token_payload.get("sub")} + + if "email" in scopes: + # TODO: try and get the email based on the user id of the entire user object + payload["email"] = user.get("emails", [None])[0] + payload["email_verified"] = any( + lm.get("hasSameEmailAs")(user.get("emails", [None])[0]) + and lm.get("verified") + for lm in user.get("loginMethods", []) + ) + payload["emails"] = user.get("emails", []) + + if "phoneNumber" in scopes: + payload["phoneNumber"] = user.get("phoneNumbers", [None])[0] + payload["phoneNumber_verified"] = any( + lm.get("hasSamePhoneNumberAs")(user.get("phoneNumbers", [None])[0]) + and lm.get("verified") + for lm in user.get("loginMethods", []) + ) + payload["phoneNumbers"] = user.get("phoneNumbers", []) + + for fn in self.user_info_builders: + fn_payload = await fn( + user, access_token_payload, scopes, tenant_id, user_context + ) + payload.update(fn_payload) + + return payload + + class OAuth2ProviderRecipe(RecipeModule): recipe_id = "oauth2provider" __instance = None @@ -229,6 +300,37 @@ async def handle_error( def get_all_cors_headers(self) -> List[str]: return [] + async def get_default_access_token_payload( + self, + user: User, + scopes: List[str], + session_handle: str, + user_context: Dict[str, Any] = {}, + ) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + + if "email" in scopes: + payload["email"] = user.emails[0] + payload["email_verified"] = any( + lm.has_same_email_as(user.emails[0]) and lm.verified + for lm in user.login_methods + ) + payload["emails"] = user.emails + + if "phoneNumber" in scopes: + payload["phoneNumber"] = user.phone_numbers[0] + payload["phoneNumber_verified"] = any( + lm.has_same_phone_number_as(user.phone_numbers[0]) and lm.verified + for lm in user.login_methods + ) + payload["phoneNumbers"] = user.phone_numbers + + for fn in self.access_token_builders: + fn_payload = await fn(user, scopes, session_handle, user_context) + payload.update(fn_payload) + + return payload + @staticmethod def init( override: Union[InputOverrideConfig, None] = None, diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 750a10687..8c0da7606 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -17,22 +17,17 @@ from supertokens_python import AppInfo from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe.openid.recipe import OpenIdRecipe +from supertokens_python.recipe.session.interfaces import SessionContainer from .interfaces import ( RecipeInterface, RedirectResponse, ErrorOAuth2Response, - SessionContainer, GetOAuth2ClientOkResult, - GetOAuth2ClientErrorResult, GetOAuth2ClientsOkResult, - GetOAuth2ClientsErrorResult, CreateOAuth2ClientOkResult, - CreateOAuth2ClientErrorResult, UpdateOAuth2ClientOkResult, - UpdateOAuth2ClientErrorResult, DeleteOAuth2ClientOkResult, - DeleteOAuth2ClientErrorResult, ConsentRequest, LoginRequest, OAuth2Client, @@ -84,7 +79,7 @@ async def accept_login_request( extend_session_lifespan: Optional[bool] = None, identity_provider_session_id: Optional[str] = None, subject: str = "", - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> RedirectResponse: response = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/auth/requests/login/accept"), @@ -110,7 +105,7 @@ async def reject_login_request( self, challenge: str, error: ErrorOAuth2Response, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> RedirectResponse: response = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/auth/requests/login/reject"), @@ -129,7 +124,7 @@ async def reject_login_request( ) async def get_consent_request( - self, challenge: str, user_context: Optional[Dict[str, Any]] = None + self, challenge: str, user_context: Dict[str, Any] = {} ) -> ConsentRequest: response = await self.querier.send_get_request( NormalisedURLPath("/recipe/oauth/auth/requests/consent"), @@ -151,7 +146,7 @@ async def accept_consent_request( session_handle: str = "", initial_access_token_payload: Optional[Dict[str, Any]] = None, initial_id_token_payload: Optional[Dict[str, Any]] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> RedirectResponse: response = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/auth/requests/consent/accept"), @@ -202,7 +197,7 @@ async def authorization( params: Dict[str, str], cookies: Optional[str], session: Optional[SessionContainer], - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[RedirectResponse, ErrorOAuth2Response]: pass @@ -210,7 +205,7 @@ async def token_exchange( self, authorization_header: Optional[str], body: Dict[str, Optional[str]], - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[TokenInfo, ErrorOAuth2Response]: pass @@ -219,7 +214,7 @@ async def get_oauth2_clients( page_size: Optional[int] = None, pagination_token: Optional[str] = None, client_name: Optional[str] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: body: Dict[str, Any] = {} if page_size is not None: @@ -250,7 +245,7 @@ async def get_oauth2_clients( ) async def get_oauth2_client( - self, client_id: str, user_context: Optional[Dict[str, Any]] = None + self, client_id: str, user_context: Dict[str, Any] = {} ) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: response = await self.querier.send_get_request( NormalisedURLPath("/recipe/oauth/clients"), @@ -272,7 +267,7 @@ async def get_oauth2_client( async def create_oauth2_client( self, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: response = await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/clients"), @@ -288,7 +283,7 @@ async def create_oauth2_client( async def update_oauth2_client( self, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: response = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/clients"), @@ -306,7 +301,7 @@ async def update_oauth2_client( async def delete_oauth2_client( self, client_id: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: response = await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/clients/remove"), @@ -325,7 +320,7 @@ async def validate_oauth2_access_token( token: str, requirements: Optional[Dict[str, Any]] = None, check_database: Optional[bool] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: pass @@ -335,7 +330,7 @@ async def get_requested_scopes( session_handle: Optional[str], scope_param: List[str], client_id: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> List[str]: return scope_param @@ -345,7 +340,7 @@ async def build_access_token_payload( scopes: List[str], user: Optional[Dict[str, Any]] = None, session_handle: Optional[str] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: if user is None or session_handle is None: return {} @@ -360,7 +355,7 @@ async def build_id_token_payload( scopes: List[str], user: Optional[Dict[str, Any]] = None, session_handle: Optional[str] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: if user is None or session_handle is None: return {} @@ -373,7 +368,7 @@ async def build_user_info( access_token_payload: Dict[str, Any], scopes: List[str], tenant_id: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: return get_default_user_info( user, access_token_payload, scopes, tenant_id, user_context @@ -382,7 +377,7 @@ async def build_user_info( async def get_frontend_redirection_url( self, input_type: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> str: website_domain = self.app_info.get_origin( None, user_context @@ -395,21 +390,21 @@ async def revoke_token( authorization_header: Optional[str] = None, client_id: Optional[str] = None, client_secret: Optional[str] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[Dict[str, str], ErrorOAuth2Response]: pass async def revoke_tokens_by_client_id( self, client_id: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, str]: pass async def revoke_tokens_by_session_handle( self, session_handle: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, str]: pass @@ -417,7 +412,7 @@ async def introspect_token( self, token: str, scopes: Optional[List[str]] = None, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: pass @@ -426,20 +421,20 @@ async def end_session( params: Dict[str, str], session: Optional[SessionContainer] = None, should_try_refresh: bool = False, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[RedirectResponse, ErrorOAuth2Response]: pass async def accept_logout_request( self, challenge: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Union[RedirectResponse, ErrorOAuth2Response]: pass async def reject_logout_request( self, challenge: str, - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any] = {}, ) -> Dict[str, str]: pass diff --git a/supertokens_python/recipe/openid/recipe.py b/supertokens_python/recipe/openid/recipe.py index 0d3799ae4..cda0a300d 100644 --- a/supertokens_python/recipe/openid/recipe.py +++ b/supertokens_python/recipe/openid/recipe.py @@ -170,3 +170,10 @@ def reset(): ): raise_general_exception("calling testing function in non testing env") OpenIdRecipe.__instance = None + + @staticmethod + async def get_issuer(user_context: Dict[str, Any]) -> str: + open_id_config = await OpenIdRecipe.get_instance().recipe_implementation.get_open_id_discovery_configuration( + user_context + ) + return open_id_config.issuer From e776828970ca475e9b898c4ba9a9199b2d89f322 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 29 Nov 2024 11:12:59 +0530 Subject: [PATCH 10/38] fix: recipe impl --- .../recipe/oauth2provider/interfaces.py | 6 +- .../oauth2provider/recipe_implementation.py | 83 +++++++++++++++++-- .../passwordless/recipe_implementation.py | 1 + 3 files changed, 79 insertions(+), 11 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 577867a58..f0def6811 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -430,7 +430,7 @@ async def revoke_tokens_by_client_id( self, client_id: str, user_context: Dict[str, Any] = {}, - ) -> Dict[str, str]: + ): pass @abstractmethod @@ -438,7 +438,7 @@ async def revoke_tokens_by_session_handle( self, session_handle: str, user_context: Dict[str, Any] = {}, - ) -> Dict[str, str]: + ): pass @abstractmethod @@ -473,7 +473,7 @@ async def reject_logout_request( self, challenge: str, user_context: Dict[str, Any] = {}, - ) -> Dict[str, str]: + ): pass diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 8c0da7606..d05144f92 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -398,15 +398,23 @@ async def revoke_tokens_by_client_id( self, client_id: str, user_context: Dict[str, Any] = {}, - ) -> Dict[str, str]: - pass + ): + await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/session/revoke"), + {"client_id": client_id}, + user_context=user_context, + ) async def revoke_tokens_by_session_handle( self, session_handle: str, user_context: Dict[str, Any] = {}, - ) -> Dict[str, str]: - pass + ): + await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/session/revoke"), + {"sessionHandle": session_handle}, + user_context=user_context, + ) async def introspect_token( self, @@ -414,7 +422,34 @@ async def introspect_token( scopes: Optional[List[str]] = None, user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: - pass + # Determine if the token is an access token by checking if it doesn't start with "st_rt" + is_access_token = not token.startswith("st_rt") + + # Attempt to validate the access token locally + # If it fails, the token is not active, and we return early + if is_access_token: + try: + await self.validate_oauth2_access_token( + token=token, + requirements={"scopes": scopes}, + check_database=False, + user_context=user_context, + ) + except Exception: + return {"active": False} + + # For tokens that passed local validation or if it's a refresh token, + # validate the token with the database by calling the core introspection endpoint + res = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/introspect"), + { + "token": token, + "scope": " ".join(scopes) if scopes else None, + }, + user_context=user_context, + ) + + return res async def end_session( self, @@ -430,11 +465,43 @@ async def accept_logout_request( challenge: str, user_context: Dict[str, Any] = {}, ) -> Union[RedirectResponse, ErrorOAuth2Response]: - pass + resp = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/logout/accept"), + {"challenge": challenge}, + None, + user_context=user_context, + ) + + if resp["status"] != "OK": + return ErrorOAuth2Response( + status_code=resp["statusCode"], + error=resp["error"], + error_description=resp["errorDescription"], + ) + + redirect_to = get_updated_redirect_to(self.app_info, resp["redirectTo"]) + + if redirect_to.endswith("/fallbacks/logout/callback"): + return RedirectResponse( + redirect_to=await self.get_frontend_redirection_url( + type="post-logout-fallback", + user_context=user_context, + ) + ) + + return RedirectResponse(redirect_to=redirect_to) async def reject_logout_request( self, challenge: str, user_context: Dict[str, Any] = {}, - ) -> Dict[str, str]: - pass + ): + resp = await self.querier.send_put_request( + NormalisedURLPath("/recipe/oauth/auth/requests/logout/reject"), + {}, + {"challenge": challenge}, + user_context=user_context, + ) + + if resp["status"] != "OK": + raise Exception(resp["error"]) diff --git a/supertokens_python/recipe/passwordless/recipe_implementation.py b/supertokens_python/recipe/passwordless/recipe_implementation.py index d727d8f30..070873122 100644 --- a/supertokens_python/recipe/passwordless/recipe_implementation.py +++ b/supertokens_python/recipe/passwordless/recipe_implementation.py @@ -526,6 +526,7 @@ async def update_user( response = await self.querier.send_put_request( NormalisedURLPath("/recipe/user"), input_dict, + None, user_context=user_context, ) if response["status"] == "UNKNOWN_USER_ID_ERROR": From f4622843a2c8860533ea9a2936e842f7a7df94ff Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 29 Nov 2024 11:37:05 +0530 Subject: [PATCH 11/38] fix: recipe impl --- .../recipe/oauth2provider/interfaces.py | 17 +- .../recipe/oauth2provider/recipe.py | 256 +++++++++++------- .../oauth2provider/recipe_implementation.py | 37 ++- 3 files changed, 194 insertions(+), 116 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index f0def6811..8d3b222e9 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -13,7 +13,7 @@ # under the License. from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union from typing_extensions import Literal from supertokens_python.recipe.session import SessionContainer from supertokens_python.types import APIResponse, GeneralErrorResponse, User @@ -231,6 +231,15 @@ def __init__(self): pass +PayloadBuilderFunction = Callable[ + [User, List[str], str, Dict[str, Any]], Awaitable[Dict[str, Any]] +] + +UserInfoBuilderFunction = Callable[ + [User, Dict[str, Any], List[str], str, Dict[str, Any]], Awaitable[Dict[str, Any]] +] + + class RecipeInterface(ABC): def __init__(self): pass @@ -376,7 +385,7 @@ async def get_requested_scopes( @abstractmethod async def build_access_token_payload( self, - user: Optional[Dict[str, Any]], + user: Optional[User], client: OAuth2Client, session_handle: Optional[str], scopes: List[str], @@ -387,7 +396,7 @@ async def build_access_token_payload( @abstractmethod async def build_id_token_payload( self, - user: Optional[Dict[str, Any]], + user: Optional[User], client: OAuth2Client, session_handle: Optional[str], scopes: List[str], @@ -398,7 +407,7 @@ async def build_id_token_payload( @abstractmethod async def build_user_info( self, - user: Dict[str, Any], + user: User, access_token_payload: Dict[str, Any], scopes: List[str], tenant_id: str, diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py index 78b2a06d1..5059ca40b 100644 --- a/supertokens_python/recipe/oauth2provider/recipe.py +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -24,7 +24,7 @@ from .recipe_implementation import RecipeImplementation -from .interfaces import APIOptions +from .interfaces import APIOptions, PayloadBuilderFunction, UserInfoBuilderFunction if TYPE_CHECKING: @@ -59,76 +59,6 @@ ) -async def get_default_id_token_payload( - user: Dict[str, Any], - scopes: List[str], - session_handle: str, - user_context: Dict[str, Any] = {}, -) -> Dict[str, Any]: - payload: Dict[str, Any] = {} - - if "email" in scopes: - payload["email"] = user.get("emails", [None])[0] - payload["email_verified"] = any( - lm.get("hasSameEmailAs")(user.get("emails", [None])[0]) - and lm.get("verified") - for lm in user.get("loginMethods", []) - ) - payload["emails"] = user.get("emails", []) - - if "phoneNumber" in scopes: - payload["phoneNumber"] = user.get("phoneNumbers", [None])[0] - payload["phoneNumber_verified"] = any( - lm.get("hasSamePhoneNumberAs")(user.get("phoneNumbers", [None])[0]) - and lm.get("verified") - for lm in user.get("loginMethods", []) - ) - payload["phoneNumbers"] = user.get("phoneNumbers", []) - - for fn in self.id_token_builders: - fn_payload = await fn(user, scopes, session_handle, user_context) - payload.update(fn_payload) - - return payload - - -async def get_default_user_info( - user: Dict[str, Any], - access_token_payload: Dict[str, Any], - scopes: List[str], - tenant_id: str, - user_context: Dict[str, Any] = {}, -) -> Dict[str, Any]: - payload: Dict[str, Any] = {"sub": access_token_payload.get("sub")} - - if "email" in scopes: - # TODO: try and get the email based on the user id of the entire user object - payload["email"] = user.get("emails", [None])[0] - payload["email_verified"] = any( - lm.get("hasSameEmailAs")(user.get("emails", [None])[0]) - and lm.get("verified") - for lm in user.get("loginMethods", []) - ) - payload["emails"] = user.get("emails", []) - - if "phoneNumber" in scopes: - payload["phoneNumber"] = user.get("phoneNumbers", [None])[0] - payload["phoneNumber_verified"] = any( - lm.get("hasSamePhoneNumberAs")(user.get("phoneNumbers", [None])[0]) - and lm.get("verified") - for lm in user.get("loginMethods", []) - ) - payload["phoneNumbers"] = user.get("phoneNumbers", []) - - for fn in self.user_info_builders: - fn_payload = await fn( - user, access_token_payload, scopes, tenant_id, user_context - ) - payload.update(fn_payload) - - return payload - - class OAuth2ProviderRecipe(RecipeModule): recipe_id = "oauth2provider" __instance = None @@ -144,7 +74,13 @@ def __init__( override, ) - recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) + recipe_implementation = RecipeImplementation( + Querier.get_instance(recipe_id), + app_info, + self.get_default_access_token_payload, + self.get_default_id_token_payload, + self.get_default_user_info_payload, + ) self.recipe_implementation: RecipeImplementation = ( recipe_implementation if self.config.override.functions is None @@ -158,6 +94,10 @@ def __init__( else self.config.override.apis(api_implementation) ) + self._access_token_builders: List[PayloadBuilderFunction] = [] + self._id_token_builders: List[PayloadBuilderFunction] = [] + self._user_info_builders: List[UserInfoBuilderFunction] = [] + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, OAuth2ProviderError) @@ -300,37 +240,6 @@ async def handle_error( def get_all_cors_headers(self) -> List[str]: return [] - async def get_default_access_token_payload( - self, - user: User, - scopes: List[str], - session_handle: str, - user_context: Dict[str, Any] = {}, - ) -> Dict[str, Any]: - payload: Dict[str, Any] = {} - - if "email" in scopes: - payload["email"] = user.emails[0] - payload["email_verified"] = any( - lm.has_same_email_as(user.emails[0]) and lm.verified - for lm in user.login_methods - ) - payload["emails"] = user.emails - - if "phoneNumber" in scopes: - payload["phoneNumber"] = user.phone_numbers[0] - payload["phoneNumber_verified"] = any( - lm.has_same_phone_number_as(user.phone_numbers[0]) and lm.verified - for lm in user.login_methods - ) - payload["phoneNumbers"] = user.phone_numbers - - for fn in self.access_token_builders: - fn_payload = await fn(user, scopes, session_handle, user_context) - payload.update(fn_payload) - - return payload - @staticmethod def init( override: Union[InputOverrideConfig, None] = None, @@ -369,3 +278,144 @@ def reset(): ): raise_general_exception("calling testing function in non testing env") OAuth2ProviderRecipe.__instance = None + + def add_user_info_builder_from_other_recipe( + self, user_info_builder_fn: UserInfoBuilderFunction + ) -> None: + self._user_info_builders.append(user_info_builder_fn) + + def add_access_token_builder_from_other_recipe( + self, access_token_builder: PayloadBuilderFunction + ) -> None: + self._access_token_builders.append(access_token_builder) + + def add_id_token_builder_from_other_recipe( + self, id_token_builder: PayloadBuilderFunction + ) -> None: + self._id_token_builders.append(id_token_builder) + + async def get_default_access_token_payload( + self, + user: User, + scopes: List[str], + session_handle: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + + if "email" in scopes: + payload["email"] = user.emails[0] if user.emails else None + payload["email_verified"] = ( + any( + lm.has_same_email_as(user.emails[0]) and lm.verified + for lm in user.login_methods + ) + if user.emails + else False + ) + payload["emails"] = user.emails + + if "phoneNumber" in scopes: + payload["phoneNumber"] = ( + user.phone_numbers[0] if user.phone_numbers else None + ) + payload["phoneNumber_verified"] = ( + any( + lm.has_same_phone_number_as(user.phone_numbers[0]) and lm.verified + for lm in user.login_methods + ) + if user.phone_numbers + else False + ) + payload["phoneNumbers"] = user.phone_numbers + + for fn in self._access_token_builders: + builder_payload = await fn(user, scopes, session_handle, user_context) + payload.update(builder_payload) + + return payload + + async def get_default_id_token_payload( + self, + user: User, + scopes: List[str], + session_handle: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + + if "email" in scopes: + payload["email"] = user.emails[0] if user.emails else None + payload["email_verified"] = ( + any( + lm.has_same_email_as(user.emails[0]) and lm.verified + for lm in user.login_methods + ) + if user.emails + else False + ) + payload["emails"] = user.emails + + if "phoneNumber" in scopes: + payload["phoneNumber"] = ( + user.phone_numbers[0] if user.phone_numbers else None + ) + payload["phoneNumber_verified"] = ( + any( + lm.has_same_phone_number_as(user.phone_numbers[0]) and lm.verified + for lm in user.login_methods + ) + if user.phone_numbers + else False + ) + payload["phoneNumbers"] = user.phone_numbers + + for fn in self._id_token_builders: + builder_payload = await fn(user, scopes, session_handle, user_context) + payload.update(builder_payload) + + return payload + + async def get_default_user_info_payload( + self, + user: User, + access_token_payload: Dict[str, Any], + scopes: List[str], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + payload: Dict[str, Any] = {"sub": access_token_payload["sub"]} + + if "email" in scopes: + payload["email"] = user.emails[0] if user.emails else None + payload["email_verified"] = ( + any( + lm.has_same_email_as(user.emails[0]) and lm.verified + for lm in user.login_methods + ) + if user.emails + else False + ) + payload["emails"] = user.emails + + if "phoneNumber" in scopes: + payload["phoneNumber"] = ( + user.phone_numbers[0] if user.phone_numbers else None + ) + payload["phoneNumber_verified"] = ( + any( + lm.has_same_phone_number_as(user.phone_numbers[0]) and lm.verified + for lm in user.login_methods + ) + if user.phone_numbers + else False + ) + payload["phoneNumbers"] = user.phone_numbers + + for fn in self._user_info_builders: + builder_payload = await fn( + user, access_token_payload, scopes, tenant_id, user_context + ) + payload.update(builder_payload) + + return payload diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index d05144f92..9164d75ff 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -18,8 +18,10 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.recipe.session.interfaces import SessionContainer +from supertokens_python.types import User from .interfaces import ( + PayloadBuilderFunction, RecipeInterface, RedirectResponse, ErrorOAuth2Response, @@ -32,6 +34,7 @@ LoginRequest, OAuth2Client, TokenInfo, + UserInfoBuilderFunction, ) @@ -48,10 +51,20 @@ def get_updated_redirect_to(app_info: AppInfo, redirect_to: str) -> str: class RecipeImplementation(RecipeInterface): - def __init__(self, querier: Querier, app_info: AppInfo): + def __init__( + self, + querier: Querier, + app_info: AppInfo, + get_default_access_token_payload: PayloadBuilderFunction, + get_default_id_token_payload: PayloadBuilderFunction, + get_default_user_info_payload: UserInfoBuilderFunction, + ): super().__init__() self.querier = querier self.app_info = app_info + self._get_default_access_token_payload = get_default_access_token_payload + self._get_default_id_token_payload = get_default_id_token_payload + self._get_default_user_info_payload = get_default_user_info_payload async def get_login_request( self, challenge: str, user_context: Dict[str, Any] @@ -336,41 +349,47 @@ async def get_requested_scopes( async def build_access_token_payload( self, + user: Optional[User], client: OAuth2Client, + session_handle: Optional[str], scopes: List[str], - user: Optional[Dict[str, Any]] = None, - session_handle: Optional[str] = None, user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: if user is None or session_handle is None: return {} - return get_default_access_token_payload( + _ = client + + return await self._get_default_access_token_payload( user, scopes, session_handle, user_context ) async def build_id_token_payload( self, + user: Optional[User], client: OAuth2Client, + session_handle: Optional[str], scopes: List[str], - user: Optional[Dict[str, Any]] = None, - session_handle: Optional[str] = None, user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: if user is None or session_handle is None: return {} - return get_default_id_token_payload(user, scopes, session_handle, user_context) + _ = client + + return await self._get_default_id_token_payload( + user, scopes, session_handle, user_context + ) async def build_user_info( self, - user: Dict[str, Any], + user: User, access_token_payload: Dict[str, Any], scopes: List[str], tenant_id: str, user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: - return get_default_user_info( + return await self._get_default_user_info_payload( user, access_token_payload, scopes, tenant_id, user_context ) From c83ef6fa1defbaa80d3097f12c265478c39e17bf Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 29 Nov 2024 12:29:47 +0530 Subject: [PATCH 12/38] fix: validate_oauth2_access_token --- .../recipe/oauth2provider/interfaces.py | 14 ++++- .../oauth2provider/recipe_implementation.py | 56 ++++++++++++++++++- 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 8d3b222e9..b0611e74e 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -240,6 +240,18 @@ def __init__(self): ] +class OAuth2TokenValidationRequirements: + def __init__( + self, + client_id: Optional[str] = None, + scopes: Optional[List[str]] = None, + audience: Optional[str] = None, + ): + self.client_id = client_id + self.scopes = scopes + self.audience = audience + + class RecipeInterface(ABC): def __init__(self): pass @@ -365,7 +377,7 @@ async def delete_oauth2_client( async def validate_oauth2_access_token( self, token: str, - requirements: Optional[Dict[str, Any]] = None, + requirements: Optional[OAuth2TokenValidationRequirements] = None, check_database: Optional[bool] = None, user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 9164d75ff..2443a1f9f 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -14,13 +14,18 @@ from typing import TYPE_CHECKING, Dict, Optional, Any, Union, List +import jwt + from supertokens_python import AppInfo from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.recipe.session.interfaces import SessionContainer +from supertokens_python.recipe.session.jwks import get_latest_keys +from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.types import User from .interfaces import ( + OAuth2TokenValidationRequirements, PayloadBuilderFunction, RecipeInterface, RedirectResponse, @@ -331,11 +336,58 @@ async def delete_oauth2_client( async def validate_oauth2_access_token( self, token: str, - requirements: Optional[Dict[str, Any]] = None, + requirements: Optional[OAuth2TokenValidationRequirements] = None, check_database: Optional[bool] = None, user_context: Dict[str, Any] = {}, ) -> Dict[str, Any]: - pass + # Verify token signature using session recipe's JWKS + session_recipe = SessionRecipe.get_instance() + matching_keys = get_latest_keys(session_recipe.config) + payload = jwt.decode( + token, + matching_keys[0].key, + algorithms=["RS256"], + options={"verify_signature": True, "verify_exp": True}, + ) + + if payload.get("stt") != 1: + raise Exception("Wrong token type") + + if requirements is not None and requirements.client_id is not None: + if payload.get("client_id") != requirements.client_id: + raise Exception( + f"The token doesn't belong to the specified client ({requirements.client_id} !== {payload.get('client_id')})" + ) + + if requirements is not None and requirements.scopes is not None: + token_scopes = payload.get("scp", []) + if not isinstance(token_scopes, list): + token_scopes = [token_scopes] + + if any(scope not in token_scopes for scope in requirements.scopes): + raise Exception("The token is missing some required scopes") + + aud = payload.get("aud", []) + if not isinstance(aud, list): + aud = [aud] + + if requirements is not None and requirements.audience is not None: + if requirements.audience not in aud: + raise Exception("The token doesn't belong to the specified audience") + + if check_database: + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/introspect"), + { + "token": token, + }, + user_context=user_context, + ) + + if response.get("active") is not True: + raise Exception("The token is expired, invalid or has been revoked") + + return {"status": "OK", "payload": payload} async def get_requested_scopes( self, From 2f7e994950591f48920c1c0b5eb5d48c1bd1641d Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 29 Nov 2024 12:49:51 +0530 Subject: [PATCH 13/38] fix: authorization --- .../recipe/oauth2provider/interfaces.py | 9 +- .../oauth2provider/recipe_implementation.py | 153 +++++++++++++++++- 2 files changed, 157 insertions(+), 5 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index b0611e74e..7e3b9b877 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -16,7 +16,12 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union from typing_extensions import Literal from supertokens_python.recipe.session import SessionContainer -from supertokens_python.types import APIResponse, GeneralErrorResponse, User +from supertokens_python.types import ( + APIResponse, + GeneralErrorResponse, + RecipeUserId, + User, +) from .oauth2_client import OAuth2Client @@ -386,7 +391,7 @@ async def validate_oauth2_access_token( @abstractmethod async def get_requested_scopes( self, - recipe_user_id: Optional[str], + recipe_user_id: Optional[RecipeUserId], session_handle: Optional[str], scope_param: List[str], client_id: str, diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 2443a1f9f..8b9204bfc 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -13,16 +13,18 @@ # under the License. from typing import TYPE_CHECKING, Dict, Optional, Any, Union, List +from urllib.parse import parse_qs, urlparse import jwt from supertokens_python import AppInfo +from supertokens_python.asyncio import get_user from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.recipe.session.interfaces import SessionContainer from supertokens_python.recipe.session.jwks import get_latest_keys from supertokens_python.recipe.session.recipe import SessionRecipe -from supertokens_python.types import User +from supertokens_python.types import RecipeUserId, User from .interfaces import ( OAuth2TokenValidationRequirements, @@ -217,7 +219,147 @@ async def authorization( session: Optional[SessionContainer], user_context: Dict[str, Any] = {}, ) -> Union[RedirectResponse, ErrorOAuth2Response]: - pass + # we handle this in the backend SDK level + if params.get("prompt") == "none": + params["st_prompt"] = "none" + del params["prompt"] + + payloads = None + + if not params.get("client_id") or not isinstance(params.get("client_id"), str): + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="client_id is required and must be a string", + ) + + scopes = await self.get_requested_scopes( + scope_param=params.get("scope", "").split() if params.get("scope") else [], + client_id=params["client_id"], + recipe_user_id=( + session.get_recipe_user_id() if session is not None else None + ), + session_handle=session.get_handle() if session else None, + user_context=user_context, + ) + + response_types = ( + params.get("response_type", "").split() + if params.get("response_type") + else [] + ) + + if session is not None: + client_info = await self.get_oauth2_client( + client_id=params["client_id"], user_context=user_context + ) + + if isinstance(client_info, ErrorOAuth2Response): + return ErrorOAuth2Response( + status_code=400, + error=client_info.error, + error_description=client_info.error_description, + ) + + client = client_info.client + + user = await get_user(session.get_user_id()) + if not user: + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="User deleted", + ) + + # These default to empty dicts, because we want to keep them as required input + # but they'll not be actually used in flows where we are not building them + id_token = {} + if "openid" in scopes and ( + "id_token" in response_types or "code" in response_types + ): + id_token = await self.build_id_token_payload( + user=user, + client=client, + session_handle=session.get_handle(), + scopes=scopes, + user_context=user_context, + ) + + access_token = {} + if "token" in response_types or "code" in response_types: + access_token = await self.build_access_token_payload( + user=user, + client=client, + session_handle=session.get_handle(), + scopes=scopes, + user_context=user_context, + ) + + payloads = {"idToken": id_token, "accessToken": access_token} + + resp = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/auth"), + { + "params": {**params, "scope": " ".join(scopes)}, + "iss": await OpenIdRecipe.get_issuer(user_context), + "cookies": cookies, + "session": payloads, + }, + user_context, + ) + + if resp["status"] == "CLIENT_NOT_FOUND_ERROR": + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="The provided client_id is not valid", + ) + + if resp["status"] != "OK": + return ErrorOAuth2Response( + status_code=resp["statusCode"], + error=resp["error"], + error_description=resp["errorDescription"], + ) + + if resp.get("redirectTo") is None: + raise Exception(resp) + redirect_to = get_updated_redirect_to(self.app_info, resp["redirectTo"]) + + redirect_to_query_params_str = urlparse(redirect_to).query + redirect_to_query_params: Dict[str, List[str]] = parse_qs( + redirect_to_query_params_str + ) + consent_challenge: Optional[str] = None + + if "consent_challenge" in redirect_to_query_params: + if len(redirect_to_query_params["consent_challenge"]) > 0: + consent_challenge = redirect_to_query_params["consent_challenge"][0] + + if consent_challenge is not None and session is not None: + consent_request = await self.get_consent_request( + challenge=consent_challenge, user_context=user_context + ) + + consent_res = await self.accept_consent_request( + user_context=user_context, + challenge=consent_request.challenge, + grant_access_token_audience=consent_request.requested_access_token_audience, + grant_scope=consent_request.requested_scope, + tenant_id=session.get_tenant_id(), + rsub=session.get_recipe_user_id().get_as_string(), + session_handle=session.get_handle(), + initial_access_token_payload=( + payloads.get("accessToken") if payloads else None + ), + initial_id_token_payload=payloads.get("idToken") if payloads else None, + ) + + return RedirectResponse( + redirect_to=consent_res.redirect_to, cookies=resp["cookies"] + ) + + return RedirectResponse(redirect_to=redirect_to, cookies=resp["cookies"]) async def token_exchange( self, @@ -391,12 +533,17 @@ async def validate_oauth2_access_token( async def get_requested_scopes( self, - recipe_user_id: Optional[str], + recipe_user_id: Optional[RecipeUserId], session_handle: Optional[str], scope_param: List[str], client_id: str, user_context: Dict[str, Any] = {}, ) -> List[str]: + _ = recipe_user_id + _ = session_handle + _ = client_id + _ = user_context + return scope_param async def build_access_token_payload( From 00a6128df81bdedfe058e2ac92c1426ffb654f51 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 29 Nov 2024 12:58:18 +0530 Subject: [PATCH 14/38] fix: token exchange --- .../recipe/oauth2provider/interfaces.py | 11 ++ .../oauth2provider/recipe_implementation.py | 133 +++++++++++++++++- 2 files changed, 143 insertions(+), 1 deletion(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 7e3b9b877..faae56f7a 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -162,6 +162,17 @@ def __init__( self.scope = scope self.token_type = token_type + @staticmethod + def from_json(json: Dict[str, Any]): + return TokenInfo( + access_token=json["access_token"], + expires_in=json["expires_in"], + id_token=json["id_token"], + refresh_token=json["refresh_token"], + scope=json["scope"], + token_type=json["token_type"], + ) + class LoginInfo: def __init__( diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 8b9204bfc..d05d3890e 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +import base64 from typing import TYPE_CHECKING, Dict, Optional, Any, Union, List from urllib.parse import parse_qs, urlparse @@ -367,7 +368,137 @@ async def token_exchange( body: Dict[str, Optional[str]], user_context: Dict[str, Any] = {}, ) -> Union[TokenInfo, ErrorOAuth2Response]: - pass + request_body = { + "iss": await OpenIdRecipe.get_issuer(user_context), + "inputBody": body, + "authorizationHeader": authorization_header, + } + + if body.get("grant_type") == "password": + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="Unsupported grant type: password", + ) + + if body.get("grant_type") == "client_credentials": + client_id = None + if authorization_header: + # Extract client_id from Basic auth header + decoded = base64.b64decode( + authorization_header.replace("Basic ", "").strip() + ).decode() + client_id = decoded.split(":")[0] + else: + client_id = body.get("client_id") + + if not client_id: + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="client_id is required", + ) + + scopes = str(body.get("scope", "")).split() if body.get("scope") else [] + + client_info = await self.get_oauth2_client( + client_id=client_id, user_context=user_context + ) + + if isinstance(client_info, ErrorOAuth2Response): + return ErrorOAuth2Response( + status_code=400, + error=client_info.error, + error_description=client_info.error_description, + ) + + client = client_info.client + request_body["id_token"] = await self.build_id_token_payload( + user=None, + client=client, + session_handle=None, + scopes=scopes, + user_context=user_context, + ) + request_body["access_token"] = await self.build_access_token_payload( + user=None, + client=client, + session_handle=None, + scopes=scopes, + user_context=user_context, + ) + + if body.get("grant_type") == "refresh_token": + scopes = str(body.get("scope", "")).split() if body.get("scope") else [] + token_info = await self.introspect_token( + token=str(body["refresh_token"]), + scopes=scopes, + user_context=user_context, + ) + + if token_info.get("active"): + session_handle = token_info["sessionHandle"] + + client_info = await self.get_oauth2_client( + client_id=token_info["client_id"], user_context=user_context + ) + + if isinstance(client_info, ErrorOAuth2Response): + return ErrorOAuth2Response( + status_code=400, + error=client_info.error, + error_description=client_info.error_description, + ) + + client = client_info.client + user = await get_user(token_info["sub"]) + + if not user: + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="User not found", + ) + + request_body["id_token"] = await self.build_id_token_payload( + user=user, + client=client, + session_handle=session_handle, + scopes=scopes, + user_context=user_context, + ) + request_body["access_token"] = await self.build_access_token_payload( + user=user, + client=client, + session_handle=session_handle, + scopes=scopes, + user_context=user_context, + ) + + if authorization_header: + request_body["authorizationHeader"] = authorization_header + + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/token"), + request_body, + user_context=user_context, + ) + + if response["status"] == "CLIENT_NOT_FOUND_ERROR": + return ErrorOAuth2Response( + status_code=400, + error="invalid_request", + error_description="client_id not found", + ) + + if response["status"] != "OK": + return ErrorOAuth2Response( + status_code=response["statusCode"], + error=response["error"], + error_description=response["errorDescription"], + ) + + return TokenInfo.from_json(response) async def get_oauth2_clients( self, From 6f6b6e40b548bd33f34c0c7f7830625dfd3b90b1 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 29 Nov 2024 15:34:18 +0530 Subject: [PATCH 15/38] fix: frontend redirection url --- .../recipe/oauth2provider/interfaces.py | 35 ++++++++++++++++++- .../oauth2provider/recipe_implementation.py | 35 ++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index faae56f7a..ebdfaa048 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -268,6 +268,34 @@ def __init__( self.audience = audience +class FrontendRedirectionURLTypeLogin: + def __init__( + self, + login_challenge: str, + tenant_id: str, + force_fresh_auth: bool, + hint: Optional[str] = None, + ): + self.login_challenge = login_challenge + self.tenant_id = tenant_id + self.force_fresh_auth = force_fresh_auth + self.hint = hint + + +class FrontendRedirectionURLTypeTryRefresh: + def __init__(self, login_challenge: str): + self.login_challenge = login_challenge + + +class FrontendRedirectionURLTypeLogoutConfirmation: + def __init__(self, logout_challenge: str): + self.logout_challenge = logout_challenge + + +class FrontendRedirectionURLTypePostLogoutFallback: + pass + + class RecipeInterface(ABC): def __init__(self): pass @@ -446,7 +474,12 @@ async def build_user_info( @abstractmethod async def get_frontend_redirection_url( self, - input_type: str, + input: Union[ + FrontendRedirectionURLTypeLogin, + FrontendRedirectionURLTypeTryRefresh, + FrontendRedirectionURLTypeLogoutConfirmation, + FrontendRedirectionURLTypePostLogoutFallback, + ], user_context: Dict[str, Any] = {}, ) -> str: pass diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index d05d3890e..226817fe7 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -15,6 +15,7 @@ import base64 from typing import TYPE_CHECKING, Dict, Optional, Any, Union, List from urllib.parse import parse_qs, urlparse +import urllib.parse import jwt @@ -28,6 +29,10 @@ from supertokens_python.types import RecipeUserId, User from .interfaces import ( + FrontendRedirectionURLTypeLogin, + FrontendRedirectionURLTypeLogoutConfirmation, + FrontendRedirectionURLTypePostLogoutFallback, + FrontendRedirectionURLTypeTryRefresh, OAuth2TokenValidationRequirements, PayloadBuilderFunction, RecipeInterface, @@ -725,7 +730,12 @@ async def build_user_info( async def get_frontend_redirection_url( self, - input_type: str, + input: Union[ + FrontendRedirectionURLTypeLogin, + FrontendRedirectionURLTypeTryRefresh, + FrontendRedirectionURLTypeLogoutConfirmation, + FrontendRedirectionURLTypePostLogoutFallback, + ], user_context: Dict[str, Any] = {}, ) -> str: website_domain = self.app_info.get_origin( @@ -733,6 +743,29 @@ async def get_frontend_redirection_url( ).get_as_string_dangerous() website_base_path = self.app_info.api_base_path.get_as_string_dangerous() + if isinstance(input, FrontendRedirectionURLTypeLogin): + query_params: Dict[str, str] = {"loginChallenge": input.login_challenge} + if input.tenant_id != "public": # DEFAULT_TENANT_ID is "public" + query_params["tenantId"] = input.tenant_id + if input.hint is not None: + query_params["hint"] = input.hint + if input.force_fresh_auth: + query_params["forceFreshAuth"] = "true" + + query_string = "&".join( + f"{k}={urllib.parse.quote(str(v))}" for k, v in query_params.items() + ) + return f"{website_domain}{website_base_path}?{query_string}" + + elif isinstance(input, FrontendRedirectionURLTypeTryRefresh): + return f"{website_domain}{website_base_path}/try-refresh?loginChallenge={input.login_challenge}" + + elif isinstance(input, FrontendRedirectionURLTypePostLogoutFallback): + return f"{website_domain}{website_base_path}" + + else: # isinstance(input, FrontendRedirectionURLTypeLogoutConfirmation) + return f"{website_domain}{website_base_path}/oauth/logout?logoutChallenge={input.logout_challenge}" + async def revoke_token( self, token: str, From 07868f1fc09f91fc5d1484b65c6dc3022501d466 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 29 Nov 2024 15:43:02 +0530 Subject: [PATCH 16/38] fix: revoke token --- .../recipe/oauth2provider/interfaces.py | 23 +++++++++--- .../oauth2provider/recipe_implementation.py | 35 +++++++++++++++---- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index ebdfaa048..a7fe53d82 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -296,6 +296,19 @@ class FrontendRedirectionURLTypePostLogoutFallback: pass +class RevokeTokenUsingAuthorizationHeader: + def __init__(self, token: str, authorization_header: str): + self.token = token + self.authorization_header = authorization_header + + +class RevokeTokenUsingClientIDAndClientSecret: + def __init__(self, token: str, client_id: str, client_secret: str): + self.token = token + self.client_id = client_id + self.client_secret = client_secret + + class RecipeInterface(ABC): def __init__(self): pass @@ -487,12 +500,12 @@ async def get_frontend_redirection_url( @abstractmethod async def revoke_token( self, - token: str, - authorization_header: Optional[str] = None, - client_id: Optional[str] = None, - client_secret: Optional[str] = None, + input: Union[ + RevokeTokenUsingAuthorizationHeader, + RevokeTokenUsingClientIDAndClientSecret, + ], user_context: Dict[str, Any] = {}, - ) -> Union[Dict[str, str], ErrorOAuth2Response]: + ) -> Optional[ErrorOAuth2Response]: pass @abstractmethod diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 226817fe7..5e935886e 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -41,6 +41,8 @@ GetOAuth2ClientOkResult, GetOAuth2ClientsOkResult, CreateOAuth2ClientOkResult, + RevokeTokenUsingAuthorizationHeader, + RevokeTokenUsingClientIDAndClientSecret, UpdateOAuth2ClientOkResult, DeleteOAuth2ClientOkResult, ConsentRequest, @@ -768,13 +770,34 @@ async def get_frontend_redirection_url( async def revoke_token( self, - token: str, - authorization_header: Optional[str] = None, - client_id: Optional[str] = None, - client_secret: Optional[str] = None, + input: Union[ + RevokeTokenUsingAuthorizationHeader, + RevokeTokenUsingClientIDAndClientSecret, + ], user_context: Dict[str, Any] = {}, - ) -> Union[Dict[str, str], ErrorOAuth2Response]: - pass + ) -> Optional[ErrorOAuth2Response]: + request_body = {"token": input.token} + + if isinstance(input, RevokeTokenUsingAuthorizationHeader): + request_body["authorizationHeader"] = input.authorization_header + else: + request_body["client_id"] = input.client_id + request_body["client_secret"] = input.client_secret + + res = await self.querier.send_post_request( + NormalisedURLPath("/recipe/oauth/token/revoke"), + request_body, + user_context=user_context, + ) + + if res.get("status") != "OK": + return ErrorOAuth2Response( + status_code=res.get("statusCode"), + error=str(res.get("error")), + error_description=str(res.get("errorDescription")), + ) + + return None async def revoke_tokens_by_client_id( self, From 4386cb8b8af34ae656c0401e87e2cfdf79ffed2d Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 29 Nov 2024 15:48:01 +0530 Subject: [PATCH 17/38] fix: end session --- .../recipe/oauth2provider/interfaces.py | 2 +- .../oauth2provider/recipe_implementation.py | 81 ++++++++++++++++++- 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index a7fe53d82..712385e7a 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -537,8 +537,8 @@ async def introspect_token( async def end_session( self, params: Dict[str, str], + should_try_refresh: bool, session: Optional[SessionContainer] = None, - should_try_refresh: bool = False, user_context: Dict[str, Any] = {}, ) -> Union[RedirectResponse, ErrorOAuth2Response]: pass diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 5e935886e..7df7752eb 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -836,7 +836,11 @@ async def introspect_token( try: await self.validate_oauth2_access_token( token=token, - requirements={"scopes": scopes}, + requirements=( + OAuth2TokenValidationRequirements(scopes=scopes) + if scopes + else None + ), check_database=False, user_context=user_context, ) @@ -859,11 +863,80 @@ async def introspect_token( async def end_session( self, params: Dict[str, str], + should_try_refresh: bool, session: Optional[SessionContainer] = None, - should_try_refresh: bool = False, user_context: Dict[str, Any] = {}, ) -> Union[RedirectResponse, ErrorOAuth2Response]: - pass + # NOTE: The API response has 3 possible cases: + # + # CASE 1: end_session request with a valid id_token_hint + # - Redirects to /oauth/logout with a logout_challenge. + # + # CASE 2: end_session request with an already logged out id_token_hint + # - Redirects to the post_logout_redirect_uri or the default logout fallback page. + # + # CASE 3: end_session request with a logout_verifier (after accepting the logout request) + # - Redirects to the post_logout_redirect_uri or the default logout fallback page. + + resp = await self.querier.send_get_request( + NormalisedURLPath("/recipe/oauth/sessions/logout"), + { + "clientId": params.get("client_id"), + "idTokenHint": params.get("id_token_hint"), + "postLogoutRedirectUri": params.get("post_logout_redirect_uri"), + "state": params.get("state"), + "logoutVerifier": params.get("logout_verifier"), + }, + user_context=user_context, + ) + + if "error" in resp: + return ErrorOAuth2Response( + status_code=resp["statusCode"], + error=resp["error"], + error_description=resp["errorDescription"], + ) + + redirect_to = get_updated_redirect_to(self.app_info, resp["redirectTo"]) + + initial_redirect_url = urlparse(redirect_to) + query_params = parse_qs(initial_redirect_url.query) + logout_challenge = query_params.get("logout_challenge", [None])[0] + + # CASE 1 (See above notes) + if logout_challenge is not None: + # Redirect to the frontend to ask for logout confirmation if there is a valid or expired supertokens session + if session is not None or should_try_refresh: + return RedirectResponse( + redirect_to=await self.get_frontend_redirection_url( + FrontendRedirectionURLTypeLogoutConfirmation( + logout_challenge=logout_challenge + ), + user_context=user_context, + ) + ) + else: + # Accept the logout challenge immediately as there is no supertokens session + accept_logout_response = await self.accept_logout_request( + challenge=logout_challenge, user_context=user_context + ) + if isinstance(accept_logout_response, ErrorOAuth2Response): + return accept_logout_response + return RedirectResponse(redirect_to=accept_logout_response.redirect_to) + + # CASE 2 or 3 (See above notes) + + # NOTE: If no post_logout_redirect_uri is provided, Hydra redirects to a fallback page. + # In this case, we redirect the user to the /auth page. + if redirect_to.endswith("/fallbacks/logout/callback"): + return RedirectResponse( + redirect_to=await self.get_frontend_redirection_url( + FrontendRedirectionURLTypePostLogoutFallback(), + user_context=user_context, + ) + ) + + return RedirectResponse(redirect_to=redirect_to) async def accept_logout_request( self, @@ -889,7 +962,7 @@ async def accept_logout_request( if redirect_to.endswith("/fallbacks/logout/callback"): return RedirectResponse( redirect_to=await self.get_frontend_redirection_url( - type="post-logout-fallback", + FrontendRedirectionURLTypePostLogoutFallback(), user_context=user_context, ) ) From 2c06ffbb054ea78e08c8a9e4b97e271224c3f968 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Mon, 2 Dec 2024 08:43:53 +0530 Subject: [PATCH 18/38] fix: api stubs --- .../recipe/oauth2provider/api/__init__.py | 7 +++ .../recipe/oauth2provider/api/auth.py | 37 ++++++++++++++ .../recipe/oauth2provider/api/end_session.py | 49 +++++++++++++++++++ .../oauth2provider/api/introspect_token.py | 37 ++++++++++++++ .../recipe/oauth2provider/api/login.py | 37 ++++++++++++++ .../recipe/oauth2provider/api/login_info.py | 37 ++++++++++++++ .../recipe/oauth2provider/api/logout.py | 37 ++++++++++++++ .../recipe/oauth2provider/api/revoke_token.py | 37 ++++++++++++++ .../recipe/oauth2provider/api/token.py | 37 ++++++++++++++ .../recipe/oauth2provider/api/user_info.py | 37 ++++++++++++++ .../recipe/oauth2provider/api/utils.py | 13 +++++ .../recipe/oauth2provider/interfaces.py | 11 ++++- .../recipe/oauth2provider/recipe.py | 15 ++++-- 13 files changed, 387 insertions(+), 4 deletions(-) create mode 100644 supertokens_python/recipe/oauth2provider/api/auth.py create mode 100644 supertokens_python/recipe/oauth2provider/api/end_session.py create mode 100644 supertokens_python/recipe/oauth2provider/api/introspect_token.py create mode 100644 supertokens_python/recipe/oauth2provider/api/login.py create mode 100644 supertokens_python/recipe/oauth2provider/api/login_info.py create mode 100644 supertokens_python/recipe/oauth2provider/api/logout.py create mode 100644 supertokens_python/recipe/oauth2provider/api/revoke_token.py create mode 100644 supertokens_python/recipe/oauth2provider/api/token.py create mode 100644 supertokens_python/recipe/oauth2provider/api/user_info.py create mode 100644 supertokens_python/recipe/oauth2provider/api/utils.py diff --git a/supertokens_python/recipe/oauth2provider/api/__init__.py b/supertokens_python/recipe/oauth2provider/api/__init__.py index a59c06f52..1c33ac564 100644 --- a/supertokens_python/recipe/oauth2provider/api/__init__.py +++ b/supertokens_python/recipe/oauth2provider/api/__init__.py @@ -11,3 +11,10 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + +from .auth import auth_get # type: ignore +from .end_session import end_session_get, end_session_post # type: ignore +from .logout import logout_post # type: ignore +from .revoke_token import revoke_token_post # type: ignore +from .token import token_post # type: ignore +from .user_info import user_info_get # type: ignore diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py new file mode 100644 index 000000000..61bf9650b --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + + +async def auth_get( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_auth_get is True: + return None + + raise NotImplementedError() diff --git a/supertokens_python/recipe/oauth2provider/api/end_session.py b/supertokens_python/recipe/oauth2provider/api/end_session.py new file mode 100644 index 000000000..98f0840b1 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/end_session.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + + +async def end_session_get( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_end_session_get is True: + return None + + raise NotImplementedError() + + +async def end_session_post( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_end_session_post is True: + return None + + raise NotImplementedError() diff --git a/supertokens_python/recipe/oauth2provider/api/introspect_token.py b/supertokens_python/recipe/oauth2provider/api/introspect_token.py new file mode 100644 index 000000000..1595f9b89 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/introspect_token.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + + +async def introspect_token_post( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_introspect_token_post is True: + return None + + raise NotImplementedError() diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py new file mode 100644 index 000000000..30479cbc9 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + + +async def login_get( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_login_get is True: + return None + + raise NotImplementedError() diff --git a/supertokens_python/recipe/oauth2provider/api/login_info.py b/supertokens_python/recipe/oauth2provider/api/login_info.py new file mode 100644 index 000000000..ae4f7bb3c --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/login_info.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + + +async def login_info_get( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_login_info_get is True: + return None + + raise NotImplementedError() diff --git a/supertokens_python/recipe/oauth2provider/api/logout.py b/supertokens_python/recipe/oauth2provider/api/logout.py new file mode 100644 index 000000000..86c82bd2d --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/logout.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + + +async def logout_post( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_logout_post is True: + return None + + raise NotImplementedError() diff --git a/supertokens_python/recipe/oauth2provider/api/revoke_token.py b/supertokens_python/recipe/oauth2provider/api/revoke_token.py new file mode 100644 index 000000000..fe801a98d --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/revoke_token.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + + +async def revoke_token_post( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_revoke_token_post is True: + return None + + raise NotImplementedError() diff --git a/supertokens_python/recipe/oauth2provider/api/token.py b/supertokens_python/recipe/oauth2provider/api/token.py new file mode 100644 index 000000000..3ac9c1616 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/token.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + + +async def token_post( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_token_post is True: + return None + + raise NotImplementedError() diff --git a/supertokens_python/recipe/oauth2provider/api/user_info.py b/supertokens_python/recipe/oauth2provider/api/user_info.py new file mode 100644 index 000000000..e1c6a3c90 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/user_info.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + + +async def user_info_get( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_user_info_get is True: + return None + + raise NotImplementedError() diff --git a/supertokens_python/recipe/oauth2provider/api/utils.py b/supertokens_python/recipe/oauth2provider/api/utils.py new file mode 100644 index 000000000..a59c06f52 --- /dev/null +++ b/supertokens_python/recipe/oauth2provider/api/utils.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 712385e7a..019c8f52a 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -578,7 +578,16 @@ def __init__( class APIInterface: def __init__(self): - pass + self.disable_login_get = False + self.disable_auth_get = False + self.disable_token_post = False + self.disable_login_info_get = False + self.disable_user_info_get = False + self.disable_revoke_token_post = False + self.disable_introspect_token_post = False + self.disable_end_session_get = False + self.disable_end_session_post = False + self.disable_logout_post = False @abstractmethod async def login_get( diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py index 5059ca40b..89ab8d911 100644 --- a/supertokens_python/recipe/oauth2provider/recipe.py +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from supertokens_python.recipe.accountlinking.interfaces import RecipeInterface from supertokens_python.recipe.oauth2provider.exceptions import OAuth2ProviderError from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.types import User @@ -40,7 +41,15 @@ ) -from .api import * # TODO: fix this +from .api import ( + auth_get, + end_session_get, + end_session_post, + logout_post, + revoke_token_post, + token_post, + user_info_get, +) from .constants import ( LOGIN_PATH, AUTH_PATH, @@ -74,14 +83,14 @@ def __init__( override, ) - recipe_implementation = RecipeImplementation( + recipe_implementation: RecipeInterface = RecipeImplementation( Querier.get_instance(recipe_id), app_info, self.get_default_access_token_payload, self.get_default_id_token_payload, self.get_default_user_info_payload, ) - self.recipe_implementation: RecipeImplementation = ( + self.recipe_implementation: RecipeInterface = ( recipe_implementation if self.config.override.functions is None else self.config.override.functions(recipe_implementation) From eae13cc3046a1f4769dac7e69850cb704e942ece Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Mon, 2 Dec 2024 17:29:14 +0530 Subject: [PATCH 19/38] fix: api structures and lint fixes --- .../recipe/oauth2provider/api/__init__.py | 3 + .../oauth2provider/api/implementation.py | 19 +- .../recipe/oauth2provider/api/utils.py | 319 ++++++++++++++++++ .../recipe/oauth2provider/interfaces.py | 3 - .../recipe/oauth2provider/recipe.py | 82 +++-- 5 files changed, 391 insertions(+), 35 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/api/__init__.py b/supertokens_python/recipe/oauth2provider/api/__init__.py index 1c33ac564..03cd5b95a 100644 --- a/supertokens_python/recipe/oauth2provider/api/__init__.py +++ b/supertokens_python/recipe/oauth2provider/api/__init__.py @@ -14,6 +14,9 @@ from .auth import auth_get # type: ignore from .end_session import end_session_get, end_session_post # type: ignore +from .introspect_token import introspect_token_post # type: ignore +from .login_info import login_info_get # type: ignore +from .login import login_get, login_post # type: ignore from .logout import logout_post # type: ignore from .revoke_token import revoke_token_post # type: ignore from .token import token_post # type: ignore diff --git a/supertokens_python/recipe/oauth2provider/api/implementation.py b/supertokens_python/recipe/oauth2provider/api/implementation.py index 5a9b155e2..ad683174e 100644 --- a/supertokens_python/recipe/oauth2provider/api/implementation.py +++ b/supertokens_python/recipe/oauth2provider/api/implementation.py @@ -1,2 +1,19 @@ -class APIImplementation: +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from ..interfaces import APIInterface + + +class APIImplementation(APIInterface): pass diff --git a/supertokens_python/recipe/oauth2provider/api/utils.py b/supertokens_python/recipe/oauth2provider/api/utils.py index a59c06f52..201b87270 100644 --- a/supertokens_python/recipe/oauth2provider/api/utils.py +++ b/supertokens_python/recipe/oauth2provider/api/utils.py @@ -11,3 +11,322 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from urllib.parse import parse_qs, urlparse +import time + +from supertokens_python import Supertokens +from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID +from supertokens_python.recipe.session.asyncio import get_session_information +from ..constants import LOGIN_PATH, AUTH_PATH, END_SESSION_PATH + +if TYPE_CHECKING: + from ..interfaces import ( + RecipeInterface, + ErrorOAuth2Response, + RedirectResponse, + FrontendRedirectionURLTypeTryRefresh, + FrontendRedirectionURLTypeLogin, + ) + from supertokens_python.recipe.session.interfaces import SessionContainer + + +async def login_get( + recipe_implementation: RecipeInterface, + login_challenge: str, + session: Optional[SessionContainer] = None, + should_try_refresh: bool = False, + cookies: Optional[str] = None, + is_direct_call: bool = False, + user_context: Dict[str, Any] = {}, +) -> Union[RedirectResponse, ErrorOAuth2Response]: + login_request = await recipe_implementation.get_login_request( + challenge=login_challenge, + user_context=user_context, + ) + + if isinstance(login_request, ErrorOAuth2Response): + return login_request + + session_info = ( + await get_session_information(session.get_handle()) if session else None + ) + if not session_info: + session = None + + incoming_auth_url_query_params = parse_qs(urlparse(login_request.request_url).query) + prompt_param = ( + incoming_auth_url_query_params.get("prompt", [None])[0] + or incoming_auth_url_query_params.get("st_prompt", [None])[0] + ) + max_age_param = incoming_auth_url_query_params.get("max_age", [None])[0] + + if max_age_param is not None: + try: + max_age_parsed = int(max_age_param) + + if max_age_parsed < 0: + reject = await recipe_implementation.reject_login_request( + challenge=login_challenge, + error=ErrorOAuth2Response( + error="invalid_request", + error_description="max_age cannot be negative", + ), + user_context=user_context, + ) + return RedirectResponse( + redirect_to=reject.redirect_to, + cookies=cookies, + ) + + except ValueError: + reject = await recipe_implementation.reject_login_request( + challenge=login_challenge, + error=ErrorOAuth2Response( + error="invalid_request", + error_description="max_age must be an integer", + ), + user_context=user_context, + ) + return RedirectResponse( + redirect_to=reject.redirect_to, + cookies=cookies, + ) + + tenant_id_param = incoming_auth_url_query_params.get("tenant_id", [None])[0] + + if ( + session + and session_info + and ( + not login_request.subject or session.get_user_id() == login_request.subject + ) + and (not tenant_id_param or session.get_tenant_id() == tenant_id_param) + and (prompt_param != "login" or is_direct_call) + and ( + max_age_param is None + or (max_age_param == "0" and is_direct_call) + or int(max_age_param) * 1000 + > time.time() * 1000 - session_info.time_created + ) + ): + accept = await recipe_implementation.accept_login_request( + challenge=login_challenge, + subject=session.get_user_id(), + identity_provider_session_id=session.get_handle(), + user_context=user_context, + ) + return RedirectResponse( + redirect_to=accept.redirect_to, + cookies=cookies, + ) + + if should_try_refresh and prompt_param != "login": + return RedirectResponse( + redirect_to=await recipe_implementation.get_frontend_redirection_url( + input=FrontendRedirectionURLTypeTryRefresh( + login_challenge=login_challenge, + ), + user_context=user_context, + ), + cookies=cookies, + ) + + if prompt_param == "none": + reject = await recipe_implementation.reject_login_request( + challenge=login_challenge, + error=ErrorOAuth2Response( + error="login_required", + error_description="The Authorization Server requires End-User authentication. Prompt 'none' was requested, but no existing or expired login session was found.", + ), + user_context=user_context, + ) + return RedirectResponse( + redirect_to=reject.redirect_to, + cookies=cookies, + ) + + return RedirectResponse( + redirect_to=await recipe_implementation.get_frontend_redirection_url( + input=FrontendRedirectionURLTypeLogin( + login_challenge=login_challenge, + force_fresh_auth=session is not None or prompt_param == "login", + tenant_id=tenant_id_param or DEFAULT_TENANT_ID, + hint=( + login_request.oidc_context.get("login_hint") + if login_request.oidc_context + else None + ), + ), + user_context=user_context, + ), + cookies=cookies, + ) + + +def get_merged_cookies( + orig_cookies: str = "", new_cookies: Optional[str] = None +) -> str: + if not new_cookies: + return orig_cookies + + cookie_map: Dict[str, str] = {} + for cookie in orig_cookies.split(";"): + if "=" in cookie: + name, value = cookie.split("=", 1) + cookie_map[name.strip()] = value + + # Note: This is a simplified version. In production code you'd want to use a proper + # cookie parsing library to handle all cookie attributes correctly + if new_cookies: + for cookie in new_cookies.split(","): + if "=" in cookie: + name, value = cookie.split("=", 1) + cookie_map[name.strip()] = value + + return ";".join(f"{key}={value}" for key, value in cookie_map.items()) + + +def merge_set_cookie_headers( + set_cookie1: Optional[str] = None, set_cookie2: Optional[str] = None +) -> str: + if not set_cookie1: + return set_cookie2 or "" + if not set_cookie2 or set_cookie1 == set_cookie2: + return set_cookie1 + return f"{set_cookie1}, {set_cookie2}" + + +def is_login_internal_redirect(redirect_to: str) -> bool: + instance = Supertokens.get_instance() + api_domain = instance.app_info.api_domain.get_as_string_dangerous() + api_base_path = instance.app_info.api_base_path.get_as_string_dangerous() + base_path = f"{api_domain}{api_base_path}" + + return any( + redirect_to.startswith(f"{base_path}{path}") for path in [LOGIN_PATH, AUTH_PATH] + ) + + +def is_logout_internal_redirect(redirect_to: str) -> bool: + instance = Supertokens.get_instance() + api_domain = instance.app_info.api_domain.get_as_string_dangerous() + api_base_path = instance.app_info.api_base_path.get_as_string_dangerous() + base_path = f"{api_domain}{api_base_path}" + return redirect_to.startswith(f"{base_path}{END_SESSION_PATH}") + + +async def handle_login_internal_redirects( + response: RedirectResponse, + recipe_implementation: RecipeInterface, + session: Optional[SessionContainer] = None, + should_try_refresh: bool = False, + cookie: str = "", + user_context: Dict[str, Any] = {}, +) -> Union[RedirectResponse, ErrorOAuth2Response]: + if not is_login_internal_redirect(response.redirect_to): + return response + + max_redirects = 10 + redirect_count = 0 + + while redirect_count < max_redirects and is_login_internal_redirect( + response.redirect_to + ): + cookie = get_merged_cookies(cookie, response.cookies) + + query_string = ( + response.redirect_to.split("?", 1)[1] if "?" in response.redirect_to else "" + ) + params = parse_qs(query_string) + + if LOGIN_PATH in response.redirect_to: + login_challenge = ( + params.get("login_challenge", [None])[0] + or params.get("loginChallenge", [None])[0] + ) + if not login_challenge: + raise Exception(f"Expected loginChallenge in {response.redirect_to}") + + login_res = await login_get( + recipe_implementation=recipe_implementation, + login_challenge=login_challenge, + session=session, + should_try_refresh=should_try_refresh, + cookies=response.cookies, + is_direct_call=False, + user_context=user_context, + ) + + if isinstance(login_res, ErrorOAuth2Response): + return login_res + + response = RedirectResponse( + redirect_to=login_res.redirect_to, + cookies=merge_set_cookie_headers(login_res.cookies, response.cookies), + ) + + elif AUTH_PATH in response.redirect_to: + auth_res = await recipe_implementation.authorization( + params={k: v[0] for k, v in params.items()}, + cookies=cookie, + session=session, + user_context=user_context, + ) + + if isinstance(auth_res, ErrorOAuth2Response): + return auth_res + + response = RedirectResponse( + redirect_to=auth_res.redirect_to, + cookies=merge_set_cookie_headers(auth_res.cookies, response.cookies), + ) + + else: + raise Exception(f"Unexpected internal redirect {response.redirect_to}") + + redirect_count += 1 + + return response + + +async def handle_logout_internal_redirects( + response: RedirectResponse, + recipe_implementation: RecipeInterface, + session: Optional[SessionContainer] = None, + user_context: Dict[str, Any] = {}, +) -> Union[RedirectResponse, ErrorOAuth2Response]: + if not is_logout_internal_redirect(response.redirect_to): + return response + + max_redirects = 10 + redirect_count = 0 + + while redirect_count < max_redirects and is_logout_internal_redirect( + response.redirect_to + ): + query_string = ( + response.redirect_to.split("?", 1)[1] if "?" in response.redirect_to else "" + ) + params = parse_qs(query_string) + + if END_SESSION_PATH in response.redirect_to: + end_session_res = await recipe_implementation.end_session( + params={k: v[0] for k, v in params.items()}, + session=session, + should_try_refresh=False, + user_context=user_context, + ) + if isinstance(end_session_res, ErrorOAuth2Response): + return end_session_res + response = end_session_res + else: + raise Exception(f"Unexpected internal redirect {response.redirect_to}") + + redirect_count += 1 + + return response diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 019c8f52a..ee5b70226 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -310,9 +310,6 @@ def __init__(self, token: str, client_id: str, client_secret: str): class RecipeInterface(ABC): - def __init__(self): - pass - @abstractmethod async def authorization( self, diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py index 89ab8d911..e2d4f6277 100644 --- a/supertokens_python/recipe/oauth2provider/recipe.py +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -18,14 +18,22 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.exceptions import SuperTokensError, raise_general_exception -from supertokens_python.recipe.accountlinking.interfaces import RecipeInterface +from supertokens_python.recipe.oauth2provider.api.introspect_token import ( + introspect_token_post, +) +from supertokens_python.recipe.oauth2provider.api.login import login_get +from supertokens_python.recipe.oauth2provider.api.login_info import login_info_get from supertokens_python.recipe.oauth2provider.exceptions import OAuth2ProviderError from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.types import User -from .recipe_implementation import RecipeImplementation - -from .interfaces import APIOptions, PayloadBuilderFunction, UserInfoBuilderFunction +from .interfaces import ( + APIInterface, + APIOptions, + PayloadBuilderFunction, + UserInfoBuilderFunction, + RecipeInterface, +) if TYPE_CHECKING: @@ -83,6 +91,8 @@ def __init__( override, ) + from .recipe_implementation import RecipeImplementation + recipe_implementation: RecipeInterface = RecipeImplementation( Querier.get_instance(recipe_id), app_info, @@ -91,16 +101,18 @@ def __init__( self.get_default_user_info_payload, ) self.recipe_implementation: RecipeInterface = ( - recipe_implementation - if self.config.override.functions is None - else self.config.override.functions(recipe_implementation) + self.config.override.functions(recipe_implementation) + if self.config.override is not None + and self.config.override.functions is not None + else recipe_implementation ) api_implementation = APIImplementation() - self.api_implementation: APIImplementation = ( - api_implementation - if self.config.override.apis is None - else self.config.override.apis(api_implementation) + self.api_implementation: APIInterface = ( + self.config.override.apis(api_implementation) + if self.config.override is not None + and self.config.override.apis is not None + else api_implementation ) self._access_token_builders: List[PayloadBuilderFunction] = [] @@ -116,61 +128,61 @@ def get_apis_handled(self) -> List[APIHandled]: NormalisedURLPath(LOGIN_PATH), "get", LOGIN_PATH, - self.api_implementation.login_get, + self.api_implementation.disable_login_get, ), APIHandled( NormalisedURLPath(TOKEN_PATH), "post", TOKEN_PATH, - self.api_implementation.token_post, + self.api_implementation.disable_token_post, ), APIHandled( NormalisedURLPath(AUTH_PATH), "get", AUTH_PATH, - self.api_implementation.auth_get, + self.api_implementation.disable_auth_get, ), APIHandled( NormalisedURLPath(LOGIN_INFO_PATH), "get", LOGIN_INFO_PATH, - self.api_implementation.login_info_get, + self.api_implementation.disable_login_info_get, ), APIHandled( NormalisedURLPath(USER_INFO_PATH), "get", USER_INFO_PATH, - self.api_implementation.user_info_get, + self.api_implementation.disable_user_info_get, ), APIHandled( NormalisedURLPath(REVOKE_TOKEN_PATH), "post", REVOKE_TOKEN_PATH, - self.api_implementation.revoke_token_post, + self.api_implementation.disable_revoke_token_post, ), APIHandled( NormalisedURLPath(INTROSPECT_TOKEN_PATH), "post", INTROSPECT_TOKEN_PATH, - self.api_implementation.introspect_token_post, + self.api_implementation.disable_introspect_token_post, ), APIHandled( NormalisedURLPath(END_SESSION_PATH), "get", END_SESSION_PATH, - self.api_implementation.end_session_get, + self.api_implementation.disable_end_session_get, ), APIHandled( NormalisedURLPath(END_SESSION_PATH), "post", END_SESSION_PATH, - self.api_implementation.end_session_post, + self.api_implementation.disable_end_session_post, ), APIHandled( NormalisedURLPath(LOGOUT_PATH), "post", LOGOUT_PATH, - self.api_implementation.logout_post, + self.api_implementation.disable_logout_post, ), ] @@ -192,46 +204,54 @@ async def handle_api_request( self.recipe_implementation, ) if request_id == LOGIN_PATH: - return await login_api(self.api_implementation, api_options, user_context) + return await login_get( + tenant_id, self.api_implementation, api_options, user_context + ) if request_id == TOKEN_PATH: - return await token_post(self.api_implementation, api_options, user_context) + return await token_post( + tenant_id, self.api_implementation, api_options, user_context + ) if request_id == AUTH_PATH: - return await auth_get(self.api_implementation, api_options, user_context) + return await auth_get( + tenant_id, self.api_implementation, api_options, user_context + ) if request_id == LOGIN_INFO_PATH: return await login_info_get( - self.api_implementation, api_options, user_context + tenant_id, self.api_implementation, api_options, user_context ) if request_id == USER_INFO_PATH: return await user_info_get( - self.api_implementation, tenant_id, api_options, user_context + tenant_id, self.api_implementation, api_options, user_context ) if request_id == REVOKE_TOKEN_PATH: return await revoke_token_post( - self.api_implementation, api_options, user_context + tenant_id, self.api_implementation, api_options, user_context ) if request_id == INTROSPECT_TOKEN_PATH: return await introspect_token_post( - self.api_implementation, api_options, user_context + tenant_id, self.api_implementation, api_options, user_context ) if request_id == END_SESSION_PATH and method == "get": return await end_session_get( - self.api_implementation, api_options, user_context + tenant_id, self.api_implementation, api_options, user_context ) if request_id == END_SESSION_PATH and method == "post": return await end_session_post( - self.api_implementation, api_options, user_context + tenant_id, self.api_implementation, api_options, user_context ) if request_id == LOGOUT_PATH and method == "post": - return await logout_post(self.api_implementation, api_options, user_context) + return await logout_post( + tenant_id, self.api_implementation, api_options, user_context + ) raise Exception( "Should never come here: handle_api_request called with unknown id" From c4c8d11778ae5b980bb558a21f7f89dbe9e6dfb6 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Mon, 2 Dec 2024 17:39:44 +0530 Subject: [PATCH 20/38] fix: remaining type fixes --- .../recipe/oauth2provider/api/__init__.py | 2 +- .../recipe/oauth2provider/api/auth.py | 2 -- .../recipe/oauth2provider/api/end_session.py | 2 -- .../recipe/oauth2provider/api/introspect_token.py | 2 -- .../recipe/oauth2provider/api/login.py | 4 +--- .../recipe/oauth2provider/api/login_info.py | 2 -- .../recipe/oauth2provider/api/logout.py | 2 -- .../recipe/oauth2provider/api/revoke_token.py | 2 -- .../recipe/oauth2provider/api/token.py | 2 -- .../recipe/oauth2provider/api/user_info.py | 2 -- .../recipe/oauth2provider/interfaces.py | 15 ++++++++++++++- .../recipe/oauth2provider/recipe.py | 5 +++-- 12 files changed, 19 insertions(+), 23 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/api/__init__.py b/supertokens_python/recipe/oauth2provider/api/__init__.py index 03cd5b95a..fdef2a2bc 100644 --- a/supertokens_python/recipe/oauth2provider/api/__init__.py +++ b/supertokens_python/recipe/oauth2provider/api/__init__.py @@ -16,7 +16,7 @@ from .end_session import end_session_get, end_session_post # type: ignore from .introspect_token import introspect_token_post # type: ignore from .login_info import login_info_get # type: ignore -from .login import login_get, login_post # type: ignore +from .login import login # type: ignore from .logout import logout_post # type: ignore from .revoke_token import revoke_token_post # type: ignore from .token import token_post # type: ignore diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py index 61bf9650b..4adab3ba9 100644 --- a/supertokens_python/recipe/oauth2provider/api/auth.py +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -22,8 +22,6 @@ APIInterface, ) -from supertokens_python.utils import send_200_response - async def auth_get( _tenant_id: str, diff --git a/supertokens_python/recipe/oauth2provider/api/end_session.py b/supertokens_python/recipe/oauth2provider/api/end_session.py index 98f0840b1..e3fbee652 100644 --- a/supertokens_python/recipe/oauth2provider/api/end_session.py +++ b/supertokens_python/recipe/oauth2provider/api/end_session.py @@ -22,8 +22,6 @@ APIInterface, ) -from supertokens_python.utils import send_200_response - async def end_session_get( _tenant_id: str, diff --git a/supertokens_python/recipe/oauth2provider/api/introspect_token.py b/supertokens_python/recipe/oauth2provider/api/introspect_token.py index 1595f9b89..defd6a1ea 100644 --- a/supertokens_python/recipe/oauth2provider/api/introspect_token.py +++ b/supertokens_python/recipe/oauth2provider/api/introspect_token.py @@ -22,8 +22,6 @@ APIInterface, ) -from supertokens_python.utils import send_200_response - async def introspect_token_post( _tenant_id: str, diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py index 30479cbc9..265fadba2 100644 --- a/supertokens_python/recipe/oauth2provider/api/login.py +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -22,10 +22,8 @@ APIInterface, ) -from supertokens_python.utils import send_200_response - -async def login_get( +async def login( _tenant_id: str, api_implementation: APIInterface, api_options: APIOptions, diff --git a/supertokens_python/recipe/oauth2provider/api/login_info.py b/supertokens_python/recipe/oauth2provider/api/login_info.py index ae4f7bb3c..f160a15f1 100644 --- a/supertokens_python/recipe/oauth2provider/api/login_info.py +++ b/supertokens_python/recipe/oauth2provider/api/login_info.py @@ -22,8 +22,6 @@ APIInterface, ) -from supertokens_python.utils import send_200_response - async def login_info_get( _tenant_id: str, diff --git a/supertokens_python/recipe/oauth2provider/api/logout.py b/supertokens_python/recipe/oauth2provider/api/logout.py index 86c82bd2d..67ff04a26 100644 --- a/supertokens_python/recipe/oauth2provider/api/logout.py +++ b/supertokens_python/recipe/oauth2provider/api/logout.py @@ -22,8 +22,6 @@ APIInterface, ) -from supertokens_python.utils import send_200_response - async def logout_post( _tenant_id: str, diff --git a/supertokens_python/recipe/oauth2provider/api/revoke_token.py b/supertokens_python/recipe/oauth2provider/api/revoke_token.py index fe801a98d..b1e4c1337 100644 --- a/supertokens_python/recipe/oauth2provider/api/revoke_token.py +++ b/supertokens_python/recipe/oauth2provider/api/revoke_token.py @@ -22,8 +22,6 @@ APIInterface, ) -from supertokens_python.utils import send_200_response - async def revoke_token_post( _tenant_id: str, diff --git a/supertokens_python/recipe/oauth2provider/api/token.py b/supertokens_python/recipe/oauth2provider/api/token.py index 3ac9c1616..2d49260bd 100644 --- a/supertokens_python/recipe/oauth2provider/api/token.py +++ b/supertokens_python/recipe/oauth2provider/api/token.py @@ -22,8 +22,6 @@ APIInterface, ) -from supertokens_python.utils import send_200_response - async def token_post( _tenant_id: str, diff --git a/supertokens_python/recipe/oauth2provider/api/user_info.py b/supertokens_python/recipe/oauth2provider/api/user_info.py index e1c6a3c90..8cee3f4ef 100644 --- a/supertokens_python/recipe/oauth2provider/api/user_info.py +++ b/supertokens_python/recipe/oauth2provider/api/user_info.py @@ -22,8 +22,6 @@ APIInterface, ) -from supertokens_python.utils import send_200_response - async def user_info_get( _tenant_id: str, diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index ee5b70226..1fbff5e14 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -309,6 +309,19 @@ def __init__(self, token: str, client_id: str, client_secret: str): self.client_secret = client_secret +class InactiveTokenResponse: + def to_json(self): + return {"active": False} + + +class ActiveTokenResponse: + def __init__(self, payload: Dict[str, Any]): + self.payload = payload + + def to_json(self): + return {"active": True, **self.payload} + + class RecipeInterface(ABC): @abstractmethod async def authorization( @@ -667,7 +680,7 @@ async def introspect_token_post( scopes: Optional[List[str]], options: APIOptions, user_context: Dict[str, Any] = {}, - ) -> Union[IntrospectTokenResponse, GeneralErrorResponse]: + ) -> Union[ActiveTokenResponse, InactiveTokenResponse, GeneralErrorResponse]: pass @abstractmethod diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py index e2d4f6277..c4f0a4a94 100644 --- a/supertokens_python/recipe/oauth2provider/recipe.py +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -21,7 +21,7 @@ from supertokens_python.recipe.oauth2provider.api.introspect_token import ( introspect_token_post, ) -from supertokens_python.recipe.oauth2provider.api.login import login_get +from supertokens_python.recipe.oauth2provider.api.login import login from supertokens_python.recipe.oauth2provider.api.login_info import login_info_get from supertokens_python.recipe.oauth2provider.exceptions import OAuth2ProviderError from supertokens_python.recipe_module import APIHandled, RecipeModule @@ -51,6 +51,7 @@ from .api import ( auth_get, + login, end_session_get, end_session_post, logout_post, @@ -204,7 +205,7 @@ async def handle_api_request( self.recipe_implementation, ) if request_id == LOGIN_PATH: - return await login_get( + return await login( tenant_id, self.api_implementation, api_options, user_context ) From a1dff9d41ef3274ede5330d125f834e26c8ff3bc Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 3 Dec 2024 11:28:14 +0530 Subject: [PATCH 21/38] fix: end session --- .../framework/django/django_response.py | 5 ++ .../framework/fastapi/fastapi_response.py | 5 ++ .../framework/flask/flask_response.py | 4 + supertokens_python/framework/request.py | 8 ++ supertokens_python/framework/response.py | 4 + .../recipe/oauth2provider/api/end_session.py | 88 ++++++++++++++++++- 6 files changed, 111 insertions(+), 3 deletions(-) diff --git a/supertokens_python/framework/django/django_response.py b/supertokens_python/framework/django/django_response.py index 1029c1a55..9692b2d55 100644 --- a/supertokens_python/framework/django/django_response.py +++ b/supertokens_python/framework/django/django_response.py @@ -88,3 +88,8 @@ def set_json_content(self, content: Dict[str, Any]): separators=(",", ":"), ).encode("utf-8") self.response_sent = True + + def redirect(self, url: str): + if not self.response_sent: + self.set_header("Location", url) + self.set_status_code(302) diff --git a/supertokens_python/framework/fastapi/fastapi_response.py b/supertokens_python/framework/fastapi/fastapi_response.py index 3f6078af3..45813a5cc 100644 --- a/supertokens_python/framework/fastapi/fastapi_response.py +++ b/supertokens_python/framework/fastapi/fastapi_response.py @@ -94,3 +94,8 @@ def set_json_content(self, content: Dict[str, Any]): self.set_header("Content-Length", str(len(body))) self.response.body = body self.response_sent = True + + def redirect(self, url: str): + if not self.response_sent: + self.set_header("Location", url) + self.set_status_code(302) diff --git a/supertokens_python/framework/flask/flask_response.py b/supertokens_python/framework/flask/flask_response.py index 647f8d3df..a74bdfb83 100644 --- a/supertokens_python/framework/flask/flask_response.py +++ b/supertokens_python/framework/flask/flask_response.py @@ -85,3 +85,7 @@ def set_json_content(self, content: Dict[str, Any]): separators=(",", ":"), ).encode("utf-8") self.response_sent = True + + def redirect(self, url: str): + self.response.headers.set("Location", url) + self.set_status_code(302) diff --git a/supertokens_python/framework/request.py b/supertokens_python/framework/request.py index 6f6533185..f53e3d1f2 100644 --- a/supertokens_python/framework/request.py +++ b/supertokens_python/framework/request.py @@ -47,6 +47,14 @@ async def json(self) -> Union[Any, None]: async def form_data(self) -> Dict[str, Any]: pass + async def get_json_or_form_data(self) -> Union[Dict[str, Any], None]: + content_type = self.get_header("Content-Type") + if content_type is None: + return None + if content_type.startswith("application/json"): + return await self.json() + return await self.form_data() + @abstractmethod def method(self) -> str: pass diff --git a/supertokens_python/framework/response.py b/supertokens_python/framework/response.py index fb5cc3477..8669e3ae4 100644 --- a/supertokens_python/framework/response.py +++ b/supertokens_python/framework/response.py @@ -61,3 +61,7 @@ def set_json_content(self, content: Dict[str, Any]): @abstractmethod def set_html_content(self, content: str): pass + + @abstractmethod + def redirect(self, url: str): + pass diff --git a/supertokens_python/recipe/oauth2provider/api/end_session.py b/supertokens_python/recipe/oauth2provider/api/end_session.py index e3fbee652..cfbeed500 100644 --- a/supertokens_python/recipe/oauth2provider/api/end_session.py +++ b/supertokens_python/recipe/oauth2provider/api/end_session.py @@ -14,12 +14,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union +import urllib.parse + +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.framework import BaseResponse +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.asyncio import get_session +from supertokens_python.recipe.session.exceptions import TryRefreshTokenError +from supertokens_python.types import GeneralErrorResponse +from supertokens_python.utils import send_200_response, send_non_200_response if TYPE_CHECKING: from ..interfaces import ( APIOptions, APIInterface, + RedirectResponse, + ErrorOAuth2Response, ) @@ -32,7 +43,13 @@ async def end_session_get( if api_implementation.disable_end_session_get is True: return None - raise NotImplementedError() + orig_url = api_options.request.get_original_url() + split_url = orig_url.split("?") + params = dict(urllib.parse.parse_qsl(split_url[1])) + + return await end_session_common( + params, api_implementation.end_session_get, api_options, user_context + ) async def end_session_post( @@ -44,4 +61,69 @@ async def end_session_post( if api_implementation.disable_end_session_post is True: return None - raise NotImplementedError() + params = await api_options.request.get_json_or_form_data() + if params is None: + raise_bad_input_exception("Please provide a JSON body or form data") + + return await end_session_common( + params, api_implementation.end_session_post, api_options, user_context + ) + + +EndSessionCallable = Callable[ + [Dict[str, str], APIOptions, Optional[SessionContainer], bool, Dict[str, Any]], + Awaitable[Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]], +] + + +async def end_session_common( + params: Dict[str, str], + api_implementation: Optional[EndSessionCallable], + options: APIOptions, + user_context: Dict[str, Any], +) -> Optional[BaseResponse]: + if api_implementation is None: + return None + + session = None + should_try_refresh = False + try: + session = await get_session( + options.request, + False, + user_context=user_context, + ) + should_try_refresh = False + except Exception as error: + # We can handle this as if the session is not present, because then we redirect to the frontend, + # which should handle the validation error + session = None + if isinstance(error, TryRefreshTokenError): + should_try_refresh = True + else: + should_try_refresh = False + + response = await api_implementation( + params, + options, + session, + should_try_refresh, + user_context, + ) + + if isinstance(response, RedirectResponse): + options.response.redirect(response.redirect_to) + elif isinstance(response, ErrorOAuth2Response): + send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code or 400, + options.response, + ) + else: + if isinstance(response, dict): + return send_200_response(response, options.response) + else: + return send_200_response(response.to_json(), options.response) From 1e35b54132407a5f49114a49ec3806edc5a10619 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 3 Dec 2024 17:27:40 +0530 Subject: [PATCH 22/38] fix: api endpoints --- .../oauth2provider/api/introspect_token.py | 23 +++++- .../recipe/oauth2provider/api/login.py | 82 ++++++++++++++++++- .../recipe/oauth2provider/api/login_info.py | 32 +++++++- .../recipe/oauth2provider/api/logout.py | 48 ++++++++++- .../recipe/oauth2provider/api/revoke_token.py | 57 ++++++++++++- .../recipe/oauth2provider/interfaces.py | 42 ++++++++-- 6 files changed, 263 insertions(+), 21 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/api/introspect_token.py b/supertokens_python/recipe/oauth2provider/api/introspect_token.py index defd6a1ea..cb7364498 100644 --- a/supertokens_python/recipe/oauth2provider/api/introspect_token.py +++ b/supertokens_python/recipe/oauth2provider/api/introspect_token.py @@ -14,7 +14,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, List + +from supertokens_python.utils import send_200_response, send_non_200_response if TYPE_CHECKING: from ..interfaces import ( @@ -32,4 +34,21 @@ async def introspect_token_post( if api_implementation.disable_introspect_token_post is True: return None - raise NotImplementedError() + body = await api_options.request.get_json_or_form_data() + if body is None or "token" not in body: + return send_non_200_response( + {"message": "token is required in the request body"}, + 400, + api_options.response, + ) + + scopes: List[str] = body.get("scope", "").split(" ") if "scope" in body else [] + + response = await api_implementation.introspect_token_post( + body["token"], + scopes, + api_options, + user_context, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py index 265fadba2..91ff8142f 100644 --- a/supertokens_python/recipe/oauth2provider/api/login.py +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -14,12 +14,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional + +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.framework import BaseResponse +from supertokens_python.recipe.session.asyncio import get_session +from supertokens_python.recipe.session.exceptions import TryRefreshTokenError +from supertokens_python.utils import send_200_response, send_non_200_response +from http.cookies import SimpleCookie if TYPE_CHECKING: from ..interfaces import ( APIOptions, APIInterface, + FrontendRedirectResponse, + ErrorOAuth2Response, ) @@ -28,8 +37,75 @@ async def login( api_implementation: APIInterface, api_options: APIOptions, user_context: Dict[str, Any], -): +) -> Optional[BaseResponse]: if api_implementation.disable_login_get is True: return None - raise NotImplementedError() + session = None + should_try_refresh = False + try: + session = await get_session( + api_options.request, + False, + user_context=user_context, + ) + should_try_refresh = False + except Exception as error: + # We can handle this as if the session is not present, because then we redirect to the frontend, + # which should handle the validation error + session = None + if isinstance(error, TryRefreshTokenError): + should_try_refresh = True + else: + should_try_refresh = False + + login_challenge = api_options.request.get_query_param( + "login_challenge" + ) or api_options.request.get_query_param("loginChallenge") + if login_challenge is None: + raise_bad_input_exception("Missing input param: loginChallenge") + + response = await api_implementation.login_get( + login_challenge=login_challenge, + options=api_options, + session=session, + should_try_refresh=should_try_refresh, + user_context=user_context, + ) + + if isinstance(response, FrontendRedirectResponse): + if response.cookies: + cookie = SimpleCookie() + cookie.load(response.cookies) + for morsel in cookie.values(): + api_options.response.set_cookie( + key=morsel.key, + value=morsel.value, + domain=morsel.get("domain"), + secure=morsel.get("secure", True), + httponly=morsel.get("httponly", True), + expires=morsel.get("expires", None), + path=morsel.get("path", "/"), + samesite=morsel.get("samesite", "lax"), + ) + + return send_200_response( + {"frontendRedirectTo": response.frontend_redirect_to}, + api_options.response, + ) + + elif isinstance(response, ErrorOAuth2Response): + # We want to avoid returning a 401 to the frontend, as it may trigger a refresh loop + if response.status_code == 401: + response.status_code = 400 + + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code or 400, + api_options.response, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/login_info.py b/supertokens_python/recipe/oauth2provider/api/login_info.py index f160a15f1..aab799f33 100644 --- a/supertokens_python/recipe/oauth2provider/api/login_info.py +++ b/supertokens_python/recipe/oauth2provider/api/login_info.py @@ -16,6 +16,10 @@ from typing import TYPE_CHECKING, Any, Dict +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.oauth2provider.interfaces import ErrorOAuth2Response +from supertokens_python.utils import send_200_response, send_non_200_response + if TYPE_CHECKING: from ..interfaces import ( APIOptions, @@ -32,4 +36,30 @@ async def login_info_get( if api_implementation.disable_login_info_get is True: return None - raise NotImplementedError() + login_challenge = api_options.request.get_query_param( + "login_challenge" + ) or api_options.request.get_query_param("loginChallenge") + + if login_challenge is None: + raise_bad_input_exception("Missing input param: loginChallenge") + + response = await api_implementation.login_info_get( + login_challenge, + api_options, + user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + # We want to avoid returning a 401 to the frontend, as it may trigger a refresh loop + if response.status_code == 401: + response.status_code = 400 + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code or 400, + api_options.response, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/logout.py b/supertokens_python/recipe/oauth2provider/api/logout.py index 67ff04a26..646a705ff 100644 --- a/supertokens_python/recipe/oauth2provider/api/logout.py +++ b/supertokens_python/recipe/oauth2provider/api/logout.py @@ -14,12 +14,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional + +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.framework import BaseResponse +from supertokens_python.recipe.session.asyncio import get_session +from supertokens_python.utils import send_200_response, send_non_200_response if TYPE_CHECKING: from ..interfaces import ( APIOptions, APIInterface, + FrontendRedirectResponse, + ErrorOAuth2Response, ) @@ -28,8 +35,43 @@ async def logout_post( api_implementation: APIInterface, api_options: APIOptions, user_context: Dict[str, Any], -): +) -> Optional[BaseResponse]: if api_implementation.disable_logout_post is True: return None + session = None + try: + session = await get_session( + api_options.request, session_required=False, user_context=user_context + ) + except: + pass + + body = await api_options.request.json() + + if body is None or "logoutChallenge" not in body: + raise_bad_input_exception("Missing body param: logoutChallenge") + + response = await api_implementation.logout_post( + logout_challenge=body["logoutChallenge"], + options=api_options, + session=session, + user_context=user_context, + ) + + if isinstance(response, FrontendRedirectResponse): + return send_200_response(response.to_json(), api_options.response) + elif isinstance(response, ErrorOAuth2Response): + # We want to avoid returning a 401 to the frontend, as it may trigger a refresh loop + if response.status_code == 401: + response.status_code = 400 - raise NotImplementedError() + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code if response.status_code is not None else 400, + api_options.response, + ) + else: + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/revoke_token.py b/supertokens_python/recipe/oauth2provider/api/revoke_token.py index b1e4c1337..e576a7ff6 100644 --- a/supertokens_python/recipe/oauth2provider/api/revoke_token.py +++ b/supertokens_python/recipe/oauth2provider/api/revoke_token.py @@ -14,12 +14,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional + +from supertokens_python.framework import BaseResponse +from supertokens_python.utils import send_200_response, send_non_200_response if TYPE_CHECKING: from ..interfaces import ( APIOptions, APIInterface, + ErrorOAuth2Response, + GeneralErrorResponse, ) @@ -28,8 +33,54 @@ async def revoke_token_post( api_implementation: APIInterface, api_options: APIOptions, user_context: Dict[str, Any], -): +) -> Optional[BaseResponse]: if api_implementation.disable_revoke_token_post is True: return None - raise NotImplementedError() + body = await api_options.request.get_json_or_form_data() + + if body is None or "token" not in body: + return send_non_200_response( + {"message": "token is required in the request body"}, + 400, + api_options.response, + ) + + authorization_header = api_options.request.get_header("authorization") + + if authorization_header is not None and ( + "client_id" in body or "client_secret" in body + ): + return send_non_200_response( + { + "message": "Only one of authorization header or client_id and client_secret can be provided" + }, + 400, + api_options.response, + ) + + response = await api_implementation.revoke_token_post( + token=body["token"], + options=api_options, + authorization_header=authorization_header, + client_id=body.get("client_id"), + client_secret=body.get("client_secret"), + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code if response.status_code is not None else 400, + api_options.response, + ) + elif isinstance(response, GeneralErrorResponse): + return send_200_response( + response.to_json(), + api_options.response, + ) + + return send_200_response({"status": "OK"}, api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 1fbff5e14..7fdac563e 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -43,7 +43,7 @@ def __init__( self.error_description = error_description self.status_code = status_code - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "status": self.status, "error": self.error, @@ -193,6 +193,20 @@ def __init__( self.client_uri = client_uri self.metadata = metadata + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "info": { + "clientId": self.client_id, + "clientName": self.client_name, + "tosUri": self.tos_uri, + "policyUri": self.policy_uri, + "logoUri": self.logo_uri, + "clientUri": self.client_uri, + "metadata": self.metadata, + }, + } + class RedirectResponse: def __init__(self, redirect_to: str, cookies: Optional[str] = None): @@ -200,6 +214,20 @@ def __init__(self, redirect_to: str, cookies: Optional[str] = None): self.cookies = cookies +class FrontendRedirectResponse: + def __init__(self, frontend_redirect_to: str, cookies: Optional[str] = None): + self.frontend_redirect_to = frontend_redirect_to + self.cookies = cookies + + def to_json(self) -> Dict[str, Any]: + result = { + "frontendRedirectTo": self.frontend_redirect_to, + } + if self.cookies is not None: + result["cookies"] = self.cookies + return result + + class GetOAuth2ClientsOkResult: def __init__( self, clients: List[OAuth2Client], next_pagination_token: Optional[str] @@ -607,9 +635,7 @@ async def login_get( session: Optional[SessionContainer] = None, should_try_refresh: bool = False, user_context: Dict[str, Any] = {}, - ) -> Union[ - Dict[str, Union[str, Optional[str]]], ErrorOAuth2Response, GeneralErrorResponse - ]: + ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: pass @abstractmethod @@ -643,7 +669,7 @@ async def login_info_get( options: APIOptions, user_context: Dict[str, Any] = {}, ) -> Union[ - Dict[Literal["status", "info"], Union[Literal["OK"], LoginInfo]], + LoginInfo, ErrorOAuth2Response, GeneralErrorResponse, ]: @@ -670,7 +696,7 @@ async def revoke_token_post( authorization_header: Optional[str] = None, client_id: Optional[str] = None, client_secret: Optional[str] = None, - ) -> Union[Dict[Literal["status"], Literal["OK"]], ErrorOAuth2Response]: + ) -> Union[None, ErrorOAuth2Response, GeneralErrorResponse]: pass @abstractmethod @@ -712,7 +738,5 @@ async def logout_post( options: APIOptions, session: Optional[SessionContainer] = None, user_context: Dict[str, Any] = {}, - ) -> Union[ - Dict[str, Union[Literal["OK"], str]], ErrorOAuth2Response, GeneralErrorResponse - ]: + ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: pass From 79194a415197c1d5ab9472fb3fefa82134ab35b6 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Wed, 4 Dec 2024 10:43:56 +0530 Subject: [PATCH 23/38] fix: remaining apis --- .../recipe/oauth2provider/api/auth.py | 69 +++++++++++++- .../recipe/oauth2provider/api/token.py | 27 +++++- .../recipe/oauth2provider/api/user_info.py | 92 ++++++++++++++++++- .../recipe/oauth2provider/interfaces.py | 15 ++- 4 files changed, 196 insertions(+), 7 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py index 4adab3ba9..15b735ff4 100644 --- a/supertokens_python/recipe/oauth2provider/api/auth.py +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -14,12 +14,22 @@ from __future__ import annotations +from http.cookies import SimpleCookie from typing import TYPE_CHECKING, Any, Dict +from urllib.parse import parse_qsl + +from fastapi.responses import RedirectResponse + +from supertokens_python.recipe.session.asyncio import get_session +from supertokens_python.recipe.session.exceptions import TryRefreshTokenError +from supertokens_python.utils import send_200_response, send_non_200_response if TYPE_CHECKING: from ..interfaces import ( APIOptions, APIInterface, + RedirectResponse, + ErrorOAuth2Response, ) @@ -32,4 +42,61 @@ async def auth_get( if api_implementation.disable_auth_get is True: return None - raise NotImplementedError() + original_url = api_options.request.get_original_url() + split_url = original_url.split("?") + params = dict(parse_qsl(split_url[1])) if len(split_url) > 1 else {} + + session = None + should_try_refresh = False + try: + session = await get_session( + api_options.request, + session_required=False, + user_context=user_context, + ) + should_try_refresh = False + except Exception as error: + session = None + if isinstance(error, TryRefreshTokenError): + should_try_refresh = True + else: + # This should generally not happen, but we can handle this as if the session is not present, + # because then we redirect to the frontend, which should handle the validation error + should_try_refresh = False + + response = await api_implementation.auth_get( + params=params, + cookie=api_options.request.get_header("cookie"), + session=session, + should_try_refresh=should_try_refresh, + options=api_options, + user_context=user_context, + ) + + if isinstance(response, RedirectResponse): + if response.cookies: + cookie = SimpleCookie() + cookie.load(response.cookies) + for morsel in cookie.values(): + api_options.response.set_cookie( + key=morsel.key, + value=morsel.value, + domain=morsel.get("domain"), + secure=morsel.get("secure", True), + httponly=morsel.get("httponly", True), + expires=morsel.get("expires", None), + path=morsel.get("path", "/"), + samesite=morsel.get("samesite", "lax"), + ) + return api_options.response.redirect(response.redirect_to) + elif isinstance(response, ErrorOAuth2Response): + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code or 400, + api_options.response, + ) + else: + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/token.py b/supertokens_python/recipe/oauth2provider/api/token.py index 2d49260bd..011c02ef2 100644 --- a/supertokens_python/recipe/oauth2provider/api/token.py +++ b/supertokens_python/recipe/oauth2provider/api/token.py @@ -16,10 +16,13 @@ from typing import TYPE_CHECKING, Any, Dict +from supertokens_python.utils import send_200_response, send_non_200_response + if TYPE_CHECKING: from ..interfaces import ( APIOptions, APIInterface, + ErrorOAuth2Response, ) @@ -32,4 +35,26 @@ async def token_post( if api_implementation.disable_token_post is True: return None - raise NotImplementedError() + authorization_header = api_options.request.get_header("authorization") + + body = await api_options.request.json() + + response = await api_implementation.token_post( + authorization_header=authorization_header, + body=body, + options=api_options, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + # We do not need to normalize as this is not expected to be called by frontends where interception is enabled + return send_non_200_response( + { + "error": response.error, + "error_description": response.error_description, + }, + response.status_code or 400, + api_options.response, + ) + else: + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/api/user_info.py b/supertokens_python/recipe/oauth2provider/api/user_info.py index 8cee3f4ef..7501559e8 100644 --- a/supertokens_python/recipe/oauth2provider/api/user_info.py +++ b/supertokens_python/recipe/oauth2provider/api/user_info.py @@ -14,7 +14,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional + +from supertokens_python.asyncio import get_user +from supertokens_python.utils import ( + send_200_response, + send_non_200_response_with_message, +) if TYPE_CHECKING: from ..interfaces import ( @@ -32,4 +38,86 @@ async def user_info_get( if api_implementation.disable_user_info_get is True: return None - raise NotImplementedError() + authorization_header = api_options.request.get_header("authorization") + + if authorization_header is None or not authorization_header.startswith("Bearer "): + api_options.response.set_header( + "WWW-Authenticate", 'Bearer error="invalid_token"' + ) + api_options.response.set_header( + "Access-Control-Expose-Headers", "WWW-Authenticate" + ) + return send_non_200_response_with_message( + "Missing or invalid Authorization header", + 401, + api_options.response, + ) + + access_token = authorization_header.replace("Bearer ", "").strip() + + payload: Optional[Dict[str, Any]] = None + + try: + payload = await api_options.recipe_implementation.validate_oauth2_access_token( + token=access_token, + user_context=user_context, + ) + + except Exception: + api_options.response.set_header( + "WWW-Authenticate", 'Bearer error="invalid_token"' + ) + api_options.response.set_header( + "Access-Control-Expose-Headers", "WWW-Authenticate" + ) + return send_non_200_response_with_message( + "Invalid or expired OAuth2 access token", + 401, + api_options.response, + ) + + if not isinstance(payload.get("sub"), str) or not isinstance( + payload.get("scp"), list + ): + api_options.response.set_header( + "WWW-Authenticate", 'Bearer error="invalid_token"' + ) + api_options.response.set_header( + "Access-Control-Expose-Headers", "WWW-Authenticate" + ) + return send_non_200_response_with_message( + "Malformed access token payload", + 401, + api_options.response, + ) + + user_id = payload["sub"] + + user = await get_user(user_id, user_context) + + if user is None: + api_options.response.set_header( + "WWW-Authenticate", 'Bearer error="invalid_token"' + ) + api_options.response.set_header( + "Access-Control-Expose-Headers", "WWW-Authenticate" + ) + return send_non_200_response_with_message( + "Couldn't find any user associated with the access token", + 401, + api_options.response, + ) + + response = await api_implementation.user_info_get( + access_token_payload=payload, + user=user, + tenant_id=_tenant_id, + scopes=payload["scp"], + options=api_options, + user_context=user_context, + ) + + if isinstance(response, dict): + return send_200_response(response, api_options.response) + else: + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 7fdac563e..0c741b6ab 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -173,6 +173,17 @@ def from_json(json: Dict[str, Any]): token_type=json["token_type"], ) + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "accessToken": self.access_token, + "expiresIn": self.expires_in, + "idToken": self.id_token, + "refreshToken": self.refresh_token, + "scope": self.scope, + "tokenType": self.token_type, + } + class LoginInfo: def __init__( @@ -647,9 +658,7 @@ async def auth_get( should_try_refresh: bool, options: APIOptions, user_context: Dict[str, Any] = {}, - ) -> Union[ - Dict[str, Union[str, Optional[str]]], ErrorOAuth2Response, GeneralErrorResponse - ]: + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: pass @abstractmethod From fbee6d67cbf02ad5bce0d98a4c34d01ffe837f1b Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Wed, 4 Dec 2024 12:16:30 +0530 Subject: [PATCH 24/38] fix: remaining impl --- .../recipe/oauth2provider/api/end_session.py | 2 +- .../oauth2provider/api/implementation.py | 271 +++++++++++++++++- .../recipe/oauth2provider/interfaces.py | 6 +- .../oauth2provider/recipe_implementation.py | 19 +- 4 files changed, 285 insertions(+), 13 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/api/end_session.py b/supertokens_python/recipe/oauth2provider/api/end_session.py index cfbeed500..c2aa5cf60 100644 --- a/supertokens_python/recipe/oauth2provider/api/end_session.py +++ b/supertokens_python/recipe/oauth2provider/api/end_session.py @@ -72,7 +72,7 @@ async def end_session_post( EndSessionCallable = Callable[ [Dict[str, str], APIOptions, Optional[SessionContainer], bool, Dict[str, Any]], - Awaitable[Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]], + Awaitable[Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]], ] diff --git a/supertokens_python/recipe/oauth2provider/api/implementation.py b/supertokens_python/recipe/oauth2provider/api/implementation.py index ad683174e..99d102be3 100644 --- a/supertokens_python/recipe/oauth2provider/api/implementation.py +++ b/supertokens_python/recipe/oauth2provider/api/implementation.py @@ -12,8 +12,275 @@ # License for the specific language governing permissions and limitations # under the License. -from ..interfaces import APIInterface +from typing import Any, Dict, List, Optional, Union + +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import GeneralErrorResponse, User + +from .utils import ( + handle_login_internal_redirects, + handle_logout_internal_redirects, + login_get, +) +from ..interfaces import ( + APIInterface, + APIOptions, + ActiveTokenResponse, + ErrorOAuth2Response, + FrontendRedirectResponse, + InactiveTokenResponse, + LoginInfo, + RedirectResponse, + RevokeTokenUsingAuthorizationHeader, + RevokeTokenUsingClientIDAndClientSecret, + TokenInfo, +) class APIImplementation(APIInterface): - pass + async def login_get( + self, + login_challenge: str, + options: APIOptions, + session: Optional[SessionContainer] = None, + should_try_refresh: bool = False, + user_context: Dict[str, Any] = {}, + ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + response = await login_get( + recipe_implementation=options.recipe_implementation, + login_challenge=login_challenge, + session=session, + should_try_refresh=should_try_refresh, + is_direct_call=True, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return response + + resp_after_internal_redirects = await handle_login_internal_redirects( + response=response, + cookie=options.request.get_header("cookie") or "", + recipe_implementation=options.recipe_implementation, + session=session, + should_try_refresh=should_try_refresh, + user_context=user_context, + ) + + if isinstance(resp_after_internal_redirects, ErrorOAuth2Response): + return resp_after_internal_redirects + + return FrontendRedirectResponse( + frontend_redirect_to=resp_after_internal_redirects.redirect_to, + cookies=resp_after_internal_redirects.cookies, + ) + + async def auth_get( + self, + params: Any, + cookie: Optional[str], + session: Optional[SessionContainer], + should_try_refresh: bool, + options: APIOptions, + user_context: Dict[str, Any] = {}, + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + response = await options.recipe_implementation.authorization( + params=params, + cookies=cookie, + session=session, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return response + + return await handle_login_internal_redirects( + response=response, + recipe_implementation=options.recipe_implementation, + cookie=cookie or "", + session=session, + should_try_refresh=should_try_refresh, + user_context=user_context, + ) + + async def token_post( + self, + authorization_header: Optional[str], + body: Any, + options: APIOptions, + user_context: Dict[str, Any] = {}, + ) -> Union[TokenInfo, ErrorOAuth2Response, GeneralErrorResponse]: + return await options.recipe_implementation.token_exchange( + authorization_header=authorization_header, + body=body, + user_context=user_context, + ) + + async def login_info_get( + self, + login_challenge: str, + options: APIOptions, + user_context: Dict[str, Any] = {}, + ) -> Union[LoginInfo, ErrorOAuth2Response, GeneralErrorResponse]: + login_res = await options.recipe_implementation.get_login_request( + challenge=login_challenge, + user_context=user_context, + ) + + if isinstance(login_res, ErrorOAuth2Response): + return login_res + + client = login_res.client + + return LoginInfo( + client_id=client.client_id, + client_name=client.client_name, + tos_uri=client.tos_uri, + policy_uri=client.policy_uri, + logo_uri=client.logo_uri, + client_uri=client.client_uri, + metadata=client.metadata, + ) + + async def user_info_get( + self, + access_token_payload: Dict[str, Any], + user: User, + scopes: List[str], + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any] = {}, + ) -> Union[Dict[str, Any], GeneralErrorResponse]: + return await options.recipe_implementation.build_user_info( + user=user, + access_token_payload=access_token_payload, + scopes=scopes, + tenant_id=tenant_id, + user_context=user_context, + ) + + async def revoke_token_post( + self, + token: str, + options: APIOptions, + user_context: Dict[str, Any] = {}, + authorization_header: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + ) -> Union[None, ErrorOAuth2Response, GeneralErrorResponse]: + if authorization_header is not None: + return await options.recipe_implementation.revoke_token( + input=RevokeTokenUsingAuthorizationHeader( + token=token, + authorization_header=authorization_header, + ), + user_context=user_context, + ) + elif client_id is not None: + if client_secret is None: + raise Exception("client_secret is required") + + return await options.recipe_implementation.revoke_token( + input=RevokeTokenUsingClientIDAndClientSecret( + token=token, + client_id=client_id, + client_secret=client_secret, + ), + user_context=user_context, + ) + else: + raise Exception( + "Either of 'authorization_header' or 'client_id' must be provided" + ) + + async def introspect_token_post( + self, + token: str, + scopes: Optional[List[str]], + options: APIOptions, + user_context: Dict[str, Any] = {}, + ) -> Union[ActiveTokenResponse, InactiveTokenResponse, GeneralErrorResponse]: + return await options.recipe_implementation.introspect_token( + token=token, + scopes=scopes, + user_context=user_context, + ) + + async def end_session_get( + self, + params: Dict[str, str], + options: APIOptions, + session: Optional[SessionContainer] = None, + should_try_refresh: bool = False, + user_context: Dict[str, Any] = {}, + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + response = await options.recipe_implementation.end_session( + params=params, + session=session, + should_try_refresh=should_try_refresh, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return response + + return await handle_logout_internal_redirects( + response=response, + session=session, + recipe_implementation=options.recipe_implementation, + user_context=user_context, + ) + + async def end_session_post( + self, + params: Dict[str, str], + options: APIOptions, + session: Optional[SessionContainer] = None, + should_try_refresh: bool = False, + user_context: Dict[str, Any] = {}, + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + response = await options.recipe_implementation.end_session( + params=params, + session=session, + should_try_refresh=should_try_refresh, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return response + + return await handle_logout_internal_redirects( + response=response, + session=session, + recipe_implementation=options.recipe_implementation, + user_context=user_context, + ) + + async def logout_post( + self, + logout_challenge: str, + options: APIOptions, + session: Optional[SessionContainer] = None, + user_context: Dict[str, Any] = {}, + ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: + if session is not None: + await session.revoke_session(user_context) + + response = await options.recipe_implementation.accept_logout_request( + challenge=logout_challenge, + user_context=user_context, + ) + + if isinstance(response, ErrorOAuth2Response): + return response + + res = await handle_logout_internal_redirects( + response=response, + recipe_implementation=options.recipe_implementation, + user_context=user_context, + ) + + if isinstance(res, ErrorOAuth2Response): + return res + + return FrontendRedirectResponse(frontend_redirect_to=res.redirect_to) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 0c741b6ab..10ad4eae1 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -579,7 +579,7 @@ async def introspect_token( token: str, scopes: Optional[List[str]] = None, user_context: Dict[str, Any] = {}, - ) -> Dict[str, Any]: + ) -> Union[ActiveTokenResponse, InactiveTokenResponse]: pass @abstractmethod @@ -726,7 +726,7 @@ async def end_session_get( session: Optional[SessionContainer] = None, should_try_refresh: bool = False, user_context: Dict[str, Any] = {}, - ) -> Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]: + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: pass @abstractmethod @@ -737,7 +737,7 @@ async def end_session_post( session: Optional[SessionContainer] = None, should_try_refresh: bool = False, user_context: Dict[str, Any] = {}, - ) -> Union[Dict[str, str], ErrorOAuth2Response, GeneralErrorResponse]: + ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: pass @abstractmethod diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 7df7752eb..28395e9ef 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -50,6 +50,8 @@ OAuth2Client, TokenInfo, UserInfoBuilderFunction, + ActiveTokenResponse, + InactiveTokenResponse, ) @@ -443,11 +445,11 @@ async def token_exchange( user_context=user_context, ) - if token_info.get("active"): - session_handle = token_info["sessionHandle"] + if isinstance(token_info, ActiveTokenResponse): + session_handle = token_info.payload["sessionHandle"] client_info = await self.get_oauth2_client( - client_id=token_info["client_id"], user_context=user_context + client_id=token_info.payload["client_id"], user_context=user_context ) if isinstance(client_info, ErrorOAuth2Response): @@ -458,7 +460,7 @@ async def token_exchange( ) client = client_info.client - user = await get_user(token_info["sub"]) + user = await get_user(token_info.payload["sub"]) if not user: return ErrorOAuth2Response( @@ -826,7 +828,7 @@ async def introspect_token( token: str, scopes: Optional[List[str]] = None, user_context: Dict[str, Any] = {}, - ) -> Dict[str, Any]: + ) -> Union[ActiveTokenResponse, InactiveTokenResponse]: # Determine if the token is an access token by checking if it doesn't start with "st_rt" is_access_token = not token.startswith("st_rt") @@ -845,7 +847,7 @@ async def introspect_token( user_context=user_context, ) except Exception: - return {"active": False} + return InactiveTokenResponse() # For tokens that passed local validation or if it's a refresh token, # validate the token with the database by calling the core introspection endpoint @@ -858,7 +860,10 @@ async def introspect_token( user_context=user_context, ) - return res + if res.get("active"): + return ActiveTokenResponse(payload=res) + else: + return InactiveTokenResponse() async def end_session( self, From 9eb33ce26fa67a623c3856244258d9e1d00d9a25 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Wed, 4 Dec 2024 18:22:22 +0530 Subject: [PATCH 25/38] fix: typing --- .../oauth2provider/api/implementation.py | 48 +++--- .../recipe/oauth2provider/api/user_info.py | 2 + .../recipe/oauth2provider/api/utils.py | 34 +++-- .../recipe/oauth2provider/interfaces.py | 138 +++++++++--------- .../oauth2provider/recipe_implementation.py | 134 ++++++++--------- 5 files changed, 182 insertions(+), 174 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/api/implementation.py b/supertokens_python/recipe/oauth2provider/api/implementation.py index 99d102be3..cc06b7961 100644 --- a/supertokens_python/recipe/oauth2provider/api/implementation.py +++ b/supertokens_python/recipe/oauth2provider/api/implementation.py @@ -42,9 +42,9 @@ async def login_get( self, login_challenge: str, options: APIOptions, - session: Optional[SessionContainer] = None, - should_try_refresh: bool = False, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: response = await login_get( recipe_implementation=options.recipe_implementation, @@ -52,6 +52,7 @@ async def login_get( session=session, should_try_refresh=should_try_refresh, is_direct_call=True, + cookies=None, user_context=user_context, ) @@ -82,7 +83,7 @@ async def auth_get( session: Optional[SessionContainer], should_try_refresh: bool, options: APIOptions, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: response = await options.recipe_implementation.authorization( params=params, @@ -108,7 +109,7 @@ async def token_post( authorization_header: Optional[str], body: Any, options: APIOptions, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[TokenInfo, ErrorOAuth2Response, GeneralErrorResponse]: return await options.recipe_implementation.token_exchange( authorization_header=authorization_header, @@ -120,7 +121,7 @@ async def login_info_get( self, login_challenge: str, options: APIOptions, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[LoginInfo, ErrorOAuth2Response, GeneralErrorResponse]: login_res = await options.recipe_implementation.get_login_request( challenge=login_challenge, @@ -149,7 +150,7 @@ async def user_info_get( scopes: List[str], tenant_id: str, options: APIOptions, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[Dict[str, Any], GeneralErrorResponse]: return await options.recipe_implementation.build_user_info( user=user, @@ -161,16 +162,16 @@ async def user_info_get( async def revoke_token_post( self, - token: str, options: APIOptions, - user_context: Dict[str, Any] = {}, - authorization_header: Optional[str] = None, - client_id: Optional[str] = None, - client_secret: Optional[str] = None, + token: str, + authorization_header: Optional[str], + client_id: Optional[str], + client_secret: Optional[str], + user_context: Dict[str, Any], ) -> Union[None, ErrorOAuth2Response, GeneralErrorResponse]: if authorization_header is not None: return await options.recipe_implementation.revoke_token( - input=RevokeTokenUsingAuthorizationHeader( + params=RevokeTokenUsingAuthorizationHeader( token=token, authorization_header=authorization_header, ), @@ -181,7 +182,7 @@ async def revoke_token_post( raise Exception("client_secret is required") return await options.recipe_implementation.revoke_token( - input=RevokeTokenUsingClientIDAndClientSecret( + params=RevokeTokenUsingClientIDAndClientSecret( token=token, client_id=client_id, client_secret=client_secret, @@ -198,7 +199,7 @@ async def introspect_token_post( token: str, scopes: Optional[List[str]], options: APIOptions, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[ActiveTokenResponse, InactiveTokenResponse, GeneralErrorResponse]: return await options.recipe_implementation.introspect_token( token=token, @@ -210,9 +211,9 @@ async def end_session_get( self, params: Dict[str, str], options: APIOptions, - session: Optional[SessionContainer] = None, - should_try_refresh: bool = False, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: response = await options.recipe_implementation.end_session( params=params, @@ -235,9 +236,9 @@ async def end_session_post( self, params: Dict[str, str], options: APIOptions, - session: Optional[SessionContainer] = None, - should_try_refresh: bool = False, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: response = await options.recipe_implementation.end_session( params=params, @@ -260,8 +261,8 @@ async def logout_post( self, logout_challenge: str, options: APIOptions, - session: Optional[SessionContainer] = None, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + user_context: Dict[str, Any], ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: if session is not None: await session.revoke_session(user_context) @@ -277,6 +278,7 @@ async def logout_post( res = await handle_logout_internal_redirects( response=response, recipe_implementation=options.recipe_implementation, + session=session, user_context=user_context, ) diff --git a/supertokens_python/recipe/oauth2provider/api/user_info.py b/supertokens_python/recipe/oauth2provider/api/user_info.py index 7501559e8..9fbdeaf1f 100644 --- a/supertokens_python/recipe/oauth2provider/api/user_info.py +++ b/supertokens_python/recipe/oauth2provider/api/user_info.py @@ -60,6 +60,8 @@ async def user_info_get( try: payload = await api_options.recipe_implementation.validate_oauth2_access_token( token=access_token, + requirements=None, + check_database=None, user_context=user_context, ) diff --git a/supertokens_python/recipe/oauth2provider/api/utils.py b/supertokens_python/recipe/oauth2provider/api/utils.py index 201b87270..8fb9a5db5 100644 --- a/supertokens_python/recipe/oauth2provider/api/utils.py +++ b/supertokens_python/recipe/oauth2provider/api/utils.py @@ -38,11 +38,11 @@ async def login_get( recipe_implementation: RecipeInterface, login_challenge: str, - session: Optional[SessionContainer] = None, - should_try_refresh: bool = False, - cookies: Optional[str] = None, - is_direct_call: bool = False, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + should_try_refresh: bool, + cookies: Optional[str], + is_direct_call: bool, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: login_request = await recipe_implementation.get_login_request( challenge=login_challenge, @@ -116,6 +116,10 @@ async def login_get( ): accept = await recipe_implementation.accept_login_request( challenge=login_challenge, + acr=None, + amr=None, + context=None, + extend_session_lifespan=None, subject=session.get_user_id(), identity_provider_session_id=session.get_handle(), user_context=user_context, @@ -128,7 +132,7 @@ async def login_get( if should_try_refresh and prompt_param != "login": return RedirectResponse( redirect_to=await recipe_implementation.get_frontend_redirection_url( - input=FrontendRedirectionURLTypeTryRefresh( + params=FrontendRedirectionURLTypeTryRefresh( login_challenge=login_challenge, ), user_context=user_context, @@ -152,7 +156,7 @@ async def login_get( return RedirectResponse( redirect_to=await recipe_implementation.get_frontend_redirection_url( - input=FrontendRedirectionURLTypeLogin( + params=FrontendRedirectionURLTypeLogin( login_challenge=login_challenge, force_fresh_auth=session is not None or prompt_param == "login", tenant_id=tenant_id_param or DEFAULT_TENANT_ID, @@ -168,9 +172,7 @@ async def login_get( ) -def get_merged_cookies( - orig_cookies: str = "", new_cookies: Optional[str] = None -) -> str: +def get_merged_cookies(orig_cookies: str, new_cookies: Optional[str]) -> str: if not new_cookies: return orig_cookies @@ -223,10 +225,10 @@ def is_logout_internal_redirect(redirect_to: str) -> bool: async def handle_login_internal_redirects( response: RedirectResponse, recipe_implementation: RecipeInterface, - session: Optional[SessionContainer] = None, - should_try_refresh: bool = False, - cookie: str = "", - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + should_try_refresh: bool, + cookie: str, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: if not is_login_internal_redirect(response.redirect_to): return response @@ -297,8 +299,8 @@ async def handle_login_internal_redirects( async def handle_logout_internal_redirects( response: RedirectResponse, recipe_implementation: RecipeInterface, - session: Optional[SessionContainer] = None, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: if not is_logout_internal_redirect(response.redirect_to): return response diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 10ad4eae1..af8eec1e6 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -368,7 +368,7 @@ async def authorization( params: Dict[str, str], cookies: Optional[str], session: Optional[SessionContainer], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: pass @@ -377,13 +377,13 @@ async def token_exchange( self, authorization_header: Optional[str], body: Dict[str, Optional[str]], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[TokenInfo, ErrorOAuth2Response]: pass @abstractmethod async def get_consent_request( - self, challenge: str, user_context: Dict[str, Any] = {} + self, challenge: str, user_context: Dict[str, Any] ) -> ConsentRequest: pass @@ -391,16 +391,16 @@ async def get_consent_request( async def accept_consent_request( self, challenge: str, - context: Optional[Any] = None, - grant_access_token_audience: Optional[List[str]] = None, - grant_scope: Optional[List[str]] = None, - handled_at: Optional[str] = None, - tenant_id: str = "", - rsub: str = "", - session_handle: str = "", - initial_access_token_payload: Optional[Dict[str, Any]] = None, - initial_id_token_payload: Optional[Dict[str, Any]] = None, - user_context: Dict[str, Any] = {}, + context: Optional[Any], + grant_access_token_audience: Optional[List[str]], + grant_scope: Optional[List[str]], + handled_at: Optional[str], + tenant_id: str, + rsub: str, + session_handle: str, + initial_access_token_payload: Optional[Dict[str, Any]], + initial_id_token_payload: Optional[Dict[str, Any]], + user_context: Dict[str, Any], ) -> RedirectResponse: pass @@ -420,13 +420,13 @@ async def get_login_request( async def accept_login_request( self, challenge: str, - acr: Optional[str] = None, - amr: Optional[List[str]] = None, - context: Optional[Any] = None, - extend_session_lifespan: Optional[bool] = None, - identity_provider_session_id: Optional[str] = None, - subject: str = "", - user_context: Dict[str, Any] = {}, + acr: Optional[str], + amr: Optional[List[str]], + context: Optional[Any], + extend_session_lifespan: Optional[bool], + identity_provider_session_id: Optional[str], + subject: str, + user_context: Dict[str, Any], ) -> RedirectResponse: pass @@ -435,17 +435,17 @@ async def reject_login_request( self, challenge: str, error: ErrorOAuth2Response, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> RedirectResponse: pass @abstractmethod async def get_oauth2_clients( self, - page_size: Optional[int] = None, - pagination_token: Optional[str] = None, - client_name: Optional[str] = None, - user_context: Dict[str, Any] = {}, + page_size: Optional[int], + pagination_token: Optional[str], + client_name: Optional[str], + user_context: Dict[str, Any], ) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: pass @@ -453,21 +453,21 @@ async def get_oauth2_clients( async def get_oauth2_client( self, client_id: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: pass @abstractmethod async def create_oauth2_client( self, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: pass @abstractmethod async def update_oauth2_client( self, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: pass @@ -475,7 +475,7 @@ async def update_oauth2_client( async def delete_oauth2_client( self, client_id: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: pass @@ -483,9 +483,9 @@ async def delete_oauth2_client( async def validate_oauth2_access_token( self, token: str, - requirements: Optional[OAuth2TokenValidationRequirements] = None, - check_database: Optional[bool] = None, - user_context: Dict[str, Any] = {}, + requirements: Optional[OAuth2TokenValidationRequirements], + check_database: Optional[bool], + user_context: Dict[str, Any], ) -> Dict[str, Any]: pass @@ -496,7 +496,7 @@ async def get_requested_scopes( session_handle: Optional[str], scope_param: List[str], client_id: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> List[str]: pass @@ -507,7 +507,7 @@ async def build_access_token_payload( client: OAuth2Client, session_handle: Optional[str], scopes: List[str], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Dict[str, Any]: pass @@ -518,7 +518,7 @@ async def build_id_token_payload( client: OAuth2Client, session_handle: Optional[str], scopes: List[str], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Dict[str, Any]: pass @@ -529,31 +529,31 @@ async def build_user_info( access_token_payload: Dict[str, Any], scopes: List[str], tenant_id: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Dict[str, Any]: pass @abstractmethod async def get_frontend_redirection_url( self, - input: Union[ + params: Union[ FrontendRedirectionURLTypeLogin, FrontendRedirectionURLTypeTryRefresh, FrontendRedirectionURLTypeLogoutConfirmation, FrontendRedirectionURLTypePostLogoutFallback, ], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> str: pass @abstractmethod async def revoke_token( self, - input: Union[ + params: Union[ RevokeTokenUsingAuthorizationHeader, RevokeTokenUsingClientIDAndClientSecret, ], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Optional[ErrorOAuth2Response]: pass @@ -561,7 +561,7 @@ async def revoke_token( async def revoke_tokens_by_client_id( self, client_id: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ): pass @@ -569,7 +569,7 @@ async def revoke_tokens_by_client_id( async def revoke_tokens_by_session_handle( self, session_handle: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ): pass @@ -577,8 +577,8 @@ async def revoke_tokens_by_session_handle( async def introspect_token( self, token: str, - scopes: Optional[List[str]] = None, - user_context: Dict[str, Any] = {}, + scopes: Optional[List[str]], + user_context: Dict[str, Any], ) -> Union[ActiveTokenResponse, InactiveTokenResponse]: pass @@ -587,8 +587,8 @@ async def end_session( self, params: Dict[str, str], should_try_refresh: bool, - session: Optional[SessionContainer] = None, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: pass @@ -596,7 +596,7 @@ async def end_session( async def accept_logout_request( self, challenge: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: pass @@ -604,7 +604,7 @@ async def accept_logout_request( async def reject_logout_request( self, challenge: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ): pass @@ -643,9 +643,9 @@ async def login_get( self, login_challenge: str, options: APIOptions, - session: Optional[SessionContainer] = None, - should_try_refresh: bool = False, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: pass @@ -657,7 +657,7 @@ async def auth_get( session: Optional[SessionContainer], should_try_refresh: bool, options: APIOptions, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: pass @@ -667,7 +667,7 @@ async def token_post( authorization_header: Optional[str], body: Any, options: APIOptions, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[TokenInfo, ErrorOAuth2Response, GeneralErrorResponse]: pass @@ -676,7 +676,7 @@ async def login_info_get( self, login_challenge: str, options: APIOptions, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[ LoginInfo, ErrorOAuth2Response, @@ -692,19 +692,19 @@ async def user_info_get( scopes: List[str], tenant_id: str, options: APIOptions, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[Dict[str, Any], GeneralErrorResponse]: pass @abstractmethod async def revoke_token_post( self, - token: str, options: APIOptions, - user_context: Dict[str, Any] = {}, - authorization_header: Optional[str] = None, - client_id: Optional[str] = None, - client_secret: Optional[str] = None, + token: str, + authorization_header: Optional[str], + client_id: Optional[str], + client_secret: Optional[str], + user_context: Dict[str, Any], ) -> Union[None, ErrorOAuth2Response, GeneralErrorResponse]: pass @@ -714,7 +714,7 @@ async def introspect_token_post( token: str, scopes: Optional[List[str]], options: APIOptions, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[ActiveTokenResponse, InactiveTokenResponse, GeneralErrorResponse]: pass @@ -723,9 +723,9 @@ async def end_session_get( self, params: Dict[str, str], options: APIOptions, - session: Optional[SessionContainer] = None, - should_try_refresh: bool = False, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: pass @@ -734,9 +734,9 @@ async def end_session_post( self, params: Dict[str, str], options: APIOptions, - session: Optional[SessionContainer] = None, - should_try_refresh: bool = False, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + should_try_refresh: bool, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: pass @@ -745,7 +745,7 @@ async def logout_post( self, logout_challenge: str, options: APIOptions, - session: Optional[SessionContainer] = None, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + user_context: Dict[str, Any], ) -> Union[FrontendRedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]: pass diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 28395e9ef..cdf34ccf0 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -103,13 +103,13 @@ async def get_login_request( async def accept_login_request( self, challenge: str, - acr: Optional[str] = None, - amr: Optional[List[str]] = None, - context: Optional[Any] = None, - extend_session_lifespan: Optional[bool] = None, - identity_provider_session_id: Optional[str] = None, - subject: str = "", - user_context: Dict[str, Any] = {}, + acr: Optional[str], + amr: Optional[List[str]], + context: Optional[Any], + extend_session_lifespan: Optional[bool], + identity_provider_session_id: Optional[str], + subject: str, + user_context: Dict[str, Any], ) -> RedirectResponse: response = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/auth/requests/login/accept"), @@ -135,7 +135,7 @@ async def reject_login_request( self, challenge: str, error: ErrorOAuth2Response, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> RedirectResponse: response = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/auth/requests/login/reject"), @@ -154,7 +154,7 @@ async def reject_login_request( ) async def get_consent_request( - self, challenge: str, user_context: Dict[str, Any] = {} + self, challenge: str, user_context: Dict[str, Any] ) -> ConsentRequest: response = await self.querier.send_get_request( NormalisedURLPath("/recipe/oauth/auth/requests/consent"), @@ -167,16 +167,16 @@ async def get_consent_request( async def accept_consent_request( self, challenge: str, - context: Optional[Any] = None, - grant_access_token_audience: Optional[List[str]] = None, - grant_scope: Optional[List[str]] = None, - handled_at: Optional[str] = None, - tenant_id: str = "", - rsub: str = "", - session_handle: str = "", - initial_access_token_payload: Optional[Dict[str, Any]] = None, - initial_id_token_payload: Optional[Dict[str, Any]] = None, - user_context: Dict[str, Any] = {}, + context: Optional[Any], + grant_access_token_audience: Optional[List[str]], + grant_scope: Optional[List[str]], + handled_at: Optional[str], + tenant_id: str, + rsub: str, + session_handle: str, + initial_access_token_payload: Optional[Dict[str, Any]], + initial_id_token_payload: Optional[Dict[str, Any]], + user_context: Dict[str, Any], ) -> RedirectResponse: response = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/auth/requests/consent/accept"), @@ -227,7 +227,7 @@ async def authorization( params: Dict[str, str], cookies: Optional[str], session: Optional[SessionContainer], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: # we handle this in the backend SDK level if params.get("prompt") == "none": @@ -352,10 +352,11 @@ async def authorization( ) consent_res = await self.accept_consent_request( - user_context=user_context, challenge=consent_request.challenge, + context=None, grant_access_token_audience=consent_request.requested_access_token_audience, grant_scope=consent_request.requested_scope, + handled_at=None, tenant_id=session.get_tenant_id(), rsub=session.get_recipe_user_id().get_as_string(), session_handle=session.get_handle(), @@ -363,6 +364,7 @@ async def authorization( payloads.get("accessToken") if payloads else None ), initial_id_token_payload=payloads.get("idToken") if payloads else None, + user_context=user_context, ) return RedirectResponse( @@ -375,7 +377,7 @@ async def token_exchange( self, authorization_header: Optional[str], body: Dict[str, Optional[str]], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[TokenInfo, ErrorOAuth2Response]: request_body = { "iss": await OpenIdRecipe.get_issuer(user_context), @@ -511,10 +513,10 @@ async def token_exchange( async def get_oauth2_clients( self, - page_size: Optional[int] = None, - pagination_token: Optional[str] = None, - client_name: Optional[str] = None, - user_context: Dict[str, Any] = {}, + page_size: Optional[int], + pagination_token: Optional[str], + client_name: Optional[str], + user_context: Dict[str, Any], ) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: body: Dict[str, Any] = {} if page_size is not None: @@ -545,7 +547,7 @@ async def get_oauth2_clients( ) async def get_oauth2_client( - self, client_id: str, user_context: Dict[str, Any] = {} + self, client_id: str, user_context: Dict[str, Any] ) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: response = await self.querier.send_get_request( NormalisedURLPath("/recipe/oauth/clients"), @@ -567,7 +569,7 @@ async def get_oauth2_client( async def create_oauth2_client( self, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: response = await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/clients"), @@ -583,7 +585,7 @@ async def create_oauth2_client( async def update_oauth2_client( self, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: response = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/clients"), @@ -601,7 +603,7 @@ async def update_oauth2_client( async def delete_oauth2_client( self, client_id: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: response = await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/clients/remove"), @@ -618,9 +620,9 @@ async def delete_oauth2_client( async def validate_oauth2_access_token( self, token: str, - requirements: Optional[OAuth2TokenValidationRequirements] = None, - check_database: Optional[bool] = None, - user_context: Dict[str, Any] = {}, + requirements: Optional[OAuth2TokenValidationRequirements], + check_database: Optional[bool], + user_context: Dict[str, Any], ) -> Dict[str, Any]: # Verify token signature using session recipe's JWKS session_recipe = SessionRecipe.get_instance() @@ -677,7 +679,7 @@ async def get_requested_scopes( session_handle: Optional[str], scope_param: List[str], client_id: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> List[str]: _ = recipe_user_id _ = session_handle @@ -692,7 +694,7 @@ async def build_access_token_payload( client: OAuth2Client, session_handle: Optional[str], scopes: List[str], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Dict[str, Any]: if user is None or session_handle is None: return {} @@ -709,7 +711,7 @@ async def build_id_token_payload( client: OAuth2Client, session_handle: Optional[str], scopes: List[str], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Dict[str, Any]: if user is None or session_handle is None: return {} @@ -726,7 +728,7 @@ async def build_user_info( access_token_payload: Dict[str, Any], scopes: List[str], tenant_id: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Dict[str, Any]: return await self._get_default_user_info_payload( user, access_token_payload, scopes, tenant_id, user_context @@ -734,26 +736,26 @@ async def build_user_info( async def get_frontend_redirection_url( self, - input: Union[ + params: Union[ FrontendRedirectionURLTypeLogin, FrontendRedirectionURLTypeTryRefresh, FrontendRedirectionURLTypeLogoutConfirmation, FrontendRedirectionURLTypePostLogoutFallback, ], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> str: website_domain = self.app_info.get_origin( None, user_context ).get_as_string_dangerous() website_base_path = self.app_info.api_base_path.get_as_string_dangerous() - if isinstance(input, FrontendRedirectionURLTypeLogin): - query_params: Dict[str, str] = {"loginChallenge": input.login_challenge} - if input.tenant_id != "public": # DEFAULT_TENANT_ID is "public" - query_params["tenantId"] = input.tenant_id - if input.hint is not None: - query_params["hint"] = input.hint - if input.force_fresh_auth: + if isinstance(params, FrontendRedirectionURLTypeLogin): + query_params: Dict[str, str] = {"loginChallenge": params.login_challenge} + if params.tenant_id != "public": # DEFAULT_TENANT_ID is "public" + query_params["tenantId"] = params.tenant_id + if params.hint is not None: + query_params["hint"] = params.hint + if params.force_fresh_auth: query_params["forceFreshAuth"] = "true" query_string = "&".join( @@ -761,30 +763,30 @@ async def get_frontend_redirection_url( ) return f"{website_domain}{website_base_path}?{query_string}" - elif isinstance(input, FrontendRedirectionURLTypeTryRefresh): - return f"{website_domain}{website_base_path}/try-refresh?loginChallenge={input.login_challenge}" + elif isinstance(params, FrontendRedirectionURLTypeTryRefresh): + return f"{website_domain}{website_base_path}/try-refresh?loginChallenge={params.login_challenge}" - elif isinstance(input, FrontendRedirectionURLTypePostLogoutFallback): + elif isinstance(params, FrontendRedirectionURLTypePostLogoutFallback): return f"{website_domain}{website_base_path}" - else: # isinstance(input, FrontendRedirectionURLTypeLogoutConfirmation) - return f"{website_domain}{website_base_path}/oauth/logout?logoutChallenge={input.logout_challenge}" + else: # isinstance(params, FrontendRedirectionURLTypeLogoutConfirmation) + return f"{website_domain}{website_base_path}/oauth/logout?logoutChallenge={params.logout_challenge}" async def revoke_token( self, - input: Union[ + params: Union[ RevokeTokenUsingAuthorizationHeader, RevokeTokenUsingClientIDAndClientSecret, ], - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Optional[ErrorOAuth2Response]: - request_body = {"token": input.token} + request_body = {"token": params.token} - if isinstance(input, RevokeTokenUsingAuthorizationHeader): - request_body["authorizationHeader"] = input.authorization_header + if isinstance(params, RevokeTokenUsingAuthorizationHeader): + request_body["authorizationHeader"] = params.authorization_header else: - request_body["client_id"] = input.client_id - request_body["client_secret"] = input.client_secret + request_body["client_id"] = params.client_id + request_body["client_secret"] = params.client_secret res = await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/token/revoke"), @@ -804,7 +806,7 @@ async def revoke_token( async def revoke_tokens_by_client_id( self, client_id: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ): await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/session/revoke"), @@ -815,7 +817,7 @@ async def revoke_tokens_by_client_id( async def revoke_tokens_by_session_handle( self, session_handle: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ): await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/session/revoke"), @@ -826,8 +828,8 @@ async def revoke_tokens_by_session_handle( async def introspect_token( self, token: str, - scopes: Optional[List[str]] = None, - user_context: Dict[str, Any] = {}, + scopes: Optional[List[str]], + user_context: Dict[str, Any], ) -> Union[ActiveTokenResponse, InactiveTokenResponse]: # Determine if the token is an access token by checking if it doesn't start with "st_rt" is_access_token = not token.startswith("st_rt") @@ -869,8 +871,8 @@ async def end_session( self, params: Dict[str, str], should_try_refresh: bool, - session: Optional[SessionContainer] = None, - user_context: Dict[str, Any] = {}, + session: Optional[SessionContainer], + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: # NOTE: The API response has 3 possible cases: # @@ -946,7 +948,7 @@ async def end_session( async def accept_logout_request( self, challenge: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: resp = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/auth/requests/logout/accept"), @@ -977,7 +979,7 @@ async def accept_logout_request( async def reject_logout_request( self, challenge: str, - user_context: Dict[str, Any] = {}, + user_context: Dict[str, Any], ): resp = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/auth/requests/logout/reject"), From e3d1287e97806fecd6c4a4ed9460eb908af5373d Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 10 Dec 2024 17:51:27 +0530 Subject: [PATCH 26/38] fix: type and lint --- .pylintrc | 1 + .../recipe/oauth2provider/__init__.py | 20 +++++++++++++++++++ .../recipe/oauth2provider/api/auth.py | 19 +++++++++--------- .../recipe/oauth2provider/api/end_session.py | 16 ++++++--------- .../recipe/oauth2provider/api/login.py | 12 +++++------ .../recipe/oauth2provider/api/login_info.py | 8 +++----- .../recipe/oauth2provider/api/logout.py | 10 +++++++--- .../recipe/oauth2provider/api/revoke_token.py | 7 +++++-- .../recipe/oauth2provider/api/token.py | 5 ++++- .../recipe/oauth2provider/api/utils.py | 9 +++++++-- .../recipe/oauth2provider/interfaces.py | 1 + .../oauth2provider/recipe_implementation.py | 1 + .../recipe/oauth2provider/utils.py | 8 ++++++-- 13 files changed, 76 insertions(+), 41 deletions(-) diff --git a/.pylintrc b/.pylintrc index 69ed5aab3..221068cf7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -124,6 +124,7 @@ disable=raw-checker-failed, no-else-raise, too-many-nested-blocks, broad-exception-raised, + too-many-public-methods, # Enable the message, report, category or checker with the given id(s). You can diff --git a/supertokens_python/recipe/oauth2provider/__init__.py b/supertokens_python/recipe/oauth2provider/__init__.py index a59c06f52..55b44908d 100644 --- a/supertokens_python/recipe/oauth2provider/__init__.py +++ b/supertokens_python/recipe/oauth2provider/__init__.py @@ -11,3 +11,23 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Union + +from . import exceptions as ex +from . import recipe + +exceptions = ex + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo + + from ...recipe_module import RecipeModule + from .utils import InputOverrideConfig + + +def init( + override: Union[InputOverrideConfig, None] = None, +) -> Callable[[AppInfo], RecipeModule]: + return recipe.OAuth2ProviderRecipe.init(override) diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py index 15b735ff4..1819aa238 100644 --- a/supertokens_python/recipe/oauth2provider/api/auth.py +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -18,8 +18,6 @@ from typing import TYPE_CHECKING, Any, Dict from urllib.parse import parse_qsl -from fastapi.responses import RedirectResponse - from supertokens_python.recipe.session.asyncio import get_session from supertokens_python.recipe.session.exceptions import TryRefreshTokenError from supertokens_python.utils import send_200_response, send_non_200_response @@ -28,8 +26,6 @@ from ..interfaces import ( APIOptions, APIInterface, - RedirectResponse, - ErrorOAuth2Response, ) @@ -39,6 +35,11 @@ async def auth_get( api_options: APIOptions, user_context: Dict[str, Any], ): + from ..interfaces import ( + RedirectResponse, + ErrorOAuth2Response, + ) + if api_implementation.disable_auth_get is True: return None @@ -57,12 +58,10 @@ async def auth_get( should_try_refresh = False except Exception as error: session = None - if isinstance(error, TryRefreshTokenError): - should_try_refresh = True - else: - # This should generally not happen, but we can handle this as if the session is not present, - # because then we redirect to the frontend, which should handle the validation error - should_try_refresh = False + + # should_try_refresh = False should generally not happen, but we can handle this as if the session is not present, + # because then we redirect to the frontend, which should handle the validation error + should_try_refresh = isinstance(error, TryRefreshTokenError) response = await api_implementation.auth_get( params=params, diff --git a/supertokens_python/recipe/oauth2provider/api/end_session.py b/supertokens_python/recipe/oauth2provider/api/end_session.py index c2aa5cf60..ebda43153 100644 --- a/supertokens_python/recipe/oauth2provider/api/end_session.py +++ b/supertokens_python/recipe/oauth2provider/api/end_session.py @@ -33,6 +33,11 @@ ErrorOAuth2Response, ) + EndSessionCallable = Callable[ + [Dict[str, str], APIOptions, Optional[SessionContainer], bool, Dict[str, Any]], + Awaitable[Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]], + ] + async def end_session_get( _tenant_id: str, @@ -70,12 +75,6 @@ async def end_session_post( ) -EndSessionCallable = Callable[ - [Dict[str, str], APIOptions, Optional[SessionContainer], bool, Dict[str, Any]], - Awaitable[Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]], -] - - async def end_session_common( params: Dict[str, str], api_implementation: Optional[EndSessionCallable], @@ -98,10 +97,7 @@ async def end_session_common( # We can handle this as if the session is not present, because then we redirect to the frontend, # which should handle the validation error session = None - if isinstance(error, TryRefreshTokenError): - should_try_refresh = True - else: - should_try_refresh = False + should_try_refresh = isinstance(error, TryRefreshTokenError) response = await api_implementation( params, diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py index 91ff8142f..59c93dcdb 100644 --- a/supertokens_python/recipe/oauth2provider/api/login.py +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -27,8 +27,6 @@ from ..interfaces import ( APIOptions, APIInterface, - FrontendRedirectResponse, - ErrorOAuth2Response, ) @@ -38,6 +36,11 @@ async def login( api_options: APIOptions, user_context: Dict[str, Any], ) -> Optional[BaseResponse]: + from ..interfaces import ( + FrontendRedirectResponse, + ErrorOAuth2Response, + ) + if api_implementation.disable_login_get is True: return None @@ -54,10 +57,7 @@ async def login( # We can handle this as if the session is not present, because then we redirect to the frontend, # which should handle the validation error session = None - if isinstance(error, TryRefreshTokenError): - should_try_refresh = True - else: - should_try_refresh = False + should_try_refresh = isinstance(error, TryRefreshTokenError) login_challenge = api_options.request.get_query_param( "login_challenge" diff --git a/supertokens_python/recipe/oauth2provider/api/login_info.py b/supertokens_python/recipe/oauth2provider/api/login_info.py index aab799f33..532e55f37 100644 --- a/supertokens_python/recipe/oauth2provider/api/login_info.py +++ b/supertokens_python/recipe/oauth2provider/api/login_info.py @@ -17,14 +17,10 @@ from typing import TYPE_CHECKING, Any, Dict from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python.recipe.oauth2provider.interfaces import ErrorOAuth2Response from supertokens_python.utils import send_200_response, send_non_200_response if TYPE_CHECKING: - from ..interfaces import ( - APIOptions, - APIInterface, - ) + from ..interfaces import APIOptions, APIInterface async def login_info_get( @@ -33,6 +29,8 @@ async def login_info_get( api_options: APIOptions, user_context: Dict[str, Any], ): + from ..interfaces import ErrorOAuth2Response + if api_implementation.disable_login_info_get is True: return None diff --git a/supertokens_python/recipe/oauth2provider/api/logout.py b/supertokens_python/recipe/oauth2provider/api/logout.py index 646a705ff..8cb4f349e 100644 --- a/supertokens_python/recipe/oauth2provider/api/logout.py +++ b/supertokens_python/recipe/oauth2provider/api/logout.py @@ -25,8 +25,6 @@ from ..interfaces import ( APIOptions, APIInterface, - FrontendRedirectResponse, - ErrorOAuth2Response, ) @@ -36,14 +34,20 @@ async def logout_post( api_options: APIOptions, user_context: Dict[str, Any], ) -> Optional[BaseResponse]: + from ..interfaces import ( + FrontendRedirectResponse, + ErrorOAuth2Response, + ) + if api_implementation.disable_logout_post is True: return None + session = None try: session = await get_session( api_options.request, session_required=False, user_context=user_context ) - except: + except Exception as _: pass body = await api_options.request.json() diff --git a/supertokens_python/recipe/oauth2provider/api/revoke_token.py b/supertokens_python/recipe/oauth2provider/api/revoke_token.py index e576a7ff6..331ecfaa1 100644 --- a/supertokens_python/recipe/oauth2provider/api/revoke_token.py +++ b/supertokens_python/recipe/oauth2provider/api/revoke_token.py @@ -23,8 +23,6 @@ from ..interfaces import ( APIOptions, APIInterface, - ErrorOAuth2Response, - GeneralErrorResponse, ) @@ -34,6 +32,11 @@ async def revoke_token_post( api_options: APIOptions, user_context: Dict[str, Any], ) -> Optional[BaseResponse]: + from ..interfaces import ( + ErrorOAuth2Response, + GeneralErrorResponse, + ) + if api_implementation.disable_revoke_token_post is True: return None diff --git a/supertokens_python/recipe/oauth2provider/api/token.py b/supertokens_python/recipe/oauth2provider/api/token.py index 011c02ef2..16f8d58b2 100644 --- a/supertokens_python/recipe/oauth2provider/api/token.py +++ b/supertokens_python/recipe/oauth2provider/api/token.py @@ -22,7 +22,6 @@ from ..interfaces import ( APIOptions, APIInterface, - ErrorOAuth2Response, ) @@ -32,6 +31,10 @@ async def token_post( api_options: APIOptions, user_context: Dict[str, Any], ): + from ..interfaces import ( + ErrorOAuth2Response, + ) + if api_implementation.disable_token_post is True: return None diff --git a/supertokens_python/recipe/oauth2provider/api/utils.py b/supertokens_python/recipe/oauth2provider/api/utils.py index 8fb9a5db5..2b524b017 100644 --- a/supertokens_python/recipe/oauth2provider/api/utils.py +++ b/supertokens_python/recipe/oauth2provider/api/utils.py @@ -29,8 +29,6 @@ RecipeInterface, ErrorOAuth2Response, RedirectResponse, - FrontendRedirectionURLTypeTryRefresh, - FrontendRedirectionURLTypeLogin, ) from supertokens_python.recipe.session.interfaces import SessionContainer @@ -44,6 +42,13 @@ async def login_get( is_direct_call: bool, user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: + from ..interfaces import ( + ErrorOAuth2Response, + RedirectResponse, + FrontendRedirectionURLTypeTryRefresh, + FrontendRedirectionURLTypeLogin, + ) + login_request = await recipe_implementation.get_login_request( challenge=login_challenge, user_context=user_context, diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index af8eec1e6..0b6fe9c83 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -11,6 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index cdf34ccf0..c08ec852f 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -11,6 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +from __future__ import annotations import base64 from typing import TYPE_CHECKING, Dict, Optional, Any, Union, List diff --git a/supertokens_python/recipe/oauth2provider/utils.py b/supertokens_python/recipe/oauth2provider/utils.py index cb7553dd1..c7f49c623 100644 --- a/supertokens_python/recipe/oauth2provider/utils.py +++ b/supertokens_python/recipe/oauth2provider/utils.py @@ -11,10 +11,14 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +from __future__ import annotations -from typing import Callable, Union -from .interfaces import APIInterface, RecipeInterface +from typing import Callable, TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Union + from .interfaces import APIInterface, RecipeInterface class InputOverrideConfig: From 3041401db7d34539d0a5b2a47918756059d8d743 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Wed, 11 Dec 2024 18:12:43 +0530 Subject: [PATCH 27/38] fix: types, exposed functions and cyclic import --- .../recipe/oauth2provider/asyncio/__init__.py | 228 +++++++ .../recipe/oauth2provider/interfaces.py | 593 +++++++++++++++++- .../oauth2provider/recipe_implementation.py | 13 +- .../recipe/oauth2provider/syncio/__init__.py | 163 +++++ 4 files changed, 992 insertions(+), 5 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/asyncio/__init__.py b/supertokens_python/recipe/oauth2provider/asyncio/__init__.py index a59c06f52..2e54a7144 100644 --- a/supertokens_python/recipe/oauth2provider/asyncio/__init__.py +++ b/supertokens_python/recipe/oauth2provider/asyncio/__init__.py @@ -11,3 +11,231 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +from __future__ import annotations + +import base64 +from typing import Any, Dict, Union, Optional, List + +from ..interfaces import ( + ActiveTokenResponse, + CreateOAuth2ClientInput, + CreateOAuth2ClientOkResult, + DeleteOAuth2ClientOkResult, + ErrorOAuth2Response, + GetOAuth2ClientOkResult, + GetOAuth2ClientsOkResult, + InactiveTokenResponse, + OAuth2TokenValidationRequirements, + RevokeTokenUsingAuthorizationHeader, + RevokeTokenUsingClientIDAndClientSecret, + TokenInfo, + UpdateOAuth2ClientInput, + UpdateOAuth2ClientOkResult, +) + + +async def get_oauth2_client( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.get_oauth2_client( + client_id=client_id, user_context=user_context + ) + + +async def get_oauth2_clients( + page_size: Optional[int] = None, + pagination_token: Optional[str] = None, + client_name: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.get_oauth2_clients( + page_size=page_size, + pagination_token=pagination_token, + client_name=client_name, + user_context=user_context, + ) + + +async def create_oauth2_client( + params: CreateOAuth2ClientInput, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.create_oauth2_client( + params=params, + user_context=user_context, + ) + + +async def update_oauth2_client( + params: UpdateOAuth2ClientInput, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.update_oauth2_client( + params=params, + user_context=user_context, + ) + + +async def delete_oauth2_client( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.delete_oauth2_client( + client_id=client_id, user_context=user_context + ) + + +async def validate_oauth2_access_token( + token: str, + requirements: Optional[OAuth2TokenValidationRequirements] = None, + check_database: Optional[bool] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.validate_oauth2_access_token( + token=token, + requirements=requirements, + check_database=check_database, + user_context=user_context, + ) + + +async def create_token_for_client_credentials( + client_id: str, + client_secret: str, + scope: Optional[List[str]] = None, + audience: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[TokenInfo, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return ( + await OAuth2ProviderRecipe.get_instance().recipe_implementation.token_exchange( + authorization_header=None, + body={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + "scope": " ".join(scope) if scope else None, + "audience": audience, + }, + user_context=user_context, + ) + ) + + +async def revoke_token( + token: str, + client_id: str, + client_secret: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Optional[ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + recipe = OAuth2ProviderRecipe.get_instance() + + client_info = await recipe.recipe_implementation.get_oauth2_client( + client_id=client_id, user_context=user_context + ) + + if isinstance(client_info, ErrorOAuth2Response): + raise Exception( + f"Failed to get OAuth2 client with id {client_id}: {client_info.error}" + ) + + token_endpoint_auth_method = client_info.client.token_endpoint_auth_method + + if token_endpoint_auth_method == "none": + auth_header = "Basic " + base64.b64encode(f"{client_id}:".encode()).decode() + return await recipe.recipe_implementation.revoke_token( + RevokeTokenUsingAuthorizationHeader( + token=token, + authorization_header=auth_header, + ), + user_context=user_context, + ) + elif token_endpoint_auth_method == "client_secret_basic" and client_secret: + auth_header = ( + "Basic " + + base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() + ) + return await recipe.recipe_implementation.revoke_token( + RevokeTokenUsingAuthorizationHeader( + token=token, + authorization_header=auth_header, + ), + user_context=user_context, + ) + + return await recipe.recipe_implementation.revoke_token( + RevokeTokenUsingClientIDAndClientSecret( + token=token, + client_id=client_id, + client_secret=client_secret, + ), + user_context=user_context, + ) + + +async def revoke_tokens_by_client_id( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> None: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.revoke_tokens_by_client_id( + client_id=client_id, user_context=user_context + ) + + +async def revoke_tokens_by_session_handle( + session_handle: str, user_context: Optional[Dict[str, Any]] = None +) -> None: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.revoke_tokens_by_session_handle( + session_handle=session_handle, user_context=user_context + ) + + +async def validate_oauth2_refresh_token( + token: str, + scopes: Optional[List[str]] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ActiveTokenResponse, InactiveTokenResponse]: + if user_context is None: + user_context = {} + from ..recipe import OAuth2ProviderRecipe + + return await OAuth2ProviderRecipe.get_instance().recipe_implementation.introspect_token( + token=token, scopes=scopes, user_context=user_context + ) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 0b6fe9c83..8b9f2649d 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -343,7 +343,7 @@ def __init__(self, token: str, authorization_header: str): class RevokeTokenUsingClientIDAndClientSecret: - def __init__(self, token: str, client_id: str, client_secret: str): + def __init__(self, token: str, client_id: str, client_secret: Optional[str]): self.token = token self.client_id = client_id self.client_secret = client_secret @@ -362,6 +362,595 @@ def to_json(self): return {"active": True, **self.payload} +class OAuth2ClientOptions: + def __init__( + self, + client_id: str, + client_secret: Optional[str], + created_at: str, + updated_at: str, + client_name: str, + scope: str, + redirect_uris: Optional[List[str]], + post_logout_redirect_uris: Optional[List[str]], + authorization_code_grant_access_token_lifespan: Optional[str], + authorization_code_grant_id_token_lifespan: Optional[str], + authorization_code_grant_refresh_token_lifespan: Optional[str], + client_credentials_grant_access_token_lifespan: Optional[str], + implicit_grant_access_token_lifespan: Optional[str], + implicit_grant_id_token_lifespan: Optional[str], + refresh_token_grant_access_token_lifespan: Optional[str], + refresh_token_grant_id_token_lifespan: Optional[str], + refresh_token_grant_refresh_token_lifespan: Optional[str], + token_endpoint_auth_method: str, + audience: Optional[List[str]], + grant_types: Optional[List[str]], + response_types: Optional[List[str]], + client_uri: Optional[str], + logo_uri: Optional[str], + policy_uri: Optional[str], + tos_uri: Optional[str], + metadata: Optional[Dict[str, Any]], + enable_refresh_token_rotation: Optional[bool], + ): + self.client_id = client_id + self.client_secret = client_secret + self.created_at = created_at + self.updated_at = updated_at + self.client_name = client_name + self.scope = scope + self.redirect_uris = redirect_uris + self.post_logout_redirect_uris = post_logout_redirect_uris + self.authorization_code_grant_access_token_lifespan = ( + authorization_code_grant_access_token_lifespan + ) + self.authorization_code_grant_id_token_lifespan = ( + authorization_code_grant_id_token_lifespan + ) + self.authorization_code_grant_refresh_token_lifespan = ( + authorization_code_grant_refresh_token_lifespan + ) + self.client_credentials_grant_access_token_lifespan = ( + client_credentials_grant_access_token_lifespan + ) + self.implicit_grant_access_token_lifespan = implicit_grant_access_token_lifespan + self.implicit_grant_id_token_lifespan = implicit_grant_id_token_lifespan + self.refresh_token_grant_access_token_lifespan = ( + refresh_token_grant_access_token_lifespan + ) + self.refresh_token_grant_id_token_lifespan = ( + refresh_token_grant_id_token_lifespan + ) + self.refresh_token_grant_refresh_token_lifespan = ( + refresh_token_grant_refresh_token_lifespan + ) + self.token_endpoint_auth_method = token_endpoint_auth_method + self.audience = audience + self.grant_types = grant_types + self.response_types = response_types + self.client_uri = client_uri + self.logo_uri = logo_uri + self.policy_uri = policy_uri + self.tos_uri = tos_uri + self.metadata = metadata + self.enable_refresh_token_rotation = enable_refresh_token_rotation + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = { + "clientId": self.client_id, + "createdAt": self.created_at, + "updatedAt": self.updated_at, + "clientName": self.client_name, + "scope": self.scope, + "tokenEndpointAuthMethod": self.token_endpoint_auth_method, + } + if self.client_secret is not None: + result["clientSecret"] = self.client_secret + if self.redirect_uris is not None: + result["redirectUris"] = self.redirect_uris + if self.post_logout_redirect_uris is not None: + result["postLogoutRedirectUris"] = self.post_logout_redirect_uris + if self.authorization_code_grant_access_token_lifespan is not None: + result["authorizationCodeGrantAccessTokenLifespan"] = ( + self.authorization_code_grant_access_token_lifespan + ) + if self.authorization_code_grant_id_token_lifespan is not None: + result["authorizationCodeGrantIdTokenLifespan"] = ( + self.authorization_code_grant_id_token_lifespan + ) + if self.authorization_code_grant_refresh_token_lifespan is not None: + result["authorizationCodeGrantRefreshTokenLifespan"] = ( + self.authorization_code_grant_refresh_token_lifespan + ) + if self.client_credentials_grant_access_token_lifespan is not None: + result["clientCredentialsGrantAccessTokenLifespan"] = ( + self.client_credentials_grant_access_token_lifespan + ) + if self.implicit_grant_access_token_lifespan is not None: + result["implicitGrantAccessTokenLifespan"] = ( + self.implicit_grant_access_token_lifespan + ) + if self.implicit_grant_id_token_lifespan is not None: + result["implicitGrantIdTokenLifespan"] = ( + self.implicit_grant_id_token_lifespan + ) + if self.refresh_token_grant_access_token_lifespan is not None: + result["refreshTokenGrantAccessTokenLifespan"] = ( + self.refresh_token_grant_access_token_lifespan + ) + if self.refresh_token_grant_id_token_lifespan is not None: + result["refreshTokenGrantIdTokenLifespan"] = ( + self.refresh_token_grant_id_token_lifespan + ) + if self.refresh_token_grant_refresh_token_lifespan is not None: + result["refreshTokenGrantRefreshTokenLifespan"] = ( + self.refresh_token_grant_refresh_token_lifespan + ) + if self.audience is not None: + result["audience"] = self.audience + if self.grant_types is not None: + result["grantTypes"] = self.grant_types + if self.response_types is not None: + result["responseTypes"] = self.response_types + if self.client_uri is not None: + result["clientUri"] = self.client_uri + if self.logo_uri is not None: + result["logoUri"] = self.logo_uri + if self.policy_uri is not None: + result["policyUri"] = self.policy_uri + if self.tos_uri is not None: + result["tosUri"] = self.tos_uri + if self.metadata is not None: + result["metadata"] = self.metadata + if self.enable_refresh_token_rotation is not None: + result["enableRefreshTokenRotation"] = self.enable_refresh_token_rotation + return result + + @staticmethod + def from_json(json: Dict[str, Any]) -> "OAuth2ClientOptions": + return OAuth2ClientOptions( + client_id=json["clientId"], + client_secret=json["clientSecret"], + created_at=json["createdAt"], + updated_at=json["updatedAt"], + client_name=json["clientName"], + scope=json["scope"], + redirect_uris=json["redirectUris"], + post_logout_redirect_uris=json["postLogoutRedirectUris"], + authorization_code_grant_access_token_lifespan=json.get( + "authorizationCodeGrantAccessTokenLifespan" + ), + authorization_code_grant_id_token_lifespan=json.get( + "authorizationCodeGrantIdTokenLifespan" + ), + authorization_code_grant_refresh_token_lifespan=json.get( + "authorizationCodeGrantRefreshTokenLifespan" + ), + client_credentials_grant_access_token_lifespan=json.get( + "clientCredentialsGrantAccessTokenLifespan" + ), + implicit_grant_access_token_lifespan=json.get( + "implicitGrantAccessTokenLifespan" + ), + implicit_grant_id_token_lifespan=json.get("implicitGrantIdTokenLifespan"), + refresh_token_grant_access_token_lifespan=json.get( + "refreshTokenGrantAccessTokenLifespan" + ), + refresh_token_grant_id_token_lifespan=json.get( + "refreshTokenGrantIdTokenLifespan" + ), + refresh_token_grant_refresh_token_lifespan=json.get( + "refreshTokenGrantRefreshTokenLifespan" + ), + token_endpoint_auth_method=json["tokenEndpointAuthMethod"], + audience=json.get("audience"), + grant_types=json.get("grantTypes"), + response_types=json.get("responseTypes"), + client_uri=json.get("clientUri"), + logo_uri=json.get("logoUri"), + policy_uri=json.get("policyUri"), + tos_uri=json.get("tosUri"), + metadata=json.get("metadata"), + enable_refresh_token_rotation=json.get("enableRefreshTokenRotation"), + ) + + +class CreateOAuth2ClientInput: + def __init__( + self, + client_id: Optional[str], + client_secret: Optional[str], + client_name: Optional[str], + scope: Optional[str], + redirect_uris: Optional[List[str]], + post_logout_redirect_uris: Optional[List[str]], + authorization_code_grant_access_token_lifespan: Optional[str], + authorization_code_grant_id_token_lifespan: Optional[str], + authorization_code_grant_refresh_token_lifespan: Optional[str], + client_credentials_grant_access_token_lifespan: Optional[str], + implicit_grant_access_token_lifespan: Optional[str], + implicit_grant_id_token_lifespan: Optional[str], + refresh_token_grant_access_token_lifespan: Optional[str], + refresh_token_grant_id_token_lifespan: Optional[str], + refresh_token_grant_refresh_token_lifespan: Optional[str], + token_endpoint_auth_method: Optional[str], + audience: Optional[List[str]], + grant_types: Optional[List[str]], + response_types: Optional[List[str]], + client_uri: Optional[str], + logo_uri: Optional[str], + policy_uri: Optional[str], + tos_uri: Optional[str], + metadata: Optional[Dict[str, Any]], + enable_refresh_token_rotation: Optional[bool], + ): + self.client_id = client_id + self.client_secret = client_secret + self.client_name = client_name + self.scope = scope + self.redirect_uris = redirect_uris + self.post_logout_redirect_uris = post_logout_redirect_uris + self.authorization_code_grant_access_token_lifespan = ( + authorization_code_grant_access_token_lifespan + ) + self.authorization_code_grant_id_token_lifespan = ( + authorization_code_grant_id_token_lifespan + ) + self.authorization_code_grant_refresh_token_lifespan = ( + authorization_code_grant_refresh_token_lifespan + ) + self.client_credentials_grant_access_token_lifespan = ( + client_credentials_grant_access_token_lifespan + ) + self.implicit_grant_access_token_lifespan = implicit_grant_access_token_lifespan + self.implicit_grant_id_token_lifespan = implicit_grant_id_token_lifespan + self.refresh_token_grant_access_token_lifespan = ( + refresh_token_grant_access_token_lifespan + ) + self.refresh_token_grant_id_token_lifespan = ( + refresh_token_grant_id_token_lifespan + ) + self.refresh_token_grant_refresh_token_lifespan = ( + refresh_token_grant_refresh_token_lifespan + ) + self.token_endpoint_auth_method = token_endpoint_auth_method + self.audience = audience + self.grant_types = grant_types + self.response_types = response_types + self.client_uri = client_uri + self.logo_uri = logo_uri + self.policy_uri = policy_uri + self.tos_uri = tos_uri + self.metadata = metadata + self.enable_refresh_token_rotation = enable_refresh_token_rotation + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = {} + if self.client_id is not None: + result["clientId"] = self.client_id + if self.client_name is not None: + result["clientName"] = self.client_name + if self.scope is not None: + result["scope"] = self.scope + if self.token_endpoint_auth_method is not None: + result["tokenEndpointAuthMethod"] = self.token_endpoint_auth_method + if self.client_secret is not None: + result["clientSecret"] = self.client_secret + if self.redirect_uris is not None: + result["redirectUris"] = self.redirect_uris + if self.post_logout_redirect_uris is not None: + result["postLogoutRedirectUris"] = self.post_logout_redirect_uris + if self.authorization_code_grant_access_token_lifespan is not None: + result["authorizationCodeGrantAccessTokenLifespan"] = ( + self.authorization_code_grant_access_token_lifespan + ) + if self.authorization_code_grant_id_token_lifespan is not None: + result["authorizationCodeGrantIdTokenLifespan"] = ( + self.authorization_code_grant_id_token_lifespan + ) + if self.authorization_code_grant_refresh_token_lifespan is not None: + result["authorizationCodeGrantRefreshTokenLifespan"] = ( + self.authorization_code_grant_refresh_token_lifespan + ) + if self.client_credentials_grant_access_token_lifespan is not None: + result["clientCredentialsGrantAccessTokenLifespan"] = ( + self.client_credentials_grant_access_token_lifespan + ) + if self.implicit_grant_access_token_lifespan is not None: + result["implicitGrantAccessTokenLifespan"] = ( + self.implicit_grant_access_token_lifespan + ) + if self.implicit_grant_id_token_lifespan is not None: + result["implicitGrantIdTokenLifespan"] = ( + self.implicit_grant_id_token_lifespan + ) + if self.refresh_token_grant_access_token_lifespan is not None: + result["refreshTokenGrantAccessTokenLifespan"] = ( + self.refresh_token_grant_access_token_lifespan + ) + if self.refresh_token_grant_id_token_lifespan is not None: + result["refreshTokenGrantIdTokenLifespan"] = ( + self.refresh_token_grant_id_token_lifespan + ) + if self.refresh_token_grant_refresh_token_lifespan is not None: + result["refreshTokenGrantRefreshTokenLifespan"] = ( + self.refresh_token_grant_refresh_token_lifespan + ) + if self.audience is not None: + result["audience"] = self.audience + if self.grant_types is not None: + result["grantTypes"] = self.grant_types + if self.response_types is not None: + result["responseTypes"] = self.response_types + if self.client_uri is not None: + result["clientUri"] = self.client_uri + if self.logo_uri is not None: + result["logoUri"] = self.logo_uri + if self.policy_uri is not None: + result["policyUri"] = self.policy_uri + if self.tos_uri is not None: + result["tosUri"] = self.tos_uri + if self.metadata is not None: + result["metadata"] = self.metadata + if self.enable_refresh_token_rotation is not None: + result["enableRefreshTokenRotation"] = self.enable_refresh_token_rotation + return result + + @staticmethod + def from_json(json: Dict[str, Any]) -> "CreateOAuth2ClientInput": + return CreateOAuth2ClientInput( + client_id=json.get("clientId"), + client_secret=json.get("clientSecret"), + client_name=json.get("clientName"), + scope=json.get("scope"), + redirect_uris=json.get("redirectUris"), + post_logout_redirect_uris=json.get("postLogoutRedirectUris"), + authorization_code_grant_access_token_lifespan=json.get( + "authorizationCodeGrantAccessTokenLifespan" + ), + authorization_code_grant_id_token_lifespan=json.get( + "authorizationCodeGrantIdTokenLifespan" + ), + authorization_code_grant_refresh_token_lifespan=json.get( + "authorizationCodeGrantRefreshTokenLifespan" + ), + client_credentials_grant_access_token_lifespan=json.get( + "clientCredentialsGrantAccessTokenLifespan" + ), + implicit_grant_access_token_lifespan=json.get( + "implicitGrantAccessTokenLifespan" + ), + implicit_grant_id_token_lifespan=json.get("implicitGrantIdTokenLifespan"), + refresh_token_grant_access_token_lifespan=json.get( + "refreshTokenGrantAccessTokenLifespan" + ), + refresh_token_grant_id_token_lifespan=json.get( + "refreshTokenGrantIdTokenLifespan" + ), + refresh_token_grant_refresh_token_lifespan=json.get( + "refreshTokenGrantRefreshTokenLifespan" + ), + token_endpoint_auth_method=json.get("tokenEndpointAuthMethod"), + audience=json.get("audience"), + grant_types=json.get("grantTypes"), + response_types=json.get("responseTypes"), + client_uri=json.get("clientUri"), + logo_uri=json.get("logoUri"), + policy_uri=json.get("policyUri"), + tos_uri=json.get("tosUri"), + metadata=json.get("metadata"), + enable_refresh_token_rotation=json.get("enableRefreshTokenRotation"), + ) + + +class NotSet: + pass + + +class UpdateOAuth2ClientInput: + def __init__( + self, + client_id: str, + client_secret: Union[Optional[str], NotSet] = NotSet(), + client_name: Union[Optional[str], NotSet] = NotSet(), + scope: Union[Optional[str], NotSet] = NotSet(), + redirect_uris: Union[Optional[List[str]], NotSet] = NotSet(), + post_logout_redirect_uris: Union[Optional[List[str]], NotSet] = NotSet(), + authorization_code_grant_access_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + authorization_code_grant_id_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + authorization_code_grant_refresh_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + client_credentials_grant_access_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + implicit_grant_access_token_lifespan: Union[Optional[str], NotSet] = NotSet(), + implicit_grant_id_token_lifespan: Union[Optional[str], NotSet] = NotSet(), + refresh_token_grant_access_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + refresh_token_grant_id_token_lifespan: Union[Optional[str], NotSet] = NotSet(), + refresh_token_grant_refresh_token_lifespan: Union[ + Optional[str], NotSet + ] = NotSet(), + token_endpoint_auth_method: Union[Optional[str], NotSet] = NotSet(), + audience: Union[Optional[List[str]], NotSet] = NotSet(), + grant_types: Union[Optional[List[str]], NotSet] = NotSet(), + response_types: Union[Optional[List[str]], NotSet] = NotSet(), + client_uri: Union[Optional[str], NotSet] = NotSet(), + logo_uri: Union[Optional[str], NotSet] = NotSet(), + policy_uri: Union[Optional[str], NotSet] = NotSet(), + tos_uri: Union[Optional[str], NotSet] = NotSet(), + metadata: Union[Optional[Dict[str, Any]], NotSet] = NotSet(), + enable_refresh_token_rotation: Union[Optional[bool], NotSet] = NotSet(), + ): + self.client_id = client_id + self.client_secret = client_secret + self.client_name = client_name + self.scope = scope + self.redirect_uris = redirect_uris + self.post_logout_redirect_uris = post_logout_redirect_uris + self.authorization_code_grant_access_token_lifespan = ( + authorization_code_grant_access_token_lifespan + ) + self.authorization_code_grant_id_token_lifespan = ( + authorization_code_grant_id_token_lifespan + ) + self.authorization_code_grant_refresh_token_lifespan = ( + authorization_code_grant_refresh_token_lifespan + ) + self.client_credentials_grant_access_token_lifespan = ( + client_credentials_grant_access_token_lifespan + ) + self.implicit_grant_access_token_lifespan = implicit_grant_access_token_lifespan + self.implicit_grant_id_token_lifespan = implicit_grant_id_token_lifespan + self.refresh_token_grant_access_token_lifespan = ( + refresh_token_grant_access_token_lifespan + ) + self.refresh_token_grant_id_token_lifespan = ( + refresh_token_grant_id_token_lifespan + ) + self.refresh_token_grant_refresh_token_lifespan = ( + refresh_token_grant_refresh_token_lifespan + ) + self.token_endpoint_auth_method = token_endpoint_auth_method + self.audience = audience + self.grant_types = grant_types + self.response_types = response_types + self.client_uri = client_uri + self.logo_uri = logo_uri + self.policy_uri = policy_uri + self.tos_uri = tos_uri + self.metadata = metadata + self.enable_refresh_token_rotation = enable_refresh_token_rotation + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = {} + result["clientId"] = self.client_id + + if not isinstance(self.client_name, NotSet): + result["clientName"] = self.client_name + if not isinstance(self.scope, NotSet): + result["scope"] = self.scope + if not isinstance(self.token_endpoint_auth_method, NotSet): + result["tokenEndpointAuthMethod"] = self.token_endpoint_auth_method + if not isinstance(self.client_secret, NotSet): + result["clientSecret"] = self.client_secret + if not isinstance(self.redirect_uris, NotSet): + result["redirectUris"] = self.redirect_uris + if not isinstance(self.post_logout_redirect_uris, NotSet): + result["postLogoutRedirectUris"] = self.post_logout_redirect_uris + if not isinstance(self.authorization_code_grant_access_token_lifespan, NotSet): + result["authorizationCodeGrantAccessTokenLifespan"] = ( + self.authorization_code_grant_access_token_lifespan + ) + if not isinstance(self.authorization_code_grant_id_token_lifespan, NotSet): + result["authorizationCodeGrantIdTokenLifespan"] = ( + self.authorization_code_grant_id_token_lifespan + ) + if not isinstance(self.authorization_code_grant_refresh_token_lifespan, NotSet): + result["authorizationCodeGrantRefreshTokenLifespan"] = ( + self.authorization_code_grant_refresh_token_lifespan + ) + if not isinstance(self.client_credentials_grant_access_token_lifespan, NotSet): + result["clientCredentialsGrantAccessTokenLifespan"] = ( + self.client_credentials_grant_access_token_lifespan + ) + if not isinstance(self.implicit_grant_access_token_lifespan, NotSet): + result["implicitGrantAccessTokenLifespan"] = ( + self.implicit_grant_access_token_lifespan + ) + if not isinstance(self.implicit_grant_id_token_lifespan, NotSet): + result["implicitGrantIdTokenLifespan"] = ( + self.implicit_grant_id_token_lifespan + ) + if not isinstance(self.refresh_token_grant_access_token_lifespan, NotSet): + result["refreshTokenGrantAccessTokenLifespan"] = ( + self.refresh_token_grant_access_token_lifespan + ) + if not isinstance(self.refresh_token_grant_id_token_lifespan, NotSet): + result["refreshTokenGrantIdTokenLifespan"] = ( + self.refresh_token_grant_id_token_lifespan + ) + if not isinstance(self.refresh_token_grant_refresh_token_lifespan, NotSet): + result["refreshTokenGrantRefreshTokenLifespan"] = ( + self.refresh_token_grant_refresh_token_lifespan + ) + if not isinstance(self.audience, NotSet): + result["audience"] = self.audience + if not isinstance(self.grant_types, NotSet): + result["grantTypes"] = self.grant_types + if self.response_types is not None: + result["responseTypes"] = self.response_types + if not isinstance(self.client_uri, NotSet): + result["clientUri"] = self.client_uri + if not isinstance(self.logo_uri, NotSet): + result["logoUri"] = self.logo_uri + if not isinstance(self.policy_uri, NotSet): + result["policyUri"] = self.policy_uri + if not isinstance(self.tos_uri, NotSet): + result["tosUri"] = self.tos_uri + if not isinstance(self.metadata, NotSet): + result["metadata"] = self.metadata + if not isinstance(self.enable_refresh_token_rotation, NotSet): + result["enableRefreshTokenRotation"] = self.enable_refresh_token_rotation + return result + + @staticmethod + def from_json(json: Dict[str, Any]) -> "UpdateOAuth2ClientInput": + return UpdateOAuth2ClientInput( + client_id=json["clientId"], + client_secret=json.get("clientSecret", NotSet()), + client_name=json.get("clientName", NotSet()), + scope=json.get("scope", NotSet()), + redirect_uris=json.get("redirectUris", NotSet()), + post_logout_redirect_uris=json.get("postLogoutRedirectUris", NotSet()), + authorization_code_grant_access_token_lifespan=json.get( + "authorizationCodeGrantAccessTokenLifespan", NotSet() + ), + authorization_code_grant_id_token_lifespan=json.get( + "authorizationCodeGrantIdTokenLifespan", NotSet() + ), + authorization_code_grant_refresh_token_lifespan=json.get( + "authorizationCodeGrantRefreshTokenLifespan", NotSet() + ), + client_credentials_grant_access_token_lifespan=json.get( + "clientCredentialsGrantAccessTokenLifespan", NotSet() + ), + implicit_grant_access_token_lifespan=json.get( + "implicitGrantAccessTokenLifespan", NotSet() + ), + implicit_grant_id_token_lifespan=json.get( + "implicitGrantIdTokenLifespan", NotSet() + ), + refresh_token_grant_access_token_lifespan=json.get( + "refreshTokenGrantAccessTokenLifespan", NotSet() + ), + refresh_token_grant_id_token_lifespan=json.get( + "refreshTokenGrantIdTokenLifespan", NotSet() + ), + refresh_token_grant_refresh_token_lifespan=json.get( + "refreshTokenGrantRefreshTokenLifespan", NotSet() + ), + token_endpoint_auth_method=json.get("tokenEndpointAuthMethod", NotSet()), + audience=json.get("audience", NotSet()), + grant_types=json.get("grantTypes", NotSet()), + response_types=json.get("responseTypes", NotSet()), + client_uri=json.get("clientUri", NotSet()), + logo_uri=json.get("logoUri", NotSet()), + policy_uri=json.get("policyUri", NotSet()), + tos_uri=json.get("tosUri", NotSet()), + metadata=json.get("metadata", NotSet()), + enable_refresh_token_rotation=json.get( + "enableRefreshTokenRotation", NotSet() + ), + ) + + class RecipeInterface(ABC): @abstractmethod async def authorization( @@ -461,6 +1050,7 @@ async def get_oauth2_client( @abstractmethod async def create_oauth2_client( self, + params: CreateOAuth2ClientInput, user_context: Dict[str, Any], ) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: pass @@ -468,6 +1058,7 @@ async def create_oauth2_client( @abstractmethod async def update_oauth2_client( self, + params: UpdateOAuth2ClientInput, user_context: Dict[str, Any], ) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: pass diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index c08ec852f..32e134e70 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -30,6 +30,7 @@ from supertokens_python.types import RecipeUserId, User from .interfaces import ( + CreateOAuth2ClientInput, FrontendRedirectionURLTypeLogin, FrontendRedirectionURLTypeLogoutConfirmation, FrontendRedirectionURLTypePostLogoutFallback, @@ -44,6 +45,7 @@ CreateOAuth2ClientOkResult, RevokeTokenUsingAuthorizationHeader, RevokeTokenUsingClientIDAndClientSecret, + UpdateOAuth2ClientInput, UpdateOAuth2ClientOkResult, DeleteOAuth2ClientOkResult, ConsentRequest, @@ -570,11 +572,12 @@ async def get_oauth2_client( async def create_oauth2_client( self, + params: CreateOAuth2ClientInput, user_context: Dict[str, Any], ) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: response = await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/clients"), - {}, # Empty dict since no input params in function signature + params.to_json(), user_context=user_context, ) @@ -586,11 +589,12 @@ async def create_oauth2_client( async def update_oauth2_client( self, + params: UpdateOAuth2ClientInput, user_context: Dict[str, Any], ) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: response = await self.querier.send_put_request( NormalisedURLPath("/recipe/oauth/clients"), - {}, # TODO update params + params.to_json(), None, user_context=user_context, ) @@ -672,7 +676,7 @@ async def validate_oauth2_access_token( if response.get("active") is not True: raise Exception("The token is expired, invalid or has been revoked") - return {"status": "OK", "payload": payload} + return payload async def get_requested_scopes( self, @@ -787,7 +791,8 @@ async def revoke_token( request_body["authorizationHeader"] = params.authorization_header else: request_body["client_id"] = params.client_id - request_body["client_secret"] = params.client_secret + if params.client_secret is not None: + request_body["client_secret"] = params.client_secret res = await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/token/revoke"), diff --git a/supertokens_python/recipe/oauth2provider/syncio/__init__.py b/supertokens_python/recipe/oauth2provider/syncio/__init__.py index a59c06f52..671e614f9 100644 --- a/supertokens_python/recipe/oauth2provider/syncio/__init__.py +++ b/supertokens_python/recipe/oauth2provider/syncio/__init__.py @@ -11,3 +11,166 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +from __future__ import annotations + +from typing import Any, Dict, Union, Optional, List + +from supertokens_python.async_to_sync_wrapper import sync + +from ..interfaces import ( + ActiveTokenResponse, + CreateOAuth2ClientInput, + CreateOAuth2ClientOkResult, + DeleteOAuth2ClientOkResult, + ErrorOAuth2Response, + GetOAuth2ClientOkResult, + GetOAuth2ClientsOkResult, + InactiveTokenResponse, + OAuth2TokenValidationRequirements, + TokenInfo, + UpdateOAuth2ClientInput, + UpdateOAuth2ClientOkResult, +) + + +def get_oauth2_client( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Union[GetOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + + from ..asyncio import get_oauth2_client + + return sync(get_oauth2_client(client_id, user_context)) + + +def get_oauth2_clients( + page_size: Optional[int] = None, + pagination_token: Optional[str] = None, + client_name: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[GetOAuth2ClientsOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + + from ..asyncio import get_oauth2_clients + + return sync( + get_oauth2_clients(page_size, pagination_token, client_name, user_context) + ) + + +def create_oauth2_client( + params: CreateOAuth2ClientInput, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..asyncio import create_oauth2_client + + return sync(create_oauth2_client(params, user_context)) + + +def update_oauth2_client( + params: UpdateOAuth2ClientInput, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[UpdateOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + + from ..asyncio import update_oauth2_client + + return sync(update_oauth2_client(params, user_context)) + + +def delete_oauth2_client( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Union[DeleteOAuth2ClientOkResult, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..asyncio import delete_oauth2_client + + return sync(delete_oauth2_client(client_id, user_context)) + + +def validate_oauth2_access_token( + token: str, + requirements: Optional[OAuth2TokenValidationRequirements] = None, + check_database: Optional[bool] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + if user_context is None: + user_context = {} + + from ..asyncio import validate_oauth2_access_token + + return sync( + validate_oauth2_access_token(token, requirements, check_database, user_context) + ) + + +def create_token_for_client_credentials( + client_id: str, + client_secret: str, + scope: Optional[List[str]] = None, + audience: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[TokenInfo, ErrorOAuth2Response]: + if user_context is None: + user_context = {} + + from ..asyncio import create_token_for_client_credentials + + return sync( + create_token_for_client_credentials( + client_id, client_secret, scope, audience, user_context + ) + ) + + +def revoke_token( + token: str, + client_id: str, + client_secret: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Optional[ErrorOAuth2Response]: + if user_context is None: + user_context = {} + from ..asyncio import revoke_token + + return sync(revoke_token(token, client_id, client_secret, user_context)) + + +def revoke_tokens_by_client_id( + client_id: str, user_context: Optional[Dict[str, Any]] = None +) -> None: + if user_context is None: + user_context = {} + + from ..asyncio import revoke_tokens_by_client_id + + return sync(revoke_tokens_by_client_id(client_id, user_context)) + + +def revoke_tokens_by_session_handle( + session_handle: str, user_context: Optional[Dict[str, Any]] = None +) -> None: + if user_context is None: + user_context = {} + + from ..asyncio import revoke_tokens_by_session_handle + + return sync(revoke_tokens_by_session_handle(session_handle, user_context)) + + +def validate_oauth2_refresh_token( + token: str, + scopes: Optional[List[str]] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ActiveTokenResponse, InactiveTokenResponse]: + if user_context is None: + user_context = {} + + from ..asyncio import validate_oauth2_refresh_token + + return sync(validate_oauth2_refresh_token(token, scopes, user_context)) From 34e96da2978977067b0110d5e09371bee0653214 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Thu, 12 Dec 2024 13:21:23 +0530 Subject: [PATCH 28/38] fix: backend sdk tests --- supertokens_python/auth_utils.py | 4 ++-- tests/test-server/app.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index 18b7fbeb1..ab0061f9a 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -536,7 +536,7 @@ async def check_auth_type_and_linking_status( if session_user_result.status == "SHOULD_AUTOMATICALLY_LINK_FALSE": if should_try_linking_with_session_user is True: raise BadInputError( - "should_do_automatic_account_linking returned false when creating primary user but shouldTryLinkingWithSessionUser is true" + "shouldDoAutomaticAccountLinking returned false when making the session user primary but shouldTryLinkingWithSessionUser is true" ) return OkFirstFactorResponse() elif ( @@ -565,7 +565,7 @@ async def check_auth_type_and_linking_status( if isinstance(should_link, ShouldNotAutomaticallyLink): if should_try_linking_with_session_user is True: raise BadInputError( - "should_do_automatic_account_linking returned false when creating primary user but shouldTryLinkingWithSessionUser is true" + "shouldDoAutomaticAccountLinking returned false when making the session user primary but shouldTryLinkingWithSessionUser is true" ) return OkFirstFactorResponse() else: diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 8ee45dd7e..bdfd9206a 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -278,7 +278,9 @@ async def custom_unauthorised_callback( _: BaseRequest, __: str, response: BaseResponse ) -> BaseResponse: response.set_status_code(401) - response.set_json_content(content={"type": "UNAUTHORISED"}) + response.set_json_content( + content={"type": "UNAUTHORISED", "message": "unauthorised"} + ) return response recipe_config_json = json.loads(recipe_config.get("config", "{}")) From d8dd684789697c73b6bf79e4871f443f6d4b3778 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Thu, 12 Dec 2024 16:30:22 +0530 Subject: [PATCH 29/38] fix: default recipes and fixes for test --- .../framework/django/django_response.py | 3 +- .../framework/fastapi/fastapi_response.py | 3 +- .../framework/flask/flask_response.py | 3 +- supertokens_python/framework/response.py | 2 +- .../recipe/oauth2provider/__init__.py | 4 +- .../recipe/oauth2provider/api/auth.py | 3 +- .../recipe/oauth2provider/api/login.py | 4 +- .../recipe/oauth2provider/api/utils.py | 2 + .../recipe/oauth2provider/interfaces.py | 24 +++++ .../recipe/oauth2provider/oauth2_client.py | 31 +++++++ .../oauth2provider/recipe_implementation.py | 8 +- supertokens_python/supertokens.py | 30 +++++- tests/test-server/app.py | 25 +++++ tests/test-server/oauth2provider.py | 93 +++++++++++++++++++ 14 files changed, 223 insertions(+), 12 deletions(-) create mode 100644 tests/test-server/oauth2provider.py diff --git a/supertokens_python/framework/django/django_response.py b/supertokens_python/framework/django/django_response.py index 9692b2d55..f839b9231 100644 --- a/supertokens_python/framework/django/django_response.py +++ b/supertokens_python/framework/django/django_response.py @@ -89,7 +89,8 @@ def set_json_content(self, content: Dict[str, Any]): ).encode("utf-8") self.response_sent = True - def redirect(self, url: str): + def redirect(self, url: str) -> BaseResponse: if not self.response_sent: self.set_header("Location", url) self.set_status_code(302) + return self diff --git a/supertokens_python/framework/fastapi/fastapi_response.py b/supertokens_python/framework/fastapi/fastapi_response.py index 45813a5cc..76e6f349c 100644 --- a/supertokens_python/framework/fastapi/fastapi_response.py +++ b/supertokens_python/framework/fastapi/fastapi_response.py @@ -95,7 +95,8 @@ def set_json_content(self, content: Dict[str, Any]): self.response.body = body self.response_sent = True - def redirect(self, url: str): + def redirect(self, url: str) -> BaseResponse: if not self.response_sent: self.set_header("Location", url) self.set_status_code(302) + return self diff --git a/supertokens_python/framework/flask/flask_response.py b/supertokens_python/framework/flask/flask_response.py index a74bdfb83..ef016d5d3 100644 --- a/supertokens_python/framework/flask/flask_response.py +++ b/supertokens_python/framework/flask/flask_response.py @@ -86,6 +86,7 @@ def set_json_content(self, content: Dict[str, Any]): ).encode("utf-8") self.response_sent = True - def redirect(self, url: str): + def redirect(self, url: str) -> BaseResponse: self.response.headers.set("Location", url) self.set_status_code(302) + return self diff --git a/supertokens_python/framework/response.py b/supertokens_python/framework/response.py index 8669e3ae4..c28a24104 100644 --- a/supertokens_python/framework/response.py +++ b/supertokens_python/framework/response.py @@ -63,5 +63,5 @@ def set_html_content(self, content: str): pass @abstractmethod - def redirect(self, url: str): + def redirect(self, url: str) -> "BaseResponse": pass diff --git a/supertokens_python/recipe/oauth2provider/__init__.py b/supertokens_python/recipe/oauth2provider/__init__.py index 55b44908d..3397b3430 100644 --- a/supertokens_python/recipe/oauth2provider/__init__.py +++ b/supertokens_python/recipe/oauth2provider/__init__.py @@ -16,15 +16,15 @@ from typing import TYPE_CHECKING, Callable, Union from . import exceptions as ex -from . import recipe +from . import recipe, utils exceptions = ex +InputOverrideConfig = utils.InputOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo from ...recipe_module import RecipeModule - from .utils import InputOverrideConfig def init( diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py index 1819aa238..5b5767e60 100644 --- a/supertokens_python/recipe/oauth2provider/api/auth.py +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -14,6 +14,7 @@ from __future__ import annotations +from datetime import datetime from http.cookies import SimpleCookie from typing import TYPE_CHECKING, Any, Dict from urllib.parse import parse_qsl @@ -83,7 +84,7 @@ async def auth_get( domain=morsel.get("domain"), secure=morsel.get("secure", True), httponly=morsel.get("httponly", True), - expires=morsel.get("expires", None), + expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore path=morsel.get("path", "/"), samesite=morsel.get("samesite", "lax"), ) diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py index 59c93dcdb..d080211fb 100644 --- a/supertokens_python/recipe/oauth2provider/api/login.py +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, Any, Dict, Optional +from datetime import datetime + from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.framework import BaseResponse from supertokens_python.recipe.session.asyncio import get_session @@ -84,7 +86,7 @@ async def login( domain=morsel.get("domain"), secure=morsel.get("secure", True), httponly=morsel.get("httponly", True), - expires=morsel.get("expires", None), + expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore path=morsel.get("path", "/"), samesite=morsel.get("samesite", "lax"), ) diff --git a/supertokens_python/recipe/oauth2provider/api/utils.py b/supertokens_python/recipe/oauth2provider/api/utils.py index 2b524b017..80900c034 100644 --- a/supertokens_python/recipe/oauth2provider/api/utils.py +++ b/supertokens_python/recipe/oauth2provider/api/utils.py @@ -235,6 +235,8 @@ async def handle_login_internal_redirects( cookie: str, user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: + from ..interfaces import RedirectResponse, ErrorOAuth2Response + if not is_login_internal_redirect(response.redirect_to): return response diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 8b9f2649d..57d34579a 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -254,6 +254,13 @@ def from_json(json: Dict[str, Any]): next_pagination_token=json["nextPaginationToken"], ) + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "clients": [client.to_json() for client in self.clients], + "nextPaginationToken": self.next_pagination_token, + } + class GetOAuth2ClientOkResult: def __init__(self, client: OAuth2Client): @@ -272,6 +279,12 @@ def __init__(self, client: OAuth2Client): def from_json(json: Dict[str, Any]): return CreateOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"])) + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "client": self.client.to_json(), + } + class UpdateOAuth2ClientOkResult: def __init__(self, client: OAuth2Client): @@ -281,11 +294,22 @@ def __init__(self, client: OAuth2Client): def from_json(json: Dict[str, Any]): return UpdateOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"])) + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "client": self.client.to_json(), + } + class DeleteOAuth2ClientOkResult: def __init__(self): pass + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + } + PayloadBuilderFunction = Callable[ [User, List[str], str, Dict[str, Any]], Awaitable[Dict[str, Any]] diff --git a/supertokens_python/recipe/oauth2provider/oauth2_client.py b/supertokens_python/recipe/oauth2provider/oauth2_client.py index 9b17718fe..0f16db500 100644 --- a/supertokens_python/recipe/oauth2provider/oauth2_client.py +++ b/supertokens_python/recipe/oauth2provider/oauth2_client.py @@ -243,3 +243,34 @@ def from_json(json: Dict[str, Any]) -> "OAuth2Client": metadata=json.get("metadata", {}), enable_refresh_token_rotation=json.get("enableRefreshTokenRotation", False), ) + + def to_json(self) -> Dict[str, Any]: + return { + "clientId": self.client_id, + "clientName": self.client_name, + "scope": self.scope, + "tokenEndpointAuthMethod": self.token_endpoint_auth_method, + "createdAt": self.created_at, + "updatedAt": self.updated_at, + "clientSecret": self.client_secret, + "redirectUris": self.redirect_uris, + "postLogoutRedirectUris": self.post_logout_redirect_uris, + "authorizationCodeGrantAccessTokenLifespan": self.authorization_code_grant_access_token_lifespan, + "authorizationCodeGrantIdTokenLifespan": self.authorization_code_grant_id_token_lifespan, + "authorizationCodeGrantRefreshTokenLifespan": self.authorization_code_grant_refresh_token_lifespan, + "clientCredentialsGrantAccessTokenLifespan": self.client_credentials_grant_access_token_lifespan, + "implicitGrantAccessTokenLifespan": self.implicit_grant_access_token_lifespan, + "implicitGrantIdTokenLifespan": self.implicit_grant_id_token_lifespan, + "refreshTokenGrantAccessTokenLifespan": self.refresh_token_grant_access_token_lifespan, + "refreshTokenGrantIdTokenLifespan": self.refresh_token_grant_id_token_lifespan, + "refreshTokenGrantRefreshTokenLifespan": self.refresh_token_grant_refresh_token_lifespan, + "clientUri": self.client_uri, + "audience": self.audience, + "grantTypes": self.grant_types, + "responseTypes": self.response_types, + "logoUri": self.logo_uri, + "policyUri": self.policy_uri, + "tosUri": self.tos_uri, + "metadata": self.metadata, + "enableRefreshTokenRotation": self.enable_refresh_token_rotation, + } diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 32e134e70..63228d37a 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -20,7 +20,6 @@ import jwt -from supertokens_python import AppInfo from supertokens_python.asyncio import get_user from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe.openid.recipe import OpenIdRecipe @@ -60,6 +59,7 @@ if TYPE_CHECKING: from supertokens_python.querier import Querier + from supertokens_python import AppInfo def get_updated_redirect_to(app_info: AppInfo, redirect_to: str) -> str: @@ -371,10 +371,12 @@ async def authorization( ) return RedirectResponse( - redirect_to=consent_res.redirect_to, cookies=resp["cookies"] + redirect_to=consent_res.redirect_to, cookies=",".join(resp["cookies"]) ) - return RedirectResponse(redirect_to=redirect_to, cookies=resp["cookies"]) + return RedirectResponse( + redirect_to=redirect_to, cookies=",".join(resp["cookies"]) + ) async def token_exchange( self, diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 53964da2e..6632850b4 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -259,9 +259,12 @@ def __init__( totp_found = False user_metadata_found = False multi_factor_auth_found = False + oauth2_found = False + openid_found = False + jwt_found = False def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: - nonlocal multitenancy_found, totp_found, user_metadata_found, multi_factor_auth_found + nonlocal multitenancy_found, totp_found, user_metadata_found, multi_factor_auth_found, oauth2_found, openid_found, jwt_found recipe_module = recipe(self.app_info) if recipe_module.get_recipe_id() == "multitenancy": multitenancy_found = True @@ -271,21 +274,46 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: multi_factor_auth_found = True elif recipe_module.get_recipe_id() == "totp": totp_found = True + elif recipe_module.get_recipe_id() == "oauth2provider": + oauth2_found = True + elif recipe_module.get_recipe_id() == "openid": + openid_found = True + elif recipe_module.get_recipe_id() == "jwt": + jwt_found = True return recipe_module self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) + if not jwt_found: + from supertokens_python.recipe.jwt.recipe import JWTRecipe + + self.recipe_modules.append(JWTRecipe.init()(self.app_info)) + + if not openid_found: + from supertokens_python.recipe.openid.recipe import OpenIdRecipe + + self.recipe_modules.append(OpenIdRecipe.init()(self.app_info)) + if not multitenancy_found: from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info)) + if totp_found and not multi_factor_auth_found: raise Exception("Please initialize the MultiFactorAuth recipe to use TOTP.") + if not user_metadata_found: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe self.recipe_modules.append(UserMetadataRecipe.init()(self.app_info)) + if not oauth2_found: + from supertokens_python.recipe.oauth2provider.recipe import ( + OAuth2ProviderRecipe, + ) + + self.recipe_modules.append(OAuth2ProviderRecipe.init()(self.app_info)) + self.telemetry = ( telemetry if telemetry is not None diff --git a/tests/test-server/app.py b/tests/test-server/app.py index bdfd9206a..1ab834d85 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -28,6 +28,8 @@ from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe +from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.types import RecipeUserId from test_functions_mapper import ( # pylint: disable=import-error get_func, @@ -55,6 +57,7 @@ session, thirdparty, emailverification, + oauth2provider, ) from supertokens_python.recipe.session import InputErrorHandlers, SessionContainer from supertokens_python.recipe.session.framework.flask import verify_session @@ -222,6 +225,8 @@ def st_reset(): AccountLinkingRecipe.reset() TOTPRecipe.reset() MultiFactorAuthRecipe.reset() + OAuth2ProviderRecipe.reset() + OpenIdRecipe.reset() def init_st(config: Dict[str, Any]): @@ -593,6 +598,22 @@ async def send_sms( ) ) ) + elif recipe_id == "oauth2provider": + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + recipe_list.append( + oauth2provider.init( + override=oauth2provider.InputOverrideConfig( + apis=override_builder_with_logging( + "OAuth2Provider.override.apis", + recipe_config_json.get("override", {}).get("apis"), + ), + functions=override_builder_with_logging( + "OAuth2Provider.override.functions", + recipe_config_json.get("override", {}).get("functions"), + ), + ) + ) + ) interceptor_func = None if config.get("supertokens", {}).get("networkInterceptor") is not None: @@ -822,6 +843,10 @@ def handle_exception(e: Exception): add_multifactorauth_routes(app) +from oauth2provider import add_oauth2provider_routes + +add_oauth2provider_routes(app) + if __name__ == "__main__": default_st_init() port = int(os.environ.get("API_PORT", api_port)) diff --git a/tests/test-server/oauth2provider.py b/tests/test-server/oauth2provider.py new file mode 100644 index 000000000..8f8a18cfd --- /dev/null +++ b/tests/test-server/oauth2provider.py @@ -0,0 +1,93 @@ +from flask import Flask, request, jsonify +from supertokens_python.recipe.oauth2provider.interfaces import CreateOAuth2ClientInput +import supertokens_python.recipe.oauth2provider.syncio as OAuth2Provider + + +def add_oauth2provider_routes(app: Flask): + @app.route("/test/oauth2provider/getoauth2clients", methods=["POST"]) # type: ignore + def get_oauth2_clients_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:getOAuth2Clients", request.json) + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + response = OAuth2Provider.get_oauth2_clients( + page_size=data.get("pageSize"), + pagination_token=data.get("paginationToken"), + client_name=data.get("clientName"), + user_context=data.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/createoauth2client", methods=["POST"]) # type: ignore + def create_oauth2_client_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:createOAuth2Client", request.json) + + response = OAuth2Provider.create_oauth2_client( + params=CreateOAuth2ClientInput.from_json(request.json.get("input")), + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/updateoauth2client", methods=["POST"]) # type: ignore + def update_oauth2_client_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:updateOAuth2Client", request.json) + + response = OAuth2Provider.update_oauth2_client( + params=request.json["input"], user_context=request.json.get("userContext") + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/deleteoauth2client", methods=["POST"]) # type: ignore + def delete_oauth2_client_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:deleteOAuth2Client", request.json) + + response = OAuth2Provider.delete_oauth2_client( + client_id=request.json["input"], + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/validateoauth2accesstoken", methods=["POST"]) # type: ignore + def validate_oauth2_access_token_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:validateOAuth2AccessToken", request.json) + + response = OAuth2Provider.validate_oauth2_access_token( + token=request.json["token"], + requirements=request.json["requirements"], + check_database=request.json["checkDatabase"], + user_context=request.json.get("userContext"), + ) + return jsonify(response) + + @app.route("/test/oauth2provider/validateoauth2refreshtoken", methods=["POST"]) # type: ignore + def validate_oauth2_refresh_token_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:validateOAuth2RefreshToken", request.json) + + response = OAuth2Provider.validate_oauth2_refresh_token( + token=request.json["token"], + scopes=request.json["scopes"], + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/createtokenforclientcredentials", methods=["POST"]) # type: ignore + def create_token_for_client_credentials_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:createTokenForClientCredentials", request.json) + + response = OAuth2Provider.create_token_for_client_credentials( + client_id=request.json["clientId"], + client_secret=request.json["clientSecret"], + scope=request.json["scope"], + audience=request.json["audience"], + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) From fc42477b0887cd02913c9de6ece38ef1b467ba2d Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Thu, 12 Dec 2024 16:41:07 +0530 Subject: [PATCH 30/38] fix: tests --- .../recipe/oauth2provider/interfaces.py | 10 +++++----- .../recipe/oauth2provider/recipe_implementation.py | 13 +++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 57d34579a..bbbb1e590 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -177,12 +177,12 @@ def from_json(json: Dict[str, Any]): def to_json(self) -> Dict[str, Any]: return { "status": "OK", - "accessToken": self.access_token, - "expiresIn": self.expires_in, - "idToken": self.id_token, - "refreshToken": self.refresh_token, + "access_token": self.access_token, + "expires_in": self.expires_in, + "id_token": self.id_token, + "refresh_token": self.refresh_token, "scope": self.scope, - "tokenType": self.token_type, + "token_type": self.token_type, } diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 63228d37a..e12d811e2 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -387,7 +387,6 @@ async def token_exchange( request_body = { "iss": await OpenIdRecipe.get_issuer(user_context), "inputBody": body, - "authorizationHeader": authorization_header, } if body.get("grant_type") == "password": @@ -846,7 +845,7 @@ async def introspect_token( # If it fails, the token is not active, and we return early if is_access_token: try: - await self.validate_oauth2_access_token( + payload = await self.validate_oauth2_access_token( token=token, requirements=( OAuth2TokenValidationRequirements(scopes=scopes) @@ -856,17 +855,19 @@ async def introspect_token( check_database=False, user_context=user_context, ) + return ActiveTokenResponse(payload=payload) except Exception: return InactiveTokenResponse() # For tokens that passed local validation or if it's a refresh token, # validate the token with the database by calling the core introspection endpoint + request_body = {"token": token} + if scopes: + request_body["scope"] = " ".join(scopes) + res = await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/introspect"), - { - "token": token, - "scope": " ".join(scopes) if scopes else None, - }, + request_body, user_context=user_context, ) From 724c97bcf9c660a2e40302d56bd8e61ad30453f5 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Thu, 12 Dec 2024 17:31:36 +0530 Subject: [PATCH 31/38] fix: tests --- .../recipe/oauth2provider/interfaces.py | 18 ++++++----- .../oauth2provider/recipe_implementation.py | 30 ++++++++++++++----- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index bbbb1e590..c903e3bc7 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -166,24 +166,28 @@ def __init__( @staticmethod def from_json(json: Dict[str, Any]): return TokenInfo( - access_token=json["access_token"], + access_token=json.get("access_token"), expires_in=json["expires_in"], - id_token=json["id_token"], - refresh_token=json["refresh_token"], + id_token=json.get("id_token"), + refresh_token=json.get("refresh_token"), scope=json["scope"], token_type=json["token_type"], ) def to_json(self) -> Dict[str, Any]: - return { + result = { "status": "OK", - "access_token": self.access_token, "expires_in": self.expires_in, - "id_token": self.id_token, - "refresh_token": self.refresh_token, "scope": self.scope, "token_type": self.token_type, } + if self.access_token is not None: + result["access_token"] = self.access_token + if self.id_token is not None: + result["id_token"] = self.id_token + if self.refresh_token is not None: + result["refresh_token"] = self.refresh_token + return result class LoginInfo: diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index e12d811e2..a3252e22a 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -633,12 +633,26 @@ async def validate_oauth2_access_token( # Verify token signature using session recipe's JWKS session_recipe = SessionRecipe.get_instance() matching_keys = get_latest_keys(session_recipe.config) - payload = jwt.decode( - token, - matching_keys[0].key, - algorithms=["RS256"], - options={"verify_signature": True, "verify_exp": True}, - ) + err: Optional[Exception] = None + + payload: Dict[str, Any] = {} + + for matching_key in matching_keys: + err = None + try: + payload = jwt.decode( + token, + matching_key.key, + algorithms=["RS256"], + options={"verify_signature": True, "verify_exp": True}, + ) + except Exception as e: + err = e + continue + break + + if err is not None: + raise err if payload.get("stt") != 1: raise Exception("Wrong token type") @@ -845,7 +859,7 @@ async def introspect_token( # If it fails, the token is not active, and we return early if is_access_token: try: - payload = await self.validate_oauth2_access_token( + await self.validate_oauth2_access_token( token=token, requirements=( OAuth2TokenValidationRequirements(scopes=scopes) @@ -855,7 +869,7 @@ async def introspect_token( check_database=False, user_context=user_context, ) - return ActiveTokenResponse(payload=payload) + except Exception: return InactiveTokenResponse() From dae12043cdb90752067d7e9ec054a3f559e72ef0 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Thu, 12 Dec 2024 17:50:42 +0530 Subject: [PATCH 32/38] fix: tests --- .../recipe/oauth2provider/api/auth.py | 2 +- .../recipe/oauth2provider/api/login.py | 2 +- .../recipe/oauth2provider/recipe_implementation.py | 14 ++++++++------ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py index 5b5767e60..8db3e7eb0 100644 --- a/supertokens_python/recipe/oauth2provider/api/auth.py +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -86,7 +86,7 @@ async def auth_get( httponly=morsel.get("httponly", True), expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore path=morsel.get("path", "/"), - samesite=morsel.get("samesite", "lax"), + samesite=morsel.get("samesite", "lax").lower(), ) return api_options.response.redirect(response.redirect_to) elif isinstance(response, ErrorOAuth2Response): diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py index d080211fb..a4d3e1811 100644 --- a/supertokens_python/recipe/oauth2provider/api/login.py +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -88,7 +88,7 @@ async def login( httponly=morsel.get("httponly", True), expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore path=morsel.get("path", "/"), - samesite=morsel.get("samesite", "lax"), + samesite=morsel.get("samesite", "lax").lower(), ) return send_200_response( diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index a3252e22a..d2a82f504 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -310,14 +310,16 @@ async def authorization( payloads = {"idToken": id_token, "accessToken": access_token} + request_body = { + "params": {**params, "scope": " ".join(scopes)}, + "iss": await OpenIdRecipe.get_issuer(user_context), + "session": payloads, + } + if cookies is not None: + request_body["cookies"] = cookies resp = await self.querier.send_post_request( NormalisedURLPath("/recipe/oauth/auth"), - { - "params": {**params, "scope": " ".join(scopes)}, - "iss": await OpenIdRecipe.get_issuer(user_context), - "cookies": cookies, - "session": payloads, - }, + request_body, user_context, ) From 213b9fe84f03c56c4b5f49ef4bf3ecfe0614e4fd Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Thu, 12 Dec 2024 21:43:51 +0530 Subject: [PATCH 33/38] fix: tests --- .../recipe/oauth2provider/api/auth.py | 27 +++++----- .../recipe/oauth2provider/api/end_session.py | 2 + .../recipe/oauth2provider/api/login.py | 27 +++++----- .../recipe/oauth2provider/api/utils.py | 19 +++---- .../recipe/oauth2provider/asyncio/__init__.py | 21 +++++--- .../recipe/oauth2provider/interfaces.py | 12 +++-- .../recipe/oauth2provider/oauth2_client.py | 51 +++++++++++++------ .../oauth2provider/recipe_implementation.py | 6 +-- tests/test-server/oauth2provider.py | 2 +- 9 files changed, 99 insertions(+), 68 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py index 8db3e7eb0..dedb466d1 100644 --- a/supertokens_python/recipe/oauth2provider/api/auth.py +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -75,19 +75,20 @@ async def auth_get( if isinstance(response, RedirectResponse): if response.cookies: - cookie = SimpleCookie() - cookie.load(response.cookies) - for morsel in cookie.values(): - api_options.response.set_cookie( - key=morsel.key, - value=morsel.value, - domain=morsel.get("domain"), - secure=morsel.get("secure", True), - httponly=morsel.get("httponly", True), - expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore - path=morsel.get("path", "/"), - samesite=morsel.get("samesite", "lax").lower(), - ) + for cookie_string in response.cookies: + cookie = SimpleCookie() + cookie.load(cookie_string) + for morsel in cookie.values(): + api_options.response.set_cookie( + key=morsel.key, + value=morsel.value, + domain=morsel.get("domain"), + secure=morsel.get("secure", True), + httponly=morsel.get("httponly", True), + expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore + path=morsel.get("path", "/"), + samesite=morsel.get("samesite", "lax"), + ) return api_options.response.redirect(response.redirect_to) elif isinstance(response, ErrorOAuth2Response): return send_non_200_response( diff --git a/supertokens_python/recipe/oauth2provider/api/end_session.py b/supertokens_python/recipe/oauth2provider/api/end_session.py index ebda43153..6a2e70bfc 100644 --- a/supertokens_python/recipe/oauth2provider/api/end_session.py +++ b/supertokens_python/recipe/oauth2provider/api/end_session.py @@ -81,6 +81,8 @@ async def end_session_common( options: APIOptions, user_context: Dict[str, Any], ) -> Optional[BaseResponse]: + from ..interfaces import RedirectResponse, ErrorOAuth2Response + if api_implementation is None: return None diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py index a4d3e1811..3c23da175 100644 --- a/supertokens_python/recipe/oauth2provider/api/login.py +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -77,19 +77,20 @@ async def login( if isinstance(response, FrontendRedirectResponse): if response.cookies: - cookie = SimpleCookie() - cookie.load(response.cookies) - for morsel in cookie.values(): - api_options.response.set_cookie( - key=morsel.key, - value=morsel.value, - domain=morsel.get("domain"), - secure=morsel.get("secure", True), - httponly=morsel.get("httponly", True), - expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore - path=morsel.get("path", "/"), - samesite=morsel.get("samesite", "lax").lower(), - ) + for cookie_string in response.cookies: + cookie = SimpleCookie() + cookie.load(cookie_string) + for morsel in cookie.values(): + api_options.response.set_cookie( + key=morsel.key, + value=morsel.value, + domain=morsel.get("domain"), + secure=morsel.get("secure", True), + httponly=morsel.get("httponly", True), + expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore + path=morsel.get("path", "/"), + samesite=morsel.get("samesite", "lax").lower(), + ) return send_200_response( {"frontendRedirectTo": response.frontend_redirect_to}, diff --git a/supertokens_python/recipe/oauth2provider/api/utils.py b/supertokens_python/recipe/oauth2provider/api/utils.py index 80900c034..166ae2f3b 100644 --- a/supertokens_python/recipe/oauth2provider/api/utils.py +++ b/supertokens_python/recipe/oauth2provider/api/utils.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from urllib.parse import parse_qs, urlparse import time @@ -38,7 +38,7 @@ async def login_get( login_challenge: str, session: Optional[SessionContainer], should_try_refresh: bool, - cookies: Optional[str], + cookies: Optional[List[str]], is_direct_call: bool, user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: @@ -177,7 +177,7 @@ async def login_get( ) -def get_merged_cookies(orig_cookies: str, new_cookies: Optional[str]) -> str: +def get_merged_cookies(orig_cookies: str, new_cookies: Optional[List[str]]) -> str: if not new_cookies: return orig_cookies @@ -190,7 +190,8 @@ def get_merged_cookies(orig_cookies: str, new_cookies: Optional[str]) -> str: # Note: This is a simplified version. In production code you'd want to use a proper # cookie parsing library to handle all cookie attributes correctly if new_cookies: - for cookie in new_cookies.split(","): + for cookie_str in new_cookies: + cookie = cookie_str.split(";")[0].strip() if "=" in cookie: name, value = cookie.split("=", 1) cookie_map[name.strip()] = value @@ -199,13 +200,13 @@ def get_merged_cookies(orig_cookies: str, new_cookies: Optional[str]) -> str: def merge_set_cookie_headers( - set_cookie1: Optional[str] = None, set_cookie2: Optional[str] = None -) -> str: + set_cookie1: Optional[List[str]] = None, set_cookie2: Optional[List[str]] = None +) -> List[str]: if not set_cookie1: - return set_cookie2 or "" - if not set_cookie2 or set_cookie1 == set_cookie2: + return set_cookie2 or [] + if not set_cookie2 or set(set_cookie1) == set(set_cookie2): return set_cookie1 - return f"{set_cookie1}, {set_cookie2}" + return set_cookie1 + set_cookie2 def is_login_internal_redirect(redirect_to: str) -> bool: diff --git a/supertokens_python/recipe/oauth2provider/asyncio/__init__.py b/supertokens_python/recipe/oauth2provider/asyncio/__init__.py index 2e54a7144..3cdf1b523 100644 --- a/supertokens_python/recipe/oauth2provider/asyncio/__init__.py +++ b/supertokens_python/recipe/oauth2provider/asyncio/__init__.py @@ -124,7 +124,7 @@ async def validate_oauth2_access_token( async def create_token_for_client_credentials( client_id: str, - client_secret: str, + client_secret: Optional[str] = None, scope: Optional[List[str]] = None, audience: Optional[str] = None, user_context: Optional[Dict[str, Any]] = None, @@ -133,16 +133,21 @@ async def create_token_for_client_credentials( user_context = {} from ..recipe import OAuth2ProviderRecipe + body: Dict[str, Any] = { + "grant_type": "client_credentials", + "client_id": client_id, + } + if client_secret: + body["client_secret"] = client_secret + if scope: + body["scope"] = " ".join(scope) + if audience: + body["audience"] = audience + return ( await OAuth2ProviderRecipe.get_instance().recipe_implementation.token_exchange( authorization_header=None, - body={ - "grant_type": "client_credentials", - "client_id": client_id, - "client_secret": client_secret, - "scope": " ".join(scope) if scope else None, - "audience": audience, - }, + body=body, user_context=user_context, ) ) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index c903e3bc7..e115d595e 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -45,12 +45,14 @@ def __init__( self.status_code = status_code def to_json(self) -> Dict[str, Any]: - return { + result: Dict[str, Any] = { "status": self.status, "error": self.error, "errorDescription": self.error_description, - "statusCode": self.status_code, } + if self.status_code is not None: + result["statusCode"] = self.status_code + return result @staticmethod def from_json(json: Dict[str, Any]): @@ -225,18 +227,18 @@ def to_json(self) -> Dict[str, Any]: class RedirectResponse: - def __init__(self, redirect_to: str, cookies: Optional[str] = None): + def __init__(self, redirect_to: str, cookies: Optional[List[str]] = None): self.redirect_to = redirect_to self.cookies = cookies class FrontendRedirectResponse: - def __init__(self, frontend_redirect_to: str, cookies: Optional[str] = None): + def __init__(self, frontend_redirect_to: str, cookies: Optional[List[str]] = None): self.frontend_redirect_to = frontend_redirect_to self.cookies = cookies def to_json(self) -> Dict[str, Any]: - result = { + result: Dict[str, Any] = { "frontendRedirectTo": self.frontend_redirect_to, } if self.cookies is not None: diff --git a/supertokens_python/recipe/oauth2provider/oauth2_client.py b/supertokens_python/recipe/oauth2provider/oauth2_client.py index 0f16db500..684562878 100644 --- a/supertokens_python/recipe/oauth2provider/oauth2_client.py +++ b/supertokens_python/recipe/oauth2provider/oauth2_client.py @@ -245,32 +245,53 @@ def from_json(json: Dict[str, Any]) -> "OAuth2Client": ) def to_json(self) -> Dict[str, Any]: - return { + result: Dict[str, Any] = { "clientId": self.client_id, "clientName": self.client_name, "scope": self.scope, "tokenEndpointAuthMethod": self.token_endpoint_auth_method, "createdAt": self.created_at, "updatedAt": self.updated_at, - "clientSecret": self.client_secret, - "redirectUris": self.redirect_uris, - "postLogoutRedirectUris": self.post_logout_redirect_uris, - "authorizationCodeGrantAccessTokenLifespan": self.authorization_code_grant_access_token_lifespan, - "authorizationCodeGrantIdTokenLifespan": self.authorization_code_grant_id_token_lifespan, - "authorizationCodeGrantRefreshTokenLifespan": self.authorization_code_grant_refresh_token_lifespan, - "clientCredentialsGrantAccessTokenLifespan": self.client_credentials_grant_access_token_lifespan, - "implicitGrantAccessTokenLifespan": self.implicit_grant_access_token_lifespan, - "implicitGrantIdTokenLifespan": self.implicit_grant_id_token_lifespan, - "refreshTokenGrantAccessTokenLifespan": self.refresh_token_grant_access_token_lifespan, - "refreshTokenGrantIdTokenLifespan": self.refresh_token_grant_id_token_lifespan, - "refreshTokenGrantRefreshTokenLifespan": self.refresh_token_grant_refresh_token_lifespan, "clientUri": self.client_uri, "audience": self.audience, - "grantTypes": self.grant_types, - "responseTypes": self.response_types, "logoUri": self.logo_uri, "policyUri": self.policy_uri, "tosUri": self.tos_uri, "metadata": self.metadata, "enableRefreshTokenRotation": self.enable_refresh_token_rotation, } + + if self.client_secret is not None: + result["clientSecret"] = self.client_secret + result["redirectUris"] = self.redirect_uris + if self.post_logout_redirect_uris is not None: + result["postLogoutRedirectUris"] = self.post_logout_redirect_uris + result["authorizationCodeGrantAccessTokenLifespan"] = ( + self.authorization_code_grant_access_token_lifespan + ) + result["authorizationCodeGrantIdTokenLifespan"] = ( + self.authorization_code_grant_id_token_lifespan + ) + result["authorizationCodeGrantRefreshTokenLifespan"] = ( + self.authorization_code_grant_refresh_token_lifespan + ) + result["clientCredentialsGrantAccessTokenLifespan"] = ( + self.client_credentials_grant_access_token_lifespan + ) + result["implicitGrantAccessTokenLifespan"] = ( + self.implicit_grant_access_token_lifespan + ) + result["implicitGrantIdTokenLifespan"] = self.implicit_grant_id_token_lifespan + result["refreshTokenGrantAccessTokenLifespan"] = ( + self.refresh_token_grant_access_token_lifespan + ) + result["refreshTokenGrantIdTokenLifespan"] = ( + self.refresh_token_grant_id_token_lifespan + ) + result["refreshTokenGrantRefreshTokenLifespan"] = ( + self.refresh_token_grant_refresh_token_lifespan + ) + result["grantTypes"] = self.grant_types + result["responseTypes"] = self.response_types + + return result diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index d2a82f504..f65f930a7 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -373,12 +373,10 @@ async def authorization( ) return RedirectResponse( - redirect_to=consent_res.redirect_to, cookies=",".join(resp["cookies"]) + redirect_to=consent_res.redirect_to, cookies=resp["cookies"] ) - return RedirectResponse( - redirect_to=redirect_to, cookies=",".join(resp["cookies"]) - ) + return RedirectResponse(redirect_to=redirect_to, cookies=resp["cookies"]) async def token_exchange( self, diff --git a/tests/test-server/oauth2provider.py b/tests/test-server/oauth2provider.py index 8f8a18cfd..bed153a75 100644 --- a/tests/test-server/oauth2provider.py +++ b/tests/test-server/oauth2provider.py @@ -27,7 +27,7 @@ def create_oauth2_client_api(): # type: ignore print("OAuth2Provider:createOAuth2Client", request.json) response = OAuth2Provider.create_oauth2_client( - params=CreateOAuth2ClientInput.from_json(request.json.get("input")), + params=CreateOAuth2ClientInput.from_json(request.json.get("input", {})), user_context=request.json.get("userContext"), ) return jsonify(response.to_json()) From 92a7a0b5dbbeea9ec20d5a16d2aeb12bc09be139 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 13 Dec 2024 15:10:36 +0530 Subject: [PATCH 34/38] fix: tests --- .../framework/flask/flask_response.py | 4 ++- .../recipe/oauth2provider/api/auth.py | 4 +-- .../recipe/oauth2provider/api/end_session.py | 10 +++--- .../recipe/oauth2provider/api/token.py | 2 +- .../recipe/oauth2provider/interfaces.py | 10 ++++-- .../oauth2provider/recipe_implementation.py | 33 ++++++++++++++----- tests/test-server/oauth2provider.py | 21 +++++++++--- 7 files changed, 60 insertions(+), 24 deletions(-) diff --git a/supertokens_python/framework/flask/flask_response.py b/supertokens_python/framework/flask/flask_response.py index ef016d5d3..025538ef4 100644 --- a/supertokens_python/framework/flask/flask_response.py +++ b/supertokens_python/framework/flask/flask_response.py @@ -87,6 +87,8 @@ def set_json_content(self, content: Dict[str, Any]): self.response_sent = True def redirect(self, url: str) -> BaseResponse: - self.response.headers.set("Location", url) + self.set_header("Location", url) self.set_status_code(302) + self.response.data = b"" + self.response_sent = True return self diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py index dedb466d1..71439b83c 100644 --- a/supertokens_python/recipe/oauth2provider/api/auth.py +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -45,8 +45,8 @@ async def auth_get( return None original_url = api_options.request.get_original_url() - split_url = original_url.split("?") - params = dict(parse_qsl(split_url[1])) if len(split_url) > 1 else {} + split_url = original_url.split("?", 1) + params = dict(parse_qsl(split_url[1], True)) if len(split_url) > 1 else {} session = None should_try_refresh = False diff --git a/supertokens_python/recipe/oauth2provider/api/end_session.py b/supertokens_python/recipe/oauth2provider/api/end_session.py index 6a2e70bfc..13fedab34 100644 --- a/supertokens_python/recipe/oauth2provider/api/end_session.py +++ b/supertokens_python/recipe/oauth2provider/api/end_session.py @@ -49,8 +49,10 @@ async def end_session_get( return None orig_url = api_options.request.get_original_url() - split_url = orig_url.split("?") - params = dict(urllib.parse.parse_qsl(split_url[1])) + split_url = orig_url.split("?", 1) + params = ( + dict(urllib.parse.parse_qsl(split_url[1], True)) if len(split_url) > 1 else {} + ) return await end_session_common( params, api_implementation.end_session_get, api_options, user_context @@ -110,9 +112,9 @@ async def end_session_common( ) if isinstance(response, RedirectResponse): - options.response.redirect(response.redirect_to) + return options.response.redirect(response.redirect_to) elif isinstance(response, ErrorOAuth2Response): - send_non_200_response( + return send_non_200_response( { "error": response.error, "error_description": response.error_description, diff --git a/supertokens_python/recipe/oauth2provider/api/token.py b/supertokens_python/recipe/oauth2provider/api/token.py index 16f8d58b2..f88495a74 100644 --- a/supertokens_python/recipe/oauth2provider/api/token.py +++ b/supertokens_python/recipe/oauth2provider/api/token.py @@ -40,7 +40,7 @@ async def token_post( authorization_header = api_options.request.get_header("authorization") - body = await api_options.request.json() + body = await api_options.request.get_json_or_form_data() response = await api_implementation.token_post( authorization_header=authorization_header, diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index e115d595e..7f9550262 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -241,8 +241,6 @@ def to_json(self) -> Dict[str, Any]: result: Dict[str, Any] = { "frontendRedirectTo": self.frontend_redirect_to, } - if self.cookies is not None: - result["cookies"] = self.cookies return result @@ -337,6 +335,14 @@ def __init__( self.scopes = scopes self.audience = audience + @staticmethod + def from_json(json: Dict[str, Any]): + return OAuth2TokenValidationRequirements( + client_id=json.get("clientId"), + scopes=json.get("scopes"), + audience=json.get("audience"), + ) + class FrontendRedirectionURLTypeLogin: def __init__( diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index f65f930a7..e3133e11b 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -239,7 +239,9 @@ async def authorization( payloads = None - if not params.get("client_id") or not isinstance(params.get("client_id"), str): + if params.get("client_id") is None or not isinstance( + params.get("client_id"), str + ): return ErrorOAuth2Response( status_code=400, error="invalid_request", @@ -644,7 +646,11 @@ async def validate_oauth2_access_token( token, matching_key.key, algorithms=["RS256"], - options={"verify_signature": True, "verify_exp": True}, + options={ + "verify_signature": True, + "verify_exp": True, + "verify_aud": False, + }, ) except Exception as e: err = e @@ -908,15 +914,24 @@ async def end_session( # CASE 3: end_session request with a logout_verifier (after accepting the logout request) # - Redirects to the post_logout_redirect_uri or the default logout fallback page. + request_body: Dict[str, Any] = {} + + if params.get("client_id") is not None: + request_body["clientId"] = params.get("client_id") + if params.get("id_token_hint") is not None: + request_body["idTokenHint"] = params.get("id_token_hint") + if params.get("post_logout_redirect_uri") is not None: + request_body["postLogoutRedirectUri"] = params.get( + "post_logout_redirect_uri" + ) + if params.get("state") is not None: + request_body["state"] = params.get("state") + if params.get("logout_verifier") is not None: + request_body["logoutVerifier"] = params.get("logout_verifier") + resp = await self.querier.send_get_request( NormalisedURLPath("/recipe/oauth/sessions/logout"), - { - "clientId": params.get("client_id"), - "idTokenHint": params.get("id_token_hint"), - "postLogoutRedirectUri": params.get("post_logout_redirect_uri"), - "state": params.get("state"), - "logoutVerifier": params.get("logout_verifier"), - }, + request_body, user_context=user_context, ) diff --git a/tests/test-server/oauth2provider.py b/tests/test-server/oauth2provider.py index bed153a75..d39b588c7 100644 --- a/tests/test-server/oauth2provider.py +++ b/tests/test-server/oauth2provider.py @@ -1,5 +1,9 @@ from flask import Flask, request, jsonify -from supertokens_python.recipe.oauth2provider.interfaces import CreateOAuth2ClientInput +from supertokens_python.recipe.oauth2provider.interfaces import ( + CreateOAuth2ClientInput, + OAuth2TokenValidationRequirements, + UpdateOAuth2ClientInput, +) import supertokens_python.recipe.oauth2provider.syncio as OAuth2Provider @@ -38,7 +42,8 @@ def update_oauth2_client_api(): # type: ignore print("OAuth2Provider:updateOAuth2Client", request.json) response = OAuth2Provider.update_oauth2_client( - params=request.json["input"], user_context=request.json.get("userContext") + params=UpdateOAuth2ClientInput.from_json(request.json.get("input", {})), + user_context=request.json.get("userContext"), ) return jsonify(response.to_json()) @@ -60,11 +65,17 @@ def validate_oauth2_access_token_api(): # type: ignore response = OAuth2Provider.validate_oauth2_access_token( token=request.json["token"], - requirements=request.json["requirements"], - check_database=request.json["checkDatabase"], + requirements=( + OAuth2TokenValidationRequirements.from_json( + request.json["requirements"] + ) + if "requirements" in request.json + else None + ), + check_database=request.json.get("checkDatabase"), user_context=request.json.get("userContext"), ) - return jsonify(response) + return jsonify({**response, "status": "OK"}) @app.route("/test/oauth2provider/validateoauth2refreshtoken", methods=["POST"]) # type: ignore def validate_oauth2_refresh_token_api(): # type: ignore From 91926aee8179436dcf49721445d095dc6e060d25 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 13 Dec 2024 16:09:42 +0530 Subject: [PATCH 35/38] fix: tests --- .../recipe/oauth2provider/interfaces.py | 8 +++--- .../oauth2provider/recipe_implementation.py | 9 +++++-- tests/test-server/oauth2provider.py | 10 +++++--- tests/test-server/session.py | 25 +++++++++++++++++++ 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 7f9550262..6d07fa89c 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -259,11 +259,13 @@ def from_json(json: Dict[str, Any]): ) def to_json(self) -> Dict[str, Any]: - return { + result = { "status": "OK", "clients": [client.to_json() for client in self.clients], - "nextPaginationToken": self.next_pagination_token, } + if self.next_pagination_token is not None: + result["nextPaginationToken"] = self.next_pagination_token + return result class GetOAuth2ClientOkResult: @@ -920,7 +922,7 @@ def to_json(self) -> Dict[str, Any]: result["audience"] = self.audience if not isinstance(self.grant_types, NotSet): result["grantTypes"] = self.grant_types - if self.response_types is not None: + if not isinstance(self.response_types, NotSet): result["responseTypes"] = self.response_types if not isinstance(self.client_uri, NotSet): result["clientUri"] = self.client_uri diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index e3133e11b..f14fb9449 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -25,6 +25,9 @@ from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.recipe.session.interfaces import SessionContainer from supertokens_python.recipe.session.jwks import get_latest_keys +from supertokens_python.recipe.session.jwt import ( + parse_jwt_without_signature_verification, +) from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.types import RecipeUserId, User @@ -543,7 +546,7 @@ async def get_oauth2_clients( clients=[ OAuth2Client.from_json(client) for client in response["clients"] ], - next_pagination_token=response["nextPaginationToken"], + next_pagination_token=response.get("nextPaginationToken"), ) return ErrorOAuth2Response( @@ -632,9 +635,11 @@ async def validate_oauth2_access_token( check_database: Optional[bool], user_context: Dict[str, Any], ) -> Dict[str, Any]: + access_token_obj = parse_jwt_without_signature_verification(token) + # Verify token signature using session recipe's JWKS session_recipe = SessionRecipe.get_instance() - matching_keys = get_latest_keys(session_recipe.config) + matching_keys = get_latest_keys(session_recipe.config, access_token_obj.kid) err: Optional[Exception] = None payload: Dict[str, Any] = {} diff --git a/tests/test-server/oauth2provider.py b/tests/test-server/oauth2provider.py index d39b588c7..06aa21761 100644 --- a/tests/test-server/oauth2provider.py +++ b/tests/test-server/oauth2provider.py @@ -13,7 +13,7 @@ def get_oauth2_clients_api(): # type: ignore assert request.json is not None print("OAuth2Provider:getOAuth2Clients", request.json) - data = request.json + data = request.json.get("input", {}) if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) @@ -52,9 +52,11 @@ def delete_oauth2_client_api(): # type: ignore assert request.json is not None print("OAuth2Provider:deleteOAuth2Client", request.json) + data = request.json.get("input", {}) + response = OAuth2Provider.delete_oauth2_client( - client_id=request.json["input"], - user_context=request.json.get("userContext"), + client_id=data.get("clientId"), + user_context=data.get("userContext"), ) return jsonify(response.to_json()) @@ -75,7 +77,7 @@ def validate_oauth2_access_token_api(): # type: ignore check_database=request.json.get("checkDatabase"), user_context=request.json.get("userContext"), ) - return jsonify({**response, "status": "OK"}) + return jsonify({"payload": response, "status": "OK"}) @app.route("/test/oauth2provider/validateoauth2refreshtoken", methods=["POST"]) # type: ignore def validate_oauth2_refresh_token_api(): # type: ignore diff --git a/tests/test-server/session.py b/tests/test-server/session.py index cf1dc174e..ea3dea1df 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -195,6 +195,31 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore } ) + @app.route("/test/session/sessionobject/revokesession", methods=["POST"]) # type: ignore + def revoke_session(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + log_override_event("sessionobject.revokesession", "CALL", data) + + try: + session = convert_session_to_container(data) + if not session: + raise Exception( + "This should never happen: failed to deserialize session" + ) + ret_val = session.sync_revoke_session(data.get("userContext", {})) + response = { + "retVal": ret_val, + "updatedSession": convert_session_to_json(session), + } + log_override_event("sessionobject.revokesession", "RES", ret_val) + return jsonify(response) + except Exception as e: + log_override_event("sessionobject.revokesession", "REJ", e) + return jsonify({"status": "ERROR", "message": str(e)}), 500 + @app.route("/test/session/mergeintoaccesspayload", methods=["POST"]) # type: ignore def merge_into_access_payload(): # type: ignore data = request.json From 632b5dd33f4af547826bb7649b908890c027b146 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 17 Dec 2024 17:46:46 +0530 Subject: [PATCH 36/38] fix: openid and cookies --- .../recipe/oauth2provider/api/auth.py | 4 +- .../recipe/oauth2provider/api/login.py | 4 +- .../recipe/oauth2provider/recipe.py | 10 +-- .../recipe/openid/api/implementation.py | 12 ++- .../recipe/openid/interfaces.py | 84 ++++++++++++++++++- .../recipe/openid/recipe_implementation.py | 36 ++++++-- 6 files changed, 130 insertions(+), 20 deletions(-) diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py index 71439b83c..57850ec0f 100644 --- a/supertokens_python/recipe/oauth2provider/api/auth.py +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -14,10 +14,10 @@ from __future__ import annotations -from datetime import datetime from http.cookies import SimpleCookie from typing import TYPE_CHECKING, Any, Dict from urllib.parse import parse_qsl +from dateutil import parser from supertokens_python.recipe.session.asyncio import get_session from supertokens_python.recipe.session.exceptions import TryRefreshTokenError @@ -85,7 +85,7 @@ async def auth_get( domain=morsel.get("domain"), secure=morsel.get("secure", True), httponly=morsel.get("httponly", True), - expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore + expires=parser.parse(morsel.get("expires", "")).timestamp() * 1000, # type: ignore path=morsel.get("path", "/"), samesite=morsel.get("samesite", "lax"), ) diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py index 3c23da175..ce5b7dbd6 100644 --- a/supertokens_python/recipe/oauth2provider/api/login.py +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional -from datetime import datetime +from dateutil import parser from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.framework import BaseResponse @@ -87,7 +87,7 @@ async def login( domain=morsel.get("domain"), secure=morsel.get("secure", True), httponly=morsel.get("httponly", True), - expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore + expires=parser.parse(morsel.get("expires", "")).timestamp() * 1000, # type: ignore path=morsel.get("path", "/"), samesite=morsel.get("samesite", "lax").lower(), ) diff --git a/supertokens_python/recipe/oauth2provider/recipe.py b/supertokens_python/recipe/oauth2provider/recipe.py index c4f0a4a94..aa9ec036e 100644 --- a/supertokens_python/recipe/oauth2provider/recipe.py +++ b/supertokens_python/recipe/oauth2provider/recipe.py @@ -346,9 +346,8 @@ async def get_default_access_token_payload( payload["emails"] = user.emails if "phoneNumber" in scopes: - payload["phoneNumber"] = ( - user.phone_numbers[0] if user.phone_numbers else None - ) + if user.phone_numbers: + payload["phoneNumber"] = user.phone_numbers[0] payload["phoneNumber_verified"] = ( any( lm.has_same_phone_number_as(user.phone_numbers[0]) and lm.verified @@ -387,9 +386,8 @@ async def get_default_id_token_payload( payload["emails"] = user.emails if "phoneNumber" in scopes: - payload["phoneNumber"] = ( - user.phone_numbers[0] if user.phone_numbers else None - ) + if user.phone_numbers: + payload["phoneNumber"] = user.phone_numbers[0] payload["phoneNumber_verified"] = ( any( lm.has_same_phone_number_as(user.phone_numbers[0]) and lm.verified diff --git a/supertokens_python/recipe/openid/api/implementation.py b/supertokens_python/recipe/openid/api/implementation.py index 6d6440c5a..6214f7639 100644 --- a/supertokens_python/recipe/openid/api/implementation.py +++ b/supertokens_python/recipe/openid/api/implementation.py @@ -30,5 +30,15 @@ async def open_id_discovery_configuration_get( ) ) return OpenIdDiscoveryConfigurationGetResponse( - response.issuer, response.jwks_uri + issuer=response.issuer, + jwks_uri=response.jwks_uri, + authorization_endpoint=response.authorization_endpoint, + token_endpoint=response.token_endpoint, + userinfo_endpoint=response.userinfo_endpoint, + revocation_endpoint=response.revocation_endpoint, + token_introspection_endpoint=response.token_introspection_endpoint, + end_session_endpoint=response.end_session_endpoint, + subject_types_supported=response.subject_types_supported, + id_token_signing_alg_values_supported=response.id_token_signing_alg_values_supported, + response_types_supported=response.response_types_supported, ) diff --git a/supertokens_python/recipe/openid/interfaces.py b/supertokens_python/recipe/openid/interfaces.py index 738242d6d..963360616 100644 --- a/supertokens_python/recipe/openid/interfaces.py +++ b/supertokens_python/recipe/openid/interfaces.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Union, Optional +from typing import Any, Dict, List, Union, Optional from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.jwt.interfaces import ( @@ -26,9 +26,48 @@ class GetOpenIdDiscoveryConfigurationResult: - def __init__(self, issuer: str, jwks_uri: str): + def __init__( + self, + issuer: str, + jwks_uri: str, + authorization_endpoint: str, + token_endpoint: str, + userinfo_endpoint: str, + revocation_endpoint: str, + token_introspection_endpoint: str, + end_session_endpoint: str, + subject_types_supported: List[str], + id_token_signing_alg_values_supported: List[str], + response_types_supported: List[str], + ): self.issuer = issuer self.jwks_uri = jwks_uri + self.authorization_endpoint = authorization_endpoint + self.token_endpoint = token_endpoint + self.userinfo_endpoint = userinfo_endpoint + self.revocation_endpoint = revocation_endpoint + self.token_introspection_endpoint = token_introspection_endpoint + self.end_session_endpoint = end_session_endpoint + self.subject_types_supported = subject_types_supported + self.id_token_signing_alg_values_supported = ( + id_token_signing_alg_values_supported + ) + self.response_types_supported = response_types_supported + + def to_json(self) -> Dict[str, Any]: + return { + "issuer": self.issuer, + "jwks_uri": self.jwks_uri, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "userinfo_endpoint": self.userinfo_endpoint, + "revocation_endpoint": self.revocation_endpoint, + "token_introspection_endpoint": self.token_introspection_endpoint, + "end_session_endpoint": self.end_session_endpoint, + "subject_types_supported": self.subject_types_supported, + "id_token_signing_alg_values_supported": self.id_token_signing_alg_values_supported, + "response_types_supported": self.response_types_supported, + } class RecipeInterface(ABC): @@ -75,12 +114,49 @@ def __init__( class OpenIdDiscoveryConfigurationGetResponse(APIResponse): status: str = "OK" - def __init__(self, issuer: str, jwks_uri: str): + def __init__( + self, + issuer: str, + jwks_uri: str, + authorization_endpoint: str, + token_endpoint: str, + userinfo_endpoint: str, + revocation_endpoint: str, + token_introspection_endpoint: str, + end_session_endpoint: str, + subject_types_supported: List[str], + id_token_signing_alg_values_supported: List[str], + response_types_supported: List[str], + ): self.issuer = issuer self.jwks_uri = jwks_uri + self.authorization_endpoint = authorization_endpoint + self.token_endpoint = token_endpoint + self.userinfo_endpoint = userinfo_endpoint + self.revocation_endpoint = revocation_endpoint + self.token_introspection_endpoint = token_introspection_endpoint + self.end_session_endpoint = end_session_endpoint + self.subject_types_supported = subject_types_supported + self.id_token_signing_alg_values_supported = ( + id_token_signing_alg_values_supported + ) + self.response_types_supported = response_types_supported def to_json(self): - return {"status": self.status, "issuer": self.issuer, "jwks_uri": self.jwks_uri} + return { + "status": self.status, + "issuer": self.issuer, + "jwks_uri": self.jwks_uri, + "authorization_endpoint": self.authorization_endpoint, + "token_endpoint": self.token_endpoint, + "userinfo_endpoint": self.userinfo_endpoint, + "revocation_endpoint": self.revocation_endpoint, + "token_introspection_endpoint": self.token_introspection_endpoint, + "end_session_endpoint": self.end_session_endpoint, + "subject_types_supported": self.subject_types_supported, + "id_token_signing_alg_values_supported": self.id_token_signing_alg_values_supported, + "response_types_supported": self.response_types_supported, + } class APIInterface: diff --git a/supertokens_python/recipe/openid/recipe_implementation.py b/supertokens_python/recipe/openid/recipe_implementation.py index 32bda3926..f117c5007 100644 --- a/supertokens_python/recipe/openid/recipe_implementation.py +++ b/supertokens_python/recipe/openid/recipe_implementation.py @@ -39,19 +39,45 @@ class RecipeImplementation(RecipeInterface): async def get_open_id_discovery_configuration( self, user_context: Dict[str, Any] ) -> GetOpenIdDiscoveryConfigurationResult: + from ..oauth2provider.constants import ( + AUTH_PATH, + TOKEN_PATH, + USER_INFO_PATH, + REVOKE_TOKEN_PATH, + INTROSPECT_TOKEN_PATH, + END_SESSION_PATH, + ) + issuer = ( - self.config.issuer_domain.get_as_string_dangerous() - + self.config.issuer_path.get_as_string_dangerous() + self.app_info.api_domain.get_as_string_dangerous() + + self.app_info.api_base_path.get_as_string_dangerous() ) jwks_uri = ( - self.config.issuer_domain.get_as_string_dangerous() - + self.config.issuer_path.append( + self.app_info.api_domain.get_as_string_dangerous() + + self.app_info.api_base_path.append( NormalisedURLPath(GET_JWKS_API) ).get_as_string_dangerous() ) - return GetOpenIdDiscoveryConfigurationResult(issuer, jwks_uri) + api_base_path: str = ( + self.app_info.api_domain.get_as_string_dangerous() + + self.app_info.api_base_path.get_as_string_dangerous() + ) + + return GetOpenIdDiscoveryConfigurationResult( + issuer=issuer, + jwks_uri=jwks_uri, + authorization_endpoint=api_base_path + AUTH_PATH, + token_endpoint=api_base_path + TOKEN_PATH, + userinfo_endpoint=api_base_path + USER_INFO_PATH, + revocation_endpoint=api_base_path + REVOKE_TOKEN_PATH, + token_introspection_endpoint=api_base_path + INTROSPECT_TOKEN_PATH, + end_session_endpoint=api_base_path + END_SESSION_PATH, + subject_types_supported=["public"], + id_token_signing_alg_values_supported=["RS256"], + response_types_supported=["code", "id_token", "id_token token"], + ) def __init__( self, From 2685c315bce76355b26895b7dd0d1610ed59aa55 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 17 Dec 2024 18:10:08 +0530 Subject: [PATCH 37/38] fix: roles and permissions for oauth2 --- supertokens_python/recipe/userroles/recipe.py | 110 +++++++++++++++++- 1 file changed, 108 insertions(+), 2 deletions(-) diff --git a/supertokens_python/recipe/userroles/recipe.py b/supertokens_python/recipe/userroles/recipe.py index 844b41895..7a2d830d5 100644 --- a/supertokens_python/recipe/userroles/recipe.py +++ b/supertokens_python/recipe/userroles/recipe.py @@ -21,19 +21,20 @@ from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier +from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.recipe.userroles.recipe_implementation import ( RecipeImplementation, ) from supertokens_python.recipe.userroles.utils import validate_and_normalise_user_input from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.supertokens import AppInfo -from supertokens_python.types import RecipeUserId +from supertokens_python.types import RecipeUserId, User from ...post_init_callbacks import PostSTInitCallbacks from ..session import SessionRecipe from ..session.claim_base_classes.primitive_array_claim import PrimitiveArrayClaim from .exceptions import SuperTokensUserRolesError -from .interfaces import GetPermissionsForRoleOkResult +from .interfaces import GetPermissionsForRoleOkResult, UnknownRoleError from .utils import InputOverrideConfig @@ -49,6 +50,8 @@ def __init__( skip_adding_permissions_to_access_token: Optional[bool] = None, override: Union[InputOverrideConfig, None] = None, ): + from ..oauth2provider.recipe import OAuth2ProviderRecipe + super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( self, @@ -72,6 +75,109 @@ def callback(): PermissionClaim ) + async def token_payload_builder( + user: User, + scopes: List[str], + session_handle: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + payload: Dict[str, Any] = {"roles": None, "permissions": None} + + session_info = await get_session_information( + session_handle, user_context + ) + + if session_info is None: + raise Exception("should never come here") + + user_roles: List[str] = [] + + if "roles" in scopes or "permissions" in scopes: + res = await self.recipe_implementation.get_roles_for_user( + tenant_id=session_info.tenant_id, + user_id=user.id, + user_context=user_context, + ) + + user_roles = res.roles + + if "roles" in scopes: + payload["roles"] = user_roles + + if "permissions" in scopes: + user_permissions: Set[str] = set() + for role in user_roles: + role_permissions = ( + await self.recipe_implementation.get_permissions_for_role( + role=role, + user_context=user_context, + ) + ) + + if isinstance(role_permissions, UnknownRoleError): + raise Exception("Failed to fetch permissions for the role") + + for perm in role_permissions.permissions: + user_permissions.add(perm) + + payload["permissions"] = list(user_permissions) + + return payload + + OAuth2ProviderRecipe.get_instance().add_access_token_builder_from_other_recipe( + token_payload_builder + ) + OAuth2ProviderRecipe.get_instance().add_id_token_builder_from_other_recipe( + token_payload_builder + ) + + async def user_info_builder( + user: User, + _access_token_payload: Dict[str, Any], + scopes: List[str], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + user_info: Dict[str, Any] = {"roles": None, "permissions": None} + + user_roles = [] + + if "roles" in scopes or "permissions" in scopes: + res = await self.recipe_implementation.get_roles_for_user( + tenant_id=tenant_id, + user_id=user.id, + user_context=user_context, + ) + + user_roles = res.roles + + if "roles" in scopes: + user_info["roles"] = user_roles + + if "permissions" in scopes: + user_permissions: Set[str] = set() + for role in user_roles: + role_permissions = ( + await self.recipe_implementation.get_permissions_for_role( + role=role, + user_context=user_context, + ) + ) + + if isinstance(role_permissions, UnknownRoleError): + raise Exception("Failed to fetch permissions for the role") + + for perm in role_permissions.permissions: + user_permissions.add(perm) + + user_info["permissions"] = list(user_permissions) + + return user_info + + OAuth2ProviderRecipe.get_instance().add_user_info_builder_from_other_recipe( + user_info_builder + ) + PostSTInitCallbacks.add_post_init_callback(callback) def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: From 0226085837025e0f2727c5d27cd369abd77339dc Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Tue, 17 Dec 2024 18:23:38 +0530 Subject: [PATCH 38/38] fix: auth react tests --- tests/auth-react/django3x/mysite/utils.py | 11 ++++++++++- tests/auth-react/django3x/polls/urls.py | 1 + tests/auth-react/django3x/polls/views.py | 11 +++++++++++ tests/auth-react/fastapi-server/app.py | 23 ++++++++++++++++++++++- tests/auth-react/flask-server/app.py | 23 ++++++++++++++++++++++- 5 files changed, 66 insertions(+), 3 deletions(-) diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index cfafd8777..ae8e5872d 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -10,6 +10,7 @@ emailpassword, emailverification, multifactorauth, + oauth2provider, passwordless, session, thirdparty, @@ -40,6 +41,8 @@ APIOptions as EVAPIOptions, ) from supertokens_python.recipe.jwt import JWTRecipe +from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe +from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.recipe.passwordless import ( ContactEmailOnlyConfig, ContactEmailOrPhoneConfig, @@ -315,6 +318,8 @@ def custom_init(): Supertokens.reset() TOTPRecipe.reset() MultiFactorAuthRecipe.reset() + OpenIdRecipe.reset() + OAuth2ProviderRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -942,6 +947,10 @@ async def resync_session_and_fetch_mfa_info_put( ) ), }, + { + "id": "oauth2provider", + "init": oauth2provider.init(), + }, ] accountlinking_config_input: Dict[str, Any] = { @@ -1002,7 +1011,7 @@ async def should_do_automatic_account_linking( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( app_name="SuperTokens Demo", - api_domain="0.0.0.0:" + get_api_port(), + api_domain="localhost:" + get_api_port(), website_domain=get_website_domain(), ), framework="django", diff --git a/tests/auth-react/django3x/polls/urls.py b/tests/auth-react/django3x/polls/urls.py index 8b3cf999e..cff4e60f7 100644 --- a/tests/auth-react/django3x/polls/urls.py +++ b/tests/auth-react/django3x/polls/urls.py @@ -35,6 +35,7 @@ name="setEnabledRecipes", ), path("test/getTOTPCode", views.test_get_totp_code, name="getTotpCode"), # type: ignore + path("test/create-oauth2-client", views.test_create_oauth2_client, name="createOAuth2Client"), # type: ignore path("test/getDevice", views.test_get_device, name="getDevice"), # type: ignore path("test/featureFlags", views.test_feature_flags, name="featureFlags"), # type: ignore path("beforeeach", views.before_each, name="beforeeach"), # type: ignore diff --git a/tests/auth-react/django3x/polls/views.py b/tests/auth-react/django3x/polls/views.py index cd6b7c9be..08c28afcf 100644 --- a/tests/auth-react/django3x/polls/views.py +++ b/tests/auth-react/django3x/polls/views.py @@ -27,6 +27,8 @@ from supertokens_python.recipe.multifactorauth.asyncio import ( add_to_required_secondary_factors_for_user, ) +from supertokens_python.recipe.oauth2provider.syncio import create_oauth2_client +from supertokens_python.recipe.oauth2provider.interfaces import CreateOAuth2ClientInput from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.interfaces import SessionClaimValidator from supertokens_python.recipe.thirdparty import ProviderConfig @@ -457,6 +459,14 @@ def test_get_totp_code(request: HttpRequest): return JsonResponse({"totp": code}) +def test_create_oauth2_client(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Invalid request body") + client = create_oauth2_client(CreateOAuth2ClientInput.from_json(body)) + return JsonResponse(client.to_json()) + + def before_each(request: HttpRequest): import mysite.store @@ -486,6 +496,7 @@ def test_feature_flags(request: HttpRequest): "mfa", "recipeConfig", "accountlinking-fixes", + "oauth2", ] } ) diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index cfdaaec10..7a727da93 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -43,6 +43,7 @@ from supertokens_python.recipe import ( emailpassword, emailverification, + oauth2provider, passwordless, session, thirdparty, @@ -106,6 +107,10 @@ AssociateUserToTenantUnknownUserIdError, TenantConfigCreateOrUpdate, ) +from supertokens_python.recipe.oauth2provider.interfaces import CreateOAuth2ClientInput +from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe +from supertokens_python.recipe.oauth2provider.asyncio import create_oauth2_client +from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.recipe.passwordless import ( ContactEmailOnlyConfig, ContactEmailOrPhoneConfig, @@ -391,6 +396,8 @@ def custom_init(): Supertokens.reset() TOTPRecipe.reset() MultiFactorAuthRecipe.reset() + OpenIdRecipe.reset() + OAuth2ProviderRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -1021,6 +1028,10 @@ async def resync_session_and_fetch_mfa_info_put( ) ), }, + { + "id": "oauth2provider", + "init": oauth2provider.init(), + }, ] global accountlinking_config @@ -1084,7 +1095,7 @@ async def should_do_automatic_account_linking( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( app_name="SuperTokens Demo", - api_domain="0.0.0.0:" + get_api_port(), + api_domain="localhost:" + get_api_port(), website_domain=get_website_domain(), ), framework="fastapi", @@ -1374,6 +1385,15 @@ async def test_get_totp_code(request: Request): return JSONResponse({"totp": code}) +@app.post("/test/create-oauth2-client") +async def test_create_oauth2_client(request: Request): + body = await request.json() + if body is None: + raise Exception("Invalid request body") + client = await create_oauth2_client(CreateOAuth2ClientInput.from_json(body)) + return JSONResponse(client.to_json()) + + @app.get("/test/getDevice") def test_get_device(request: Request): global code_store @@ -1397,6 +1417,7 @@ def test_feature_flags(request: Request): "mfa", "recipeConfig", "accountlinking-fixes", + "oauth2", ] return JSONResponse({"available": available}) diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 6c840c9e3..6dbb3ce81 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -35,6 +35,7 @@ accountlinking, emailpassword, emailverification, + oauth2provider, passwordless, session, thirdparty, @@ -77,6 +78,10 @@ delete_tenant, disassociate_user_from_tenant, ) +from supertokens_python.recipe.oauth2provider.syncio import create_oauth2_client +from supertokens_python.recipe.oauth2provider.interfaces import CreateOAuth2ClientInput +from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe +from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.recipe.passwordless.syncio import update_user from supertokens_python.recipe.session.exceptions import ( ClaimValidationError, @@ -374,6 +379,8 @@ def custom_init(): Supertokens.reset() TOTPRecipe.reset() MultiFactorAuthRecipe.reset() + OpenIdRecipe.reset() + OAuth2ProviderRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -1004,6 +1011,10 @@ async def resync_session_and_fetch_mfa_info_put( ) ), }, + { + "id": "oauth2provider", + "init": oauth2provider.init(), + }, ] global accountlinking_config @@ -1066,7 +1077,7 @@ async def should_do_automatic_account_linking( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( app_name="SuperTokens Demo", - api_domain="0.0.0.0:" + get_api_port(), + api_domain="localhost:" + get_api_port(), website_domain=get_website_domain(), ), framework="flask", @@ -1401,6 +1412,15 @@ def test_get_totp_code(): return jsonify({"totp": code}) +@app.post("/test/create-oauth2-client") # type: ignore +def test_create_oauth2_client(): + body = request.get_json() + if body is None: + raise Exception("Invalid request body") + client = create_oauth2_client(CreateOAuth2ClientInput.from_json(body)) + return jsonify(client.to_json()) + + @app.get("/test/getDevice") # type: ignore def test_get_device(): global code_store @@ -1424,6 +1444,7 @@ def test_feature_flags(): "mfa", "recipeConfig", "accountlinking-fixes", + "oauth2", ] return jsonify({"available": available})