Skip to content

Commit

Permalink
Merge pull request #453 from supertokens/feat/access-token-validation
Browse files Browse the repository at this point in the history
feat: Add `validate_access_token` function to providers
  • Loading branch information
rishabhpoddar authored Oct 5, 2023
2 parents eae5482 + 69b5d7b commit 6e71e91
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 30 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

## [0.16.4] - 2023-10-05

- Add `validate_access_token` function to providers
- This can be used to verify the access token received from providers.
- Implemented `validate_access_token` for the Github provider.

## [0.16.3] - 2023-09-28

- Add Twitter provider for thirdparty login
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

setup(
name="supertokens_python",
version="0.16.3",
version="0.16.4",
author="SuperTokens",
license="Apache 2.0",
author_email="[email protected]",
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

SUPPORTED_CDI_VERSIONS = ["3.0"]
VERSION = "0.16.3"
VERSION = "0.16.4"
TELEMETRY = "/telemetry"
USER_COUNT = "/users/count"
USER_DELETE = "/user/remove"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse:
require_email=p.get("requireEmail", True),
validate_id_token_payload=None,
generate_fake_email=None,
validate_access_token=None,
)
)

Expand Down
21 changes: 21 additions & 0 deletions supertokens_python/recipe/thirdparty/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def __init__(
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
):
self.third_party_id = third_party_id
self.name = name
Expand All @@ -192,6 +198,7 @@ def __init__(
self.require_email = require_email
self.validate_id_token_payload = validate_id_token_payload
self.generate_fake_email = generate_fake_email
self.validate_access_token = validate_access_token

def to_json(self) -> Dict[str, Any]:
res = {
Expand Down Expand Up @@ -250,6 +257,12 @@ def __init__(
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
):
ProviderClientConfig.__init__(
self,
Expand Down Expand Up @@ -277,6 +290,7 @@ def __init__(
require_email,
validate_id_token_payload,
generate_fake_email,
validate_access_token,
)

def to_json(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -313,6 +327,12 @@ def __init__(
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
):
super().__init__(
third_party_id,
Expand All @@ -330,6 +350,7 @@ def __init__(
require_email,
validate_id_token_payload,
generate_fake_email,
validate_access_token,
)
self.clients = clients

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def merge_config(
user_info_map=config_from_static.user_info_map,
generate_fake_email=config_from_static.generate_fake_email,
validate_id_token_payload=config_from_static.validate_id_token_payload,
validate_access_token=config_from_static.validate_access_token,
)

if result.user_info_map is None:
Expand Down
34 changes: 20 additions & 14 deletions supertokens_python/recipe/thirdparty/providers/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def get_provider_config_for_client(
require_email=config.require_email,
validate_id_token_payload=config.validate_id_token_payload,
generate_fake_email=config.generate_fake_email,
validate_access_token=config.validate_access_token,
)


Expand Down Expand Up @@ -375,7 +376,8 @@ async def exchange_auth_code_for_oauth_tokens(
access_token_params["redirect_uri"] = DEV_OAUTH_REDIRECT_URL
# Transformation needed for dev keys END

return await do_post_request(token_api_url, access_token_params)
_, body = await do_post_request(token_api_url, access_token_params)
return body

async def get_user_info(
self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any]
Expand All @@ -402,25 +404,29 @@ async def get_user_info(
user_context,
)

if access_token is not None and self.config.token_endpoint is not None:
if self.config.validate_access_token is not None and access_token is not None:
await self.config.validate_access_token(
access_token, self.config, user_context
)

if access_token is not None and self.config.user_info_endpoint is not None:
headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"}
query_params: Dict[str, str] = {}

if self.config.user_info_endpoint is not None:
if self.config.user_info_endpoint_headers is not None:
headers = merge_into_dict(
self.config.user_info_endpoint_headers, headers
)

if self.config.user_info_endpoint_query_params is not None:
query_params = merge_into_dict(
self.config.user_info_endpoint_query_params, query_params
)
if self.config.user_info_endpoint_headers is not None:
headers = merge_into_dict(
self.config.user_info_endpoint_headers, headers
)

raw_user_info_from_provider.from_user_info_api = await do_get_request(
self.config.user_info_endpoint, query_params, headers
if self.config.user_info_endpoint_query_params is not None:
query_params = merge_into_dict(
self.config.user_info_endpoint_query_params, query_params
)

raw_user_info_from_provider.from_user_info_api = await do_get_request(
self.config.user_info_endpoint, query_params, headers
)

user_info_result = get_supertokens_user_info_result_from_raw_user_info(
self.config, raw_user_info_from_provider
)
Expand Down
32 changes: 31 additions & 1 deletion supertokens_python/recipe/thirdparty/providers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import base64
from typing import Any, Dict, List, Optional

from supertokens_python.recipe.thirdparty.providers.utils import do_get_request
from supertokens_python.recipe.thirdparty.providers.utils import (
do_get_request,
do_post_request,
)
from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail

from .custom import GenericProvider, NewProvider
Expand Down Expand Up @@ -71,4 +76,29 @@ def Github(input: ProviderInput) -> Provider: # pylint: disable=redefined-built
if input.config.token_endpoint is None:
input.config.token_endpoint = "https://github.com/login/oauth/access_token"

if input.config.validate_access_token is None:
input.config.validate_access_token = validate_access_token

return NewProvider(input, GithubImpl)


async def validate_access_token(
access_token: str, config: ProviderConfigForClient, _: Dict[str, Any]
):
client_secret = "" if config.client_secret is None else config.client_secret
basic_auth_token = base64.b64encode(
f"{config.client_id}:{client_secret}".encode()
).decode()

url = f"https://api.github.com/applications/{config.client_id}/token"
headers = {
"Authorization": f"Basic {basic_auth_token}",
"Content-Type": "application/json",
}

status, body = await do_post_request(url, {"access_token": access_token}, headers)
if status != 200:
raise ValueError("Invalid access token")

if "app" not in body or body["app"].get("client_id") != config.client_id:
raise ValueError("Access token does not belong to your application")
3 changes: 2 additions & 1 deletion supertokens_python/recipe/thirdparty/providers/twitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ async def exchange_auth_code_for_oauth_tokens(

assert self.config.token_endpoint is not None

return await do_post_request(
_, body = await do_post_request(
self.config.token_endpoint,
body_params=twitter_oauth_tokens_params,
headers={"Authorization": f"Basic {auth_token}"},
)
return body


def Twitter(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
Expand Down
6 changes: 3 additions & 3 deletions supertokens_python/recipe/thirdparty/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

from httpx import AsyncClient

Expand Down Expand Up @@ -48,7 +48,7 @@ async def do_post_request(
url: str,
body_params: Optional[Dict[str, str]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
) -> Tuple[int, Dict[str, Any]]:
if body_params is None:
body_params = {}
if headers is None:
Expand All @@ -62,4 +62,4 @@ async def do_post_request(
log_debug_message(
"Received response with status %s and body %s", res.status_code, res.text
)
return res.json()
return res.status_code, res.json()
Loading

0 comments on commit 6e71e91

Please sign in to comment.