Skip to content

Commit

Permalink
fixes more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhpoddar committed Oct 10, 2024
1 parent 29a8969 commit 81cf53d
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 32 deletions.
10 changes: 8 additions & 2 deletions tests/test-server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def builder(oI: T) -> T:
for attr in dir(oI)
if callable(getattr(oI, attr)) and not attr.startswith("__")
]

for member in members:
create_override(oI, member, name, override_name)
return oI
Expand Down Expand Up @@ -340,7 +339,7 @@ def init_st(config: Dict[str, Any]):
),
)

include_in_non_public_tenants_by_default = None
include_in_non_public_tenants_by_default = False

if "includeInNonPublicTenantsByDefault" in provider:
include_in_non_public_tenants_by_default = provider[
Expand Down Expand Up @@ -391,6 +390,10 @@ def init_st(config: Dict[str, Any]):
),
),
include_in_non_public_tenants_by_default=include_in_non_public_tenants_by_default,
override=override_builder_with_logging(
"ThirdParty.providers.override",
provider.get("override", None),
),
)
providers.append(provider_input)
recipe_list.append(
Expand Down Expand Up @@ -662,6 +665,9 @@ def not_found(error: Any) -> Any: # pylint: disable=unused-argument
add_accountlinking_routes(app)
add_passwordless_routes(app)
add_totp_routes(app)
from supertokens import add_supertokens_routes # pylint: disable=import-error

add_supertokens_routes(app)

if __name__ == "__main__":
default_st_init()
Expand Down
62 changes: 33 additions & 29 deletions tests/test-server/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any
from typing import Any, Dict
from flask import Flask, request, jsonify
from override_logging import log_override_event # pylint: disable=import-error
from supertokens_python.recipe.session import SessionContainer
from supertokens_python.recipe.session.interfaces import TokenInfo
from supertokens_python.recipe.session.jwt import (
parse_jwt_without_signature_verification,
Expand Down Expand Up @@ -50,27 +51,7 @@ def create_new_session_without_request_response(): # type: ignore
user_context,
)

return jsonify(
{
"sessionHandle": session_container.get_handle(),
"userId": session_container.get_user_id(),
"tenantId": session_container.get_tenant_id(),
"userDataInAccessToken": session_container.get_access_token_payload(),
"accessToken": session_container.get_access_token(),
"frontToken": session_container.get_all_session_tokens_dangerously()[
"frontToken"
],
"refreshToken": session_container.get_all_session_tokens_dangerously()[
"refreshToken"
],
"antiCsrfToken": session_container.get_all_session_tokens_dangerously()[
"antiCsrfToken"
],
"accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[
"accessAndFrontTokenUpdated"
],
}
)
return jsonify(convert_session_to_json(session_container))

@app.route("/test/session/getsessionwithoutrequestresponse", methods=["POST"]) # type: ignore
def get_session_without_request_response(): # type: ignore
Expand All @@ -83,13 +64,14 @@ def get_session_without_request_response(): # type: ignore
options = data.get("options")
user_context = data.get("userContext", {})

try:
session_container = session.get_session_without_request_response(
access_token, anti_csrf_token, options, user_context
)
return jsonify(session_container)
except Exception as e:
return jsonify({"error": str(e)}), 500
session_container = session.get_session_without_request_response(
access_token, anti_csrf_token, options, user_context
)
return jsonify(
None
if session_container is None
else convert_session_to_json(session_container)
)

@app.route("/test/session/sessionobject/assertclaims", methods=["POST"]) # type: ignore
def assert_claims(): # type: ignore
Expand Down Expand Up @@ -213,6 +195,28 @@ def fetch_and_set_claim_api(): # type: ignore
return jsonify(response)


def convert_session_to_json(session_container: SessionContainer) -> Dict[str, Any]:
return {
"sessionHandle": session_container.get_handle(),
"userId": session_container.get_user_id(),
"tenantId": session_container.get_tenant_id(),
"userDataInAccessToken": session_container.get_access_token_payload(),
"accessToken": session_container.get_access_token(),
"frontToken": session_container.get_all_session_tokens_dangerously()[
"frontToken"
],
"refreshToken": session_container.get_all_session_tokens_dangerously()[
"refreshToken"
],
"antiCsrfToken": session_container.get_all_session_tokens_dangerously()[
"antiCsrfToken"
],
"accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[
"accessAndFrontTokenUpdated"
],
}


