Skip to content

Commit

Permalink
fix: default recipes and fixes for test
Browse files Browse the repository at this point in the history
  • Loading branch information
sattvikc committed Dec 12, 2024
1 parent 34e96da commit d8dd684
Show file tree
Hide file tree
Showing 14 changed files with 223 additions and 12 deletions.
3 changes: 2 additions & 1 deletion supertokens_python/framework/django/django_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def set_json_content(self, content: Dict[str, Any]):
).encode("utf-8")
self.response_sent = True

def redirect(self, url: str):
def redirect(self, url: str) -> BaseResponse:
if not self.response_sent:
self.set_header("Location", url)
self.set_status_code(302)
return self
3 changes: 2 additions & 1 deletion supertokens_python/framework/fastapi/fastapi_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def set_json_content(self, content: Dict[str, Any]):
self.response.body = body
self.response_sent = True

def redirect(self, url: str):
def redirect(self, url: str) -> BaseResponse:
if not self.response_sent:
self.set_header("Location", url)
self.set_status_code(302)
return self
3 changes: 2 additions & 1 deletion supertokens_python/framework/flask/flask_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def set_json_content(self, content: Dict[str, Any]):
).encode("utf-8")
self.response_sent = True

def redirect(self, url: str):
def redirect(self, url: str) -> BaseResponse:
self.response.headers.set("Location", url)
self.set_status_code(302)
return self
2 changes: 1 addition & 1 deletion supertokens_python/framework/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ def set_html_content(self, content: str):
pass

@abstractmethod
def redirect(self, url: str):
def redirect(self, url: str) -> "BaseResponse":
pass
4 changes: 2 additions & 2 deletions supertokens_python/recipe/oauth2provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
from typing import TYPE_CHECKING, Callable, Union

from . import exceptions as ex
from . import recipe
from . import recipe, utils

exceptions = ex
InputOverrideConfig = utils.InputOverrideConfig

if TYPE_CHECKING:
from supertokens_python.supertokens import AppInfo

from ...recipe_module import RecipeModule
from .utils import InputOverrideConfig


def init(
Expand Down
3 changes: 2 additions & 1 deletion supertokens_python/recipe/oauth2provider/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from datetime import datetime
from http.cookies import SimpleCookie
from typing import TYPE_CHECKING, Any, Dict
from urllib.parse import parse_qsl
Expand Down Expand Up @@ -83,7 +84,7 @@ async def auth_get(
domain=morsel.get("domain"),
secure=morsel.get("secure", True),
httponly=morsel.get("httponly", True),
expires=morsel.get("expires", None),
expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore
path=morsel.get("path", "/"),
samesite=morsel.get("samesite", "lax"),
)
Expand Down
4 changes: 3 additions & 1 deletion supertokens_python/recipe/oauth2provider/api/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from typing import TYPE_CHECKING, Any, Dict, Optional

from datetime import datetime

from supertokens_python.exceptions import raise_bad_input_exception
from supertokens_python.framework import BaseResponse
from supertokens_python.recipe.session.asyncio import get_session
Expand Down Expand Up @@ -84,7 +86,7 @@ async def login(
domain=morsel.get("domain"),
secure=morsel.get("secure", True),
httponly=morsel.get("httponly", True),
expires=morsel.get("expires", None),
expires=datetime.strptime(morsel.get("expires", ""), "%a, %d %b %Y %H:%M:%S %Z").timestamp() * 1000, # type: ignore
path=morsel.get("path", "/"),
samesite=morsel.get("samesite", "lax"),
)
Expand Down
2 changes: 2 additions & 0 deletions supertokens_python/recipe/oauth2provider/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ async def handle_login_internal_redirects(
cookie: str,
user_context: Dict[str, Any],
) -> Union[RedirectResponse, ErrorOAuth2Response]:
from ..interfaces import RedirectResponse, ErrorOAuth2Response

if not is_login_internal_redirect(response.redirect_to):
return response

Expand Down
24 changes: 24 additions & 0 deletions supertokens_python/recipe/oauth2provider/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ def from_json(json: Dict[str, Any]):
next_pagination_token=json["nextPaginationToken"],
)

