Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add validate_access_token function to providers #453

Merged
merged 10 commits into from
Oct 5, 2023
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

IamMayankThakur marked this conversation as resolved.
Show resolved Hide resolved
## [0.16.4] - 2023-10-05

- Add `validate_access_token` function to providers
- This can be used to verify the access token received from providers.
- Implemented `validate_access_token` for the Github provider.

## [0.16.3] - 2023-09-28

- Add Twitter provider for thirdparty login
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

setup(
name="supertokens_python",
version="0.16.3",
version="0.16.4",
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
author="SuperTokens",
license="Apache 2.0",
author_email="[email protected]",
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

SUPPORTED_CDI_VERSIONS = ["3.0"]
VERSION = "0.16.3"
VERSION = "0.16.4"
TELEMETRY = "/telemetry"
USER_COUNT = "/users/count"
USER_DELETE = "/user/remove"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse:
require_email=p.get("requireEmail", True),
validate_id_token_payload=None,
generate_fake_email=None,
validate_access_token=None,
)
)

Expand Down
21 changes: 21 additions & 0 deletions supertokens_python/recipe/thirdparty/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def __init__(
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
):
self.third_party_id = third_party_id
self.name = name
Expand All @@ -192,6 +198,7 @@ def __init__(
self.require_email = require_email
self.validate_id_token_payload = validate_id_token_payload
self.generate_fake_email = generate_fake_email
self.validate_access_token = validate_access_token

def to_json(self) -> Dict[str, Any]:
res = {
Expand Down Expand Up @@ -250,6 +257,12 @@ def __init__(
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
):
ProviderClientConfig.__init__(
self,
Expand Down Expand Up @@ -277,6 +290,7 @@ def __init__(
require_email,
validate_id_token_payload,
generate_fake_email,
validate_access_token,
)

def to_json(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -313,6 +327,12 @@ def __init__(
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
):
super().__init__(
third_party_id,
Expand All @@ -330,6 +350,7 @@ def __init__(
require_email,
validate_id_token_payload,
generate_fake_email,
validate_access_token,
)
self.clients = clients

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def merge_config(
user_info_map=config_from_static.user_info_map,
generate_fake_email=config_from_static.generate_fake_email,
validate_id_token_payload=config_from_static.validate_id_token_payload,
validate_access_token=config_from_static.validate_access_token,
)

if result.user_info_map is None:
Expand Down
34 changes: 20 additions & 14 deletions supertokens_python/recipe/thirdparty/providers/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def get_provider_config_for_client(
require_email=config.require_email,
validate_id_token_payload=config.validate_id_token_payload,
generate_fake_email=config.generate_fake_email,
validate_access_token=config.validate_access_token,
)


Expand Down Expand Up @@ -375,7 +376,8 @@ async def exchange_auth_code_for_oauth_tokens(
access_token_params["redirect_uri"] = DEV_OAUTH_REDIRECT_URL
# Transformation needed for dev keys END

return await do_post_request(token_api_url, access_token_params)
_, body = await do_post_request(token_api_url, access_token_params)
return body

async def get_user_info(
self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any]
Expand All @@ -402,25 +404,29 @@ async def get_user_info(
user_context,
)

if access_token is not None and self.config.token_endpoint is not None:
if self.config.validate_access_token is not None and access_token is not None:
await self.config.validate_access_token(
access_token, self.config, user_context
)

if access_token is not None and self.config.user_info_endpoint is not None:
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"}
query_params: Dict[str, str] = {}

if self.config.user_info_endpoint is not None:
if self.config.user_info_endpoint_headers is not None:
headers = merge_into_dict(
self.config.user_info_endpoint_headers, headers
)

if self.config.user_info_endpoint_query_params is not None:
query_params = merge_into_dict(
self.config.user_info_endpoint_query_params, query_params
)
if self.config.user_info_endpoint_headers is not None:
headers = merge_into_dict(
self.config.user_info_endpoint_headers, headers
)

raw_user_info_from_provider.from_user_info_api = await do_get_request(
self.config.user_info_endpoint, query_params, headers
if self.config.user_info_endpoint_query_params is not None:
query_params = merge_into_dict(
self.config.user_info_endpoint_query_params, query_params
)

raw_user_info_from_provider.from_user_info_api = await do_get_request(
self.config.user_info_endpoint, query_params, headers
)

user_info_result = get_supertokens_user_info_result_from_raw_user_info(
self.config, raw_user_info_from_provider
)
Expand Down
32 changes: 31 additions & 1 deletion supertokens_python/recipe/thirdparty/providers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import base64
from typing import Any, Dict, List, Optional

from supertokens_python.recipe.thirdparty.providers.utils import do_get_request
from supertokens_python.recipe.thirdparty.providers.utils import (
do_get_request,
do_post_request,
)
from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail

from .custom import GenericProvider, NewProvider
Expand Down Expand Up @@ -71,4 +76,29 @@ def Github(input: ProviderInput) -> Provider: # pylint: disable=redefined-built
if input.config.token_endpoint is None:
input.config.token_endpoint = "https://github.com/login/oauth/access_token"

if input.config.validate_access_token is None:
input.config.validate_access_token = validate_access_token

return NewProvider(input, GithubImpl)


async def validate_access_token(
access_token: str, config: ProviderConfigForClient, _: Dict[str, Any]
):
client_secret = "" if config.client_secret is None else config.client_secret
basic_auth_token = base64.b64encode(
f"{config.client_id}:{client_secret}".encode()
).decode()

url = f"https://api.github.com/applications/{config.client_id}/token"
headers = {
"Authorization": f"Basic {basic_auth_token}",
"Content-Type": "application/json",
}

status, body = await do_post_request(url, {"access_token": access_token}, headers)
if status != 200:
raise ValueError("Invalid access token")

if "app" not in body or body["app"].get("client_id") != config.client_id:
raise ValueError("Access token does not belong to your application")
3 changes: 2 additions & 1 deletion supertokens_python/recipe/thirdparty/providers/twitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ async def exchange_auth_code_for_oauth_tokens(

assert self.config.token_endpoint is not None

return await do_post_request(
_, body = await do_post_request(
self.config.token_endpoint,
body_params=twitter_oauth_tokens_params,
headers={"Authorization": f"Basic {auth_token}"},
)
return body


def Twitter(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
Expand Down
6 changes: 3 additions & 3 deletions supertokens_python/recipe/thirdparty/providers/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what version is Tuple supported in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from httpx import AsyncClient

Expand Down Expand Up @@ -48,7 +48,7 @@ async def do_post_request(
url: str,
body_params: Optional[Dict[str, str]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
) -> Tuple[int, Dict[str, Any]]:
if body_params is None:
body_params = {}
if headers is None:
Expand All @@ -62,4 +62,4 @@ async def do_post_request(
log_debug_message(
"Received response with status %s and body %s", res.status_code, res.text
)
return res.json()
return res.status_code, res.json()
Loading
Loading