Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add validate_access_token function to providers #453

Merged
merged 10 commits into from
Oct 5, 2023
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

IamMayankThakur marked this conversation as resolved.
Show resolved Hide resolved
## [0.16.4] - 2023-10-05

- Add `validate_access_token` function to providers
- This can be used to verify the access token received from providers.
- Implemented `validate_access_token` for the Github provider.

## [0.16.3] - 2023-09-28

- Add Twitter provider for thirdparty login
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

setup(
name="supertokens_python",
version="0.16.3",
version="0.16.4",
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
author="SuperTokens",
license="Apache 2.0",
author_email="[email protected]",
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

SUPPORTED_CDI_VERSIONS = ["3.0"]
VERSION = "0.16.3"
VERSION = "0.16.4"
TELEMETRY = "/telemetry"
USER_COUNT = "/users/count"
USER_DELETE = "/user/remove"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse:
require_email=p.get("requireEmail", True),
validate_id_token_payload=None,
generate_fake_email=None,
validate_access_token=None,
)
)

Expand Down
21 changes: 21 additions & 0 deletions supertokens_python/recipe/thirdparty/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def __init__(
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
):
self.third_party_id = third_party_id
self.name = name
Expand All @@ -192,6 +198,7 @@ def __init__(
self.require_email = require_email
self.validate_id_token_payload = validate_id_token_payload
self.generate_fake_email = generate_fake_email
self.validate_access_token = validate_access_token

def to_json(self) -> Dict[str, Any]:
res = {
Expand Down Expand Up @@ -250,6 +257,12 @@ def __init__(
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
):
ProviderClientConfig.__init__(
self,
Expand Down Expand Up @@ -277,6 +290,7 @@ def __init__(
require_email,
validate_id_token_payload,
generate_fake_email,
validate_access_token,
)

def to_json(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -313,6 +327,12 @@ def __init__(
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_access_token: Optional[
Callable[
[str, ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
):
super().__init__(
third_party_id,
Expand All @@ -330,6 +350,7 @@ def __init__(
require_email,
validate_id_token_payload,
generate_fake_email,
validate_access_token,
)
self.clients = clients

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def merge_config(
user_info_map=config_from_static.user_info_map,
generate_fake_email=config_from_static.generate_fake_email,
validate_id_token_payload=config_from_static.validate_id_token_payload,
validate_access_token=config_from_static.validate_access_token,
)

if result.user_info_map is None:
Expand Down
8 changes: 7 additions & 1 deletion supertokens_python/recipe/thirdparty/providers/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def get_provider_config_for_client(
require_email=config.require_email,
validate_id_token_payload=config.validate_id_token_payload,
generate_fake_email=config.generate_fake_email,
validate_access_token=config.validate_access_token,
)


Expand Down Expand Up @@ -402,7 +403,12 @@ async def get_user_info(
user_context,
)

if access_token is not None and self.config.token_endpoint is not None:
if self.config.validate_access_token is not None and access_token is not None:
await self.config.validate_access_token(
access_token, self.config, user_context
)

if access_token is not None and self.config.user_info_endpoint is not None:
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"}
query_params: Dict[str, str] = {}

Expand Down
33 changes: 32 additions & 1 deletion supertokens_python/recipe/thirdparty/providers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import base64
from typing import Any, Dict, List, Optional

from supertokens_python.recipe.thirdparty.providers.utils import do_get_request
from supertokens_python.recipe.thirdparty.providers.utils import (
do_get_request,
do_post_request,
)
from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail

from .custom import GenericProvider, NewProvider
Expand Down Expand Up @@ -71,4 +76,30 @@ def Github(input: ProviderInput) -> Provider: # pylint: disable=redefined-built
if input.config.token_endpoint is None:
input.config.token_endpoint = "https://github.com/login/oauth/access_token"

if input.config.validate_access_token is None:
input.config.validate_access_token = validate_access_token

return NewProvider(input, GithubImpl)


async def validate_access_token(
access_token: str, config: ProviderConfigForClient, _: Dict[str, Any]
):
client_secret = "" if config.client_secret is None else config.client_secret
basic_auth_token = base64.b64encode(
f"{config.client_id}:{client_secret}".encode()
).decode()

url = f"https://api.github.com/applications/{config.client_id}/token"
headers = {
"Authorization": f"Basic {basic_auth_token}",
"Content-Type": "application/json",
}

try:
body = await do_post_request(url, {"access_token": access_token}, headers)
IamMayankThakur marked this conversation as resolved.
Show resolved Hide resolved
except Exception:
raise ValueError("Invalid access token")

if "app" not in body or body["app"].get("client_id") != config.client_id:
raise ValueError("Access token does not belong to your application")
Loading
Loading