def convert_session_to_container(data: Any) -> Session:
jwt_info = parse_jwt_without_signature_verification(data["session"]["accessToken"])
jwt_payload = jwt_info.payload
Expand Down
89 changes: 89 additions & 0 deletions tests/test-server/supertokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from flask import Flask, request, jsonify
from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo
from supertokens_python.types import AccountInfo
from supertokens_python.syncio import (
get_user,
delete_user,
list_users_by_account_info,
get_users_newest_first,
get_users_oldest_first,
)


def add_supertokens_routes(app: Flask):
@app.route("/test/supertokens/getuser", methods=["POST"]) # type: ignore
def get_user_api(): # type: ignore
assert request.json is not None
response = get_user(request.json["userId"], request.json.get("userContext"))
return jsonify(None if response is None else response.to_json())

@app.route("/test/supertokens/deleteuser", methods=["POST"]) # type: ignore
def delete_user_api(): # type: ignore
assert request.json is not None
delete_user(
request.json["userId"],
request.json["removeAllLinkedAccounts"],
request.json.get("userContext"),
)
return jsonify({"status": "OK"})

@app.route("/test/supertokens/listusersbyaccountinfo", methods=["POST"]) # type: ignore
def list_users_by_account_info_api(): # type: ignore
assert request.json is not None
response = list_users_by_account_info(
request.json["tenantId"],
AccountInfo(
email=request.json["accountInfo"].get("email", None),
phone_number=request.json["accountInfo"].get("phoneNumber", None),
third_party=(
None
if "thirdParty" not in request.json["accountInfo"]
else ThirdPartyInfo(
third_party_id=request.json["accountInfo"]["thirdParty"][
"thirdPartyId"
],
third_party_user_id=request.json["accountInfo"]["thirdParty"][
"id"
],
)
),
),
request.json["doUnionOfAccountInfo"],
request.json.get("userContext"),
)

return jsonify([r.to_json() for r in response])

@app.route("/test/supertokens/getusersnewestfirst", methods=["POST"]) # type: ignore
def get_users_newest_first_api(): # type: ignore
assert request.json is not None
response = get_users_newest_first(
include_recipe_ids=request.json["includeRecipeIds"],
limit=request.json["limit"],
pagination_token=request.json["paginationToken"],
tenant_id=request.json["tenantId"],
user_context=request.json.get("userContext"),
)
return jsonify(
{
"nextPaginationToken": response.next_pagination_token,
"users": [r.to_json() for r in response.users],
}
)

@app.route("/test/supertokens/getusersoldestfirst", methods=["POST"]) # type: ignore
def get_users_oldest_first_api(): # type: ignore
assert request.json is not None
response = get_users_oldest_first(
include_recipe_ids=request.json["includeRecipeIds"],
limit=request.json["limit"],
pagination_token=request.json["paginationToken"],
tenant_id=request.json["tenantId"],
user_context=request.json.get("userContext"),
)
return jsonify(
{
"nextPaginationToken": response.next_pagination_token,
"users": [r.to_json() for r in response.users],
}
)
164 changes: 163 additions & 1 deletion tests/test-server/test_functions_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
ShouldAutomaticallyLink,
ShouldNotAutomaticallyLink,
)
from supertokens_python.recipe.thirdparty.types import (
RawUserInfoFromProvider,
UserInfo,
UserInfoEmail,
)
from supertokens_python.types import AccountInfo, RecipeUserId
from supertokens_python.types import APIResponse, User

