Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sattvikc committed Dec 12, 2024
1 parent dae1204 commit 213b9fe
Showing 9 changed files with 99 additions and 68 deletions.
27 changes: 14 additions & 13 deletions supertokens_python/recipe/oauth2provider/api/auth.py
Original file line number Diff line number Diff line change
@@ -75,19 +75,20 @@ async def auth_get(

if isinstance(response, RedirectResponse):
if response.cookies:
cookie = SimpleCookie()
cookie.load(response.cookies)
for morsel in cookie.values():
api_options.response.set_cookie(
key=morsel.key,
value=morsel.value,
domain=morsel.get("domain"),
secure=morsel.get("secure", True),
httponly=morsel.get("httponly", True),
expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore
path=morsel.get("path", "/"),
samesite=morsel.get("samesite", "lax").lower(),
)
for cookie_string in response.cookies:
cookie = SimpleCookie()
cookie.load(cookie_string)
for morsel in cookie.values():
api_options.response.set_cookie(
key=morsel.key,
value=morsel.value,
domain=morsel.get("domain"),
secure=morsel.get("secure", True),
httponly=morsel.get("httponly", True),
expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore
path=morsel.get("path", "/"),
samesite=morsel.get("samesite", "lax"),
)
return api_options.response.redirect(response.redirect_to)
elif isinstance(response, ErrorOAuth2Response):
return send_non_200_response(
2 changes: 2 additions & 0 deletions supertokens_python/recipe/oauth2provider/api/end_session.py
Original file line number Diff line number Diff line change
@@ -81,6 +81,8 @@ async def end_session_common(
options: APIOptions,
user_context: Dict[str, Any],
) -> Optional[BaseResponse]:
from ..interfaces import RedirectResponse, ErrorOAuth2Response

if api_implementation is None:
return None

27 changes: 14 additions & 13 deletions supertokens_python/recipe/oauth2provider/api/login.py
Original file line number Diff line number Diff line change
@@ -77,19 +77,20 @@ async def login(

if isinstance(response, FrontendRedirectResponse):
if response.cookies:
cookie = SimpleCookie()
cookie.load(response.cookies)
for morsel in cookie.values():
api_options.response.set_cookie(
key=morsel.key,
value=morsel.value,
domain=morsel.get("domain"),
secure=morsel.get("secure", True),
httponly=morsel.get("httponly", True),
expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore
path=morsel.get("path", "/"),
samesite=morsel.get("samesite", "lax").lower(),
)
for cookie_string in response.cookies:
cookie = SimpleCookie()
cookie.load(cookie_string)
for morsel in cookie.values():
api_options.response.set_cookie(
key=morsel.key,
value=morsel.value,
domain=morsel.get("domain"),
secure=morsel.get("secure", True),
httponly=morsel.get("httponly", True),
expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore
path=morsel.get("path", "/"),
samesite=morsel.get("samesite", "lax").lower(),
)

return send_200_response(
{"frontendRedirectTo": response.frontend_redirect_to},
19 changes: 10 additions & 9 deletions supertokens_python/recipe/oauth2provider/api/utils.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from urllib.parse import parse_qs, urlparse
import time

@@ -38,7 +38,7 @@ async def login_get(
login_challenge: str,
session: Optional[SessionContainer],
should_try_refresh: bool,
cookies: Optional[str],
cookies: Optional[List[str]],
is_direct_call: bool,
user_context: Dict[str, Any],
) -> Union[RedirectResponse, ErrorOAuth2Response]:
@@ -177,7 +177,7 @@ async def login_get(
)


def get_merged_cookies(orig_cookies: str, new_cookies: Optional[str]) -> str:
def get_merged_cookies(orig_cookies: str, new_cookies: Optional[List[str]]) -> str:
if not new_cookies:
return orig_cookies

@@ -190,7 +190,8 @@ def get_merged_cookies(orig_cookies: str, new_cookies: Optional[str]) -> str:
# Note: This is a simplified version. In production code you'd want to use a proper
# cookie parsing library to handle all cookie attributes correctly
if new_cookies:
for cookie in new_cookies.split(","):
for cookie_str in new_cookies:
cookie = cookie_str.split(";")[0].strip()
if "=" in cookie:
name, value = cookie.split("=", 1)
cookie_map[name.strip()] = value
@@ -199,13 +200,13 @@ def get_merged_cookies(orig_cookies: str, new_cookies: Optional[str]) -> str:


def merge_set_cookie_headers(
set_cookie1: Optional[str] = None, set_cookie2: Optional[str] = None
) -> str:
set_cookie1: Optional[List[str]] = None, set_cookie2: Optional[List[str]] = None
) -> List[str]:
if not set_cookie1:
return set_cookie2 or ""
if not set_cookie2 or set_cookie1 == set_cookie2:
return set_cookie2 or []
if not set_cookie2 or set(set_cookie1) == set(set_cookie2):
return set_cookie1
return f"{set_cookie1}, {set_cookie2}"
return set_cookie1 + set_cookie2


def is_login_internal_redirect(redirect_to: str) -> bool:
21 changes: 13 additions & 8 deletions supertokens_python/recipe/oauth2provider/asyncio/__init__.py
Original file line number Diff line number Diff line change
@@ -124,7 +124,7 @@ async def validate_oauth2_access_token(

async def create_token_for_client_credentials(
client_id: str,
client_secret: str,
client_secret: Optional[str] = None,
scope: Optional[List[str]] = None,
audience: Optional[str] = None,
user_context: Optional[Dict[str, Any]] = None,
@@ -133,16 +133,21 @@ async def create_token_for_client_credentials(
user_context = {}
from ..recipe import OAuth2ProviderRecipe

body: Dict[str, Any] = {
"grant_type": "client_credentials",
"client_id": client_id,
}
if client_secret:
body["client_secret"] = client_secret
if scope:
body["scope"] = " ".join(scope)
if audience:
body["audience"] = audience

return (
await OAuth2ProviderRecipe.get_instance().recipe_implementation.token_exchange(
authorization_header=None,
body={
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
"scope": " ".join(scope) if scope else None,
"audience": audience,
},
body=body,
user_context=user_context,
)
)
12 changes: 7 additions & 5 deletions supertokens_python/recipe/oauth2provider/interfaces.py
Original file line number Diff line number Diff line change
@@ -45,12 +45,14 @@ def __init__(
self.status_code = status_code

def to_json(self) -> Dict[str, Any]:
return {
result: Dict[str, Any] = {
"status": self.status,
"error": self.error,
"errorDescription": self.error_description,
"statusCode": self.status_code,
}
if self.status_code is not None:
result["statusCode"] = self.status_code
return result

@staticmethod
def from_json(json: Dict[str, Any]):
@@ -225,18 +227,18 @@ def to_json(self) -> Dict[str, Any]:


class RedirectResponse:
def __init__(self, redirect_to: str, cookies: Optional[str] = None):
def __init__(self, redirect_to: str, cookies: Optional[List[str]] = None):
self.redirect_to = redirect_to
self.cookies = cookies


class FrontendRedirectResponse:
def __init__(self, frontend_redirect_to: str, cookies: Optional[str] = None):
def __init__(self, frontend_redirect_to: str, cookies: Optional[List[str]] = None):
self.frontend_redirect_to = frontend_redirect_to
self.cookies = cookies

def to_json(self) -> Dict[str, Any]:
result = {
result: Dict[str, Any] = {
"frontendRedirectTo": self.frontend_redirect_to,
}
if self.cookies is not None:
51 changes: 36 additions & 15 deletions supertokens_python/recipe/oauth2provider/oauth2_client.py
Original file line number Diff line number Diff line change
@@ -245,32 +245,53 @@ def from_json(json: Dict[str, Any]) -> "OAuth2Client":
)

def to_json(self) -> Dict[str, Any]:
return {
result: Dict[str, Any] = {
"clientId": self.client_id,
"clientName": self.client_name,
"scope": self.scope,
"tokenEndpointAuthMethod": self.token_endpoint_auth_method,
"createdAt": self.created_at,
"updatedAt": self.updated_at,
"clientSecret": self.client_secret,
"redirectUris": self.redirect_uris,
"postLogoutRedirectUris": self.post_logout_redirect_uris,
"authorizationCodeGrantAccessTokenLifespan": self.authorization_code_grant_access_token_lifespan,
"authorizationCodeGrantIdTokenLifespan": self.authorization_code_grant_id_token_lifespan,
"authorizationCodeGrantRefreshTokenLifespan": self.authorization_code_grant_refresh_token_lifespan,
"clientCredentialsGrantAccessTokenLifespan": self.client_credentials_grant_access_token_lifespan,
"implicitGrantAccessTokenLifespan": self.implicit_grant_access_token_lifespan,
"implicitGrantIdTokenLifespan": self.implicit_grant_id_token_lifespan,
"refreshTokenGrantAccessTokenLifespan": self.refresh_token_grant_access_token_lifespan,
"refreshTokenGrantIdTokenLifespan": self.refresh_token_grant_id_token_lifespan,
"refreshTokenGrantRefreshTokenLifespan": self.refresh_token_grant_refresh_token_lifespan,
"clientUri": self.client_uri,
"audience": self.audience,
"grantTypes": self.grant_types,
"responseTypes": self.response_types,
"logoUri": self.logo_uri,
"policyUri": self.policy_uri,
"tosUri": self.tos_uri,
"metadata": self.metadata,
"enableRefreshTokenRotation": self.enable_refresh_token_rotation,
}

if self.client_secret is not None:
result["clientSecret"] = self.client_secret
result["redirectUris"] = self.redirect_uris
if self.post_logout_redirect_uris is not None:
result["postLogoutRedirectUris"] = self.post_logout_redirect_uris
result["authorizationCodeGrantAccessTokenLifespan"] = (
self.authorization_code_grant_access_token_lifespan
)
result["authorizationCodeGrantIdTokenLifespan"] = (
self.authorization_code_grant_id_token_lifespan
)
result["authorizationCodeGrantRefreshTokenLifespan"] = (
self.authorization_code_grant_refresh_token_lifespan
)
result["clientCredentialsGrantAccessTokenLifespan"] = (
self.client_credentials_grant_access_token_lifespan
)
result["implicitGrantAccessTokenLifespan"] = (
self.implicit_grant_access_token_lifespan
)
result["implicitGrantIdTokenLifespan"] = self.implicit_grant_id_token_lifespan
result["refreshTokenGrantAccessTokenLifespan"] = (
self.refresh_token_grant_access_token_lifespan
)
result["refreshTokenGrantIdTokenLifespan"] = (
self.refresh_token_grant_id_token_lifespan
)
result["refreshTokenGrantRefreshTokenLifespan"] = (
self.refresh_token_grant_refresh_token_lifespan
)
result["grantTypes"] = self.grant_types
result["responseTypes"] = self.response_types

return result
Original file line number Diff line number Diff line change
@@ -373,12 +373,10 @@ async def authorization(
)

return RedirectResponse(
redirect_to=consent_res.redirect_to, cookies=",".join(resp["cookies"])
redirect_to=consent_res.redirect_to, cookies=resp["cookies"]
)

return RedirectResponse(
redirect_to=redirect_to, cookies=",".join(resp["cookies"])
)
return RedirectResponse(redirect_to=redirect_to, cookies=resp["cookies"])

async def token_exchange(
self,
2 changes: 1 addition & 1 deletion tests/test-server/oauth2provider.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ def create_oauth2_client_api(): # type: ignore
print("OAuth2Provider:createOAuth2Client", request.json)

response = OAuth2Provider.create_oauth2_client(
params=CreateOAuth2ClientInput.from_json(request.json.get("input")),
params=CreateOAuth2ClientInput.from_json(request.json.get("input", {})),
user_context=request.json.get("userContext"),
)
return jsonify(response.to_json())

0 comments on commit 213b9fe

Please sign in to comment.