diff --git a/CHANGELOG.md b/CHANGELOG.md index f14039af5..fb63402f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.16.4] - 2023-10-05 + +- Add `validate_access_token` function to providers + - This can be used to verify the access token received from providers. + - Implemented `validate_access_token` for the Github provider. + ## [0.16.3] - 2023-09-28 - Add Twitter provider for thirdparty login diff --git a/setup.py b/setup.py index 1a7960248..d81afe157 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="supertokens_python", - version="0.16.3", + version="0.16.4", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 1cd45287b..1094f851f 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["3.0"] -VERSION = "0.16.3" +VERSION = "0.16.4" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 7d569ed95..782a36ea1 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -105,6 +105,7 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse: require_email=p.get("requireEmail", True), validate_id_token_payload=None, generate_fake_email=None, + validate_access_token=None, ) ) diff --git a/supertokens_python/recipe/thirdparty/provider.py b/supertokens_python/recipe/thirdparty/provider.py index aa13b32d6..039816b7e 100644 --- a/supertokens_python/recipe/thirdparty/provider.py +++ b/supertokens_python/recipe/thirdparty/provider.py @@ -176,6 +176,12 @@ def __init__( generate_fake_email: Optional[ Callable[[str, str, Dict[str, Any]], Awaitable[str]] ] = None, + validate_access_token: Optional[ + Callable[ + [str, ProviderConfigForClient, Dict[str, Any]], + Awaitable[None], + ] + ] = None, ): self.third_party_id = third_party_id self.name = name @@ -192,6 +198,7 @@ def __init__( self.require_email = require_email self.validate_id_token_payload = validate_id_token_payload self.generate_fake_email = generate_fake_email + self.validate_access_token = validate_access_token def to_json(self) -> Dict[str, Any]: res = { @@ -250,6 +257,12 @@ def __init__( generate_fake_email: Optional[ Callable[[str, str, Dict[str, Any]], Awaitable[str]] ] = None, + validate_access_token: Optional[ + Callable[ + [str, ProviderConfigForClient, Dict[str, Any]], + Awaitable[None], + ] + ] = None, ): ProviderClientConfig.__init__( self, @@ -277,6 +290,7 @@ def __init__( require_email, validate_id_token_payload, generate_fake_email, + validate_access_token, ) def to_json(self) -> Dict[str, Any]: @@ -313,6 +327,12 @@ def __init__( generate_fake_email: Optional[ Callable[[str, str, Dict[str, Any]], Awaitable[str]] ] = None, + validate_access_token: Optional[ + Callable[ + [str, ProviderConfigForClient, Dict[str, Any]], + Awaitable[None], + ] + ] = None, ): super().__init__( third_party_id, @@ -330,6 +350,7 @@ def __init__( require_email, validate_id_token_payload, generate_fake_email, + validate_access_token, ) self.clients = clients diff --git a/supertokens_python/recipe/thirdparty/providers/config_utils.py b/supertokens_python/recipe/thirdparty/providers/config_utils.py index d520106aa..e3c70c517 100644 --- a/supertokens_python/recipe/thirdparty/providers/config_utils.py +++ b/supertokens_python/recipe/thirdparty/providers/config_utils.py @@ -87,6 +87,7 @@ def merge_config( user_info_map=config_from_static.user_info_map, generate_fake_email=config_from_static.generate_fake_email, validate_id_token_payload=config_from_static.validate_id_token_payload, + validate_access_token=config_from_static.validate_access_token, ) if result.user_info_map is None: diff --git a/supertokens_python/recipe/thirdparty/providers/custom.py b/supertokens_python/recipe/thirdparty/providers/custom.py index e9807e914..05e2ff3b1 100644 --- a/supertokens_python/recipe/thirdparty/providers/custom.py +++ b/supertokens_python/recipe/thirdparty/providers/custom.py @@ -60,6 +60,7 @@ def get_provider_config_for_client( require_email=config.require_email, validate_id_token_payload=config.validate_id_token_payload, generate_fake_email=config.generate_fake_email, + validate_access_token=config.validate_access_token, ) @@ -375,7 +376,8 @@ async def exchange_auth_code_for_oauth_tokens( access_token_params["redirect_uri"] = DEV_OAUTH_REDIRECT_URL # Transformation needed for dev keys END - return await do_post_request(token_api_url, access_token_params) + _, body = await do_post_request(token_api_url, access_token_params) + return body async def get_user_info( self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any] @@ -402,25 +404,29 @@ async def get_user_info( user_context, ) - if access_token is not None and self.config.token_endpoint is not None: + if self.config.validate_access_token is not None and access_token is not None: + await self.config.validate_access_token( + access_token, self.config, user_context + ) + + if access_token is not None and self.config.user_info_endpoint is not None: headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"} query_params: Dict[str, str] = {} - if self.config.user_info_endpoint is not None: - if self.config.user_info_endpoint_headers is not None: - headers = merge_into_dict( - self.config.user_info_endpoint_headers, headers - ) - - if self.config.user_info_endpoint_query_params is not None: - query_params = merge_into_dict( - self.config.user_info_endpoint_query_params, query_params - ) + if self.config.user_info_endpoint_headers is not None: + headers = merge_into_dict( + self.config.user_info_endpoint_headers, headers + ) - raw_user_info_from_provider.from_user_info_api = await do_get_request( - self.config.user_info_endpoint, query_params, headers + if self.config.user_info_endpoint_query_params is not None: + query_params = merge_into_dict( + self.config.user_info_endpoint_query_params, query_params ) + raw_user_info_from_provider.from_user_info_api = await do_get_request( + self.config.user_info_endpoint, query_params, headers + ) + user_info_result = get_supertokens_user_info_result_from_raw_user_info( self.config, raw_user_info_from_provider ) diff --git a/supertokens_python/recipe/thirdparty/providers/github.py b/supertokens_python/recipe/thirdparty/providers/github.py index e411d87aa..7c03aefa9 100644 --- a/supertokens_python/recipe/thirdparty/providers/github.py +++ b/supertokens_python/recipe/thirdparty/providers/github.py @@ -12,9 +12,14 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations + +import base64 from typing import Any, Dict, List, Optional -from supertokens_python.recipe.thirdparty.providers.utils import do_get_request +from supertokens_python.recipe.thirdparty.providers.utils import ( + do_get_request, + do_post_request, +) from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail from .custom import GenericProvider, NewProvider @@ -71,4 +76,29 @@ def Github(input: ProviderInput) -> Provider: # pylint: disable=redefined-built if input.config.token_endpoint is None: input.config.token_endpoint = "https://github.com/login/oauth/access_token" + if input.config.validate_access_token is None: + input.config.validate_access_token = validate_access_token + return NewProvider(input, GithubImpl) + + +async def validate_access_token( + access_token: str, config: ProviderConfigForClient, _: Dict[str, Any] +): + client_secret = "" if config.client_secret is None else config.client_secret + basic_auth_token = base64.b64encode( + f"{config.client_id}:{client_secret}".encode() + ).decode() + + url = f"https://api.github.com/applications/{config.client_id}/token" + headers = { + "Authorization": f"Basic {basic_auth_token}", + "Content-Type": "application/json", + } + + status, body = await do_post_request(url, {"access_token": access_token}, headers) + if status != 200: + raise ValueError("Invalid access token") + + if "app" not in body or body["app"].get("client_id") != config.client_id: + raise ValueError("Access token does not belong to your application") diff --git a/supertokens_python/recipe/thirdparty/providers/twitter.py b/supertokens_python/recipe/thirdparty/providers/twitter.py index 8a8637122..f1c94ecc6 100644 --- a/supertokens_python/recipe/thirdparty/providers/twitter.py +++ b/supertokens_python/recipe/thirdparty/providers/twitter.py @@ -84,11 +84,12 @@ async def exchange_auth_code_for_oauth_tokens( assert self.config.token_endpoint is not None - return await do_post_request( + _, body = await do_post_request( self.config.token_endpoint, body_params=twitter_oauth_tokens_params, headers={"Authorization": f"Basic {auth_token}"}, ) + return body def Twitter(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin diff --git a/supertokens_python/recipe/thirdparty/providers/utils.py b/supertokens_python/recipe/thirdparty/providers/utils.py index c5fbf7295..71868be4b 100644 --- a/supertokens_python/recipe/thirdparty/providers/utils.py +++ b/supertokens_python/recipe/thirdparty/providers/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from httpx import AsyncClient @@ -48,7 +48,7 @@ async def do_post_request( url: str, body_params: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None, -) -> Dict[str, Any]: +) -> Tuple[int, Dict[str, Any]]: if body_params is None: body_params = {} if headers is None: @@ -62,4 +62,4 @@ async def do_post_request( log_debug_message( "Received response with status %s and body %s", res.status_code, res.text ) - return res.json() + return res.status_code, res.json() diff --git a/tests/thirdparty/test_thirdparty.py b/tests/thirdparty/test_thirdparty.py index e20863836..ee6c36255 100644 --- a/tests/thirdparty/test_thirdparty.py +++ b/tests/thirdparty/test_thirdparty.py @@ -1,15 +1,31 @@ -import respx +import datetime import json +from base64 import b64encode +from typing import Dict, Any, Optional -from pytest import fixture, mark +import respx from fastapi import FastAPI -from supertokens_python.framework.fastapi import get_middleware +from pytest import fixture, mark +from pytest_mock import MockerFixture from starlette.testclient import TestClient -from supertokens_python.recipe import session, thirdparty from supertokens_python import init -from base64 import b64encode - +from supertokens_python.framework.fastapi import get_middleware +from supertokens_python.recipe import session, thirdparty +from supertokens_python.recipe import thirdpartyemailpassword +from supertokens_python.recipe.thirdparty.provider import ( + ProviderClientConfig, + ProviderConfig, + ProviderInput, + Provider, + RedirectUriInfo, + ProviderConfigForClient, +) +from supertokens_python.recipe.thirdparty.types import ( + UserInfo, + UserInfoEmail, + RawUserInfoFromProvider, +) from tests.utils import ( setup_function, teardown_function, @@ -17,23 +33,23 @@ st_init_common_args, ) - _ = setup_function # type:ignore _ = teardown_function # type:ignore _ = start_st # type:ignore - pytestmark = mark.asyncio respx_mock = respx.MockRouter +access_token_validated: bool = False + @fixture(scope="function") async def fastapi_client(): app = FastAPI() app.add_middleware(get_middleware()) - return TestClient(app) + return TestClient(app, raise_server_exceptions=False) async def test_thirdpary_parsing_works(fastapi_client: TestClient): @@ -81,3 +97,174 @@ async def test_thirdpary_parsing_works(fastapi_client: TestClient): res.headers["location"] == f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}" ) + + +async def exchange_auth_code_for_valid_oauth_tokens( # pylint: disable=unused-argument + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], +) -> Dict[str, Any]: + return { + "access_token": "accesstoken", + "id_token": "idtoken", + } + + +async def exchange_auth_code_for_invalid_oauth_tokens( # pylint: disable=unused-argument + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], +) -> Dict[str, Any]: + return { + "access_token": "wrongaccesstoken", + "id_token": "wrongidtoken", + } + + +def get_custom_invalid_token_provider(provider: Provider) -> Provider: + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_invalid_oauth_tokens + ) + return provider + + +def get_custom_valid_token_provider(provider: Provider) -> Provider: + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_valid_oauth_tokens + ) + return provider + + +async def invalid_access_token( # pylint: disable=unused-argument + access_token: str, + config: ProviderConfigForClient, + user_context: Optional[Dict[str, Any]], +): + if access_token == "wrongaccesstoken": + raise Exception("Invalid access token") + + +async def valid_access_token( # pylint: disable=unused-argument + access_token: str, + config: ProviderConfigForClient, + user_context: Optional[Dict[str, Any]], +): + global access_token_validated + if access_token == "accesstoken": + access_token_validated = True + return + raise Exception("Unexpected access token") + + +async def test_signinup_when_validate_access_token_throws(fastapi_client: TestClient): + st_init_args = { + **st_init_common_args, + "recipe_list": [ + session.init(), + thirdpartyemailpassword.init( + providers=[ + ProviderInput( + config=ProviderConfig( + third_party_id="custom", + clients=[ + ProviderClientConfig( + client_id="test", + client_secret="test-secret", + scope=["profile", "email"], + ), + ], + authorization_endpoint="https://example.com/oauth/authorize", + validate_access_token=invalid_access_token, + authorization_endpoint_query_params={ + "response_type": "token", # Changing an existing parameter + "response_mode": "form", # Adding a new parameter + "scope": None, # Removing a parameter + }, + token_endpoint="https://example.com/oauth/token", + ), + override=get_custom_invalid_token_provider, + ) + ] + ), + ], + } + init(**st_init_args) # type: ignore + start_st() + + res = fastapi_client.post( + "/auth/signinup", + json={ + "thirdPartyId": "custom", + "redirectURIInfo": { + "redirectURIOnProviderDashboard": "http://127.0.0.1/callback", + "redirectURIQueryParams": { + "code": "abcdefghj", + }, + }, + }, + ) + assert res.status_code == 500 + + +async def test_signinup_works_when_validate_access_token_does_not_throw( + fastapi_client: TestClient, mocker: MockerFixture +): + time = str(datetime.datetime.now()) + mocker.patch( + "supertokens_python.recipe.thirdparty.providers.custom.get_supertokens_user_info_result_from_raw_user_info", + return_value=UserInfo( + "" + time, + UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True), + RawUserInfoFromProvider({}, {}), + ), + ) + + st_init_args = { + **st_init_common_args, + "recipe_list": [ + session.init(), + thirdpartyemailpassword.init( + providers=[ + ProviderInput( + config=ProviderConfig( + third_party_id="custom", + clients=[ + ProviderClientConfig( + client_id="test", + client_secret="test-secret", + scope=["profile", "email"], + ), + ], + authorization_endpoint="https://example.com/oauth/authorize", + validate_access_token=valid_access_token, + authorization_endpoint_query_params={ + "response_type": "token", # Changing an existing parameter + "response_mode": "form", # Adding a new parameter + "scope": None, # Removing a parameter + }, + token_endpoint="https://example.com/oauth/token", + ), + override=get_custom_valid_token_provider, + ) + ] + ), + ], + } + + init(**st_init_args) # type: ignore + start_st() + + res = fastapi_client.post( + "/auth/signinup", + json={ + "thirdPartyId": "custom", + "redirectURIInfo": { + "redirectURIOnProviderDashboard": "http://127.0.0.1/callback", + "redirectURIQueryParams": { + "code": "abcdefghj", + }, + }, + }, + ) + + assert res.status_code == 200 + assert access_token_validated is True + assert res.json()["status"] == "OK"