Skip to content

Commit

Permalink
Merge pull request #529 from supertokens/auto-transpiling-cyclic-import
Browse files Browse the repository at this point in the history
auto transpile: pr2
  • Loading branch information
rishabhpoddar authored Sep 30, 2024
2 parents e0995f0 + 0cfec6f commit 0c48c1b
Show file tree
Hide file tree
Showing 23 changed files with 206 additions and 103 deletions.
11 changes: 9 additions & 2 deletions supertokens_python/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,19 +912,26 @@ async def filter_out_invalid_second_factors_or_throw_if_all_are_invalid(
factors_set_up_for_user_prom: Optional[List[str]] = None
mfa_info_prom = None

async def get_factors_set_up_for_user():
async def get_factors_set_up_for_user() -> List[str]:
nonlocal factors_set_up_for_user_prom
if factors_set_up_for_user_prom is None:
factors_set_up_for_user_prom = await mfa_instance.recipe_implementation.get_factors_setup_for_user(
user=session_user, user_context=user_context
)
assert factors_set_up_for_user_prom is not None
return factors_set_up_for_user_prom

async def get_mfa_requirements_for_auth():
nonlocal mfa_info_prom
if mfa_info_prom is None:
from .recipe.multifactorauth.multi_factor_auth_claim import (
MultiFactorAuthClaim,
)

mfa_info_prom = await update_and_get_mfa_related_info_in_session(
input_session=session, user_context=user_context
MultiFactorAuthClaim,
input_session=session,
user_context=user_context,
)
return mfa_info_prom.mfa_requirements_for_auth

Expand Down
32 changes: 19 additions & 13 deletions supertokens_python/recipe/multifactorauth/api/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,19 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import importlib

from typing import TYPE_CHECKING, Any, Dict, List, Union
from typing import Any, Dict, List, Union, TYPE_CHECKING

from supertokens_python.recipe.session import SessionContainer
from supertokens_python.recipe.multifactorauth.utils import (
update_and_get_mfa_related_info_in_session,
)
from supertokens_python.recipe.multitenancy.asyncio import get_tenant
from ..multi_factor_auth_claim import MultiFactorAuthClaim
from supertokens_python.asyncio import get_user
from supertokens_python.recipe.session.exceptions import (
InvalidClaimsError,
SuperTokensSessionError,
UnauthorisedError,
)

if TYPE_CHECKING:
from supertokens_python.recipe.multifactorauth.interfaces import (
APIInterface,
APIOptions,
)

from supertokens_python.types import GeneralErrorResponse
from ..interfaces import (
APIInterface,
Expand All @@ -42,6 +33,11 @@
ResyncSessionAndFetchMFAInfoPUTOkResult,
)

if TYPE_CHECKING:
from ..multi_factor_auth_claim import (
MultiFactorAuthClaimClass as MultiFactorAuthClaimType,
)


class APIImplementation(APIInterface):
async def resync_session_and_fetch_mfa_info_put(
Expand All @@ -50,14 +46,24 @@ async def resync_session_and_fetch_mfa_info_put(
session: SessionContainer,
user_context: Dict[str, Any],
) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]:

mfa = importlib.import_module("supertokens_python.recipe.multifactorauth")

MultiFactorAuthClaim: MultiFactorAuthClaimType = mfa.MultiFactorAuthClaim

module = importlib.import_module(
"supertokens_python.recipe.multifactorauth.utils"
)

session_user = await get_user(session.get_user_id(), user_context)

if session_user is None:
raise UnauthorisedError(
"Session user not found",
)

mfa_info = await update_and_get_mfa_related_info_in_session(
mfa_info = await module.update_and_get_mfa_related_info_in_session(
MultiFactorAuthClaim,
input_session=session,
user_context=user_context,
)
Expand Down Expand Up @@ -144,7 +150,7 @@ async def get_mfa_requirements_for_auth():
)
return ResyncSessionAndFetchMFAInfoPUTOkResult(
factors=NextFactors(
next=next_factors,
next_=next_factors,
already_setup=factors_setup_for_user,
allowed_to_setup=factors_allowed_to_setup,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


async def handle_resync_session_and_fetch_mfa_info_api(
tenant_id: str,
_tenant_id: str,
api_implementation: APIInterface,
api_options: APIOptions,
user_context: Dict[str, Any],
Expand Down
16 changes: 12 additions & 4 deletions supertokens_python/recipe/multifactorauth/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ..types import (
MFARequirementList,
)
from ..recipe import MultiFactorAuthRecipe
from ..utils import update_and_get_mfa_related_info_in_session
from supertokens_python.recipe.accountlinking.asyncio import get_user

Expand All @@ -34,13 +33,17 @@ async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error(
if user_context is None:
user_context = {}

from ..multi_factor_auth_claim import MultiFactorAuthClaim

mfa_info = await update_and_get_mfa_related_info_in_session(
MultiFactorAuthClaim,
input_session=session,
user_context=user_context,
)
factors_set_up_for_user = await get_factors_setup_for_user(
session.get_user_id(), user_context
)
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()

Expand All @@ -66,7 +69,10 @@ async def get_mfa_requirements_for_auth(
if user_context is None:
user_context = {}

from ..multi_factor_auth_claim import MultiFactorAuthClaim

mfa_info = await update_and_get_mfa_related_info_in_session(
MultiFactorAuthClaim,
input_session=session,
user_context=user_context,
)
Expand All @@ -81,6 +87,7 @@ async def mark_factor_as_complete_in_session(
) -> None:
if user_context is None:
user_context = {}
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()
await recipe.recipe_implementation.mark_factor_as_complete_in_session(
Expand All @@ -100,6 +107,7 @@ async def get_factors_setup_for_user(
user = await get_user(user_id, user_context)
if user is None:
raise Exception("Unknown user id")
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()
return await recipe.recipe_implementation.get_factors_setup_for_user(
Expand All @@ -114,6 +122,7 @@ async def get_required_secondary_factors_for_user(
) -> List[str]:
if user_context is None:
user_context = {}
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()
return await recipe.recipe_implementation.get_required_secondary_factors_for_user(
Expand All @@ -129,6 +138,7 @@ async def add_to_required_secondary_factors_for_user(
) -> None:
if user_context is None:
user_context = {}
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()
await recipe.recipe_implementation.add_to_required_secondary_factors_for_user(
Expand All @@ -145,13 +155,11 @@ async def remove_from_required_secondary_factors_for_user(
) -> None:
if user_context is None:
user_context = {}
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()
await recipe.recipe_implementation.remove_from_required_secondary_factors_for_user(
user_id=user_id,
factor_id=factor_id,
user_context=user_context,
)


init = MultiFactorAuthRecipe.init
6 changes: 3 additions & 3 deletions supertokens_python/recipe/multifactorauth/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ async def resync_session_and_fetch_mfa_info_put(

class NextFactors:
def __init__(
self, next: List[str], already_setup: List[str], allowed_to_setup: List[str]
self, next_: List[str], already_setup: List[str], allowed_to_setup: List[str]
):
self.next = next
self.next_ = next_
self.already_setup = already_setup
self.allowed_to_setup = allowed_to_setup

def to_json(self) -> Dict[str, Any]:
return {
"next": self.next,
"next": self.next_,
"alreadySetup": self.already_setup,
"allowedToSetup": self.allowed_to_setup,
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from __future__ import annotations
import importlib

from typing import Any, Dict, Optional, Set

Expand All @@ -15,7 +30,6 @@
MFAClaimValue,
MFARequirementList,
)
from .utils import update_and_get_mfa_related_info_in_session


class HasCompletedRequirementListSCV(SessionClaimValidator):
Expand All @@ -29,14 +43,10 @@ def __init__(
self.claim: MultiFactorAuthClaimClass = claim
self.requirement_list = requirement_list

async def should_refetch(
def should_refetch(
self, payload: Dict[str, Any], user_context: Dict[str, Any]
) -> bool:
return (
True
if self.claim.key not in payload or not payload[self.claim.key]
else False
)
return bool(self.claim.key not in payload or not payload[self.claim.key])

async def validate(
self, payload: JSONObject, user_context: Dict[str, Any]
Expand Down Expand Up @@ -65,7 +75,7 @@ async def validate(

factor_ids = next_set_of_unsatisfied_factors.factor_ids

if next_set_of_unsatisfied_factors.type == "string":
if next_set_of_unsatisfied_factors.type_ == "string":
return ClaimValidationResult(
is_valid=False,
reason={
Expand All @@ -74,7 +84,7 @@ async def validate(
},
)

elif next_set_of_unsatisfied_factors.type == "oneOf":
elif next_set_of_unsatisfied_factors.type_ == "oneOf":
return ClaimValidationResult(
is_valid=False,
reason={
Expand All @@ -101,15 +111,11 @@ def __init__(
super().__init__(id_)
self.claim = claim

async def should_refetch(
def should_refetch(
self, payload: Dict[str, Any], user_context: Dict[str, Any]
) -> bool:
assert self.claim is not None
return (
True
if self.claim.key not in payload or not payload[self.claim.key]
else False
)
return bool(self.claim.key not in payload or not payload[self.claim.key])

async def validate(
self, payload: JSONObject, user_context: Dict[str, Any]
Expand Down Expand Up @@ -161,13 +167,18 @@ def __init__(self, key: Optional[str] = None):
key = key or "st-mfa"

async def fetch_value(
user_id: str,
_user_id: str,
recipe_user_id: RecipeUserId,
tenant_id: str,
current_payload: Dict[str, Any],
user_context: Dict[str, Any],
) -> MFAClaimValue:
mfa_info = await update_and_get_mfa_related_info_in_session(
module = importlib.import_module(
"supertokens_python.recipe.multifactorauth.utils"
)

mfa_info = await module.update_and_get_mfa_related_info_in_session(
self,
input_session_recipe_user_id=recipe_user_id,
input_tenant_id=tenant_id,
input_access_token_payload=current_payload,
Expand Down Expand Up @@ -209,9 +220,11 @@ def get_next_set_of_unsatisfied_factors(
)

if len(next_factors) > 0:
return FactorIdsAndType(factor_ids=list(next_factors), type=factor_type)
return FactorIdsAndType(
factor_ids=list(next_factors), type_=factor_type
)

return FactorIdsAndType(factor_ids=[], type="string")
return FactorIdsAndType(factor_ids=[], type_="string")

def add_to_payload_(
self,
Expand Down
15 changes: 10 additions & 5 deletions supertokens_python/recipe/multifactorauth/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import importlib

from os import environ
from typing import Any, Dict, Optional, List, Union
Expand All @@ -31,7 +32,6 @@
MultiFactorAuthClaim,
)
from supertokens_python.recipe.multitenancy.interfaces import TenantConfig
from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe
from supertokens_python.recipe.session.recipe import SessionRecipe
from supertokens_python.recipe_module import APIHandled, RecipeModule
from supertokens_python.supertokens import AppInfo
Expand All @@ -47,9 +47,6 @@
GetEmailsForFactorOkResult,
GetPhoneNumbersForFactorsOkResult,
)
from .utils import validate_and_normalise_user_input
from .recipe_implementation import RecipeImplementation
from .api.implementation import APIImplementation
from .interfaces import APIOptions


Expand Down Expand Up @@ -79,10 +76,15 @@ def __init__(
] = []
self.is_get_mfa_requirements_for_auth_overridden: bool = False

self.config = validate_and_normalise_user_input(
module = importlib.import_module(
"supertokens_python.recipe.multifactorauth.utils"
)

self.config = module.validate_and_normalise_user_input(
first_factors,
override,
)
from .recipe_implementation import RecipeImplementation

recipe_implementation = RecipeImplementation(
Querier.get_instance(recipe_id), self
Expand All @@ -92,6 +94,7 @@ def __init__(
if self.config.override.functions is None
else self.config.override.functions(recipe_implementation)
)
from .api.implementation import APIImplementation

api_implementation = APIImplementation()
self.api_implementation = (
Expand All @@ -101,6 +104,8 @@ def __init__(
)

def callback():
from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe

mt_recipe = MultitenancyRecipe.get_instance()
mt_recipe.static_first_factors = self.config.first_factors

Expand Down
Loading

0 comments on commit 0c48c1b

Please sign in to comment.