def to_json(self) -> Dict[str, Any]:
return {
"status": "OK",
"clients": [client.to_json() for client in self.clients],
"nextPaginationToken": self.next_pagination_token,
}


class GetOAuth2ClientOkResult:
def __init__(self, client: OAuth2Client):
Expand All @@ -272,6 +279,12 @@ def __init__(self, client: OAuth2Client):
def from_json(json: Dict[str, Any]):
return CreateOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"]))

def to_json(self) -> Dict[str, Any]:
return {
"status": "OK",
"client": self.client.to_json(),
}


class UpdateOAuth2ClientOkResult:
def __init__(self, client: OAuth2Client):
Expand All @@ -281,11 +294,22 @@ def __init__(self, client: OAuth2Client):
def from_json(json: Dict[str, Any]):
return UpdateOAuth2ClientOkResult(client=OAuth2Client.from_json(json["client"]))

def to_json(self) -> Dict[str, Any]:
return {
"status": "OK",
"client": self.client.to_json(),
}


class DeleteOAuth2ClientOkResult:
def __init__(self):
pass

def to_json(self) -> Dict[str, Any]:
return {
"status": "OK",
}


PayloadBuilderFunction = Callable[
[User, List[str], str, Dict[str, Any]], Awaitable[Dict[str, Any]]
Expand Down
31 changes: 31 additions & 0 deletions supertokens_python/recipe/oauth2provider/oauth2_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,34 @@ def from_json(json: Dict[str, Any]) -> "OAuth2Client":
metadata=json.get("metadata", {}),
enable_refresh_token_rotation=json.get("enableRefreshTokenRotation", False),
)

def to_json(self) -> Dict[str, Any]:
return {
"clientId": self.client_id,
"clientName": self.client_name,
"scope": self.scope,
"tokenEndpointAuthMethod": self.token_endpoint_auth_method,
"createdAt": self.created_at,
"updatedAt": self.updated_at,
"clientSecret": self.client_secret,
"redirectUris": self.redirect_uris,
"postLogoutRedirectUris": self.post_logout_redirect_uris,
"authorizationCodeGrantAccessTokenLifespan": self.authorization_code_grant_access_token_lifespan,
"authorizationCodeGrantIdTokenLifespan": self.authorization_code_grant_id_token_lifespan,
"authorizationCodeGrantRefreshTokenLifespan": self.authorization_code_grant_refresh_token_lifespan,
"clientCredentialsGrantAccessTokenLifespan": self.client_credentials_grant_access_token_lifespan,
"implicitGrantAccessTokenLifespan": self.implicit_grant_access_token_lifespan,
"implicitGrantIdTokenLifespan": self.implicit_grant_id_token_lifespan,
"refreshTokenGrantAccessTokenLifespan": self.refresh_token_grant_access_token_lifespan,
"refreshTokenGrantIdTokenLifespan": self.refresh_token_grant_id_token_lifespan,
"refreshTokenGrantRefreshTokenLifespan": self.refresh_token_grant_refresh_token_lifespan,
"clientUri": self.client_uri,
"audience": self.audience,
"grantTypes": self.grant_types,
"responseTypes": self.response_types,
"logoUri": self.logo_uri,
"policyUri": self.policy_uri,
"tosUri": self.tos_uri,
"metadata": self.metadata,
"enableRefreshTokenRotation": self.enable_refresh_token_rotation,
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import jwt

from supertokens_python import AppInfo
from supertokens_python.asyncio import get_user
from supertokens_python.normalised_url_path import NormalisedURLPath
from supertokens_python.recipe.openid.recipe import OpenIdRecipe
Expand Down Expand Up @@ -60,6 +59,7 @@

if TYPE_CHECKING:
from supertokens_python.querier import Querier
from supertokens_python import AppInfo


def get_updated_redirect_to(app_info: AppInfo, redirect_to: str) -> str:
Expand Down Expand Up @@ -371,10 +371,12 @@ async def authorization(
)

return RedirectResponse(
redirect_to=consent_res.redirect_to, cookies=resp["cookies"]
redirect_to=consent_res.redirect_to, cookies=",".join(resp["cookies"])
)

return RedirectResponse(redirect_to=redirect_to, cookies=resp["cookies"])
return RedirectResponse(
redirect_to=redirect_to, cookies=",".join(resp["cookies"])
)

async def token_exchange(
self,
Expand Down
30 changes: 29 additions & 1 deletion supertokens_python/supertokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,12 @@ def __init__(
totp_found = False
user_metadata_found = False
multi_factor_auth_found = False
oauth2_found = False
openid_found = False
jwt_found = False

def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule:
nonlocal multitenancy_found, totp_found, user_metadata_found, multi_factor_auth_found
nonlocal multitenancy_found, totp_found, user_metadata_found, multi_factor_auth_found, oauth2_found, openid_found, jwt_found
recipe_module = recipe(self.app_info)
if recipe_module.get_recipe_id() == "multitenancy":
multitenancy_found = True
Expand All @@ -271,21 +274,46 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule:
multi_factor_auth_found = True
elif recipe_module.get_recipe_id() == "totp":
totp_found = True
elif recipe_module.get_recipe_id() == "oauth2provider":
oauth2_found = True
elif recipe_module.get_recipe_id() == "openid":
openid_found = True
elif recipe_module.get_recipe_id() == "jwt":
jwt_found = True
return recipe_module

self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list))

if not jwt_found:
from supertokens_python.recipe.jwt.recipe import JWTRecipe

self.recipe_modules.append(JWTRecipe.init()(self.app_info))

if not openid_found:
from supertokens_python.recipe.openid.recipe import OpenIdRecipe

self.recipe_modules.append(OpenIdRecipe.init()(self.app_info))

if not multitenancy_found:
from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe

self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info))

