diff --git a/supertokens_python/framework/django/django_response.py b/supertokens_python/framework/django/django_response.py index 9692b2d5..f839b923 100644 --- a/supertokens_python/framework/django/django_response.py +++ b/supertokens_python/framework/django/django_response.py @@ -89,7 +89,8 @@ def set_json_content(self, content: Dict[str, Any]): ).encode("utf-8") self.response_sent = True - def redirect(self, url: str): + def redirect(self, url: str) -> BaseResponse: if not self.response_sent: self.set_header("Location", url) self.set_status_code(302) + return self diff --git a/supertokens_python/framework/fastapi/fastapi_response.py b/supertokens_python/framework/fastapi/fastapi_response.py index 45813a5c..76e6f349 100644 --- a/supertokens_python/framework/fastapi/fastapi_response.py +++ b/supertokens_python/framework/fastapi/fastapi_response.py @@ -95,7 +95,8 @@ def set_json_content(self, content: Dict[str, Any]): self.response.body = body self.response_sent = True - def redirect(self, url: str): + def redirect(self, url: str) -> BaseResponse: if not self.response_sent: self.set_header("Location", url) self.set_status_code(302) + return self diff --git a/supertokens_python/framework/flask/flask_response.py b/supertokens_python/framework/flask/flask_response.py index a74bdfb8..ef016d5d 100644 --- a/supertokens_python/framework/flask/flask_response.py +++ b/supertokens_python/framework/flask/flask_response.py @@ -86,6 +86,7 @@ def set_json_content(self, content: Dict[str, Any]): ).encode("utf-8") self.response_sent = True - def redirect(self, url: str): + def redirect(self, url: str) -> BaseResponse: self.response.headers.set("Location", url) self.set_status_code(302) + return self diff --git a/supertokens_python/framework/response.py b/supertokens_python/framework/response.py index 8669e3ae..c28a2410 100644 --- a/supertokens_python/framework/response.py +++ b/supertokens_python/framework/response.py @@ -63,5 +63,5 @@ def set_html_content(self, content: str): pass @abstractmethod - def redirect(self, url: str): + def redirect(self, url: str) -> "BaseResponse": pass diff --git a/supertokens_python/recipe/oauth2provider/__init__.py b/supertokens_python/recipe/oauth2provider/__init__.py index 55b44908..3397b343 100644 --- a/supertokens_python/recipe/oauth2provider/__init__.py +++ b/supertokens_python/recipe/oauth2provider/__init__.py @@ -16,15 +16,15 @@ from typing import TYPE_CHECKING, Callable, Union from . import exceptions as ex -from . import recipe +from . import recipe, utils exceptions = ex +InputOverrideConfig = utils.InputOverrideConfig if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo from ...recipe_module import RecipeModule - from .utils import InputOverrideConfig def init( diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py index 1819aa23..5b5767e6 100644 --- a/supertokens_python/recipe/oauth2provider/api/auth.py +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -14,6 +14,7 @@ from __future__ import annotations +from datetime import datetime from http.cookies import SimpleCookie from typing import TYPE_CHECKING, Any, Dict from urllib.parse import parse_qsl @@ -83,7 +84,7 @@ async def auth_get( domain=morsel.get("domain"), secure=morsel.get("secure", True), httponly=morsel.get("httponly", True), - expires=morsel.get("expires", None), + expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore path=morsel.get("path", "/"), samesite=morsel.get("samesite", "lax"), ) diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py index 59c93dcd..d080211f 100644 --- a/supertokens_python/recipe/oauth2provider/api/login.py +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, Any, Dict, Optional +from datetime import datetime + from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.framework import BaseResponse from supertokens_python.recipe.session.asyncio import get_session @@ -84,7 +86,7 @@ async def login( domain=morsel.get("domain"), secure=morsel.get("secure", True), httponly=morsel.get("httponly", True), - expires=morsel.get("expires", None), + expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore path=morsel.get("path", "/"), samesite=morsel.get("samesite", "lax"), ) diff --git a/supertokens_python/recipe/oauth2provider/api/utils.py b/supertokens_python/recipe/oauth2provider/api/utils.py index 2b524b01..80900c03 100644 --- a/supertokens_python/recipe/oauth2provider/api/utils.py +++ b/supertokens_python/recipe/oauth2provider/api/utils.py @@ -235,6 +235,8 @@ async def handle_login_internal_redirects( cookie: str, user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: + from ..interfaces import RedirectResponse, ErrorOAuth2Response + if not is_login_internal_redirect(response.redirect_to): return response diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index 8b9f2649..57d34579 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -254,6 +254,13 @@ def from_json(json: Dict[str, Any]): next_pagination_token=json["nextPaginationToken"], ) + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "clients": [client.to_json() for client in self.clients], + "nextPaginationToken": self.next_pagination_token, + } + class GetOAuth2ClientOkResult: def __init__(self, client: OAuth2Client): @@ -272,6 +279,12 @@ def __init__(self, client: OAuth2Client): def from_json(json: Dict[str, Any]): return CreateOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"])) + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "client": self.client.to_json(), + } + class UpdateOAuth2ClientOkResult: def __init__(self, client: OAuth2Client): @@ -281,11 +294,22 @@ def __init__(self, client: OAuth2Client): def from_json(json: Dict[str, Any]): return UpdateOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"])) + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "client": self.client.to_json(), + } + class DeleteOAuth2ClientOkResult: def __init__(self): pass + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + } + PayloadBuilderFunction = Callable[ [User, List[str], str, Dict[str, Any]], Awaitable[Dict[str, Any]] diff --git a/supertokens_python/recipe/oauth2provider/oauth2_client.py b/supertokens_python/recipe/oauth2provider/oauth2_client.py index 9b17718f..0f16db50 100644 --- a/supertokens_python/recipe/oauth2provider/oauth2_client.py +++ b/supertokens_python/recipe/oauth2provider/oauth2_client.py @@ -243,3 +243,34 @@ def from_json(json: Dict[str, Any]) -> "OAuth2Client": metadata=json.get("metadata", {}), enable_refresh_token_rotation=json.get("enableRefreshTokenRotation", False), ) + + def to_json(self) -> Dict[str, Any]: + return { + "clientId": self.client_id, + "clientName": self.client_name, + "scope": self.scope, + "tokenEndpointAuthMethod": self.token_endpoint_auth_method, + "createdAt": self.created_at, + "updatedAt": self.updated_at, + "clientSecret": self.client_secret, + "redirectUris": self.redirect_uris, + "postLogoutRedirectUris": self.post_logout_redirect_uris, + "authorizationCodeGrantAccessTokenLifespan": self.authorization_code_grant_access_token_lifespan, + "authorizationCodeGrantIdTokenLifespan": self.authorization_code_grant_id_token_lifespan, + "authorizationCodeGrantRefreshTokenLifespan": self.authorization_code_grant_refresh_token_lifespan, + "clientCredentialsGrantAccessTokenLifespan": self.client_credentials_grant_access_token_lifespan, + "implicitGrantAccessTokenLifespan": self.implicit_grant_access_token_lifespan, + "implicitGrantIdTokenLifespan": self.implicit_grant_id_token_lifespan, + "refreshTokenGrantAccessTokenLifespan": self.refresh_token_grant_access_token_lifespan, + "refreshTokenGrantIdTokenLifespan": self.refresh_token_grant_id_token_lifespan, + "refreshTokenGrantRefreshTokenLifespan": self.refresh_token_grant_refresh_token_lifespan, + "clientUri": self.client_uri, + "audience": self.audience, + "grantTypes": self.grant_types, + "responseTypes": self.response_types, + "logoUri": self.logo_uri, + "policyUri": self.policy_uri, + "tosUri": self.tos_uri, + "metadata": self.metadata, + "enableRefreshTokenRotation": self.enable_refresh_token_rotation, + } diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index 32e134e7..63228d37 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -20,7 +20,6 @@ import jwt -from supertokens_python import AppInfo from supertokens_python.asyncio import get_user from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe.openid.recipe import OpenIdRecipe @@ -60,6 +59,7 @@ if TYPE_CHECKING: from supertokens_python.querier import Querier + from supertokens_python import AppInfo def get_updated_redirect_to(app_info: AppInfo, redirect_to: str) -> str: @@ -371,10 +371,12 @@ async def authorization( ) return RedirectResponse( - redirect_to=consent_res.redirect_to, cookies=resp["cookies"] + redirect_to=consent_res.redirect_to, cookies=",".join(resp["cookies"]) ) - return RedirectResponse(redirect_to=redirect_to, cookies=resp["cookies"]) + return RedirectResponse( + redirect_to=redirect_to, cookies=",".join(resp["cookies"]) + ) async def token_exchange( self, diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 53964da2..6632850b 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -259,9 +259,12 @@ def __init__( totp_found = False user_metadata_found = False multi_factor_auth_found = False + oauth2_found = False + openid_found = False + jwt_found = False def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: - nonlocal multitenancy_found, totp_found, user_metadata_found, multi_factor_auth_found + nonlocal multitenancy_found, totp_found, user_metadata_found, multi_factor_auth_found, oauth2_found, openid_found, jwt_found recipe_module = recipe(self.app_info) if recipe_module.get_recipe_id() == "multitenancy": multitenancy_found = True @@ -271,21 +274,46 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: multi_factor_auth_found = True elif recipe_module.get_recipe_id() == "totp": totp_found = True + elif recipe_module.get_recipe_id() == "oauth2provider": + oauth2_found = True + elif recipe_module.get_recipe_id() == "openid": + openid_found = True + elif recipe_module.get_recipe_id() == "jwt": + jwt_found = True return recipe_module self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) + if not jwt_found: + from supertokens_python.recipe.jwt.recipe import JWTRecipe + + self.recipe_modules.append(JWTRecipe.init()(self.app_info)) + + if not openid_found: + from supertokens_python.recipe.openid.recipe import OpenIdRecipe + + self.recipe_modules.append(OpenIdRecipe.init()(self.app_info)) + if not multitenancy_found: from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info)) + if totp_found and not multi_factor_auth_found: raise Exception("Please initialize the MultiFactorAuth recipe to use TOTP.") + if not user_metadata_found: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe self.recipe_modules.append(UserMetadataRecipe.init()(self.app_info)) + if not oauth2_found: + from supertokens_python.recipe.oauth2provider.recipe import ( + OAuth2ProviderRecipe, + ) + + self.recipe_modules.append(OAuth2ProviderRecipe.init()(self.app_info)) + self.telemetry = ( telemetry if telemetry is not None diff --git a/tests/test-server/app.py b/tests/test-server/app.py index bdfd9206..1ab834d8 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -28,6 +28,8 @@ from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe +from supertokens_python.recipe.openid.recipe import OpenIdRecipe from supertokens_python.types import RecipeUserId from test_functions_mapper import ( # pylint: disable=import-error get_func, @@ -55,6 +57,7 @@ session, thirdparty, emailverification, + oauth2provider, ) from supertokens_python.recipe.session import InputErrorHandlers, SessionContainer from supertokens_python.recipe.session.framework.flask import verify_session @@ -222,6 +225,8 @@ def st_reset(): AccountLinkingRecipe.reset() TOTPRecipe.reset() MultiFactorAuthRecipe.reset() + OAuth2ProviderRecipe.reset() + OpenIdRecipe.reset() def init_st(config: Dict[str, Any]): @@ -593,6 +598,22 @@ async def send_sms( ) ) ) + elif recipe_id == "oauth2provider": + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + recipe_list.append( + oauth2provider.init( + override=oauth2provider.InputOverrideConfig( + apis=override_builder_with_logging( + "OAuth2Provider.override.apis", + recipe_config_json.get("override", {}).get("apis"), + ), + functions=override_builder_with_logging( + "OAuth2Provider.override.functions", + recipe_config_json.get("override", {}).get("functions"), + ), + ) + ) + ) interceptor_func = None if config.get("supertokens", {}).get("networkInterceptor") is not None: @@ -822,6 +843,10 @@ def handle_exception(e: Exception): add_multifactorauth_routes(app) +from oauth2provider import add_oauth2provider_routes + +add_oauth2provider_routes(app) + if __name__ == "__main__": default_st_init() port = int(os.environ.get("API_PORT", api_port)) diff --git a/tests/test-server/oauth2provider.py b/tests/test-server/oauth2provider.py new file mode 100644 index 00000000..8f8a18cf --- /dev/null +++ b/tests/test-server/oauth2provider.py @@ -0,0 +1,93 @@ +from flask import Flask, request, jsonify +from supertokens_python.recipe.oauth2provider.interfaces import CreateOAuth2ClientInput +import supertokens_python.recipe.oauth2provider.syncio as OAuth2Provider + + +def add_oauth2provider_routes(app: Flask): + @app.route("/test/oauth2provider/getoauth2clients", methods=["POST"]) # type: ignore + def get_oauth2_clients_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:getOAuth2Clients", request.json) + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + response = OAuth2Provider.get_oauth2_clients( + page_size=data.get("pageSize"), + pagination_token=data.get("paginationToken"), + client_name=data.get("clientName"), + user_context=data.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/createoauth2client", methods=["POST"]) # type: ignore + def create_oauth2_client_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:createOAuth2Client", request.json) + + response = OAuth2Provider.create_oauth2_client( + params=CreateOAuth2ClientInput.from_json(request.json.get("input")), + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/updateoauth2client", methods=["POST"]) # type: ignore + def update_oauth2_client_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:updateOAuth2Client", request.json) + + response = OAuth2Provider.update_oauth2_client( + params=request.json["input"], user_context=request.json.get("userContext") + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/deleteoauth2client", methods=["POST"]) # type: ignore + def delete_oauth2_client_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:deleteOAuth2Client", request.json) + + response = OAuth2Provider.delete_oauth2_client( + client_id=request.json["input"], + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/validateoauth2accesstoken", methods=["POST"]) # type: ignore + def validate_oauth2_access_token_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:validateOAuth2AccessToken", request.json) + + response = OAuth2Provider.validate_oauth2_access_token( + token=request.json["token"], + requirements=request.json["requirements"], + check_database=request.json["checkDatabase"], + user_context=request.json.get("userContext"), + ) + return jsonify(response) + + @app.route("/test/oauth2provider/validateoauth2refreshtoken", methods=["POST"]) # type: ignore + def validate_oauth2_refresh_token_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:validateOAuth2RefreshToken", request.json) + + response = OAuth2Provider.validate_oauth2_refresh_token( + token=request.json["token"], + scopes=request.json["scopes"], + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json()) + + @app.route("/test/oauth2provider/createtokenforclientcredentials", methods=["POST"]) # type: ignore + def create_token_for_client_credentials_api(): # type: ignore + assert request.json is not None + print("OAuth2Provider:createTokenForClientCredentials", request.json) + + response = OAuth2Provider.create_token_for_client_credentials( + client_id=request.json["clientId"], + client_secret=request.json["clientSecret"], + scope=request.json["scope"], + audience=request.json["audience"], + user_context=request.json.get("userContext"), + ) + return jsonify(response.to_json())