Skip to content

Commit

Permalink
cyclic import issue
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhpoddar committed Sep 30, 2024
1 parent e0995f0 commit eae7e2d
Show file tree
Hide file tree
Showing 21 changed files with 171 additions and 89 deletions.
8 changes: 7 additions & 1 deletion supertokens_python/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,8 +923,14 @@ async def get_factors_set_up_for_user():
async def get_mfa_requirements_for_auth():
nonlocal mfa_info_prom
if mfa_info_prom is None:
from .recipe.multifactorauth.multi_factor_auth_claim import (
MultiFactorAuthClaim,
)

mfa_info_prom = await update_and_get_mfa_related_info_in_session(
input_session=session, user_context=user_context
MultiFactorAuthClaim,
input_session=session,
user_context=user_context,
)
return mfa_info_prom.mfa_requirements_for_auth

Expand Down
23 changes: 14 additions & 9 deletions supertokens_python/recipe/multifactorauth/api/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,22 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import importlib

from typing import TYPE_CHECKING, Any, Dict, List, Union
from typing import Any, Dict, List, Union, TYPE_CHECKING

from supertokens_python.recipe.session import SessionContainer
from supertokens_python.recipe.multifactorauth.utils import (
update_and_get_mfa_related_info_in_session,
)
from supertokens_python.recipe.multitenancy.asyncio import get_tenant
from ..multi_factor_auth_claim import MultiFactorAuthClaim
from supertokens_python.asyncio import get_user
from supertokens_python.recipe.session.exceptions import (
InvalidClaimsError,
SuperTokensSessionError,
UnauthorisedError,
)

if TYPE_CHECKING:
from supertokens_python.recipe.multifactorauth.interfaces import (
APIInterface,
APIOptions,
)

from supertokens_python.types import GeneralErrorResponse
from ..interfaces import (
APIInterface,
Expand All @@ -42,6 +36,11 @@
ResyncSessionAndFetchMFAInfoPUTOkResult,
)

if TYPE_CHECKING:
from ..multi_factor_auth_claim import (
MultiFactorAuthClaimClass as MultiFactorAuthClaimType,
)


class APIImplementation(APIInterface):
async def resync_session_and_fetch_mfa_info_put(
Expand All @@ -50,6 +49,11 @@ async def resync_session_and_fetch_mfa_info_put(
session: SessionContainer,
user_context: Dict[str, Any],
) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]:

mfa = importlib.import_module("supertokens_python.recipe.multifactorauth")

MultiFactorAuthClaim: MultiFactorAuthClaimType = mfa.MultiFactorAuthClaim

session_user = await get_user(session.get_user_id(), user_context)

if session_user is None:
Expand All @@ -58,6 +62,7 @@ async def resync_session_and_fetch_mfa_info_put(
)

