Skip to content

Commit

Permalink
fix: Added validate_access_token for github provider
Browse files Browse the repository at this point in the history
  • Loading branch information
IamMayankThakur committed Oct 4, 2023
1 parent 3429fef commit 51378bc
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 0 deletions.
36 changes: 36 additions & 0 deletions supertokens_python/recipe/thirdparty/providers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import base64
import json
from typing import Any, Dict, List, Optional

import requests

from supertokens_python.recipe.thirdparty.providers.utils import do_get_request
from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail

Expand Down Expand Up @@ -71,4 +76,35 @@ def Github(input: ProviderInput) -> Provider: # pylint: disable=redefined-built
if input.config.token_endpoint is None:
input.config.token_endpoint = "https://github.com/login/oauth/access_token"

if input.config.validate_access_token is None:
input.config.validate_access_token = validate_access_token

return NewProvider(input, GithubImpl)


async def validate_access_token(
access_token: str, config: ProviderConfigForClient, _: Dict[str, Any]
):
client_secret = "" if config.client_secret is None else config.client_secret
basic_auth_token = base64.b64encode(
f"{config.client_id}:{client_secret}".encode()
).decode()

# POST request to get applications response
url = f"https://api.github.com/applications/{config.client_id}/token"
headers = {
"Authorization": f"Basic {basic_auth_token}",
"Content-Type": "application/json",
}
payload = json.dumps({"access_token": access_token})

resp = requests.post(url, headers=headers, data=payload)

# Error handling and validation
if resp.status_code != 200:
raise ValueError("Invalid access token")

body = resp.json()

if "app" not in body or body["app"]["client_id"] != config.client_id:
raise ValueError("Access token does not belong to your application")
262 changes: 262 additions & 0 deletions tests/thirdparty/test_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import datetime
import json
from base64 import b64encode
from typing import Dict, Any, Optional

import respx
from fastapi import FastAPI
from pytest import fixture, mark
from starlette.testclient import TestClient

from supertokens_python import init
from supertokens_python.framework.fastapi import get_middleware
from supertokens_python.recipe import session, thirdparty
from supertokens_python.recipe import thirdpartyemailpassword
from supertokens_python.recipe.thirdparty.provider import (
ProviderClientConfig,
ProviderConfig,
ProviderInput,
Provider,
RedirectUriInfo,
ProviderConfigForClient,
)
from supertokens_python.recipe.thirdparty.types import (
UserInfo,
UserInfoEmail,
RawUserInfoFromProvider,
)
from tests.utils import (
setup_function,
teardown_function,
start_st,
st_init_common_args,
)

_ = setup_function # type:ignore
_ = teardown_function # type:ignore
_ = start_st # type:ignore

pytestmark = mark.asyncio

respx_mock = respx.MockRouter


@fixture(scope="function")
async def fastapi_client():
app = FastAPI()
app.add_middleware(get_middleware())

return TestClient(app, raise_server_exceptions=False)


async def test_thirdpary_parsing_works(fastapi_client: TestClient):
st_init_args = {
**st_init_common_args,
"recipe_list": [
session.init(),
thirdparty.init(
sign_in_and_up_feature=thirdparty.SignInAndUpFeature(
providers=[
thirdparty.ProviderInput(
config=thirdparty.ProviderConfig(
third_party_id="apple",
clients=[
thirdparty.ProviderClientConfig(
client_id="4398792-io.supertokens.example.service",
additional_config={
"keyId": "7M48Y4RYDL",
"teamId": "YWQCXGJRJL",
"privateKey": "-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----",
},
),
],
)
),
]
)
),
],
}
init(**st_init_args) # type: ignore
start_st()

state = b64encode(
json.dumps({"frontendRedirectURI": "http://localhost:3000/redirect"}).encode()
).decode()
code = "testing"

data = {"state": state, "code": code}
res = fastapi_client.post("/auth/callback/apple", data=data)

assert res.status_code == 303
assert res.content == b""
assert (
res.headers["location"]
== f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}"
)


async def exchange_auth_code_for_valid_oauth_tokens( # pylint: disable=unused-argument
redirect_uri_info: RedirectUriInfo,
user_context: Dict[str, Any],
) -> Dict[str, Any]:
return {
"access_token": "accesstoken",
"id_token": "idtoken",
}