if totp_found and not multi_factor_auth_found:
raise Exception("Please initialize the MultiFactorAuth recipe to use TOTP.")

if not user_metadata_found:
from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe

self.recipe_modules.append(UserMetadataRecipe.init()(self.app_info))

if not oauth2_found:
from supertokens_python.recipe.oauth2provider.recipe import (
OAuth2ProviderRecipe,
)

self.recipe_modules.append(OAuth2ProviderRecipe.init()(self.app_info))

self.telemetry = (
telemetry
if telemetry is not None
Expand Down
25 changes: 25 additions & 0 deletions tests/test-server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe
from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe
from supertokens_python.recipe.userroles.recipe import UserRolesRecipe
from supertokens_python.recipe.oauth2provider.recipe import OAuth2ProviderRecipe
from supertokens_python.recipe.openid.recipe import OpenIdRecipe
from supertokens_python.types import RecipeUserId
from test_functions_mapper import ( # pylint: disable=import-error
get_func,
Expand Down Expand Up @@ -55,6 +57,7 @@
session,
thirdparty,
emailverification,
oauth2provider,
)
from supertokens_python.recipe.session import InputErrorHandlers, SessionContainer
from supertokens_python.recipe.session.framework.flask import verify_session
Expand Down Expand Up @@ -222,6 +225,8 @@ def st_reset():
AccountLinkingRecipe.reset()
TOTPRecipe.reset()
MultiFactorAuthRecipe.reset()
OAuth2ProviderRecipe.reset()
OpenIdRecipe.reset()


def init_st(config: Dict[str, Any]):
Expand Down Expand Up @@ -593,6 +598,22 @@ async def send_sms(
)
)
)
elif recipe_id == "oauth2provider":
recipe_config_json = json.loads(recipe_config.get("config", "{}"))
recipe_list.append(
oauth2provider.init(
override=oauth2provider.InputOverrideConfig(
apis=override_builder_with_logging(
"OAuth2Provider.override.apis",
recipe_config_json.get("override", {}).get("apis"),
),
functions=override_builder_with_logging(
"OAuth2Provider.override.functions",
recipe_config_json.get("override", {}).get("functions"),
),
)
)
)

interceptor_func = None
if config.get("supertokens", {}).get("networkInterceptor") is not None:
Expand Down Expand Up @@ -822,6 +843,10 @@ def handle_exception(e: Exception):

add_multifactorauth_routes(app)

from oauth2provider import add_oauth2provider_routes

add_oauth2provider_routes(app)

if __name__ == "__main__":
default_st_init()
port = int(os.environ.get("API_PORT", api_port))
Expand Down
Loading

0 comments on commit d8dd684

Please sign in to comment.