From 713efb5a574525793a94f843b388a7134ce3d7cc Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 16 Sep 2024 17:58:09 +0530 Subject: [PATCH] more apis --- .../multitenancy/get_third_party_config.py | 378 ++++++++++++++++++ supertokens_python/recipe/dashboard/recipe.py | 5 + 2 files changed, 383 insertions(+) create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py new file mode 100644 index 00000000..375e3f06 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py @@ -0,0 +1,378 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, List, Optional, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception + +from supertokens_python.recipe.multitenancy.asyncio import get_tenant +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.thirdparty import ( + ProviderClientConfig, + ProviderConfig, + ProviderInput, +) +from supertokens_python.recipe.thirdparty.provider import CommonProviderConfig, Provider +from supertokens_python.recipe.thirdparty.providers.utils import do_get_request +from supertokens_python.types import APIResponse +from supertokens_python.recipe.thirdparty.providers.config_utils import ( + find_and_create_provider_instance, + merge_providers_from_core_and_static, +) +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.normalised_url_domain import NormalisedURLDomain +from ...interfaces import APIInterface, APIOptions + + +class ProviderConfigResponse(APIResponse): + def __init__( + self, + provider_config: ProviderConfig, + is_get_authorisation_redirect_url_overridden: bool, + is_exchange_auth_code_for_oauth_tokens_overridden: bool, + is_get_user_info_overridden: bool, + ): + self.provider_config = provider_config + self.is_get_authorisation_redirect_url_overridden = ( + is_get_authorisation_redirect_url_overridden + ) + self.is_exchange_auth_code_for_oauth_tokens_overridden = ( + is_exchange_auth_code_for_oauth_tokens_overridden + ) + self.is_get_user_info_overridden = is_get_user_info_overridden + + def to_json(self) -> Dict[str, Any]: + json_response = self.provider_config.to_json() + json_response[ + "isGetAuthorisationRedirectUrlOverridden" + ] = self.is_get_authorisation_redirect_url_overridden + json_response[ + "isExchangeAuthCodeForOAuthTokensOverridden" + ] = self.is_exchange_auth_code_for_oauth_tokens_overridden + json_response["isGetUserInfoOverridden"] = self.is_get_user_info_overridden + json_response["status"] = "OK" + return json_response + + +class GetThirdPartyConfigUnknownTenantError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +async def get_third_party_config( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ProviderConfigResponse, GetThirdPartyConfigUnknownTenantError]: + tenant_res = await get_tenant(tenant_id, user_context) + + if tenant_res is None: + return GetThirdPartyConfigUnknownTenantError() + + third_party_id = options.request.get_query_param("thirdPartyId") + + if third_party_id is None: + raise_bad_input_exception("Please provide thirdPartyId") + + providers_from_core = tenant_res.third_party_providers + mt_recipe = MultitenancyRecipe.get_instance() + static_providers = mt_recipe.static_third_party_providers or [] + + additional_config: Optional[Dict[str, Any]] = None + + providers_from_core = [ + provider + for provider in providers_from_core + if provider.third_party_id == third_party_id + ] + + if not providers_from_core: + providers_from_core.append(ProviderConfig(third_party_id=third_party_id)) + + if third_party_id in ["okta", "active-directory", "boxy-saml", "google-workspaces"]: + if third_party_id == "okta": + okta_domain = options.request.get_query_param("oktaDomain") + if okta_domain is not None: + additional_config = {"oktaDomain": okta_domain} + elif third_party_id == "active-directory": + directory_id = options.request.get_query_param("directoryId") + if directory_id is not None: + additional_config = {"directoryId": directory_id} + elif third_party_id == "boxy-saml": + boxy_url = options.request.get_query_param("boxyUrl") + boxy_api_key = options.request.get_query_param("boxyAPIKey") + if boxy_url is not None: + additional_config = {"boxyURL": boxy_url} + if boxy_api_key is not None: + additional_config["boxyAPIKey"] = boxy_api_key + elif third_party_id == "google-workspaces": + hd = options.request.get_query_param("hd") + if hd is not None: + additional_config = {"hd": hd} + + if additional_config is not None: + providers_from_core[0].oidc_discovery_endpoint = None + providers_from_core[0].authorization_endpoint = None + providers_from_core[0].token_endpoint = None + providers_from_core[0].user_info_endpoint = None + + if providers_from_core[0].clients is not None: + for existing_client in providers_from_core[0].clients: + if existing_client.additional_config is not None: + existing_client.additional_config = { + **existing_client.additional_config, + **additional_config, + } + else: + existing_client.additional_config = additional_config + else: + providers_from_core[0].clients = [ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id", + additional_config=additional_config, + ) + ] + + static_providers = [ + provider + for provider in static_providers + if provider.config.third_party_id == third_party_id + ] + + if not static_providers and third_party_id == "apple": + static_providers.append( + ProviderInput( + config=ProviderConfig( + third_party_id="apple", + clients=[ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id" + ) + ], + ) + ) + ) + + additional_config = { + "teamId": "", + "keyId": "", + "privateKey": "-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----", + } + + if len(static_providers) == 1 and additional_config is not None: + static_providers[0].config.oidc_discovery_endpoint = None + static_providers[0].config.authorization_endpoint = None + static_providers[0].config.token_endpoint = None + static_providers[0].config.user_info_endpoint = None + if static_providers[0].config.clients is not None: + for existing_client in static_providers[0].config.clients: + if existing_client.additional_config is not None: + existing_client.additional_config = { + **existing_client.additional_config, + **additional_config, + } + else: + existing_client.additional_config = additional_config + else: + static_providers[0].config.clients = [ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id", + additional_config=additional_config, + ) + ] + + merged_providers_from_core_and_static = merge_providers_from_core_and_static( + providers_from_core, static_providers, True + ) + + if len(merged_providers_from_core_and_static) != 1: + raise Exception("should never come here!") + + for merged_provider in merged_providers_from_core_and_static: + if merged_provider.config.third_party_id == third_party_id: + if not merged_provider.config.clients: + merged_provider.config.clients = [ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id", + additional_config=( + additional_config if additional_config is not None else None + ), + ) + ] + clients: List[ProviderClientConfig] = [] + common_provider_config: CommonProviderConfig = CommonProviderConfig( + third_party_id=third_party_id + ) + is_get_authorisation_redirect_url_overridden = False + is_exchange_auth_code_for_oauth_tokens_overridden = False + is_get_user_info_overridden = False + + for provider in merged_providers_from_core_and_static: + if provider.config.third_party_id == third_party_id: + found_correct_config = False + + for client in provider.config.clients or []: + try: + provider_instance = await find_and_create_provider_instance( + merged_providers_from_core_and_static, + third_party_id, + client.client_type, + user_context, + ) + assert provider_instance is not None + clients.append( + ProviderClientConfig( + client_id=provider_instance.config.client_id, + client_secret=provider_instance.config.client_secret, + scope=provider_instance.config.scope, + client_type=provider_instance.config.client_type, + additional_config=provider_instance.config.additional_config, + force_pkce=provider_instance.config.force_pkce, + ) + ) + # common_provider_config = CommonProviderConfig( + # third_party_id=provider_instance.config.third_party_id, + # name=provider_instance.config.name, + # authorization_endpoint=provider_instance.config.authorization_endpoint, + # authorization_endpoint_query_params=provider_instance.config.authorization_endpoint_query_params, + # token_endpoint=provider_instance.config.token_endpoint, + # token_endpoint_body_params=provider_instance.config.token_endpoint_body_params, + # user_info_endpoint=provider_instance.config.user_info_endpoint, + # user_info_endpoint_query_params=provider_instance.config.user_info_endpoint_query_params, + # user_info_endpoint_headers=provider_instance.config.user_info_endpoint_headers, + # jwks_uri=provider_instance.config.jwks_uri, + # oidc_discovery_endpoint=provider_instance.config.oidc_discovery_endpoint, + # user_info_map=provider_instance.config.user_info_map, + # require_email=provider_instance.config.require_email, + # validate_id_token_payload=provider_instance.config.validate_id_token_payload, + # validate_access_token=provider_instance.config.validate_access_token, + # generate_fake_email=provider_instance.config.generate_fake_email, + # ) + common_provider_config = provider_instance.config + + if provider.override is not None: + before_override = Provider( + config=provider_instance.config, + id=provider_instance.id, + ) + after_override = provider.override(before_override) + + if ( + before_override.get_authorisation_redirect_url + != after_override.get_authorisation_redirect_url + ): + is_get_authorisation_redirect_url_overridden = True + if ( + before_override.exchange_auth_code_for_oauth_tokens + != after_override.exchange_auth_code_for_oauth_tokens + ): + is_exchange_auth_code_for_oauth_tokens_overridden = True + if ( + before_override.get_user_info + != after_override.get_user_info + ): + is_get_user_info_overridden = True + + found_correct_config = True + except Exception: + clients.append(client) + + if not found_correct_config: + common_provider_config = provider.config + + break + + if additional_config and "privateKey" in additional_config: + additional_config["privateKey"] = "" + + temp_clients = [ + client + for client in clients + if client.client_id == "nonguessable-temporary-client-id" + ] + + final_clients = [ + client + for client in clients + if client.client_id != "nonguessable-temporary-client-id" + ] + if not final_clients: + final_clients = [ + ProviderClientConfig( + client_id="", + client_secret="", + additional_config=additional_config, + client_type=temp_clients[0].client_type, + force_pkce=temp_clients[0].force_pkce, + scope=temp_clients[0].scope, + ) + ] + + if third_party_id.startswith("boxy-saml"): + boxy_api_key = options.request.get_query_param("boxyAPIKey") + if boxy_api_key and final_clients[0].client_id: + assert isinstance(final_clients[0].additional_config, dict) + boxy_url = final_clients[0].additional_config["boxyURL"] + normalised_domain = NormalisedURLDomain(boxy_url) + normalised_base_path = NormalisedURLPath(boxy_url) + connections_path = NormalisedURLPath("/api/v1/saml/config") + + resp = await do_get_request( + normalised_domain.get_as_string_dangerous() + + normalised_base_path.append( + connections_path + ).get_as_string_dangerous(), + {"clientID": final_clients[0].client_id}, + {"Authorization": f"Api-Key {boxy_api_key}"}, + ) + + json_response = resp + final_clients[0].additional_config.update( + { + "redirectURLs": json_response["redirectUrl"], + "boxyTenant": json_response["tenant"], + "boxyProduct": json_response["product"], + } + ) + + provider_config = ProviderConfig( + third_party_id=third_party_id, + clients=final_clients, + authorization_endpoint=common_provider_config.authorization_endpoint, + authorization_endpoint_query_params=common_provider_config.authorization_endpoint_query_params, + token_endpoint=common_provider_config.token_endpoint, + token_endpoint_body_params=common_provider_config.token_endpoint_body_params, + user_info_endpoint=common_provider_config.user_info_endpoint, + user_info_endpoint_query_params=common_provider_config.user_info_endpoint_query_params, + user_info_endpoint_headers=common_provider_config.user_info_endpoint_headers, + jwks_uri=common_provider_config.jwks_uri, + oidc_discovery_endpoint=common_provider_config.oidc_discovery_endpoint, + user_info_map=common_provider_config.user_info_map, + require_email=common_provider_config.require_email, + validate_id_token_payload=common_provider_config.validate_id_token_payload, + validate_access_token=common_provider_config.validate_access_token, + generate_fake_email=common_provider_config.generate_fake_email, + name=common_provider_config.name, + ) + + return ProviderConfigResponse( + provider_config=provider_config, + is_get_authorisation_redirect_url_overridden=is_get_authorisation_redirect_url_overridden, + is_exchange_auth_code_for_oauth_tokens_overridden=is_exchange_auth_code_for_oauth_tokens_overridden, + is_get_user_info_overridden=is_get_user_info_overridden, + ) diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 32040a20..fc7761f9 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -32,6 +32,9 @@ from supertokens_python.recipe.dashboard.api.multitenancy.get_tenant_info import ( get_tenant_info, ) +from supertokens_python.recipe.dashboard.api.multitenancy.get_third_party_config import ( + get_third_party_config, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -402,6 +405,8 @@ async def handle_api_request( api_function = handle_create_or_update_third_party_config if method == "delete": api_function = delete_third_party_config_api + if method == "get": + api_function = get_third_party_config if api_function is not None: return await api_key_protector(