Expand Down Expand Up @@ -124,7 +129,164 @@ async def func(

return func

raise Exception("Unknown eval string")
if eval_str.startswith("thirdparty.init.signInAndUpFeature.providers"):

def custom_provider(provider: Any):
if "custom-ev" in eval_str:

def exchange_auth_code_for_oauth_tokens1(
redirect_uri_info: Any, # pylint: disable=unused-argument
user_context: Any, # pylint: disable=unused-argument
) -> Any:
return {}

def get_user_info1(
oauth_tokens: Any,
user_context: Any, # pylint: disable=unused-argument
): # pylint: disable=unused-argument
return UserInfo(
third_party_user_id=oauth_tokens.get("userId", "user"),
email=UserInfoEmail(
email=oauth_tokens.get("email", "[email protected]"),
is_verified=True,
),
raw_user_info_from_provider=RawUserInfoFromProvider(
from_id_token_payload=None,
from_user_info_api=None,
),
)

provider.exchange_auth_code_for_oauth_tokens = (
exchange_auth_code_for_oauth_tokens1
)
provider.get_user_info = get_user_info1
return provider

if "custom-no-ev" in eval_str:

def exchange_auth_code_for_oauth_tokens2(
redirect_uri_info: Any, # pylint: disable=unused-argument
user_context: Any, # pylint: disable=unused-argument
) -> Any:
return {}

def get_user_info2(
oauth_tokens: Any, user_context: Any
): # pylint: disable=unused-argument
return UserInfo(
third_party_user_id=oauth_tokens.get("userId", "user"),
email=UserInfoEmail(
email=oauth_tokens.get("email", "[email protected]"),
is_verified=False,
),
raw_user_info_from_provider=RawUserInfoFromProvider(
from_id_token_payload=None,
from_user_info_api=None,
),
)

provider.exchange_auth_code_for_oauth_tokens = (
exchange_auth_code_for_oauth_tokens2
)
provider.get_user_info = get_user_info2
return provider

if "custom2" in eval_str:

def exchange_auth_code_for_oauth_tokens3(
redirect_uri_info: Any,
user_context: Any, # pylint: disable=unused-argument
) -> Any:
return redirect_uri_info["redirectURIQueryParams"]

def get_user_info3(
oauth_tokens: Any, user_context: Any
): # pylint: disable=unused-argument
return UserInfo(
third_party_user_id=f"custom2{oauth_tokens['email']}",
email=UserInfoEmail(
email=oauth_tokens["email"],
is_verified=True,
),
raw_user_info_from_provider=RawUserInfoFromProvider(
from_id_token_payload=None,
from_user_info_api=None,
),
)

provider.exchange_auth_code_for_oauth_tokens = (
exchange_auth_code_for_oauth_tokens3
)
provider.get_user_info = get_user_info3
return provider

if "custom3" in eval_str:

def exchange_auth_code_for_oauth_tokens4(
redirect_uri_info: Any,
user_context: Any, # pylint: disable=unused-argument
) -> Any:
return redirect_uri_info["redirectURIQueryParams"]

def get_user_info4(
oauth_tokens: Any, user_context: Any
): # pylint: disable=unused-argument
return UserInfo(
third_party_user_id=oauth_tokens["email"],
email=UserInfoEmail(
email=oauth_tokens["email"],
is_verified=True,
),
raw_user_info_from_provider=RawUserInfoFromProvider(
from_id_token_payload=None,
from_user_info_api=None,
),
)

provider.exchange_auth_code_for_oauth_tokens = (
exchange_auth_code_for_oauth_tokens4
)
provider.get_user_info = get_user_info4
return provider

if "custom" in eval_str:

def exchange_auth_code_for_oauth_tokens5(
redirect_uri_info: Any,
user_context: Any, # pylint: disable=unused-argument
) -> Any:
return redirect_uri_info

def get_user_info5(
oauth_tokens: Any, user_context: Any
): # pylint: disable=unused-argument
if oauth_tokens.get("error"):
raise Exception("Credentials error")
return UserInfo(
third_party_user_id=oauth_tokens.get("userId", "userId"),
email=(
None
if oauth_tokens.get("email") is None
else UserInfoEmail(
email=oauth_tokens.get("email"),
is_verified=oauth_tokens.get("isVerified", False),
)
),
raw_user_info_from_provider=RawUserInfoFromProvider(
from_id_token_payload=None,
from_user_info_api=None,
),
)

provider.exchange_auth_code_for_oauth_tokens = (
exchange_auth_code_for_oauth_tokens5
)
provider.get_user_info = get_user_info5
return provider

return custom_provider

raise Exception("Unknown eval string: " + eval_str)


class OverrideParams(APIResponse):
Expand Down

0 comments on commit 81cf53d

Please sign in to comment.