mfa_info = await update_and_get_mfa_related_info_in_session(
MultiFactorAuthClaim,
input_session=session,
user_context=user_context,
)
Expand Down Expand Up @@ -144,7 +149,7 @@ async def get_mfa_requirements_for_auth():
)
return ResyncSessionAndFetchMFAInfoPUTOkResult(
factors=NextFactors(
next=next_factors,
next_=next_factors,
already_setup=factors_setup_for_user,
allowed_to_setup=factors_allowed_to_setup,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


async def handle_resync_session_and_fetch_mfa_info_api(
tenant_id: str,
_tenant_id: str,
api_implementation: APIInterface,
api_options: APIOptions,
user_context: Dict[str, Any],
Expand Down
16 changes: 12 additions & 4 deletions supertokens_python/recipe/multifactorauth/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ..types import (
MFARequirementList,
)
from ..recipe import MultiFactorAuthRecipe
from ..utils import update_and_get_mfa_related_info_in_session
from supertokens_python.recipe.accountlinking.asyncio import get_user

Expand All @@ -34,13 +33,17 @@ async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error(
if user_context is None:
user_context = {}

from ..multi_factor_auth_claim import MultiFactorAuthClaim

mfa_info = await update_and_get_mfa_related_info_in_session(
MultiFactorAuthClaim,
input_session=session,
user_context=user_context,
)
factors_set_up_for_user = await get_factors_setup_for_user(
session.get_user_id(), user_context
)
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()

Expand All @@ -66,7 +69,10 @@ async def get_mfa_requirements_for_auth(
if user_context is None:
user_context = {}

from ..multi_factor_auth_claim import MultiFactorAuthClaim

mfa_info = await update_and_get_mfa_related_info_in_session(
MultiFactorAuthClaim,
input_session=session,
user_context=user_context,
)
Expand All @@ -81,6 +87,7 @@ async def mark_factor_as_complete_in_session(
) -> None:
if user_context is None:
user_context = {}
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()
await recipe.recipe_implementation.mark_factor_as_complete_in_session(
Expand All @@ -100,6 +107,7 @@ async def get_factors_setup_for_user(
user = await get_user(user_id, user_context)
if user is None:
raise Exception("Unknown user id")
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()
return await recipe.recipe_implementation.get_factors_setup_for_user(
Expand All @@ -114,6 +122,7 @@ async def get_required_secondary_factors_for_user(
) -> List[str]:
if user_context is None:
user_context = {}
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()
return await recipe.recipe_implementation.get_required_secondary_factors_for_user(
Expand All @@ -129,6 +138,7 @@ async def add_to_required_secondary_factors_for_user(
) -> None:
if user_context is None:
user_context = {}
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()
await recipe.recipe_implementation.add_to_required_secondary_factors_for_user(
Expand All @@ -145,13 +155,11 @@ async def remove_from_required_secondary_factors_for_user(
) -> None:
if user_context is None:
user_context = {}
from ..recipe import MultiFactorAuthRecipe

recipe = MultiFactorAuthRecipe.get_instance_or_throw_error()
await recipe.recipe_implementation.remove_from_required_secondary_factors_for_user(
user_id=user_id,
factor_id=factor_id,
user_context=user_context,
)


init = MultiFactorAuthRecipe.init
6 changes: 3 additions & 3 deletions supertokens_python/recipe/multifactorauth/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ async def resync_session_and_fetch_mfa_info_put(

class NextFactors:
def __init__(
self, next: List[str], already_setup: List[str], allowed_to_setup: List[str]
self, next_: List[str], already_setup: List[str], allowed_to_setup: List[str]
):
self.next = next
self.next_ = next_
self.already_setup = already_setup
self.allowed_to_setup = allowed_to_setup

def to_json(self) -> Dict[str, Any]:
return {
"next": self.next,
"next": self.next_,
"alreadySetup": self.already_setup,
"allowedToSetup": self.allowed_to_setup,
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Any, Dict, Optional, Set
Expand All @@ -15,7 +29,6 @@
MFAClaimValue,
MFARequirementList,
)
from .utils import update_and_get_mfa_related_info_in_session


class HasCompletedRequirementListSCV(SessionClaimValidator):
Expand All @@ -29,14 +42,10 @@ def __init__(
self.claim: MultiFactorAuthClaimClass = claim
self.requirement_list = requirement_list

async def should_refetch(
def should_refetch(
self, payload: Dict[str, Any], user_context: Dict[str, Any]
) -> bool:
return (
True
if self.claim.key not in payload or not payload[self.claim.key]
else False
)
return bool(self.claim.key not in payload or not payload[self.claim.key])

async def validate(
self, payload: JSONObject, user_context: Dict[str, Any]
Expand Down Expand Up @@ -65,7 +74,7 @@ async def validate(

factor_ids = next_set_of_unsatisfied_factors.factor_ids

if next_set_of_unsatisfied_factors.type == "string":
if next_set_of_unsatisfied_factors.type_ == "string":
return ClaimValidationResult(
is_valid=False,
reason={
Expand All @@ -74,7 +83,7 @@ async def validate(
},
)

elif next_set_of_unsatisfied_factors.type == "oneOf":
elif next_set_of_unsatisfied_factors.type_ == "oneOf":
return ClaimValidationResult(
is_valid=False,
reason={
Expand All @@ -101,15 +110,11 @@ def __init__(
super().__init__(id_)
self.claim = claim

async def should_refetch(
def should_refetch(
self, payload: Dict[str, Any], user_context: Dict[str, Any]
) -> bool:
assert self.claim is not None
return (
True
if self.claim.key not in payload or not payload[self.claim.key]
else False
)
return bool(self.claim.key not in payload or not payload[self.claim.key])

async def validate(
self, payload: JSONObject, user_context: Dict[str, Any]
Expand Down Expand Up @@ -161,13 +166,16 @@ def __init__(self, key: Optional[str] = None):
key = key or "st-mfa"

async def fetch_value(
user_id: str,
_user_id: str,
recipe_user_id: RecipeUserId,
tenant_id: str,
current_payload: Dict[str, Any],
user_context: Dict[str, Any],
) -> MFAClaimValue:
from .utils import update_and_get_mfa_related_info_in_session

mfa_info = await update_and_get_mfa_related_info_in_session(
self,
input_session_recipe_user_id=recipe_user_id,
input_tenant_id=tenant_id,
input_access_token_payload=current_payload,
Expand Down Expand Up @@ -209,9 +217,11 @@ def get_next_set_of_unsatisfied_factors(
)

if len(next_factors) > 0:
return FactorIdsAndType(factor_ids=list(next_factors), type=factor_type)
return FactorIdsAndType(
factor_ids=list(next_factors), type_=factor_type
)

return FactorIdsAndType(factor_ids=[], type="string")
return FactorIdsAndType(factor_ids=[], type_="string")

def add_to_payload_(
self,
Expand Down
7 changes: 4 additions & 3 deletions supertokens_python/recipe/multifactorauth/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
GetEmailsForFactorOkResult,
GetPhoneNumbersForFactorsOkResult,
)
from .utils import validate_and_normalise_user_input
from .recipe_implementation import RecipeImplementation
from .api.implementation import APIImplementation
from .interfaces import APIOptions


Expand Down Expand Up @@ -79,10 +76,13 @@ def __init__(
] = []
self.is_get_mfa_requirements_for_auth_overridden: bool = False

from .utils import validate_and_normalise_user_input

self.config = validate_and_normalise_user_input(
first_factors,
override,
)
from .recipe_implementation import RecipeImplementation

recipe_implementation = RecipeImplementation(
Querier.get_instance(recipe_id), self
Expand All @@ -92,6 +92,7 @@ def __init__(
if self.config.override.functions is None
else self.config.override.functions(recipe_implementation)
)
from .api.implementation import APIImplementation

api_implementation = APIImplementation()
self.api_implementation = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def __init__(
self.factor_id = factor_id
self.mfa_requirement_for_auth = mfa_requirement_for_auth

async def should_refetch(
def should_refetch(
self, payload: Dict[str, Any], user_context: Dict[str, Any]
) -> bool:
return True if self.claim.get_value_from_payload(payload) is None else False
return self.claim.get_value_from_payload(payload) is None

async def validate(
self, payload: JSONObject, user_context: Dict[str, Any]
Expand Down Expand Up @@ -174,6 +174,7 @@ async def mark_factor_as_complete_in_session(
self, session: SessionContainer, factor_id: str, user_context: Dict[str, Any]
):
await update_and_get_mfa_related_info_in_session(
MultiFactorAuthClaim,
input_session=session,
input_updated_factor_id=factor_id,
user_context=user_context,
Expand Down
4 changes: 0 additions & 4 deletions supertokens_python/recipe/multifactorauth/syncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ..interfaces import (
MFARequirementList,
)
from ..recipe import MultiFactorAuthRecipe


def assert_allowed_to_setup_factor_else_throw_invalid_claim_error(
Expand Down Expand Up @@ -125,6 +124,3 @@ def remove_from_required_secondary_factors_for_user(
)

return sync(async_func(user_id, factor_id, user_context))


init = MultiFactorAuthRecipe.init
4 changes: 2 additions & 2 deletions supertokens_python/recipe/multifactorauth/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ class FactorIdsAndType:
def __init__(
self,
factor_ids: List[str],
type: Union[Literal["string"], Literal["oneOf"], Literal["allOfInAnyOrder"]],
type_: Union[Literal["string"], Literal["oneOf"], Literal["allOfInAnyOrder"]],
):
self.factor_ids = factor_ids
self.type = type
self.type_ = type_


class GetFactorsSetupForUserFromOtherRecipesFunc:
Expand Down
Loading

0 comments on commit eae7e2d

Please sign in to comment.