From 51378bc3c9aa56cbeeb1eca839cd43a0b8deebd0 Mon Sep 17 00:00:00 2001 From: Mayank Thakur Date: Wed, 4 Oct 2023 18:27:47 +0530 Subject: [PATCH] fix: Added `validate_access_token` for github provider --- .../recipe/thirdparty/providers/github.py | 36 +++ tests/thirdparty/test_providers.py | 262 ++++++++++++++++++ 2 files changed, 298 insertions(+) create mode 100644 tests/thirdparty/test_providers.py diff --git a/supertokens_python/recipe/thirdparty/providers/github.py b/supertokens_python/recipe/thirdparty/providers/github.py index e411d87aa..e14c03acc 100644 --- a/supertokens_python/recipe/thirdparty/providers/github.py +++ b/supertokens_python/recipe/thirdparty/providers/github.py @@ -12,8 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations + +import base64 +import json from typing import Any, Dict, List, Optional +import requests + from supertokens_python.recipe.thirdparty.providers.utils import do_get_request from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail @@ -71,4 +76,35 @@ def Github(input: ProviderInput) -> Provider: # pylint: disable=redefined-built if input.config.token_endpoint is None: input.config.token_endpoint = "https://github.com/login/oauth/access_token" + if input.config.validate_access_token is None: + input.config.validate_access_token = validate_access_token + return NewProvider(input, GithubImpl) + + +async def validate_access_token( + access_token: str, config: ProviderConfigForClient, _: Dict[str, Any] +): + client_secret = "" if config.client_secret is None else config.client_secret + basic_auth_token = base64.b64encode( + f"{config.client_id}:{client_secret}".encode() + ).decode() + + # POST request to get applications response + url = f"https://api.github.com/applications/{config.client_id}/token" + headers = { + "Authorization": f"Basic {basic_auth_token}", + "Content-Type": "application/json", + } + payload = json.dumps({"access_token": access_token}) + + resp = requests.post(url, headers=headers, data=payload) + + # Error handling and validation + if resp.status_code != 200: + raise ValueError("Invalid access token") + + body = resp.json() + + if "app" not in body or body["app"]["client_id"] != config.client_id: + raise ValueError("Access token does not belong to your application") diff --git a/tests/thirdparty/test_providers.py b/tests/thirdparty/test_providers.py new file mode 100644 index 000000000..cdd68d41b --- /dev/null +++ b/tests/thirdparty/test_providers.py @@ -0,0 +1,262 @@ +import datetime +import json +from base64 import b64encode +from typing import Dict, Any, Optional + +import respx +from fastapi import FastAPI +from pytest import fixture, mark +from starlette.testclient import TestClient + +from supertokens_python import init +from supertokens_python.framework.fastapi import get_middleware +from supertokens_python.recipe import session, thirdparty +from supertokens_python.recipe import thirdpartyemailpassword +from supertokens_python.recipe.thirdparty.provider import ( + ProviderClientConfig, + ProviderConfig, + ProviderInput, + Provider, + RedirectUriInfo, + ProviderConfigForClient, +) +from supertokens_python.recipe.thirdparty.types import ( + UserInfo, + UserInfoEmail, + RawUserInfoFromProvider, +) +from tests.utils import ( + setup_function, + teardown_function, + start_st, + st_init_common_args, +) + +_ = setup_function # type:ignore +_ = teardown_function # type:ignore +_ = start_st # type:ignore + +pytestmark = mark.asyncio + +respx_mock = respx.MockRouter + + +@fixture(scope="function") +async def fastapi_client(): + app = FastAPI() + app.add_middleware(get_middleware()) + + return TestClient(app, raise_server_exceptions=False) + + +async def test_thirdpary_parsing_works(fastapi_client: TestClient): + st_init_args = { + **st_init_common_args, + "recipe_list": [ + session.init(), + thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature( + providers=[ + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="apple", + clients=[ + thirdparty.ProviderClientConfig( + client_id="4398792-io.supertokens.example.service", + additional_config={ + "keyId": "7M48Y4RYDL", + "teamId": "YWQCXGJRJL", + "privateKey": "-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----", + }, + ), + ], + ) + ), + ] + ) + ), + ], + } + init(**st_init_args) # type: ignore + start_st() + + state = b64encode( + json.dumps({"frontendRedirectURI": "http://localhost:3000/redirect"}).encode() + ).decode() + code = "testing" + + data = {"state": state, "code": code} + res = fastapi_client.post("/auth/callback/apple", data=data) + + assert res.status_code == 303 + assert res.content == b"" + assert ( + res.headers["location"] + == f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}" + ) + + +async def exchange_auth_code_for_valid_oauth_tokens( # pylint: disable=unused-argument + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], +) -> Dict[str, Any]: + return { + "access_token": "accesstoken", + "id_token": "idtoken", + } + + +async def get_user_info( # pylint: disable=unused-argument + oauth_tokens: Dict[str, Any], + user_context: Dict[str, Any], +) -> UserInfo: + time = str(datetime.datetime.now()) + return UserInfo( + "" + time, + UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True), + RawUserInfoFromProvider({}, {}), + ) + + +async def exchange_auth_code_for_invalid_oauth_tokens( # pylint: disable=unused-argument + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], +) -> Dict[str, Any]: + return { + "access_token": "wrongaccesstoken", + "id_token": "wrongidtoken", + } + + +def get_custom_invalid_token_provider(provider: Provider) -> Provider: + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_invalid_oauth_tokens + ) + return provider + + +def get_custom_valid_token_provider(provider: Provider) -> Provider: + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_valid_oauth_tokens + ) + provider.get_user_info = get_user_info + return provider + + +async def invalid_access_token( # pylint: disable=unused-argument + access_token: str, + config: ProviderConfigForClient, + user_context: Optional[Dict[str, Any]], +): + if access_token == "wrongaccesstoken": + raise Exception("Invalid access token") + + +async def valid_access_token( # pylint: disable=unused-argument + access_token: str, config: ProviderConfig, user_context: Optional[Dict[str, Any]] +): + if access_token == "accesstoken": + return + raise Exception("Unexpected access token") + + +async def test_signinup_when_validate_access_token_throws(fastapi_client: TestClient): + st_init_args = { + **st_init_common_args, + "recipe_list": [ + session.init(), + thirdpartyemailpassword.init( + providers=[ + ProviderInput( + config=ProviderConfig( + third_party_id="custom", + clients=[ + ProviderClientConfig( + client_id="test", + client_secret="test-secret", + scope=["profile", "email"], + ), + ], + authorization_endpoint="https://example.com/oauth/authorize", + validate_access_token=invalid_access_token, + authorization_endpoint_query_params={ + "response_type": "token", # Changing an existing parameter + "response_mode": "form", # Adding a new parameter + "scope": None, # Removing a parameter + }, + token_endpoint="https://example.com/oauth/token", + ), + override=get_custom_invalid_token_provider, + ) + ] + ), + ], + } + init(**st_init_args) # type: ignore + start_st() + + res = fastapi_client.post( + "/auth/signinup", + json={ + "thirdPartyId": "custom", + "redirectURIInfo": { + "redirectURIOnProviderDashboard": "http://127.0.0.1/callback", + "redirectURIQueryParams": { + "code": "abcdefghj", + }, + }, + }, + ) + assert res.status_code == 500 + + +# async def test_signinup_works_when_validate_access_token_does_not_throw(fastapi_client: TestClient): +# st_init_args = { +# **st_init_common_args, +# "recipe_list": [ +# session.init(), +# thirdpartyemailpassword.init( +# providers=[ +# ProviderInput( +# config=ProviderConfig( +# third_party_id="custom", +# clients=[ +# ProviderClientConfig( +# client_id="test", +# client_secret="test-secret", +# scope=["profile", "email"], +# ), +# ], +# authorization_endpoint="https://example.com/oauth/authorize", +# validate_access_token=valid_access_token, +# authorization_endpoint_query_params={ +# "response_type": "token", # Changing an existing parameter +# "response_mode": "form", # Adding a new parameter +# "scope": None, # Removing a parameter +# }, +# token_endpoint="https://example.com/oauth/token", +# ), +# override=get_custom_valid_token_provider +# ) +# ] +# ), +# ], +# } +# +# init(**st_init_args) # type: ignore +# start_st() +# +# res = fastapi_client.post( +# "/auth/signinup", +# json={ +# "thirdPartyId": "custom", +# "redirectURIInfo": { +# "redirectURIOnProviderDashboard": "http://127.0.0.1/callback", +# "redirectURIQueryParams": { +# "code": "abcdefghj", +# }, +# }, +# } +# ) +# assert res.status_code == 200 +# assert res.json()["status"] == "OK"