From e5343166c7cb5c41296d90feca0625b8bbb3d206 Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Wed, 27 Sep 2023 12:10:43 +0530 Subject: [PATCH 01/10] feat: Added access token validation --- supertokens_python/recipe/thirdparty/provider.py | 14 ++++++++++++++ .../recipe/thirdparty/providers/custom.py | 8 ++++++++ 2 files changed, 22 insertions(+) diff --git a/supertokens_python/recipe/thirdparty/provider.py b/supertokens_python/recipe/thirdparty/provider.py index aa13b32d6..e216447e9 100644 --- a/supertokens_python/recipe/thirdparty/provider.py +++ b/supertokens_python/recipe/thirdparty/provider.py @@ -173,6 +173,12 @@ def __init__( Awaitable[None], ] ] = None, + validate_access_token: Optional[ + Callable[ + [str, ProviderConfigForClient, Dict[str, Any]], + Awaitable[None], + ] + ] = None, generate_fake_email: Optional[ Callable[[str, str, Dict[str, Any]], Awaitable[str]] ] = None, @@ -191,6 +197,7 @@ def __init__( self.user_info_map = user_info_map self.require_email = require_email self.validate_id_token_payload = validate_id_token_payload + self.validate_access_token = validate_access_token self.generate_fake_email = generate_fake_email def to_json(self) -> Dict[str, Any]: @@ -247,6 +254,12 @@ def __init__( Awaitable[None], ] ] = None, + validate_access_token: Optional[ + Callable[ + [str, ProviderConfigForClient, Dict[str, Any]], + Awaitable[None], + ] + ] = None, generate_fake_email: Optional[ Callable[[str, str, Dict[str, Any]], Awaitable[str]] ] = None, @@ -278,6 +291,7 @@ def __init__( validate_id_token_payload, generate_fake_email, ) + self.validate_access_token = validate_access_token def to_json(self) -> Dict[str, Any]: d1 = ProviderClientConfig.to_json(self) diff --git a/supertokens_python/recipe/thirdparty/providers/custom.py b/supertokens_python/recipe/thirdparty/providers/custom.py index e9807e914..e42fc3ee9 100644 --- a/supertokens_python/recipe/thirdparty/providers/custom.py +++ b/supertokens_python/recipe/thirdparty/providers/custom.py @@ -59,6 +59,7 @@ def get_provider_config_for_client( user_info_map=config.user_info_map, require_email=config.require_email, validate_id_token_payload=config.validate_id_token_payload, + validate_access_token=config.validate_access_token, generate_fake_email=config.generate_fake_email, ) @@ -421,6 +422,13 @@ async def get_user_info( self.config.user_info_endpoint, query_params, headers ) + if self.config.validate_access_token is not None: + await self.config.validate_access_token( + access_token, + self.config, + user_context + ) + user_info_result = get_supertokens_user_info_result_from_raw_user_info( self.config, raw_user_info_from_provider ) From 42eb9d86fa7e8810fa38ad02a81f9d34c14f1a97 Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Wed, 4 Oct 2023 17:23:16 +0530 Subject: [PATCH 02/10] feat: Added `validate_access_token` function * Added `validate_access_token` function to providers, to allow verifying the access token received from the providers. --- .../multitenancy/recipe_implementation.py | 1 + .../recipe/thirdparty/provider.py | 9 +- .../thirdparty/providers/config_utils.py | 1 + .../recipe/thirdparty/providers/custom.py | 6 +- tests/thirdparty/test_thirdparty.py | 83 ------------------- 5 files changed, 12 insertions(+), 88 deletions(-) delete mode 100644 tests/thirdparty/test_thirdparty.py diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 7d569ed95..023c45b45 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -104,6 +104,7 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse: user_info_map=user_info_map, require_email=p.get("requireEmail", True), validate_id_token_payload=None, + validate_access_token=None, generate_fake_email=None, ) ) diff --git a/supertokens_python/recipe/thirdparty/provider.py b/supertokens_python/recipe/thirdparty/provider.py index e216447e9..6b42050eb 100644 --- a/supertokens_python/recipe/thirdparty/provider.py +++ b/supertokens_python/recipe/thirdparty/provider.py @@ -289,9 +289,9 @@ def __init__( user_info_map, require_email, validate_id_token_payload, + validate_access_token, generate_fake_email, ) - self.validate_access_token = validate_access_token def to_json(self) -> Dict[str, Any]: d1 = ProviderClientConfig.to_json(self) @@ -324,6 +324,12 @@ def __init__( Awaitable[None], ] ] = None, + validate_access_token: Optional[ + Callable[ + [str, ProviderConfigForClient, Dict[str, Any]], + Awaitable[None], + ] + ] = None, generate_fake_email: Optional[ Callable[[str, str, Dict[str, Any]], Awaitable[str]] ] = None, @@ -343,6 +349,7 @@ def __init__( user_info_map, require_email, validate_id_token_payload, + validate_access_token, generate_fake_email, ) self.clients = clients diff --git a/supertokens_python/recipe/thirdparty/providers/config_utils.py b/supertokens_python/recipe/thirdparty/providers/config_utils.py index d520106aa..e3c70c517 100644 --- a/supertokens_python/recipe/thirdparty/providers/config_utils.py +++ b/supertokens_python/recipe/thirdparty/providers/config_utils.py @@ -87,6 +87,7 @@ def merge_config( user_info_map=config_from_static.user_info_map, generate_fake_email=config_from_static.generate_fake_email, validate_id_token_payload=config_from_static.validate_id_token_payload, + validate_access_token=config_from_static.validate_access_token, ) if result.user_info_map is None: diff --git a/supertokens_python/recipe/thirdparty/providers/custom.py b/supertokens_python/recipe/thirdparty/providers/custom.py index e42fc3ee9..732a74136 100644 --- a/supertokens_python/recipe/thirdparty/providers/custom.py +++ b/supertokens_python/recipe/thirdparty/providers/custom.py @@ -422,11 +422,9 @@ async def get_user_info( self.config.user_info_endpoint, query_params, headers ) - if self.config.validate_access_token is not None: + if self.config.validate_access_token is not None and access_token is not None: await self.config.validate_access_token( - access_token, - self.config, - user_context + access_token, self.config, user_context ) user_info_result = get_supertokens_user_info_result_from_raw_user_info( diff --git a/tests/thirdparty/test_thirdparty.py b/tests/thirdparty/test_thirdparty.py deleted file mode 100644 index e20863836..000000000 --- a/tests/thirdparty/test_thirdparty.py +++ /dev/null @@ -1,83 +0,0 @@ -import respx -import json - -from pytest import fixture, mark -from fastapi import FastAPI -from supertokens_python.framework.fastapi import get_middleware -from starlette.testclient import TestClient - -from supertokens_python.recipe import session, thirdparty -from supertokens_python import init -from base64 import b64encode - -from tests.utils import ( - setup_function, - teardown_function, - start_st, - st_init_common_args, -) - - -_ = setup_function # type:ignore -_ = teardown_function # type:ignore -_ = start_st # type:ignore - - -pytestmark = mark.asyncio - -respx_mock = respx.MockRouter - - -@fixture(scope="function") -async def fastapi_client(): - app = FastAPI() - app.add_middleware(get_middleware()) - - return TestClient(app) - - -async def test_thirdpary_parsing_works(fastapi_client: TestClient): - st_init_args = { - **st_init_common_args, - "recipe_list": [ - session.init(), - thirdparty.init( - sign_in_and_up_feature=thirdparty.SignInAndUpFeature( - providers=[ - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="apple", - clients=[ - thirdparty.ProviderClientConfig( - client_id="4398792-io.supertokens.example.service", - additional_config={ - "keyId": "7M48Y4RYDL", - "teamId": "YWQCXGJRJL", - "privateKey": "-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----", - }, - ), - ], - ) - ), - ] - ) - ), - ], - } - init(**st_init_args) # type: ignore - start_st() - - state = b64encode( - json.dumps({"frontendRedirectURI": "http://localhost:3000/redirect"}).encode() - ).decode() - code = "testing" - - data = {"state": state, "code": code} - res = fastapi_client.post("/auth/callback/apple", data=data) - - assert res.status_code == 303 - assert res.content == b"" - assert ( - res.headers["location"] - == f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}" - ) From ce76af3351f3a7760309a583138d6ba34d30421c Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Wed, 4 Oct 2023 17:31:46 +0530 Subject: [PATCH 03/10] updated the sdk version --- setup.py | 2 +- supertokens_python/constants.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 1a7960248..d81afe157 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="supertokens_python", - version="0.16.3", + version="0.16.4", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 1cd45287b..1094f851f 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["3.0"] -VERSION = "0.16.3" +VERSION = "0.16.4" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" From 3429fef847434e753df26261f85d54c03a24b015 Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Wed, 4 Oct 2023 17:37:24 +0530 Subject: [PATCH 04/10] updated changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f14039af5..6010d7119 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +- Add `validate_access_token` function to providers + - This can be used to verify the access token received from providers. + ## [0.16.3] - 2023-09-28 - Add Twitter provider for thirdparty login From 51378bc3c9aa56cbeeb1eca839cd43a0b8deebd0 Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Wed, 4 Oct 2023 18:27:47 +0530 Subject: [PATCH 05/10] fix: Added `validate_access_token` for github provider --- .../recipe/thirdparty/providers/github.py | 36 +++ tests/thirdparty/test_providers.py | 262 ++++++++++++++++++ 2 files changed, 298 insertions(+) create mode 100644 tests/thirdparty/test_providers.py diff --git a/supertokens_python/recipe/thirdparty/providers/github.py b/supertokens_python/recipe/thirdparty/providers/github.py index e411d87aa..e14c03acc 100644 --- a/supertokens_python/recipe/thirdparty/providers/github.py +++ b/supertokens_python/recipe/thirdparty/providers/github.py @@ -12,8 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations + +import base64 +import json from typing import Any, Dict, List, Optional +import requests + from supertokens_python.recipe.thirdparty.providers.utils import do_get_request from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail @@ -71,4 +76,35 @@ def Github(input: ProviderInput) -> Provider: # pylint: disable=redefined-built if input.config.token_endpoint is None: input.config.token_endpoint = "https://github.com/login/oauth/access_token" + if input.config.validate_access_token is None: + input.config.validate_access_token = validate_access_token + return NewProvider(input, GithubImpl) + + +async def validate_access_token( + access_token: str, config: ProviderConfigForClient, _: Dict[str, Any] +): + client_secret = "" if config.client_secret is None else config.client_secret + basic_auth_token = base64.b64encode( + f"{config.client_id}:{client_secret}".encode() + ).decode() + + # POST request to get applications response + url = f"https://api.github.com/applications/{config.client_id}/token" + headers = { + "Authorization": f"Basic {basic_auth_token}", + "Content-Type": "application/json", + } + payload = json.dumps({"access_token": access_token}) + + resp = requests.post(url, headers=headers, data=payload) + + # Error handling and validation + if resp.status_code != 200: + raise ValueError("Invalid access token") + + body = resp.json() + + if "app" not in body or body["app"]["client_id"] != config.client_id: + raise ValueError("Access token does not belong to your application") diff --git a/tests/thirdparty/test_providers.py b/tests/thirdparty/test_providers.py new file mode 100644 index 000000000..cdd68d41b --- /dev/null +++ b/tests/thirdparty/test_providers.py @@ -0,0 +1,262 @@ +import datetime +import json +from base64 import b64encode +from typing import Dict, Any, Optional + +import respx +from fastapi import FastAPI +from pytest import fixture, mark +from starlette.testclient import TestClient + +from supertokens_python import init +from supertokens_python.framework.fastapi import get_middleware +from supertokens_python.recipe import session, thirdparty +from supertokens_python.recipe import thirdpartyemailpassword +from supertokens_python.recipe.thirdparty.provider import ( + ProviderClientConfig, + ProviderConfig, + ProviderInput, + Provider, + RedirectUriInfo, + ProviderConfigForClient, +) +from supertokens_python.recipe.thirdparty.types import ( + UserInfo, + UserInfoEmail, + RawUserInfoFromProvider, +) +from tests.utils import ( + setup_function, + teardown_function, + start_st, + st_init_common_args, +) + +_ = setup_function # type:ignore +_ = teardown_function # type:ignore +_ = start_st # type:ignore + +pytestmark = mark.asyncio + +respx_mock = respx.MockRouter + + +@fixture(scope="function") +async def fastapi_client(): + app = FastAPI() + app.add_middleware(get_middleware()) + + return TestClient(app, raise_server_exceptions=False) + + +async def test_thirdpary_parsing_works(fastapi_client: TestClient): + st_init_args = { + **st_init_common_args, + "recipe_list": [ + session.init(), + thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature( + providers=[ + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="apple", + clients=[ + thirdparty.ProviderClientConfig( + client_id="4398792-io.supertokens.example.service", + additional_config={ + "keyId": "7M48Y4RYDL", + "teamId": "YWQCXGJRJL", + "privateKey": "-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----", + }, + ), + ], + ) + ), + ] + ) + ), + ], + } + init(**st_init_args) # type: ignore + start_st() + + state = b64encode( + json.dumps({"frontendRedirectURI": "http://localhost:3000/redirect"}).encode() + ).decode() + code = "testing" + + data = {"state": state, "code": code} + res = fastapi_client.post("/auth/callback/apple", data=data) + + assert res.status_code == 303 + assert res.content == b"" + assert ( + res.headers["location"] + == f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}" + ) + + +async def exchange_auth_code_for_valid_oauth_tokens( # pylint: disable=unused-argument + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], +) -> Dict[str, Any]: + return { + "access_token": "accesstoken", + "id_token": "idtoken", + } + + +async def get_user_info( # pylint: disable=unused-argument + oauth_tokens: Dict[str, Any], + user_context: Dict[str, Any], +) -> UserInfo: + time = str(datetime.datetime.now()) + return UserInfo( + "" + time, + UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True), + RawUserInfoFromProvider({}, {}), + ) + + +async def exchange_auth_code_for_invalid_oauth_tokens( # pylint: disable=unused-argument + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], +) -> Dict[str, Any]: + return { + "access_token": "wrongaccesstoken", + "id_token": "wrongidtoken", + } + + +def get_custom_invalid_token_provider(provider: Provider) -> Provider: + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_invalid_oauth_tokens + ) + return provider + + +def get_custom_valid_token_provider(provider: Provider) -> Provider: + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_valid_oauth_tokens + ) + provider.get_user_info = get_user_info + return provider + + +async def invalid_access_token( # pylint: disable=unused-argument + access_token: str, + config: ProviderConfigForClient, + user_context: Optional[Dict[str, Any]], +): + if access_token == "wrongaccesstoken": + raise Exception("Invalid access token") + + +async def valid_access_token( # pylint: disable=unused-argument + access_token: str, config: ProviderConfig, user_context: Optional[Dict[str, Any]] +): + if access_token == "accesstoken": + return + raise Exception("Unexpected access token") + + +async def test_signinup_when_validate_access_token_throws(fastapi_client: TestClient): + st_init_args = { + **st_init_common_args, + "recipe_list": [ + session.init(), + thirdpartyemailpassword.init( + providers=[ + ProviderInput( + config=ProviderConfig( + third_party_id="custom", + clients=[ + ProviderClientConfig( + client_id="test", + client_secret="test-secret", + scope=["profile", "email"], + ), + ], + authorization_endpoint="https://example.com/oauth/authorize", + validate_access_token=invalid_access_token, + authorization_endpoint_query_params={ + "response_type": "token", # Changing an existing parameter + "response_mode": "form", # Adding a new parameter + "scope": None, # Removing a parameter + }, + token_endpoint="https://example.com/oauth/token", + ), + override=get_custom_invalid_token_provider, + ) + ] + ), + ], + } + init(**st_init_args) # type: ignore + start_st() + + res = fastapi_client.post( + "/auth/signinup", + json={ + "thirdPartyId": "custom", + "redirectURIInfo": { + "redirectURIOnProviderDashboard": "http://127.0.0.1/callback", + "redirectURIQueryParams": { + "code": "abcdefghj", + }, + }, + }, + ) + assert res.status_code == 500 + + +# async def test_signinup_works_when_validate_access_token_does_not_throw(fastapi_client: TestClient): +# st_init_args = { +# **st_init_common_args, +# "recipe_list": [ +# session.init(), +# thirdpartyemailpassword.init( +# providers=[ +# ProviderInput( +# config=ProviderConfig( +# third_party_id="custom", +# clients=[ +# ProviderClientConfig( +# client_id="test", +# client_secret="test-secret", +# scope=["profile", "email"], +# ), +# ], +# authorization_endpoint="https://example.com/oauth/authorize", +# validate_access_token=valid_access_token, +# authorization_endpoint_query_params={ +# "response_type": "token", # Changing an existing parameter +# "response_mode": "form", # Adding a new parameter +# "scope": None, # Removing a parameter +# }, +# token_endpoint="https://example.com/oauth/token", +# ), +# override=get_custom_valid_token_provider +# ) +# ] +# ), +# ], +# } +# +# init(**st_init_args) # type: ignore +# start_st() +# +# res = fastapi_client.post( +# "/auth/signinup", +# json={ +# "thirdPartyId": "custom", +# "redirectURIInfo": { +# "redirectURIOnProviderDashboard": "http://127.0.0.1/callback", +# "redirectURIQueryParams": { +# "code": "abcdefghj", +# }, +# }, +# } +# ) +# assert res.status_code == 200 +# assert res.json()["status"] == "OK" From 3bddf981acc3cef1c549dfb1c38334db58c2d455 Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Wed, 4 Oct 2023 18:39:17 +0530 Subject: [PATCH 06/10] Updated changelog --- CHANGELOG.md | 1 + supertokens_python/recipe/thirdparty/providers/github.py | 2 -- tests/thirdparty/{test_providers.py => test_thirdparty.py} | 0 3 files changed, 1 insertion(+), 2 deletions(-) rename tests/thirdparty/{test_providers.py => test_thirdparty.py} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6010d7119..a0a5ac5f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `validate_access_token` function to providers - This can be used to verify the access token received from providers. + - Implemented `validate_access_token` for the Github provider. ## [0.16.3] - 2023-09-28 diff --git a/supertokens_python/recipe/thirdparty/providers/github.py b/supertokens_python/recipe/thirdparty/providers/github.py index e14c03acc..81f9e86f1 100644 --- a/supertokens_python/recipe/thirdparty/providers/github.py +++ b/supertokens_python/recipe/thirdparty/providers/github.py @@ -90,7 +90,6 @@ async def validate_access_token( f"{config.client_id}:{client_secret}".encode() ).decode() - # POST request to get applications response url = f"https://api.github.com/applications/{config.client_id}/token" headers = { "Authorization": f"Basic {basic_auth_token}", @@ -100,7 +99,6 @@ async def validate_access_token( resp = requests.post(url, headers=headers, data=payload) - # Error handling and validation if resp.status_code != 200: raise ValueError("Invalid access token") diff --git a/tests/thirdparty/test_providers.py b/tests/thirdparty/test_thirdparty.py similarity index 100% rename from tests/thirdparty/test_providers.py rename to tests/thirdparty/test_thirdparty.py From d679e31d2ca1076f69bbc2c4fc09befec7dae9d4 Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Thu, 5 Oct 2023 12:21:28 +0530 Subject: [PATCH 07/10] added success test case changed order of positional arguments --- CHANGELOG.md | 2 + .../recipe/thirdparty/provider.py | 10 +- .../recipe/thirdparty/providers/custom.py | 12 +- .../recipe/thirdparty/providers/github.py | 19 ++- tests/thirdparty/test_thirdparty.py | 131 +++++++++--------- 5 files changed, 88 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0a5ac5f7..fb63402f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.16.4] - 2023-10-05 + - Add `validate_access_token` function to providers - This can be used to verify the access token received from providers. - Implemented `validate_access_token` for the Github provider. diff --git a/supertokens_python/recipe/thirdparty/provider.py b/supertokens_python/recipe/thirdparty/provider.py index 6b42050eb..0d37e4f8c 100644 --- a/supertokens_python/recipe/thirdparty/provider.py +++ b/supertokens_python/recipe/thirdparty/provider.py @@ -173,15 +173,15 @@ def __init__( Awaitable[None], ] ] = None, + generate_fake_email: Optional[ + Callable[[str, str, Dict[str, Any]], Awaitable[str]] + ] = None, validate_access_token: Optional[ Callable[ [str, ProviderConfigForClient, Dict[str, Any]], Awaitable[None], ] ] = None, - generate_fake_email: Optional[ - Callable[[str, str, Dict[str, Any]], Awaitable[str]] - ] = None, ): self.third_party_id = third_party_id self.name = name @@ -289,8 +289,8 @@ def __init__( user_info_map, require_email, validate_id_token_payload, - validate_access_token, generate_fake_email, + validate_access_token, ) def to_json(self) -> Dict[str, Any]: @@ -349,8 +349,8 @@ def __init__( user_info_map, require_email, validate_id_token_payload, - validate_access_token, generate_fake_email, + validate_access_token, ) self.clients = clients diff --git a/supertokens_python/recipe/thirdparty/providers/custom.py b/supertokens_python/recipe/thirdparty/providers/custom.py index 732a74136..330d510a1 100644 --- a/supertokens_python/recipe/thirdparty/providers/custom.py +++ b/supertokens_python/recipe/thirdparty/providers/custom.py @@ -403,7 +403,12 @@ async def get_user_info( user_context, ) - if access_token is not None and self.config.token_endpoint is not None: + if self.config.validate_access_token is not None and access_token is not None: + await self.config.validate_access_token( + access_token, self.config, user_context + ) + + if access_token is not None and self.config.user_info_endpoint is not None: headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"} query_params: Dict[str, str] = {} @@ -422,11 +427,6 @@ async def get_user_info( self.config.user_info_endpoint, query_params, headers ) - if self.config.validate_access_token is not None and access_token is not None: - await self.config.validate_access_token( - access_token, self.config, user_context - ) - user_info_result = get_supertokens_user_info_result_from_raw_user_info( self.config, raw_user_info_from_provider ) diff --git a/supertokens_python/recipe/thirdparty/providers/github.py b/supertokens_python/recipe/thirdparty/providers/github.py index 81f9e86f1..89a39ce0d 100644 --- a/supertokens_python/recipe/thirdparty/providers/github.py +++ b/supertokens_python/recipe/thirdparty/providers/github.py @@ -14,12 +14,12 @@ from __future__ import annotations import base64 -import json from typing import Any, Dict, List, Optional -import requests - -from supertokens_python.recipe.thirdparty.providers.utils import do_get_request +from supertokens_python.recipe.thirdparty.providers.utils import ( + do_get_request, + do_post_request, +) from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail from .custom import GenericProvider, NewProvider @@ -95,14 +95,11 @@ async def validate_access_token( "Authorization": f"Basic {basic_auth_token}", "Content-Type": "application/json", } - payload = json.dumps({"access_token": access_token}) - - resp = requests.post(url, headers=headers, data=payload) - if resp.status_code != 200: + try: + body = await do_post_request(url, {"access_token": access_token}, headers) + except Exception: raise ValueError("Invalid access token") - body = resp.json() - - if "app" not in body or body["app"]["client_id"] != config.client_id: + if "app" not in body or body["app"].get("client_id") != config.client_id: raise ValueError("Access token does not belong to your application") diff --git a/tests/thirdparty/test_thirdparty.py b/tests/thirdparty/test_thirdparty.py index cdd68d41b..30b900b66 100644 --- a/tests/thirdparty/test_thirdparty.py +++ b/tests/thirdparty/test_thirdparty.py @@ -6,6 +6,7 @@ import respx from fastapi import FastAPI from pytest import fixture, mark +from pytest_mock import MockerFixture from starlette.testclient import TestClient from supertokens_python import init @@ -106,18 +107,6 @@ async def exchange_auth_code_for_valid_oauth_tokens( # pylint: disable=unused-a } -async def get_user_info( # pylint: disable=unused-argument - oauth_tokens: Dict[str, Any], - user_context: Dict[str, Any], -) -> UserInfo: - time = str(datetime.datetime.now()) - return UserInfo( - "" + time, - UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True), - RawUserInfoFromProvider({}, {}), - ) - - async def exchange_auth_code_for_invalid_oauth_tokens( # pylint: disable=unused-argument redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any], @@ -139,7 +128,6 @@ def get_custom_valid_token_provider(provider: Provider) -> Provider: provider.exchange_auth_code_for_oauth_tokens = ( exchange_auth_code_for_valid_oauth_tokens ) - provider.get_user_info = get_user_info return provider @@ -153,7 +141,9 @@ async def invalid_access_token( # pylint: disable=unused-argument async def valid_access_token( # pylint: disable=unused-argument - access_token: str, config: ProviderConfig, user_context: Optional[Dict[str, Any]] + access_token: str, + config: ProviderConfigForClient, + user_context: Optional[Dict[str, Any]], ): if access_token == "accesstoken": return @@ -210,53 +200,66 @@ async def test_signinup_when_validate_access_token_throws(fastapi_client: TestCl assert res.status_code == 500 -# async def test_signinup_works_when_validate_access_token_does_not_throw(fastapi_client: TestClient): -# st_init_args = { -# **st_init_common_args, -# "recipe_list": [ -# session.init(), -# thirdpartyemailpassword.init( -# providers=[ -# ProviderInput( -# config=ProviderConfig( -# third_party_id="custom", -# clients=[ -# ProviderClientConfig( -# client_id="test", -# client_secret="test-secret", -# scope=["profile", "email"], -# ), -# ], -# authorization_endpoint="https://example.com/oauth/authorize", -# validate_access_token=valid_access_token, -# authorization_endpoint_query_params={ -# "response_type": "token", # Changing an existing parameter -# "response_mode": "form", # Adding a new parameter -# "scope": None, # Removing a parameter -# }, -# token_endpoint="https://example.com/oauth/token", -# ), -# override=get_custom_valid_token_provider -# ) -# ] -# ), -# ], -# } -# -# init(**st_init_args) # type: ignore -# start_st() -# -# res = fastapi_client.post( -# "/auth/signinup", -# json={ -# "thirdPartyId": "custom", -# "redirectURIInfo": { -# "redirectURIOnProviderDashboard": "http://127.0.0.1/callback", -# "redirectURIQueryParams": { -# "code": "abcdefghj", -# }, -# }, -# } -# ) -# assert res.status_code == 200 -# assert res.json()["status"] == "OK" +async def test_signinup_works_when_validate_access_token_does_not_throw( + fastapi_client: TestClient, mocker: MockerFixture +): + time = str(datetime.datetime.now()) + mocker.patch( + "supertokens_python.recipe.thirdparty.providers.custom.get_supertokens_user_info_result_from_raw_user_info", + return_value=UserInfo( + "" + time, + UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True), + RawUserInfoFromProvider({}, {}), + ), + ) + + st_init_args = { + **st_init_common_args, + "recipe_list": [ + session.init(), + thirdpartyemailpassword.init( + providers=[ + ProviderInput( + config=ProviderConfig( + third_party_id="custom", + clients=[ + ProviderClientConfig( + client_id="test", + client_secret="test-secret", + scope=["profile", "email"], + ), + ], + authorization_endpoint="https://example.com/oauth/authorize", + validate_access_token=valid_access_token, + authorization_endpoint_query_params={ + "response_type": "token", # Changing an existing parameter + "response_mode": "form", # Adding a new parameter + "scope": None, # Removing a parameter + }, + token_endpoint="https://example.com/oauth/token", + ), + override=get_custom_valid_token_provider, + ) + ] + ), + ], + } + + init(**st_init_args) # type: ignore + start_st() + + res = fastapi_client.post( + "/auth/signinup", + json={ + "thirdPartyId": "custom", + "redirectURIInfo": { + "redirectURIOnProviderDashboard": "http://127.0.0.1/callback", + "redirectURIQueryParams": { + "code": "abcdefghj", + }, + }, + }, + ) + + assert res.status_code == 200 + assert res.json()["status"] == "OK" From ab2abbf1478e0dde3caf6a53d10df1af6d864c1b Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Thu, 5 Oct 2023 12:36:21 +0530 Subject: [PATCH 08/10] refactor: changed the ordering of member functions --- .../recipe/multitenancy/recipe_implementation.py | 2 +- supertokens_python/recipe/thirdparty/provider.py | 14 +++++++------- .../recipe/thirdparty/providers/custom.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 023c45b45..782a36ea1 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -104,8 +104,8 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse: user_info_map=user_info_map, require_email=p.get("requireEmail", True), validate_id_token_payload=None, - validate_access_token=None, generate_fake_email=None, + validate_access_token=None, ) ) diff --git a/supertokens_python/recipe/thirdparty/provider.py b/supertokens_python/recipe/thirdparty/provider.py index 0d37e4f8c..039816b7e 100644 --- a/supertokens_python/recipe/thirdparty/provider.py +++ b/supertokens_python/recipe/thirdparty/provider.py @@ -197,8 +197,8 @@ def __init__( self.user_info_map = user_info_map self.require_email = require_email self.validate_id_token_payload = validate_id_token_payload - self.validate_access_token = validate_access_token self.generate_fake_email = generate_fake_email + self.validate_access_token = validate_access_token def to_json(self) -> Dict[str, Any]: res = { @@ -254,15 +254,15 @@ def __init__( Awaitable[None], ] ] = None, + generate_fake_email: Optional[ + Callable[[str, str, Dict[str, Any]], Awaitable[str]] + ] = None, validate_access_token: Optional[ Callable[ [str, ProviderConfigForClient, Dict[str, Any]], Awaitable[None], ] ] = None, - generate_fake_email: Optional[ - Callable[[str, str, Dict[str, Any]], Awaitable[str]] - ] = None, ): ProviderClientConfig.__init__( self, @@ -324,15 +324,15 @@ def __init__( Awaitable[None], ] ] = None, + generate_fake_email: Optional[ + Callable[[str, str, Dict[str, Any]], Awaitable[str]] + ] = None, validate_access_token: Optional[ Callable[ [str, ProviderConfigForClient, Dict[str, Any]], Awaitable[None], ] ] = None, - generate_fake_email: Optional[ - Callable[[str, str, Dict[str, Any]], Awaitable[str]] - ] = None, ): super().__init__( third_party_id, diff --git a/supertokens_python/recipe/thirdparty/providers/custom.py b/supertokens_python/recipe/thirdparty/providers/custom.py index 330d510a1..e7a57ec35 100644 --- a/supertokens_python/recipe/thirdparty/providers/custom.py +++ b/supertokens_python/recipe/thirdparty/providers/custom.py @@ -59,8 +59,8 @@ def get_provider_config_for_client( user_info_map=config.user_info_map, require_email=config.require_email, validate_id_token_payload=config.validate_id_token_payload, - validate_access_token=config.validate_access_token, generate_fake_email=config.generate_fake_email, + validate_access_token=config.validate_access_token, ) From 0ba43d4a675fa3d9d7c8c6cb1e55f64eb93dab62 Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Thu, 5 Oct 2023 13:19:38 +0530 Subject: [PATCH 09/10] refactor: changed `do_post_request()` to return `status_code` --- .../recipe/thirdparty/providers/custom.py | 26 +++++++++---------- .../recipe/thirdparty/providers/github.py | 5 ++-- .../recipe/thirdparty/providers/twitter.py | 3 ++- .../recipe/thirdparty/providers/utils.py | 6 ++--- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/supertokens_python/recipe/thirdparty/providers/custom.py b/supertokens_python/recipe/thirdparty/providers/custom.py index e7a57ec35..05e2ff3b1 100644 --- a/supertokens_python/recipe/thirdparty/providers/custom.py +++ b/supertokens_python/recipe/thirdparty/providers/custom.py @@ -376,7 +376,8 @@ async def exchange_auth_code_for_oauth_tokens( access_token_params["redirect_uri"] = DEV_OAUTH_REDIRECT_URL # Transformation needed for dev keys END - return await do_post_request(token_api_url, access_token_params) + _, body = await do_post_request(token_api_url, access_token_params) + return body async def get_user_info( self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any] @@ -412,21 +413,20 @@ async def get_user_info( headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"} query_params: Dict[str, str] = {} - if self.config.user_info_endpoint is not None: - if self.config.user_info_endpoint_headers is not None: - headers = merge_into_dict( - self.config.user_info_endpoint_headers, headers - ) - - if self.config.user_info_endpoint_query_params is not None: - query_params = merge_into_dict( - self.config.user_info_endpoint_query_params, query_params - ) + if self.config.user_info_endpoint_headers is not None: + headers = merge_into_dict( + self.config.user_info_endpoint_headers, headers + ) - raw_user_info_from_provider.from_user_info_api = await do_get_request( - self.config.user_info_endpoint, query_params, headers + if self.config.user_info_endpoint_query_params is not None: + query_params = merge_into_dict( + self.config.user_info_endpoint_query_params, query_params ) + raw_user_info_from_provider.from_user_info_api = await do_get_request( + self.config.user_info_endpoint, query_params, headers + ) + user_info_result = get_supertokens_user_info_result_from_raw_user_info( self.config, raw_user_info_from_provider ) diff --git a/supertokens_python/recipe/thirdparty/providers/github.py b/supertokens_python/recipe/thirdparty/providers/github.py index 89a39ce0d..7c03aefa9 100644 --- a/supertokens_python/recipe/thirdparty/providers/github.py +++ b/supertokens_python/recipe/thirdparty/providers/github.py @@ -96,9 +96,8 @@ async def validate_access_token( "Content-Type": "application/json", } - try: - body = await do_post_request(url, {"access_token": access_token}, headers) - except Exception: + status, body = await do_post_request(url, {"access_token": access_token}, headers) + if status != 200: raise ValueError("Invalid access token") if "app" not in body or body["app"].get("client_id") != config.client_id: diff --git a/supertokens_python/recipe/thirdparty/providers/twitter.py b/supertokens_python/recipe/thirdparty/providers/twitter.py index 8a8637122..f1c94ecc6 100644 --- a/supertokens_python/recipe/thirdparty/providers/twitter.py +++ b/supertokens_python/recipe/thirdparty/providers/twitter.py @@ -84,11 +84,12 @@ async def exchange_auth_code_for_oauth_tokens( assert self.config.token_endpoint is not None - return await do_post_request( + _, body = await do_post_request( self.config.token_endpoint, body_params=twitter_oauth_tokens_params, headers={"Authorization": f"Basic {auth_token}"}, ) + return body def Twitter(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin diff --git a/supertokens_python/recipe/thirdparty/providers/utils.py b/supertokens_python/recipe/thirdparty/providers/utils.py index c5fbf7295..71868be4b 100644 --- a/supertokens_python/recipe/thirdparty/providers/utils.py +++ b/supertokens_python/recipe/thirdparty/providers/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from httpx import AsyncClient @@ -48,7 +48,7 @@ async def do_post_request( url: str, body_params: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None, -) -> Dict[str, Any]: +) -> Tuple[int, Dict[str, Any]]: if body_params is None: body_params = {} if headers is None: @@ -62,4 +62,4 @@ async def do_post_request( log_debug_message( "Received response with status %s and body %s", res.status_code, res.text ) - return res.json() + return res.status_code, res.json() From 69b5d7b5aad949bbea9e7a5ced78253c8d266d04 Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Thu, 5 Oct 2023 13:27:49 +0530 Subject: [PATCH 10/10] added assert to test if is called. --- tests/thirdparty/test_thirdparty.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/thirdparty/test_thirdparty.py b/tests/thirdparty/test_thirdparty.py index 30b900b66..ee6c36255 100644 --- a/tests/thirdparty/test_thirdparty.py +++ b/tests/thirdparty/test_thirdparty.py @@ -41,6 +41,8 @@ respx_mock = respx.MockRouter +access_token_validated: bool = False + @fixture(scope="function") async def fastapi_client(): @@ -145,7 +147,9 @@ async def valid_access_token( # pylint: disable=unused-argument config: ProviderConfigForClient, user_context: Optional[Dict[str, Any]], ): + global access_token_validated if access_token == "accesstoken": + access_token_validated = True return raise Exception("Unexpected access token") @@ -262,4 +266,5 @@ async def test_signinup_works_when_validate_access_token_does_not_throw( ) assert res.status_code == 200 + assert access_token_validated is True assert res.json()["status"] == "OK"