Skip to content

Commit

Permalink
fixes stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhpoddar committed Oct 11, 2024
1 parent 579ac9a commit c812d73
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 68 deletions.
20 changes: 20 additions & 0 deletions supertokens_python/process_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,23 @@ def get_event_by_last_event_by_name(
if event == state:
return event
return None

def wait_for_event(
self, state: PROCESS_STATE, time_in_ms: int = 7000
) -> Optional[PROCESS_STATE]:
from time import time, sleep

start_time = time()

def try_and_get() -> Optional[PROCESS_STATE]:
result = self.get_event_by_last_event_by_name(state)
if result is None:
if (time() - start_time) * 1000 > time_in_ms:
return None
else:
sleep(1)
return try_and_get()
else:
return result

return try_and_get()
14 changes: 14 additions & 0 deletions supertokens_python/recipe/accountlinking/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def __init__(self, user: User, was_already_a_primary_user: bool):
self.user = user
self.was_already_a_primary_user = was_already_a_primary_user

def to_json(self) -> Dict[str, Any]:
return {
"status": self.status,
"user": self.user.to_json(),
"wasAlreadyAPrimaryUser": self.was_already_a_primary_user,
}


class CreatePrimaryUserRecipeUserIdAlreadyLinkedError:
def __init__(self, primary_user_id: str, description: Optional[str] = None):
Expand Down Expand Up @@ -216,6 +223,13 @@ def __init__(self, accounts_already_linked: bool, user: User):
self.accounts_already_linked = accounts_already_linked
self.user = user

def to_json(self) -> Dict[str, Any]:
return {
"status": self.status,
"accountsAlreadyLinked": self.accounts_already_linked,
"user": self.user.to_json(),
}


class LinkAccountsRecipeUserIdAlreadyLinkedError:
def __init__(
Expand Down
6 changes: 6 additions & 0 deletions supertokens_python/recipe/accountlinking/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def __init__(
"emailpassword", "thirdparty", "passwordless"
] = recipe_id

def to_json(self) -> Dict[str, Any]:
return {
**super().to_json(),
"recipeId": self.recipe_id,
}


class RecipeLevelUser(AccountInfoWithRecipeId):
def __init__(
Expand Down
7 changes: 7 additions & 0 deletions supertokens_python/recipe/emailpassword/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def __init__(self, user: User, recipe_user_id: RecipeUserId):
self.user = user
self.recipe_user_id = recipe_user_id

def to_json(self) -> Dict[str, Any]:
return {
"status": self.status,
"user": self.user.to_json(),
"recipeUserId": self.recipe_user_id.get_as_string(),
}


class EmailAlreadyExistsError(APIResponse):
status: str = "EMAIL_ALREADY_EXISTS_ERROR"
Expand Down
5 changes: 4 additions & 1 deletion supertokens_python/recipe/emailpassword/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Awaitable, Callable, Optional, TypeVar, Union, Any
from typing import Awaitable, Callable, Dict, Optional, TypeVar, Union, Any

from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient
from supertokens_python.ingredients.emaildelivery.types import (
Expand All @@ -33,6 +33,9 @@ def __init__(self, id: str, value: Any): # pylint: disable=redefined-builtin
self.id: str = id
self.value: Any = value

def to_json(self) -> Dict[str, Any]:
return {"id": self.id, "value": self.value}


class InputFormField:
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/session/access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_info_from_access_token(
user_data = payload

session_handle = sanitize_string(payload.get("sessionHandle"))
recipe_user_id = sanitize_string(payload.get("recipeUserId", user_id))
recipe_user_id = sanitize_string(payload.get("rsub", user_id))
refresh_token_hash_1 = sanitize_string(payload.get("refreshTokenHash1"))
parent_refresh_token_hash_1 = sanitize_string(
payload.get("parentRefreshTokenHash1")
Expand Down
6 changes: 6 additions & 0 deletions supertokens_python/recipe/session/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ def __init__(
self.invalid_claims = invalid_claims
self.access_token_payload_update = access_token_payload_update

def to_json(self) -> Dict[str, Any]:
return {
"invalidClaims": [i.to_json() for i in self.invalid_claims],
"accessTokenPayloadUpdate": self.access_token_payload_update,
}


class GetSessionTokensDangerouslyDict(TypedDict):
accessToken: str
Expand Down
13 changes: 13 additions & 0 deletions supertokens_python/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ def __init__(
self.phone_number = phone_number
self.third_party = third_party

def to_json(self) -> Dict[str, Any]:
json_repo: Dict[str, Any] = {}
if self.email is not None:
json_repo["email"] = self.email
if self.phone_number is not None:
json_repo["phoneNumber"] = self.phone_number
if self.third_party is not None:
json_repo["thirdParty"] = {
"id": self.third_party.id,
"userId": self.third_party.user_id,
}
return json_repo


class LoginMethod(AccountInfo):
def __init__(
Expand Down
23 changes: 20 additions & 3 deletions tests/test-server/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple
from flask import Flask, request, jsonify
from supertokens_python import process_state
from supertokens_python.framework import BaseRequest
from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig
from supertokens_python.ingredients.smsdelivery.types import SMSDeliveryConfig
Expand Down Expand Up @@ -150,9 +151,10 @@ def builder(oI: T) -> T:

def logging_override_func_sync(name: str, c: Any) -> Any:
def inner(*args: Any, **kwargs: Any) -> Any:
override_logging.log_override_event(
name, "CALL", {"args": args, "kwargs": kwargs}
)
if len(args) > 0:
override_logging.log_override_event(name, "CALL", args)
else:
override_logging.log_override_event(name, "CALL", kwargs)
try:
res = c(*args, **kwargs)
override_logging.log_override_event(name, "RES", res)
Expand Down Expand Up @@ -668,6 +670,21 @@ def verify_session_route():
return jsonify({"status": "OK"})


@app.route("/test/waitforevent", methods=["GET"]) # type: ignore
def wait_for_event_api(): # type: ignore
event = request.args.get("event")
if not event:
raise ValueError("event query param missing")

event_enum = process_state.PROCESS_STATE(int(event))
instance = process_state.ProcessState.get_instance()
event_result = instance.wait_for_event(event_enum)
if event_result is None:
return jsonify(None)
else:
return jsonify("Found")


@app.errorhandler(404)
def not_found(error: Any) -> Any: # pylint: disable=unused-argument
return jsonify({"error": f"Route not found: {request.method} {request.path}"}), 404
Expand Down
54 changes: 52 additions & 2 deletions tests/test-server/override_logging.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@
from typing import Any, Dict, List, Set, Union
from typing import Any, Callable, Coroutine, Dict, List, Set, Union
import time

from httpx import Response

from supertokens_python.framework.flask.flask_request import FlaskRequest
from supertokens_python.types import RecipeUserId
from supertokens_python.recipe.accountlinking.interfaces import (
CreatePrimaryUserOkResult,
LinkAccountsOkResult,
)
from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId
from supertokens_python.recipe.emailpassword.types import FormField
from supertokens_python.recipe.emailpassword.interfaces import (
APIOptions as EmailPasswordAPIOptions,
SignUpOkResult,
SignUpPostOkResult,
)
from supertokens_python.recipe.session.interfaces import ClaimsValidationResult
from supertokens_python.recipe.session.session_class import Session
from supertokens_python.recipe.thirdparty.interfaces import (
APIOptions as ThirdPartyAPIOptions,
)
from supertokens_python.recipe.passwordless.interfaces import (
APIOptions as PasswordlessAPIOptions,
)
from supertokens_python.types import AccountInfo, RecipeUserId, User

override_logs: List[Dict[str, Any]] = []

Expand Down Expand Up @@ -43,5 +62,36 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A
return "Response"
if isinstance(data, RecipeUserId):
return data.get_as_string()
if isinstance(data, AccountInfoWithRecipeId):
return data.to_json()
if isinstance(data, AccountInfo):
return data.to_json()
if isinstance(data, User):
return data.to_json()
if isinstance(data, Coroutine):
return "Coroutine"
if isinstance(data, Callable):
return "Callable"
if isinstance(data, FormField):
return data.to_json()
if isinstance(data, EmailPasswordAPIOptions):
return "EmailPasswordAPIOptions"
if isinstance(data, ThirdPartyAPIOptions):
return "ThirdPartyAPIOptions"
if isinstance(data, PasswordlessAPIOptions):
return "PasswordlessAPIOptions"
if isinstance(data, SignUpOkResult):
return data.to_json()
if isinstance(data, CreatePrimaryUserOkResult):
return data.to_json()
if isinstance(data, LinkAccountsOkResult):
return data.to_json()
if isinstance(data, Session):
from session import convert_session_to_json

return convert_session_to_json(data)
if isinstance(data, SignUpPostOkResult):
return data.to_json()
if isinstance(data, ClaimsValidationResult):
return data.to_json()
return data
67 changes: 6 additions & 61 deletions tests/test-server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,7 @@ def assert_claims(): # type: ignore
return jsonify(
{
"status": "OK",
"updatedSession": {
"sessionHandle": session_container.get_handle(),
"userId": session_container.get_user_id(),
"tenantId": session_container.get_tenant_id(),
"userDataInAccessToken": session_container.get_access_token_payload(),
"accessToken": session_container.get_access_token(),
"frontToken": session_container.get_all_session_tokens_dangerously()[
"frontToken"
],
"refreshToken": session_container.get_all_session_tokens_dangerously()[
"refreshToken"
],
"antiCsrfToken": session_container.get_all_session_tokens_dangerously()[
"antiCsrfToken"
],
"accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[
"accessAndFrontTokenUpdated"
],
},
"updatedSession": convert_session_to_json(session_container),
}
)
except Exception as e:
Expand All @@ -134,26 +116,7 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore
return jsonify(
{
"status": "OK",
"updatedSession": {
"sessionHandle": session_container.get_handle(),
"userId": session_container.get_user_id(),
"recipeUserId": session_container.get_recipe_user_id().get_as_string(),
"tenantId": session_container.get_tenant_id(),
"userDataInAccessToken": session_container.get_access_token_payload(),
"accessToken": session_container.get_access_token(),
"frontToken": session_container.get_all_session_tokens_dangerously()[
"frontToken"
],
"refreshToken": session_container.get_all_session_tokens_dangerously()[
"refreshToken"
],
"antiCsrfToken": session_container.get_all_session_tokens_dangerously()[
"antiCsrfToken"
],
"accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[
"accessAndFrontTokenUpdated"
],
},
"updatedSession": convert_session_to_json(session_container),
}
)

Expand All @@ -170,28 +133,7 @@ def fetch_and_set_claim_api(): # type: ignore
user_context = data.get("userContext", {})

session.sync_fetch_and_set_claim(claim, user_context)
response = {
"updatedSession": {
"sessionHandle": session.get_handle(),
"userId": session.get_user_id(),
"recipeUserId": session.get_recipe_user_id().get_as_string(),
"tenantId": session.get_tenant_id(),
"userDataInAccessToken": session.get_access_token_payload(),
"accessToken": session.get_access_token(),
"frontToken": session.get_all_session_tokens_dangerously()[
"frontToken"
],
"refreshToken": session.get_all_session_tokens_dangerously()[
"refreshToken"
],
"antiCsrfToken": session.get_all_session_tokens_dangerously()[
"antiCsrfToken"
],
"accessTokenUpdated": session.get_all_session_tokens_dangerously()[
"accessAndFrontTokenUpdated"
],
}
}
response = {"updatedSession": convert_session_to_json(session)}
return jsonify(response)


Expand All @@ -214,6 +156,9 @@ def convert_session_to_json(session_container: SessionContainer) -> Dict[str, An
"accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[
"accessAndFrontTokenUpdated"
],
"recipeUserId": {
"recipeUserId": session_container.get_recipe_user_id().get_as_string()
},
}


Expand Down

0 comments on commit c812d73

Please sign in to comment.