async def get_user_info( # pylint: disable=unused-argument
oauth_tokens: Dict[str, Any],
user_context: Dict[str, Any],
) -> UserInfo:
time = str(datetime.datetime.now())
return UserInfo(
"" + time,
UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True),
RawUserInfoFromProvider({}, {}),
)


async def exchange_auth_code_for_invalid_oauth_tokens( # pylint: disable=unused-argument
redirect_uri_info: RedirectUriInfo,
user_context: Dict[str, Any],
) -> Dict[str, Any]:
return {
"access_token": "wrongaccesstoken",
"id_token": "wrongidtoken",
}


def get_custom_invalid_token_provider(provider: Provider) -> Provider:
provider.exchange_auth_code_for_oauth_tokens = (
exchange_auth_code_for_invalid_oauth_tokens
)
return provider


def get_custom_valid_token_provider(provider: Provider) -> Provider:
provider.exchange_auth_code_for_oauth_tokens = (
exchange_auth_code_for_valid_oauth_tokens
)
provider.get_user_info = get_user_info
return provider


async def invalid_access_token( # pylint: disable=unused-argument
access_token: str,
config: ProviderConfigForClient,
user_context: Optional[Dict[str, Any]],
):
if access_token == "wrongaccesstoken":
raise Exception("Invalid access token")


async def valid_access_token( # pylint: disable=unused-argument
access_token: str, config: ProviderConfig, user_context: Optional[Dict[str, Any]]
):
if access_token == "accesstoken":
return
raise Exception("Unexpected access token")


async def test_signinup_when_validate_access_token_throws(fastapi_client: TestClient):
st_init_args = {
**st_init_common_args,
"recipe_list": [
session.init(),
thirdpartyemailpassword.init(
providers=[
ProviderInput(
config=ProviderConfig(
third_party_id="custom",
clients=[
ProviderClientConfig(
client_id="test",
client_secret="test-secret",
scope=["profile", "email"],
),
],
authorization_endpoint="https://example.com/oauth/authorize",
validate_access_token=invalid_access_token,
authorization_endpoint_query_params={
"response_type": "token", # Changing an existing parameter
"response_mode": "form", # Adding a new parameter
"scope": None, # Removing a parameter
},
token_endpoint="https://example.com/oauth/token",
),
override=get_custom_invalid_token_provider,
)
]
),
],
}
init(**st_init_args) # type: ignore
start_st()

res = fastapi_client.post(
"/auth/signinup",
json={
"thirdPartyId": "custom",
"redirectURIInfo": {
"redirectURIOnProviderDashboard": "http://127.0.0.1/callback",
"redirectURIQueryParams": {
"code": "abcdefghj",
},
},
},
)
assert res.status_code == 500


# async def test_signinup_works_when_validate_access_token_does_not_throw(fastapi_client: TestClient):
# st_init_args = {
# **st_init_common_args,
# "recipe_list": [
# session.init(),
# thirdpartyemailpassword.init(
# providers=[
# ProviderInput(
# config=ProviderConfig(
# third_party_id="custom",
# clients=[
# ProviderClientConfig(
# client_id="test",
# client_secret="test-secret",
# scope=["profile", "email"],
# ),
# ],
# authorization_endpoint="https://example.com/oauth/authorize",
# validate_access_token=valid_access_token,
# authorization_endpoint_query_params={
# "response_type": "token", # Changing an existing parameter
# "response_mode": "form", # Adding a new parameter
# "scope": None, # Removing a parameter
# },
# token_endpoint="https://example.com/oauth/token",
# ),
# override=get_custom_valid_token_provider
# )
# ]
# ),
# ],
# }
#
# init(**st_init_args) # type: ignore
# start_st()
#
# res = fastapi_client.post(
# "/auth/signinup",
# json={
# "thirdPartyId": "custom",
# "redirectURIInfo": {
# "redirectURIOnProviderDashboard": "http://127.0.0.1/callback",
# "redirectURIQueryParams": {
# "code": "abcdefghj",
# },
# },
# }
# )
# assert res.status_code == 200
# assert res.json()["status"] == "OK"

0 comments on commit 51378bc

Please sign in to comment.