diff --git a/lib/galaxy/authnz/custos_authnz.py b/lib/galaxy/authnz/custos_authnz.py index 18e59d716e9d..8f7b65d5de8c 100644 --- a/lib/galaxy/authnz/custos_authnz.py +++ b/lib/galaxy/authnz/custos_authnz.py @@ -4,10 +4,15 @@ import logging import os import time +from dataclasses import dataclass from datetime import ( datetime, timedelta, ) +from typing import ( + List, + Optional, +) from urllib.parse import quote import jwt @@ -43,35 +48,54 @@ class InvalidAuthnzConfigException(Exception): pass -class CustosAuthnz(IdentityProvider): +@dataclass +class CustosAuthnzConfiguration: + provider: str + verify_ssl: Optional[bool] + url: str + label: str + client_id: str + client_secret: str + require_create_confirmation: bool + redirect_uri: str + ca_bundle: Optional[str] + pkce_support: bool + extra_params: Optional[dict] + authorization_endpoint: Optional[str] + token_endpoint: Optional[str] + end_session_endpoint: Optional[str] + well_known_oidc_config_uri: Optional[str] + iam_client_secret: Optional[str] + userinfo_endpoint: Optional[str] + credential_url: Optional[str] + + +class OIDCAuthnzBase(IdentityProvider): def __init__(self, provider, oidc_config, oidc_backend_config, idphint=None): provider = provider.lower() - self.config = {"provider": provider} - self.config["verify_ssl"] = oidc_config["VERIFY_SSL"] - self.config["url"] = oidc_backend_config["url"] - self.config["label"] = oidc_backend_config.get("label", provider.capitalize()) - self.config["client_id"] = oidc_backend_config["client_id"] - self.config["client_secret"] = oidc_backend_config["client_secret"] - self.config["require_create_confirmation"] = oidc_backend_config.get( - "require_create_confirmation", provider == "custos" + self.config = CustosAuthnzConfiguration( + provider=provider, + verify_ssl=oidc_config["VERIFY_SSL"], + url=oidc_backend_config["url"], + label=oidc_backend_config.get("label", provider.capitalize()), + client_id=oidc_backend_config["client_id"], + client_secret=oidc_backend_config["client_secret"], + require_create_confirmation=oidc_backend_config.get("require_create_confirmation", provider == "custos"), + redirect_uri=oidc_backend_config["redirect_uri"], + ca_bundle=oidc_backend_config.get("ca_bundle", None), + pkce_support=oidc_backend_config.get("pkce_support", False), + extra_params={}, + authorization_endpoint=None, + token_endpoint=None, + end_session_endpoint=None, + well_known_oidc_config_uri=None, + iam_client_secret=None, + userinfo_endpoint=None, + credential_url=None, ) - self.config["redirect_uri"] = oidc_backend_config["redirect_uri"] - self.config["ca_bundle"] = oidc_backend_config.get("ca_bundle", None) - self.config["pkce_support"] = oidc_backend_config.get("pkce_support", False) - self.config["extra_params"] = { - "kc_idp_hint": oidc_backend_config.get( - "idphint", "oidc" if self.config["provider"] in ["custos", "keycloak"] else "cilogon" - ) - } - if provider == "cilogon": - self._load_config_for_cilogon() - elif provider == "custos": - self._load_config_for_custos() - elif provider == "keycloak": - self._load_config_for_keycloak() def _decode_token_no_signature(self, token): - return jwt.decode(token, audience=self.config["client_id"], options={"verify_signature": False}) + return jwt.decode(token, audience=self.config.client_id, options={"verify_signature": False}) def refresh(self, trans, custos_authnz_token): if custos_authnz_token is None: @@ -82,14 +106,15 @@ def refresh(self, trans, custos_authnz_token): return False log.info(custos_authnz_token.access_token) oauth2_session = self._create_oauth2_session() - token_endpoint = self.config["token_endpoint"] - if self.config.get("iam_client_secret"): - client_secret = self.config["iam_client_secret"] + token_endpoint = self.config.token_endpoint + if self.config.iam_client_secret: + client_secret = self.config.iam_client_secret else: - client_secret = self.config["client_secret"] - clientIdAndSec = f"{self.config['client_id']}:{self.config['client_secret']}" # for custos + client_secret = self.config.client_secret + clientIdAndSec = f"{self.config.client_id}:{self.config.client_secret}" # for custos params = { + "client_id": self.config.client_id, "client_secret": client_secret, "refresh_token": custos_authnz_token.refresh_token, "headers": { @@ -110,18 +135,20 @@ def refresh(self, trans, custos_authnz_token): trans.sa_session.flush() return True + def _get_provider_specific_scopes(self): + return [] + def authenticate(self, trans, idphint=None): - base_authorize_url = self.config["authorization_endpoint"] + base_authorize_url = self.config.authorization_endpoint scopes = ["openid", "email", "profile"] - if self.config["provider"] in ["custos", "cilogon"]: - scopes.append("org.cilogon.userinfo") + scopes.extend(self._get_provider_specific_scopes()) oauth2_session = self._create_oauth2_session(scope=scopes) nonce = generate_nonce() nonce_hash = self._hash_nonce(nonce) extra_params = {"nonce": nonce_hash} if idphint is not None: extra_params["idphint"] = idphint - if self.config["pkce_support"]: + if self.config.pkce_support: if not pkce: raise InvalidAuthnzConfigException( "The python 'pkce' library is not installed but Galaxy is configured to use it " @@ -131,8 +158,8 @@ def authenticate(self, trans, idphint=None): extra_params["code_challenge"] = code_challenge extra_params["code_challenge_method"] = "S256" trans.set_cookie(value=code_verifier, name=VERIFIER_COOKIE_NAME) - if "extra_params" in self.config: - extra_params.update(self.config["extra_params"]) + if self.config.extra_params: + extra_params.update(self.config.extra_params) authorization_url, state = oauth2_session.authorization_url(base_authorize_url, **extra_params) trans.set_cookie(value=state, name=STATE_COOKIE_NAME) trans.set_cookie(value=nonce, name=NONCE_COOKIE_NAME) @@ -185,7 +212,7 @@ def callback(self, state_token, authz_code, trans, login_redirect_url): refresh_expiration_time = processed_token["refresh_expiration_time"] # Create or update custos_authnz_token record - custos_authnz_token = self._get_custos_authnz_token(trans.sa_session, user_id, self.config["provider"]) + custos_authnz_token = self._get_custos_authnz_token(trans.sa_session, user_id, self.config.provider) if custos_authnz_token is None: user = trans.user existing_user = trans.sa_session.query(User).filter_by(email=email).first() @@ -208,13 +235,13 @@ def callback(self, state_token, authz_code, trans, login_redirect_url): log.info(message) login_redirect_url = ( f"{login_redirect_url}login/start" - f"?connect_external_provider={self.config['provider']}" + f"?connect_external_provider={self.config.provider}" f"&connect_external_email={email}" - f"&connect_external_label={self.config['label']}" + f"&connect_external_label={self.config.label}" ) return login_redirect_url, None - elif self.config["require_create_confirmation"]: - login_redirect_url = f"{login_redirect_url}login/start?confirm=true&provider_token={json.dumps(token)}&provider={self.config['provider']}" + elif self.config.require_create_confirmation: + login_redirect_url = f"{login_redirect_url}login/start?confirm=true&provider_token={json.dumps(token)}&provider={self.config.provider}" return login_redirect_url, None else: user = trans.app.user_manager.create(email=email, username=username) @@ -225,14 +252,14 @@ def callback(self, state_token, authz_code, trans, login_redirect_url): custos_authnz_token = CustosAuthnzToken( user=user, external_user_id=user_id, - provider=self.config["provider"], + provider=self.config.provider, access_token=access_token, id_token=id_token, refresh_token=refresh_token, expiration_time=expiration_time, refresh_expiration_time=refresh_expiration_time, ) - label = self.config["label"] + label = self.config.label if existing_user and existing_user != user: redirect_url = ( f"{login_redirect_url}user/external_ids" @@ -293,7 +320,7 @@ def create_user(self, token, trans, login_redirect_url): custos_authnz_token = CustosAuthnzToken( external_user_id=user_id, - provider=self.config["provider"], + provider=self.config.provider, access_token=access_token, id_token=id_token, refresh_token=refresh_token, @@ -314,9 +341,9 @@ def disconnect(self, provider, trans, email=None, disconnect_redirect_url=None): user = trans.user index = 0 # Find CustosAuthnzToken record for this provider (should only be one) - provider_tokens = [token for token in user.custos_auth if token.provider == self.config["provider"]] + provider_tokens = [token for token in user.custos_auth if token.provider == self.config.provider] if len(provider_tokens) == 0: - raise Exception(f"User is not associated with provider {self.config['provider']}") + raise Exception(f"User is not associated with provider {self.config.provider}") if len(provider_tokens) > 1: for idx, token in enumerate(provider_tokens): id_token_decoded = self._decode_token_no_signature(token.id_token) @@ -330,8 +357,12 @@ def disconnect(self, provider, trans, email=None, disconnect_redirect_url=None): return False, f"Failed to disconnect provider {provider}: {util.unicodify(e)}", None def logout(self, trans, post_user_logout_href=None): + if not self.config.redirect_uri: + log.error("Failed to generate logout redirect_url") + return None try: - redirect_url = self.config["end_session_endpoint"] + if self.config.end_session_endpoint: + redirect_url = self.config.end_session_endpoint if post_user_logout_href is not None: redirect_url += f"?redirect_uri={quote(post_user_logout_href)}" return redirect_url @@ -340,8 +371,8 @@ def logout(self, trans, post_user_logout_href=None): return None def _create_oauth2_session(self, state=None, scope=None): - client_id = self.config["client_id"] - redirect_uri = self.config["redirect_uri"] + client_id = self.config.client_id + redirect_uri = self.config.redirect_uri if redirect_uri.startswith("http://localhost") and os.environ.get("OAUTHLIB_INSECURE_TRANSPORT", None) != "1": log.warning("Setting OAUTHLIB_INSECURE_TRANSPORT to '1' to allow plain HTTP (non-SSL) callback") os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" @@ -350,13 +381,13 @@ def _create_oauth2_session(self, state=None, scope=None): return session def _fetch_token(self, oauth2_session, trans): - if self.config.get("iam_client_secret"): + if self.config.iam_client_secret: # Custos uses the Keycloak client secret to get the token - client_secret = self.config["iam_client_secret"] + client_secret = self.config.iam_client_secret else: - client_secret = self.config["client_secret"] - token_endpoint = self.config["token_endpoint"] - clientIdAndSec = f"{self.config['client_id']}:{self.config['client_secret']}" # for custos + client_secret = self.config.client_secret + token_endpoint = self.config.token_endpoint + clientIdAndSec = f"{self.config.client_id}:{self.config.client_secret}" # for custos params = { "client_secret": client_secret, @@ -366,20 +397,22 @@ def _fetch_token(self, oauth2_session, trans): }, # for custos "verify": self._get_verify_param(), } - if self.config["pkce_support"]: + if self.config.pkce_support: code_verifier = trans.get_cookie(name=VERIFIER_COOKIE_NAME) trans.set_cookie("", name=VERIFIER_COOKIE_NAME, age=-1) params["code_verifier"] = code_verifier return oauth2_session.fetch_token(token_endpoint, **params) def _get_userinfo(self, oauth2_session): - userinfo_endpoint = self.config["userinfo_endpoint"] + userinfo_endpoint = self.config.userinfo_endpoint return oauth2_session.get(userinfo_endpoint, verify=self._get_verify_param()).json() - def _get_custos_authnz_token(self, sa_session, user_id, provider): + @staticmethod + def _get_custos_authnz_token(sa_session, user_id, provider): return sa_session.query(CustosAuthnzToken).filter_by(external_user_id=user_id, provider=provider).one_or_none() - def _hash_nonce(self, nonce): + @staticmethod + def _hash_nonce(nonce): return hashlib.sha256(util.smart_str(nonce)).hexdigest() def _validate_nonce(self, trans, nonce_hash): @@ -390,86 +423,51 @@ def _validate_nonce(self, trans, nonce_hash): if nonce_hash != nonce_cookie_hash: raise Exception("Nonce mismatch. Check that configured redirect_uri matches the URL you are using.") - def _load_config_for_cilogon(self): - # Set cilogon endpoints - self.config["well_known_oidc_config_uri"] = self._get_well_known_uri_from_url(self.config["provider"]) - well_known_oidc_config = self._fetch_well_known_oidc_config(self.config["well_known_oidc_config_uri"]) - self._load_well_known_oidc_config(well_known_oidc_config) - - def _load_config_for_custos(self): - self.config["well_known_oidc_config_uri"] = self._get_well_known_uri_from_url(self.config["provider"]) - self.config["credential_url"] = f"{self.config['url'].rstrip('/')}/credentials" - self._get_custos_credentials() - # Set custos endpoints - clientIdAndSec = f"{self.config['client_id']}:{self.config['client_secret']}" - eps = requests.get( - self.config["well_known_oidc_config_uri"], - headers={"Authorization": f"Basic {util.unicodify(base64.b64encode(util.smart_str(clientIdAndSec)))}"}, - verify=False, - params={"client_id": self.config["client_id"]}, - timeout=util.DEFAULT_SOCKET_TIMEOUT, - ) - well_known_oidc_config = eps.json() - self._load_well_known_oidc_config(well_known_oidc_config) - - def _load_config_for_keycloak(self): - self.config["well_known_oidc_config_uri"] = self._get_well_known_uri_from_url(self.config["provider"]) - well_known_oidc_config = self._fetch_well_known_oidc_config(self.config["well_known_oidc_config_uri"]) - self._load_well_known_oidc_config(well_known_oidc_config) - - def _get_custos_credentials(self): - clientIdAndSec = f"{self.config['client_id']}:{self.config['client_secret']}" - creds = requests.get( - self.config["credential_url"], - headers={"Authorization": f"Basic {util.unicodify(base64.b64encode(util.smart_str(clientIdAndSec)))}"}, - verify=False, - params={"client_id": self.config["client_id"]}, - timeout=util.DEFAULT_SOCKET_TIMEOUT, - ) - credentials = creds.json() - self.config["iam_client_secret"] = credentials["iam_client_secret"] - - def _get_well_known_uri_from_url(self, provider): - # TODO: Look up this URL from a Python library - if provider in ["custos", "keycloak"]: - base_url = self.config["url"] - # Remove potential trailing slash to avoid "//realms" - base_url = base_url if base_url[-1] != "/" else base_url[:-1] - return f"{base_url}/.well-known/openid-configuration" - if provider == "cilogon": - base_url = self.config["url"] - # backwards compatibility. CILogon URL is given with /authorize in the examples. not sure if - # this applies to the wild, but let's be safe here and remove the /authorize if it exists - # which will lead to the correct openid configuration - base_url = base_url if base_url.split("/")[-1] != "authorize" else "/".join(base_url.split("/")[:-1]) - return f"{base_url}/.well-known/openid-configuration" - else: - raise Exception(f"Unknown Custos provider name: {provider}") - - def _fetch_well_known_oidc_config(self, well_known_uri): + def _load_config(self, headers: Optional[dict] = None, params: Optional[dict] = None): + if not headers: + headers = {} + if not params: + params = {} + self.config.well_known_oidc_config_uri = self._get_well_known_uri_from_url(self.config.provider) + if not self.config.well_known_oidc_config_uri: + log.error(f"Failed to load well-known OIDC config URI: {self.config.well_known_oidc_config_uri}") + raise Exception(f"Failed to load well-known OIDC config URI: {self.config.well_known_oidc_config_uri}") try: - return requests.get( - well_known_uri, verify=self._get_verify_param(), timeout=util.DEFAULT_SOCKET_TIMEOUT + well_known_oidc_config = requests.get( + self.config.well_known_oidc_config_uri, + headers=headers, + verify=self._get_verify_param(), + timeout=util.DEFAULT_SOCKET_TIMEOUT, + params=params, ).json() + self._load_well_known_oidc_config(well_known_oidc_config) except Exception: - log.error(f"Failed to load well-known OIDC config URI: {well_known_uri}") + log.error(f"Failed to load well-known OIDC config URI: {self.config.well_known_oidc_config_uri}") raise + def _get_well_known_uri_from_url(self, provider): + # TODO: Look up this URL from a Python library + base_url = self.config.url + # Remove potential trailing slash to avoid "//realms" + base_url = base_url if base_url[-1] != "/" else base_url[:-1] + return f"{base_url}/.well-known/openid-configuration" + def _load_well_known_oidc_config(self, well_known_oidc_config): - self.config["authorization_endpoint"] = well_known_oidc_config["authorization_endpoint"] - self.config["token_endpoint"] = well_known_oidc_config["token_endpoint"] - self.config["userinfo_endpoint"] = well_known_oidc_config["userinfo_endpoint"] - self.config["end_session_endpoint"] = well_known_oidc_config.get("end_session_endpoint") + self.config.authorization_endpoint = well_known_oidc_config["authorization_endpoint"] + self.config.token_endpoint = well_known_oidc_config["token_endpoint"] + self.config.userinfo_endpoint = well_known_oidc_config["userinfo_endpoint"] + self.config.end_session_endpoint = well_known_oidc_config.get("end_session_endpoint") def _get_verify_param(self): """Return 'ca_bundle' if 'verify_ssl' is true and 'ca_bundle' is configured.""" # in requests_oauthlib, the verify param can either be a boolean or a CA bundle path - if self.config["ca_bundle"] is not None and self.config["verify_ssl"]: - return self.config["ca_bundle"] + if self.config.ca_bundle is not None and self.config.verify_ssl: + return self.config.ca_bundle else: - return self.config["verify_ssl"] + return self.config.verify_ssl - def _username_from_userinfo(self, trans, userinfo): + @staticmethod + def _username_from_userinfo(trans, userinfo): username = userinfo.get("preferred_username", userinfo["email"]) if "@" in username: username = username.split("@")[0] # username created from username portion of email @@ -482,3 +480,110 @@ def _username_from_userinfo(self, trans, userinfo): return f"{username}{count}" else: return username + + +class OIDCAuthnzBaseKeycloak(OIDCAuthnzBase): + def __init__(self, provider, oidc_config, oidc_backend_config, idphint=None): + super().__init__(provider, oidc_config, oidc_backend_config, idphint) + self.config.extra_params = {"kc_idp_hint": oidc_backend_config.get("idphint", "oidc")} + self._load_config() + + +class OIDCAuthnzBaseCiLogon(OIDCAuthnzBase): + def __init__(self, provider, oidc_config, oidc_backend_config, idphint=None): + super().__init__(provider, oidc_config, oidc_backend_config, idphint) + self.config.extra_params = {"kc_idp_hint": oidc_backend_config.get("idphint", "cilogon")} + self._load_config() + + def _get_provider_specific_scopes(self): + return ["org.cilogon.userinfo"] + + def _get_well_known_uri_from_url(self, provider): + base_url = self.config.url + # backwards compatibility. CILogon URL is given with /authorize in the examples. not sure if + # this applies to the wild, but let's be safe here and remove the /authorize if it exists + # which will lead to the correct openid configuration + base_url = base_url if base_url.split("/")[-1] != "authorize" else "/".join(base_url.split("/")[:-1]) + return f"{base_url}/.well-known/openid-configuration" + + +class CustosAuthFactory: + @dataclass + class _CustosAuthBasedProviderCacheItem: + created_at: datetime + item: OIDCAuthnzBase + provider: str + oidc_config: dict + oidc_backend_config: dict + idphint: str + + _CustosAuthBasedProvidersCache: List[_CustosAuthBasedProviderCacheItem] = [] + + @staticmethod + def GetCustosBasedAuthProvider(provider, oidc_config, oidc_backend_config, idphint=None): + # see if we have a config loaded up already + for item in CustosAuthFactory._CustosAuthBasedProvidersCache: + if ( + item.provider == provider + and item.oidc_config == oidc_config + and item.oidc_backend_config == oidc_backend_config + and item.idphint == idphint + ): + return item.item + + auth_adapter: OIDCAuthnzBase + if provider.lower() == "custos": + auth_adapter = OIDCAuthnzBaseCustos(provider, oidc_config, oidc_backend_config, idphint) + elif provider.lower() == "keycloak": + auth_adapter = OIDCAuthnzBaseKeycloak(provider, oidc_config, oidc_backend_config, idphint) + elif provider.lower() == "cilogon": + auth_adapter = OIDCAuthnzBaseCiLogon(provider, oidc_config, oidc_backend_config, idphint) + else: + raise Exception(f"Unknown Custos provider name: {provider}") + + if auth_adapter: + CustosAuthFactory._CustosAuthBasedProvidersCache.append( + CustosAuthFactory._CustosAuthBasedProviderCacheItem( + created_at=datetime.now(), + item=auth_adapter, + provider=provider, + oidc_config=oidc_config, + oidc_backend_config=oidc_backend_config, + idphint=idphint, + ) + ) + + return auth_adapter + + +class OIDCAuthnzBaseCustos(OIDCAuthnzBase): + def __init__(self, provider, oidc_config, oidc_backend_config, idphint=None): + super().__init__(provider, oidc_config, oidc_backend_config, idphint) + self.config.extra_params = {"kc_idp_hint": oidc_backend_config.get("idphint", "oidc")} + self._load_config_for_custos() + + def _get_custos_credentials(self): + clientIdAndSec = f"{self.config.client_id}:{self.config.client_secret}" + if not self.config.credential_url: + raise Exception( + f"Error OIDC provider {self.config.provider} is of type Custos, but does not have the credential url set" + ) + creds = requests.get( + self.config.credential_url, + headers={"Authorization": f"Basic {util.unicodify(base64.b64encode(util.smart_str(clientIdAndSec)))}"}, + verify=False, + params={"client_id": self.config.client_id}, + timeout=util.DEFAULT_SOCKET_TIMEOUT, + ) + credentials = creds.json() + self.config.iam_client_secret = credentials["iam_client_secret"] + + def _load_config_for_custos(self): + self.config.credential_url = f"{self.config.url.rstrip('/')}/credentials" + self._get_custos_credentials() + # Set custos endpoints + clientIdAndSec = f"{self.config.client_id}:{self.config.client_secret}" + headers = {"Authorization": f"Basic {util.unicodify(base64.b64encode(util.smart_str(clientIdAndSec)))}"} + params = {"client_id": self.config.client_id} + + self._load_config(headers, params) diff --git a/lib/galaxy/authnz/managers.py b/lib/galaxy/authnz/managers.py index 97bfdebe0c04..d766efcd111d 100644 --- a/lib/galaxy/authnz/managers.py +++ b/lib/galaxy/authnz/managers.py @@ -23,7 +23,7 @@ unicodify, ) from .custos_authnz import ( - CustosAuthnz, + CustosAuthFactory, KEYCLOAK_BACKENDS, ) from .psa_authnz import ( @@ -210,7 +210,7 @@ def _get_authnz_backend(self, provider, idphint=None): unified_provider_name = self._unify_provider_name(provider) if unified_provider_name in self.oidc_backends_config: provider = unified_provider_name - identity_provider_class = self._get_identity_provider_class(self.oidc_backends_implementation[provider]) + identity_provider_class = self._get_identity_provider_factory(self.oidc_backends_implementation[provider]) try: if provider in KEYCLOAK_BACKENDS: return ( @@ -240,11 +240,11 @@ def _get_authnz_backend(self, provider, idphint=None): return False, msg, None @staticmethod - def _get_identity_provider_class(implementation): + def _get_identity_provider_factory(implementation): if implementation == "psa": return PSAAuthnz elif implementation == "custos": - return CustosAuthnz + return CustosAuthFactory.GetCustosBasedAuthProvider else: return None diff --git a/test/unit/app/authnz/test_custos_authnz.py b/test/unit/app/authnz/test_custos_authnz.py index fd73e3bb2f04..adda2ebcb116 100644 --- a/test/unit/app/authnz/test_custos_authnz.py +++ b/test/unit/app/authnz/test_custos_authnz.py @@ -56,7 +56,7 @@ def setUp(self): self._get_credential_url(): {"iam_client_secret": "TESTSECRET"}, } ) - self.custos_authnz = custos_authnz.CustosAuthnz( + self.custos_authnz = custos_authnz.CustosAuthFactory.GetCustosBasedAuthProvider( "Custos", {"VERIFY_SSL": True}, { @@ -246,15 +246,15 @@ def tearDown(self): os.environ.pop("OAUTHLIB_INSECURE_TRANSPORT", None) def test_parse_config(self): - assert self.custos_authnz.config["verify_ssl"] - assert self.custos_authnz.config["client_id"] == "test-client-id" - assert self.custos_authnz.config["client_secret"] == "test-client-secret" - assert self.custos_authnz.config["redirect_uri"] == "https://test-redirect-uri" - assert self.custos_authnz.config["authorization_endpoint"] == "https://test-auth-endpoint" - assert self.custos_authnz.config["token_endpoint"] == "https://test-token-endpoint" - assert self.custos_authnz.config["userinfo_endpoint"] == "https://test-userinfo-endpoint" - assert self.custos_authnz.config["label"] == "test-identity-provider" - assert self.custos_authnz.config["require_create_confirmation"] is False + assert self.custos_authnz.config.verify_ssl + assert self.custos_authnz.config.client_id == "test-client-id" + assert self.custos_authnz.config.client_secret == "test-client-secret" + assert self.custos_authnz.config.redirect_uri == "https://test-redirect-uri" + assert self.custos_authnz.config.authorization_endpoint == "https://test-auth-endpoint" + assert self.custos_authnz.config.token_endpoint == "https://test-token-endpoint" + assert self.custos_authnz.config.userinfo_endpoint == "https://test-userinfo-endpoint" + assert self.custos_authnz.config.label == "test-identity-provider" + assert self.custos_authnz.config.require_create_confirmation is False def test_authenticate_set_state_cookie(self): """Verify that authenticate() sets a state cookie.""" @@ -278,7 +278,7 @@ def test_authenticate_set_pkce_verifier_cookie(self): except ImportError: raise SkipTest("pkce library is not available") """Verify that authenticate() sets a code verifier cookie.""" - self.custos_authnz.config["pkce_support"] = True + self.custos_authnz.config.pkce_support = True authorization_url = self.custos_authnz.authenticate(self.trans) parsed = urlparse(authorization_url) code_challenge_in_url = parse_qs(parsed.query)["code_challenge"][0] @@ -295,7 +295,7 @@ def test_authenticate_adds_extra_params(self): def test_authenticate_sets_env_var_when_localhost_redirect(self): """Verify that OAUTHLIB_INSECURE_TRANSPORT var is set with localhost redirect.""" - self.custos_authnz = custos_authnz.CustosAuthnz( + self.custos_authnz = custos_authnz.CustosAuthFactory.GetCustosBasedAuthProvider( "Custos", {"VERIFY_SSL": True}, { @@ -312,7 +312,7 @@ def test_authenticate_sets_env_var_when_localhost_redirect(self): assert os.environ["OAUTHLIB_INSECURE_TRANSPORT"] == "1" def test_authenticate_does_not_set_env_var_when_https_redirect(self): - assert self.custos_authnz.config["redirect_uri"].startswith("https:") + assert self.custos_authnz.config.redirect_uri.startswith("https:") assert os.environ.get("OAUTHLIB_INSECURE_TRANSPORT") is None self.custos_authnz.authenticate(self.trans) assert os.environ.get("OAUTHLIB_INSECURE_TRANSPORT") is None @@ -329,7 +329,7 @@ def test_callback_verify_with_state_cookie(self): existing_custos_authnz_token = CustosAuthnzToken( user=User(email=self.test_email, username=self.test_username), external_user_id=self.test_user_id, - provider=self.custos_authnz.config["provider"], + provider=self.custos_authnz.config.provider, access_token=old_access_token, id_token=old_id_token, refresh_token=old_refresh_token, @@ -340,7 +340,7 @@ def test_callback_verify_with_state_cookie(self): self.trans.sa_session._query.custos_authnz_token = existing_custos_authnz_token assert ( self.trans.sa_session.query(CustosAuthnzToken) - .filter_by(external_user_id=self.test_user_id, provider=self.custos_authnz.config["provider"]) + .filter_by(external_user_id=self.test_user_id, provider=self.custos_authnz.config.provider) .one_or_none() is not None ) @@ -380,7 +380,7 @@ def test_callback_nonce_validation_with_bad_nonce(self): assert not self._get_userinfo_called def test_callback_user_not_created_when_does_not_exists(self): - self.custos_authnz = custos_authnz.CustosAuthnz( + self.custos_authnz = custos_authnz.CustosAuthFactory.GetCustosBasedAuthProvider( "Keycloak", {"VERIFY_SSL": True}, { @@ -399,7 +399,7 @@ def test_callback_user_not_created_when_does_not_exists(self): assert ( self.trans.sa_session.query(CustosAuthnzToken) - .filter_by(external_user_id=self.test_user_id, provider=self.custos_authnz.config["provider"]) + .filter_by(external_user_id=self.test_user_id, provider=self.custos_authnz.config.provider) .one_or_none() is None ) @@ -415,7 +415,7 @@ def test_callback_user_not_created_when_does_not_exists(self): def test_create_user(self): assert ( self.trans.sa_session.query(CustosAuthnzToken) - .filter_by(external_user_id=self.test_user_id, provider=self.custos_authnz.config["provider"]) + .filter_by(external_user_id=self.test_user_id, provider=self.custos_authnz.config.provider) .one_or_none() is None ) @@ -468,7 +468,7 @@ def test_create_user(self): expected_refresh_expiration_time - added_custos_authnz_token.refresh_expiration_time ) assert refresh_expiration_timedelta.total_seconds() < 1 - assert self.custos_authnz.config["provider"] == added_custos_authnz_token.provider + assert self.custos_authnz.config.provider == added_custos_authnz_token.provider assert self.trans.sa_session.commit_called def test_callback_galaxy_user_not_created_when_user_logged_in_and_no_custos_authnz_token_exists(self): @@ -482,7 +482,7 @@ def test_callback_galaxy_user_not_created_when_user_logged_in_and_no_custos_auth assert ( self.trans.sa_session.query(CustosAuthnzToken) - .filter_by(external_user_id=self.test_user_id, provider=self.custos_authnz.config["provider"]) + .filter_by(external_user_id=self.test_user_id, provider=self.custos_authnz.config.provider) .one_or_none() is None ) @@ -512,7 +512,7 @@ def test_callback_galaxy_user_not_created_when_custos_authnz_token_exists(self): existing_custos_authnz_token = CustosAuthnzToken( user=User(email=self.test_email, username=self.test_username), external_user_id=self.test_user_id, - provider=self.custos_authnz.config["provider"], + provider=self.custos_authnz.config.provider, access_token=old_access_token, id_token=old_id_token, refresh_token=old_refresh_token, @@ -524,7 +524,7 @@ def test_callback_galaxy_user_not_created_when_custos_authnz_token_exists(self): assert ( self.trans.sa_session.query(CustosAuthnzToken) - .filter_by(external_user_id=self.test_user_id, provider=self.custos_authnz.config["provider"]) + .filter_by(external_user_id=self.test_user_id, provider=self.custos_authnz.config.provider) .one_or_none() is not None ) @@ -536,7 +536,7 @@ def test_callback_galaxy_user_not_created_when_custos_authnz_token_exists(self): assert self._get_userinfo_called # Make sure query was called with correct parameters assert self.test_user_id == self.trans.sa_session._query.external_user_id - assert self.custos_authnz.config["provider"] == self.trans.sa_session._query.provider + assert self.custos_authnz.config.provider == self.trans.sa_session._query.provider assert 1 == len(self.trans.sa_session.items), "Session has updated CustosAuthnzToken" session_custos_authnz_token = self.trans.sa_session.items[0] assert isinstance(session_custos_authnz_token, CustosAuthnzToken) @@ -615,7 +615,7 @@ def test_disconnect(self): custos_authnz_token = CustosAuthnzToken( user=User(email=self.test_email, username=self.test_username), external_user_id=self.test_user_id, - provider=self.custos_authnz.config["provider"], + provider=self.custos_authnz.config.provider, access_token=self.test_access_token, id_token=self.test_id_token, refresh_token=self.test_refresh_token, @@ -651,7 +651,7 @@ def test_disconnect_when_more_than_one_associated_token_for_provider(self): custos_authnz_token1 = CustosAuthnzToken( user=self.trans.user, external_user_id=self.test_user_id + "1", - provider=self.custos_authnz.config["provider"], + provider=self.custos_authnz.config.provider, access_token=self.test_access_token, id_token=self.test_id_token, refresh_token=self.test_refresh_token, @@ -661,7 +661,7 @@ def test_disconnect_when_more_than_one_associated_token_for_provider(self): custos_authnz_token2 = CustosAuthnzToken( user=self.trans.user, external_user_id=self.test_user_id + "2", - provider=self.custos_authnz.config["provider"], + provider=self.custos_authnz.config.provider, access_token=self.test_access_token, id_token=self.test_id_token, refresh_token=self.test_refresh_token,