diff --git a/supertokens_python/recipe/oauth2provider/api/auth.py b/supertokens_python/recipe/oauth2provider/api/auth.py index 8db3e7eb..dedb466d 100644 --- a/supertokens_python/recipe/oauth2provider/api/auth.py +++ b/supertokens_python/recipe/oauth2provider/api/auth.py @@ -75,19 +75,20 @@ async def auth_get( if isinstance(response, RedirectResponse): if response.cookies: - cookie = SimpleCookie() - cookie.load(response.cookies) - for morsel in cookie.values(): - api_options.response.set_cookie( - key=morsel.key, - value=morsel.value, - domain=morsel.get("domain"), - secure=morsel.get("secure", True), - httponly=morsel.get("httponly", True), - expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore - path=morsel.get("path", "/"), - samesite=morsel.get("samesite", "lax").lower(), - ) + for cookie_string in response.cookies: + cookie = SimpleCookie() + cookie.load(cookie_string) + for morsel in cookie.values(): + api_options.response.set_cookie( + key=morsel.key, + value=morsel.value, + domain=morsel.get("domain"), + secure=morsel.get("secure", True), + httponly=morsel.get("httponly", True), + expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore + path=morsel.get("path", "/"), + samesite=morsel.get("samesite", "lax"), + ) return api_options.response.redirect(response.redirect_to) elif isinstance(response, ErrorOAuth2Response): return send_non_200_response( diff --git a/supertokens_python/recipe/oauth2provider/api/end_session.py b/supertokens_python/recipe/oauth2provider/api/end_session.py index ebda4315..6a2e70bf 100644 --- a/supertokens_python/recipe/oauth2provider/api/end_session.py +++ b/supertokens_python/recipe/oauth2provider/api/end_session.py @@ -81,6 +81,8 @@ async def end_session_common( options: APIOptions, user_context: Dict[str, Any], ) -> Optional[BaseResponse]: + from ..interfaces import RedirectResponse, ErrorOAuth2Response + if api_implementation is None: return None diff --git a/supertokens_python/recipe/oauth2provider/api/login.py b/supertokens_python/recipe/oauth2provider/api/login.py index a4d3e181..3c23da17 100644 --- a/supertokens_python/recipe/oauth2provider/api/login.py +++ b/supertokens_python/recipe/oauth2provider/api/login.py @@ -77,19 +77,20 @@ async def login( if isinstance(response, FrontendRedirectResponse): if response.cookies: - cookie = SimpleCookie() - cookie.load(response.cookies) - for morsel in cookie.values(): - api_options.response.set_cookie( - key=morsel.key, - value=morsel.value, - domain=morsel.get("domain"), - secure=morsel.get("secure", True), - httponly=morsel.get("httponly", True), - expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore - path=morsel.get("path", "/"), - samesite=morsel.get("samesite", "lax").lower(), - ) + for cookie_string in response.cookies: + cookie = SimpleCookie() + cookie.load(cookie_string) + for morsel in cookie.values(): + api_options.response.set_cookie( + key=morsel.key, + value=morsel.value, + domain=morsel.get("domain"), + secure=morsel.get("secure", True), + httponly=morsel.get("httponly", True), + expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore + path=morsel.get("path", "/"), + samesite=morsel.get("samesite", "lax").lower(), + ) return send_200_response( {"frontendRedirectTo": response.frontend_redirect_to}, diff --git a/supertokens_python/recipe/oauth2provider/api/utils.py b/supertokens_python/recipe/oauth2provider/api/utils.py index 80900c03..166ae2f3 100644 --- a/supertokens_python/recipe/oauth2provider/api/utils.py +++ b/supertokens_python/recipe/oauth2provider/api/utils.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from urllib.parse import parse_qs, urlparse import time @@ -38,7 +38,7 @@ async def login_get( login_challenge: str, session: Optional[SessionContainer], should_try_refresh: bool, - cookies: Optional[str], + cookies: Optional[List[str]], is_direct_call: bool, user_context: Dict[str, Any], ) -> Union[RedirectResponse, ErrorOAuth2Response]: @@ -177,7 +177,7 @@ async def login_get( ) -def get_merged_cookies(orig_cookies: str, new_cookies: Optional[str]) -> str: +def get_merged_cookies(orig_cookies: str, new_cookies: Optional[List[str]]) -> str: if not new_cookies: return orig_cookies @@ -190,7 +190,8 @@ def get_merged_cookies(orig_cookies: str, new_cookies: Optional[str]) -> str: # Note: This is a simplified version. In production code you'd want to use a proper # cookie parsing library to handle all cookie attributes correctly if new_cookies: - for cookie in new_cookies.split(","): + for cookie_str in new_cookies: + cookie = cookie_str.split(";")[0].strip() if "=" in cookie: name, value = cookie.split("=", 1) cookie_map[name.strip()] = value @@ -199,13 +200,13 @@ def get_merged_cookies(orig_cookies: str, new_cookies: Optional[str]) -> str: def merge_set_cookie_headers( - set_cookie1: Optional[str] = None, set_cookie2: Optional[str] = None -) -> str: + set_cookie1: Optional[List[str]] = None, set_cookie2: Optional[List[str]] = None +) -> List[str]: if not set_cookie1: - return set_cookie2 or "" - if not set_cookie2 or set_cookie1 == set_cookie2: + return set_cookie2 or [] + if not set_cookie2 or set(set_cookie1) == set(set_cookie2): return set_cookie1 - return f"{set_cookie1}, {set_cookie2}" + return set_cookie1 + set_cookie2 def is_login_internal_redirect(redirect_to: str) -> bool: diff --git a/supertokens_python/recipe/oauth2provider/asyncio/__init__.py b/supertokens_python/recipe/oauth2provider/asyncio/__init__.py index 2e54a714..3cdf1b52 100644 --- a/supertokens_python/recipe/oauth2provider/asyncio/__init__.py +++ b/supertokens_python/recipe/oauth2provider/asyncio/__init__.py @@ -124,7 +124,7 @@ async def validate_oauth2_access_token( async def create_token_for_client_credentials( client_id: str, - client_secret: str, + client_secret: Optional[str] = None, scope: Optional[List[str]] = None, audience: Optional[str] = None, user_context: Optional[Dict[str, Any]] = None, @@ -133,16 +133,21 @@ async def create_token_for_client_credentials( user_context = {} from ..recipe import OAuth2ProviderRecipe + body: Dict[str, Any] = { + "grant_type": "client_credentials", + "client_id": client_id, + } + if client_secret: + body["client_secret"] = client_secret + if scope: + body["scope"] = " ".join(scope) + if audience: + body["audience"] = audience + return ( await OAuth2ProviderRecipe.get_instance().recipe_implementation.token_exchange( authorization_header=None, - body={ - "grant_type": "client_credentials", - "client_id": client_id, - "client_secret": client_secret, - "scope": " ".join(scope) if scope else None, - "audience": audience, - }, + body=body, user_context=user_context, ) ) diff --git a/supertokens_python/recipe/oauth2provider/interfaces.py b/supertokens_python/recipe/oauth2provider/interfaces.py index c903e3bc..e115d595 100644 --- a/supertokens_python/recipe/oauth2provider/interfaces.py +++ b/supertokens_python/recipe/oauth2provider/interfaces.py @@ -45,12 +45,14 @@ def __init__( self.status_code = status_code def to_json(self) -> Dict[str, Any]: - return { + result: Dict[str, Any] = { "status": self.status, "error": self.error, "errorDescription": self.error_description, - "statusCode": self.status_code, } + if self.status_code is not None: + result["statusCode"] = self.status_code + return result @staticmethod def from_json(json: Dict[str, Any]): @@ -225,18 +227,18 @@ def to_json(self) -> Dict[str, Any]: class RedirectResponse: - def __init__(self, redirect_to: str, cookies: Optional[str] = None): + def __init__(self, redirect_to: str, cookies: Optional[List[str]] = None): self.redirect_to = redirect_to self.cookies = cookies class FrontendRedirectResponse: - def __init__(self, frontend_redirect_to: str, cookies: Optional[str] = None): + def __init__(self, frontend_redirect_to: str, cookies: Optional[List[str]] = None): self.frontend_redirect_to = frontend_redirect_to self.cookies = cookies def to_json(self) -> Dict[str, Any]: - result = { + result: Dict[str, Any] = { "frontendRedirectTo": self.frontend_redirect_to, } if self.cookies is not None: diff --git a/supertokens_python/recipe/oauth2provider/oauth2_client.py b/supertokens_python/recipe/oauth2provider/oauth2_client.py index 0f16db50..68456287 100644 --- a/supertokens_python/recipe/oauth2provider/oauth2_client.py +++ b/supertokens_python/recipe/oauth2provider/oauth2_client.py @@ -245,32 +245,53 @@ def from_json(json: Dict[str, Any]) -> "OAuth2Client": ) def to_json(self) -> Dict[str, Any]: - return { + result: Dict[str, Any] = { "clientId": self.client_id, "clientName": self.client_name, "scope": self.scope, "tokenEndpointAuthMethod": self.token_endpoint_auth_method, "createdAt": self.created_at, "updatedAt": self.updated_at, - "clientSecret": self.client_secret, - "redirectUris": self.redirect_uris, - "postLogoutRedirectUris": self.post_logout_redirect_uris, - "authorizationCodeGrantAccessTokenLifespan": self.authorization_code_grant_access_token_lifespan, - "authorizationCodeGrantIdTokenLifespan": self.authorization_code_grant_id_token_lifespan, - "authorizationCodeGrantRefreshTokenLifespan": self.authorization_code_grant_refresh_token_lifespan, - "clientCredentialsGrantAccessTokenLifespan": self.client_credentials_grant_access_token_lifespan, - "implicitGrantAccessTokenLifespan": self.implicit_grant_access_token_lifespan, - "implicitGrantIdTokenLifespan": self.implicit_grant_id_token_lifespan, - "refreshTokenGrantAccessTokenLifespan": self.refresh_token_grant_access_token_lifespan, - "refreshTokenGrantIdTokenLifespan": self.refresh_token_grant_id_token_lifespan, - "refreshTokenGrantRefreshTokenLifespan": self.refresh_token_grant_refresh_token_lifespan, "clientUri": self.client_uri, "audience": self.audience, - "grantTypes": self.grant_types, - "responseTypes": self.response_types, "logoUri": self.logo_uri, "policyUri": self.policy_uri, "tosUri": self.tos_uri, "metadata": self.metadata, "enableRefreshTokenRotation": self.enable_refresh_token_rotation, } + + if self.client_secret is not None: + result["clientSecret"] = self.client_secret + result["redirectUris"] = self.redirect_uris + if self.post_logout_redirect_uris is not None: + result["postLogoutRedirectUris"] = self.post_logout_redirect_uris + result["authorizationCodeGrantAccessTokenLifespan"] = ( + self.authorization_code_grant_access_token_lifespan + ) + result["authorizationCodeGrantIdTokenLifespan"] = ( + self.authorization_code_grant_id_token_lifespan + ) + result["authorizationCodeGrantRefreshTokenLifespan"] = ( + self.authorization_code_grant_refresh_token_lifespan + ) + result["clientCredentialsGrantAccessTokenLifespan"] = ( + self.client_credentials_grant_access_token_lifespan + ) + result["implicitGrantAccessTokenLifespan"] = ( + self.implicit_grant_access_token_lifespan + ) + result["implicitGrantIdTokenLifespan"] = self.implicit_grant_id_token_lifespan + result["refreshTokenGrantAccessTokenLifespan"] = ( + self.refresh_token_grant_access_token_lifespan + ) + result["refreshTokenGrantIdTokenLifespan"] = ( + self.refresh_token_grant_id_token_lifespan + ) + result["refreshTokenGrantRefreshTokenLifespan"] = ( + self.refresh_token_grant_refresh_token_lifespan + ) + result["grantTypes"] = self.grant_types + result["responseTypes"] = self.response_types + + return result diff --git a/supertokens_python/recipe/oauth2provider/recipe_implementation.py b/supertokens_python/recipe/oauth2provider/recipe_implementation.py index d2a82f50..f65f930a 100644 --- a/supertokens_python/recipe/oauth2provider/recipe_implementation.py +++ b/supertokens_python/recipe/oauth2provider/recipe_implementation.py @@ -373,12 +373,10 @@ async def authorization( ) return RedirectResponse( - redirect_to=consent_res.redirect_to, cookies=",".join(resp["cookies"]) + redirect_to=consent_res.redirect_to, cookies=resp["cookies"] ) - return RedirectResponse( - redirect_to=redirect_to, cookies=",".join(resp["cookies"]) - ) + return RedirectResponse(redirect_to=redirect_to, cookies=resp["cookies"]) async def token_exchange( self, diff --git a/tests/test-server/oauth2provider.py b/tests/test-server/oauth2provider.py index 8f8a18cf..bed153a7 100644 --- a/tests/test-server/oauth2provider.py +++ b/tests/test-server/oauth2provider.py @@ -27,7 +27,7 @@ def create_oauth2_client_api(): # type: ignore print("OAuth2Provider:createOAuth2Client", request.json) response = OAuth2Provider.create_oauth2_client( - params=CreateOAuth2ClientInput.from_json(request.json.get("input")), + params=CreateOAuth2ClientInput.from_json(request.json.get("input", {})), user_context=request.json.get("userContext"), ) return jsonify(response.to_json())