diff --git a/setup.py b/setup.py index add6301..2f7257b 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ author_email='samuel.gulliksson@gmail.com', description='Flask extension for OpenID Connect authentication.', install_requires=[ - 'oic>=1.2.1', + 'oic>=1.4.0', 'Flask', 'requests', 'importlib_resources' diff --git a/src/flask_pyoidc/auth_response_handler.py b/src/flask_pyoidc/auth_response_handler.py index 3025e2c..825ec90 100644 --- a/src/flask_pyoidc/auth_response_handler.py +++ b/src/flask_pyoidc/auth_response_handler.py @@ -71,7 +71,8 @@ def process_auth_response(self, auth_response, auth_request): refresh_token = None # but never refresh token if 'code' in auth_response: - token_resp = self._client.exchange_authorization_code(auth_response['code']) + token_resp = self._client.exchange_authorization_code(auth_response['code'], + auth_response['state']) if token_resp: if 'error' in token_resp: raise AuthResponseErrorResponseError(token_resp.to_dict()) diff --git a/src/flask_pyoidc/message_factory.py b/src/flask_pyoidc/message_factory.py new file mode 100644 index 0000000..3a65a2b --- /dev/null +++ b/src/flask_pyoidc/message_factory.py @@ -0,0 +1,6 @@ +from oic.oauth2.message import AccessTokenResponse, CCAccessTokenRequest, MessageTuple, OauthMessageFactory + + +class CCMessageFactory(OauthMessageFactory): + """Client Credential Request Factory.""" + token_endpoint = MessageTuple(CCAccessTokenRequest, AccessTokenResponse) diff --git a/src/flask_pyoidc/provider_configuration.py b/src/flask_pyoidc/provider_configuration.py index 8b8351a..af5108e 100644 --- a/src/flask_pyoidc/provider_configuration.py +++ b/src/flask_pyoidc/provider_configuration.py @@ -1,8 +1,9 @@ import collections.abc import logging -from oic.oic import Client import requests +from oic.oic import Client +from oic.utils.settings import ClientSettings logger = logging.getLogger(__name__) @@ -169,17 +170,16 @@ def __init__(self, self.userinfo_endpoint_method = userinfo_http_method self.auth_request_params = auth_request_params or {} self.session_refresh_interval_seconds = session_refresh_interval_seconds + # For session persistence + self.client_settings = ClientSettings(timeout=self.DEFAULT_REQUEST_TIMEOUT, + requests_session=requests_session or requests.Session()) - self.requests_session = requests_session or requests.Session() - - def ensure_provider_metadata(self): + def ensure_provider_metadata(self, client: Client): if not self._provider_metadata: - resp = self.requests_session \ - .get(self._issuer + '/.well-known/openid-configuration', - timeout=self.DEFAULT_REQUEST_TIMEOUT) - logger.debug('Received discovery response: ' + resp.text) + resp = client.provider_config(self._issuer) + logger.debug(f'Received discovery response: {resp.to_dict()}') - self._provider_metadata = ProviderMetadata(**resp.json()) + self._provider_metadata = ProviderMetadata(**resp.to_dict()) return self._provider_metadata @@ -200,8 +200,8 @@ def register_client(self, client: Client): registration_response = client.register( url=self._provider_metadata['registration_endpoint'], **registration_request) + logger.info('Received registration response.') self._client_metadata = ClientMetadata( **registration_response.to_dict()) - logger.debug('Received registration response: client_id=' + self._client_metadata['client_id']) return self._client_metadata diff --git a/src/flask_pyoidc/pyoidc_facade.py b/src/flask_pyoidc/pyoidc_facade.py index 086da6b..58aaf01 100644 --- a/src/flask_pyoidc/pyoidc_facade.py +++ b/src/flask_pyoidc/pyoidc_facade.py @@ -1,43 +1,18 @@ import base64 -import json import logging from oic.extension.client import Client as ClientExtension from oic.extension.message import TokenIntrospectionResponse -from oic.oic import Client, RegistrationResponse, AuthorizationResponse, \ - AccessTokenResponse, TokenErrorResponse, AuthorizationErrorResponse, OpenIDSchema -from oic.oic.message import ProviderConfigurationResponse +from oic.oauth2 import Client as Oauth2Client +from oic.oauth2.message import AccessTokenResponse +from oic.oic import Client +from oic.oic import Token +from oic.oic.message import AuthorizationResponse, ProviderConfigurationResponse, RegistrationResponse from oic.utils.authn.client import CLIENT_AUTHN_METHOD -logger = logging.getLogger(__name__) - - -class _ClientAuthentication: - def __init__(self, client_id, client_secret): - self._client_id = client_id - self._client_secret = client_secret - - def __call__(self, method, request): - """ - Args: - method (str): Client Authentication Method. Only 'client_secret_basic' and 'client_secret_post' is - supported. - request (MutableMapping[str, str]): Token request parameters. This may be modified, i.e. if - 'client_secret_post' is used the client credentials will be added. - - Returns: - (Mapping[str, str]): HTTP headers to be included in the token request, or `None` if no extra HTTPS headers - are required for the token request. - """ - if method == 'client_secret_post': - request['client_id'] = self._client_id - request['client_secret'] = self._client_secret - return None # authentication is in the request body, so no Authorization header is returned +from .message_factory import CCMessageFactory - # default to 'client_secret_basic' - credentials = '{}:{}'.format(self._client_id, self._client_secret) - basic_auth = 'Basic {}'.format(base64.urlsafe_b64encode(credentials.encode('utf-8')).decode('utf-8')) - return {'Authorization': basic_auth} +logger = logging.getLogger(__name__) class PyoidcFacade: @@ -51,29 +26,45 @@ def __init__(self, provider_configuration, redirect_uri): provider_configuration (flask_pyoidc.provider_configuration.ProviderConfiguration) """ self._provider_configuration = provider_configuration - self._client = Client(client_authn_method=CLIENT_AUTHN_METHOD) - # Token Introspection is implemented in extension sub-package of the - # client in pyoidc - self._client_extension = ClientExtension(client_authn_method=CLIENT_AUTHN_METHOD) - - provider_metadata = provider_configuration.ensure_provider_metadata() + self._client = Client(client_authn_method=CLIENT_AUTHN_METHOD, + settings=provider_configuration.client_settings) + # Token Introspection is implemented under extension sub-package of + # the client in pyoidc. + self._client_extension = ClientExtension(client_authn_method=CLIENT_AUTHN_METHOD, + settings=provider_configuration.client_settings) + # Client Credentials Flow is implemented under oauth2 sub-package of + # the client in pyoidc. + self._oauth2_client = Oauth2Client(client_authn_method=CLIENT_AUTHN_METHOD, + message_factory=CCMessageFactory, + settings=self._provider_configuration.client_settings) + + provider_metadata = provider_configuration.ensure_provider_metadata(self._client) self._client.handle_provider_config(ProviderConfigurationResponse(**provider_metadata.to_dict()), provider_metadata['issuer']) if self._provider_configuration.registered_client_metadata: client_metadata = self._provider_configuration.registered_client_metadata.to_dict() - registration_response = RegistrationResponse(**client_metadata) - self._client.store_registration_info(registration_response) + client_metadata.update(redirect_uris=list(redirect_uri)) + self._store_registration_info(client_metadata) self._redirect_uri = redirect_uri + def _store_registration_info(self, client_metadata): + registration_response = RegistrationResponse(**client_metadata) + self._client.store_registration_info(registration_response) + self._client_extension.store_registration_info(registration_response) + # Set client_id and client_secret for _oauth2_client. This is used + # by Client Credentials Flow. + self._oauth2_client.client_id = registration_response['client_id'] + self._oauth2_client.client_secret = registration_response['client_secret'] + def is_registered(self): return bool(self._provider_configuration.registered_client_metadata) def register(self): client_metadata = self._provider_configuration.register_client(self._client) - logger.debug('client registration response: %s', client_metadata) - self._client.store_registration_info(RegistrationResponse(**client_metadata.to_dict())) + logger.debug(f'client registration response: {client_metadata}') + self._store_registration_info(client_metadata) def authentication_request(self, state, nonce, extra_auth_params): """ @@ -104,7 +95,7 @@ def authentication_request(self, state, nonce, extra_auth_params): def login_url(self, auth_request): """ Args: - auth_request (AuthorizationRequest): authenticatio request + auth_request (AuthorizationRequest): authentication request Returns: str: Authentication request as a URL to redirect the user to the provider. """ @@ -112,34 +103,55 @@ def login_url(self, auth_request): def parse_authentication_response(self, response_params): """ - Args: - response_params (Mapping[str, str]): authentication response parameters - Returns: - Union[AuthorizationResponse, AuthorizationErrorResponse]: The parsed authorization response + Parameters + ---------- + response_params: Mapping[str, str] + authentication response parameters. + + Returns + ------- + Union[AuthorizationResponse, AuthorizationErrorResponse] + The parsed authorization response. """ - auth_resp = self._parse_response(response_params, AuthorizationResponse, AuthorizationErrorResponse) + auth_resp = self._client.parse_response(AuthorizationResponse, info=response_params, sformat='dict') if 'id_token' in response_params: auth_resp['id_token_jwt'] = response_params['id_token'] return auth_resp - def exchange_authorization_code(self, authorization_code): - """ - Requests tokens from an authorization code. + def exchange_authorization_code(self, authorization_code: str, state: str): + """Requests tokens from an authorization code. - Args: - authorization_code (str): authorization code issued to client after user authorization + Parameters + ---------- + authorization_code: str + authorization code issued to client after user authorization + state: str + state is used to keep track of responses to outstanding requests. - Returns: - Union[AccessTokenResponse, TokenErrorResponse, None]: The parsed token response, or None if no token - request was performed. + Returns + ------- + Union[AccessTokenResponse, TokenErrorResponse, None] + The parsed token response, or None if no token request was performed. """ - request = { + if not self._client.token_endpoint: + return None + + request_args = { 'grant_type': 'authorization_code', 'code': authorization_code, 'redirect_uri': self._redirect_uri } - - return self._token_request(request) + logger.debug('making token request: %s', request_args) + client_auth_method = self._client.registration_response.get('token_endpoint_auth_method', + 'client_secret_basic') + token_response = self._client.do_access_token_request(state=state, + request_args=request_args, + authn_method=client_auth_method, + endpoint=self._client.token_endpoint + ) + logger.info('Received token response.') + + return token_response def verify_id_token(self, id_token, auth_request): """ @@ -156,79 +168,55 @@ def verify_id_token(self, id_token, auth_request): """ self._client.verify_id_token(id_token, auth_request) - def refresh_token(self, refresh_token): - """ - Requests new tokens using a refresh token. + def refresh_token(self, refresh_token: str): + """Requests new tokens using a refresh token. - Args: - refresh_token (str): refresh token issued to client after user authorization + Parameters + ---------- + refresh_token: str + refresh token issued to client after user authorization. - Returns: - Union[AccessTokenResponse, TokenErrorResponse, None]: The parsed token response, or None if no token - request was performed. + Returns + ------- + Union[AccessTokenResponse, TokenErrorResponse, None] + The parsed token response, or None if no token request was performed. """ - request = { + request_args = { 'grant_type': 'refresh_token', 'refresh_token': refresh_token, 'redirect_uri': self._redirect_uri } + client_auth_method = self._client.registration_response.get('token_endpoint_auth_method', + 'client_secret_basic') + return self._client.do_access_token_refresh(request_args=request_args, + authn_method=client_auth_method, + token=Token(resp={'refresh_token': refresh_token}), + endpoint=self._client.token_endpoint + ) - return self._token_request(request) - - def _token_request(self, request): - """ - Makes a token request. If the 'token_endpoint' is not configured in the provider metadata, no request will - be made. - - Args: - request (Mapping[str, str]): token request parameters - - Returns: - Union[AccessTokenResponse, TokenErrorResponse, None]: The parsed token response, or None if no token - request was performed. - """ - - if not self._client.token_endpoint: - return None + def userinfo_request(self, access_token: str): + """Retrieves ID token. - logger.debug('making token request: %s', request) - client_auth_method = self._client.registration_response.get('token_endpoint_auth_method', 'client_secret_basic') - auth_header = _ClientAuthentication(self._client.client_id, self._client.client_secret)(client_auth_method, - request) - resp = self._provider_configuration.requests_session \ - .post(self._client.token_endpoint, - data=request, - headers=auth_header) \ - .json() - logger.debug('received token response: %s', json.dumps(resp)) - - token_resp = self._parse_response(resp, AccessTokenResponse, TokenErrorResponse) - if 'id_token' in resp: - token_resp['id_token_jwt'] = resp['id_token'] - - return token_resp - - def userinfo_request(self, access_token): - """ - Args: - access_token (str): Bearer access token to use when fetching userinfo + Parameters + ---------- + access_token: str + Bearer access token to use when fetching userinfo. - Returns: - oic.oic.message.OpenIDSchema: UserInfo Response + Returns + ------- + Union[OpenIDSchema, UserInfoErrorResponse, ErrorResponse, None] """ http_method = self._provider_configuration.userinfo_endpoint_method if not access_token or http_method is None or not self._client.userinfo_endpoint: return None logger.debug('making userinfo request') - userinfo_response = self._provider_configuration.requests_session \ - .request(http_method, self._client.userinfo_endpoint, headers={'Authorization': f'Bearer {access_token}'}) \ - .json() + userinfo_response = self._client.do_user_info_request(method=http_method, token=access_token) logger.debug('received userinfo response: %s', userinfo_response) - return OpenIDSchema(**userinfo_response) + return userinfo_response - def _token_introspection_request(self, access_token: str): + def _token_introspection_request(self, access_token: str) -> TokenIntrospectionResponse: """Make token introspection request. Parameters @@ -238,23 +226,22 @@ def _token_introspection_request(self, access_token: str): Returns ------- - oic.extension.message.TokenIntrospectionResponse + TokenIntrospectionResponse Response object contains result of the token introspection. """ - request = { + request_args = { 'token': access_token, 'token_type_hint': 'access_token' } - auth_header = _ClientAuthentication(self._client.client_id, self._client.client_secret)('client_secret_basic', - request) + client_auth_method = self._client.registration_response.get('introspection_endpoint_auth_method', + 'client_secret_basic') logger.info('making token introspection request') - response = self._provider_configuration.requests_session \ - .post(self._client.introspection_endpoint, data=request, headers=auth_header) \ - .json() + token_introspection_response = self._client_extension.do_token_introspection( + request_args=request_args, authn_method=client_auth_method, endpoint=self._client.introspection_endpoint) - return TokenIntrospectionResponse(**response) + return token_introspection_response - def client_credentials_grant(self, scope: list = None, **kwargs): + def client_credentials_grant(self, scope: list = None, **kwargs) -> AccessTokenResponse: """Public method to request access_token using client_credentials flow. This is useful for service to service communication where user-agent is not available which is required in authorization code flow. Your @@ -271,6 +258,10 @@ def client_credentials_grant(self, scope: list = None, **kwargs): **kwargs : dict, optional Extra arguments to client credentials flow. + Returns + ------- + AccessTokenResponse + Examples -------- :: @@ -294,13 +285,18 @@ def client_credentials_grant(self, scope: list = None, **kwargs): auth.clients['default'].client_credentials_grant( scope=['read', 'write'], audience=['client_id1', 'client_id2']) """ - request = { + request_args = { 'grant_type': 'client_credentials', **kwargs } if scope: - request['scope'] = ' '.join(scope) - return self._token_request(request) + request_args['scope'] = ' '.join(scope) + client_auth_method = self._client.registration_response.get('token_endpoint_auth_method', + 'client_secret_basic') + access_token = self._oauth2_client.do_access_token_request(request_args=request_args, + authn_method=client_auth_method, + endpoint=self._client.token_endpoint) + return access_token @property def session_refresh_interval_seconds(self): @@ -308,17 +304,9 @@ def session_refresh_interval_seconds(self): @property def provider_end_session_endpoint(self): - provider_metadata = self._provider_configuration.ensure_provider_metadata() + provider_metadata = self._provider_configuration.ensure_provider_metadata(self._client) return provider_metadata.get('end_session_endpoint') @property def post_logout_redirect_uris(self): return self._client.registration_response.get('post_logout_redirect_uris') - - def _parse_response(self, response_params, success_response_cls, error_response_cls): - if 'error' in response_params: - response = error_response_cls(**response_params) - else: - response = success_response_cls(**response_params) - response.verify(keyjar=self._client.keyjar) - return response diff --git a/tests/test_flask_pyoidc.py b/tests/test_flask_pyoidc.py index da5e2c2..ca667f7 100644 --- a/tests/test_flask_pyoidc.py +++ b/tests/test_flask_pyoidc.py @@ -17,8 +17,8 @@ from werkzeug.exceptions import Forbidden, Unauthorized from flask_pyoidc import OIDCAuthentication -from flask_pyoidc.provider_configuration import ProviderConfiguration, ProviderMetadata, ClientMetadata, \ - ClientRegistrationInfo +from flask_pyoidc.provider_configuration import (ProviderConfiguration, ProviderMetadata, ClientMetadata, + ClientRegistrationInfo) from flask_pyoidc.user_session import UserSession from werkzeug.routing import BuildError diff --git a/tests/test_provider_configuration.py b/tests/test_provider_configuration.py index 888f726..364b620 100644 --- a/tests/test_provider_configuration.py +++ b/tests/test_provider_configuration.py @@ -45,7 +45,7 @@ def test_should_fetch_provider_metadata_if_not_given(self): provider_config = ProviderConfiguration(issuer=self.PROVIDER_BASEURL, client_registration_info=ClientRegistrationInfo()) - provider_config.ensure_provider_metadata() + provider_config.ensure_provider_metadata(Client(CLIENT_AUTHN_METHOD)) assert provider_config._provider_metadata['issuer'] == self.PROVIDER_BASEURL assert provider_config._provider_metadata['authorization_endpoint'] == self.PROVIDER_BASEURL + '/auth' assert provider_config._provider_metadata['jwks_uri'] == self.PROVIDER_BASEURL + '/jwks' @@ -55,7 +55,7 @@ def test_should_not_fetch_provider_metadata_if_given(self): provider_config = ProviderConfiguration(provider_metadata=provider_metadata, client_registration_info=ClientRegistrationInfo()) - provider_config.ensure_provider_metadata() + provider_config.ensure_provider_metadata(Client(CLIENT_AUTHN_METHOD)) assert provider_config._provider_metadata == provider_metadata @responses.activate diff --git a/tests/test_pyoidc_facade.py b/tests/test_pyoidc_facade.py index dec4c1d..af1933c 100644 --- a/tests/test_pyoidc_facade.py +++ b/tests/test_pyoidc_facade.py @@ -3,13 +3,13 @@ import pytest import responses -from oic.oic import AuthorizationResponse, AccessTokenResponse, TokenErrorResponse, OpenIDSchema, \ - AuthorizationErrorResponse -from urllib.parse import parse_qsl, urlparse +from oic.oic import (AccessTokenResponse, AuthorizationErrorResponse, AuthorizationResponse, Grant, OpenIDSchema, + TokenErrorResponse) +from urllib.parse import parse_qsl from flask_pyoidc.provider_configuration import ProviderConfiguration, ClientMetadata, ProviderMetadata, \ ClientRegistrationInfo -from flask_pyoidc.pyoidc_facade import PyoidcFacade, _ClientAuthentication +from flask_pyoidc.pyoidc_facade import PyoidcFacade from .util import signed_id_token REDIRECT_URI = 'https://rp.example.com/redirect_uri' @@ -133,12 +133,12 @@ def test_parse_authentication_response_preserves_id_token_jwt(self): assert parsed_auth_response['state'] == state assert parsed_auth_response['id_token_jwt'] == id_token - @pytest.mark.parametrize('request_func,expected_token_request', [ + @pytest.mark.parametrize('request_func, expected_token_request', [ ( - lambda facade: facade.exchange_authorization_code('auth-code'), + lambda facade: facade.exchange_authorization_code('auth-code', 'test-state'), { 'grant_type': 'authorization_code', - 'code': 'auth-code', + 'state': 'test-state', 'redirect_uri': REDIRECT_URI } ), @@ -165,14 +165,20 @@ def test_token_request(self, request_func, expected_token_request): } id_token_jwt, id_token_signing_key = signed_id_token(id_token_claims) token_response = AccessTokenResponse(access_token='test_access_token', + refresh_token='refresh-token', token_type='Bearer', - id_token=id_token_jwt) + id_token=id_token_jwt, + expires_in=now + 1) + responses.add(responses.POST, token_endpoint, json=token_response.to_dict()) provider_metadata = self.PROVIDER_METADATA.copy(token_endpoint=token_endpoint) facade = PyoidcFacade(ProviderConfiguration(provider_metadata=provider_metadata, client_metadata=self.CLIENT_METADATA), REDIRECT_URI) + grant = Grant(resp=token_response) + grant.grant_expiration_time = now + grant.exp_in + facade._client.grant = {'test-state': grant} responses.add(responses.GET, self.PROVIDER_METADATA['jwks_uri'], @@ -198,13 +204,17 @@ def test_token_request_handles_error_response(self): facade = PyoidcFacade(ProviderConfiguration(provider_metadata=provider_metadata, client_metadata=self.CLIENT_METADATA), REDIRECT_URI) - assert facade.exchange_authorization_code('1234') == token_response + state = 'test-state' + grant = Grant() + grant.grant_expiration_time = int(time.time()) + grant.exp_in + facade._client.grant = {state: grant} + assert facade.exchange_authorization_code('1234', state) == token_response def test_token_request_handles_missing_provider_token_endpoint(self): facade = PyoidcFacade(ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, client_metadata=self.CLIENT_METADATA), REDIRECT_URI) - assert facade.exchange_authorization_code('1234') is None + assert facade.exchange_authorization_code(None, None) is None @pytest.mark.parametrize('userinfo_http_method', [ 'GET', @@ -244,8 +254,12 @@ def test_no_userinfo_request_is_made_if_no_access_token(self): assert facade.userinfo_request(None) is None @responses.activate - @pytest.mark.parametrize('scope', [None, ['read', 'write']]) - def test_client_credentials_grant(self, scope): + @pytest.mark.parametrize('scope, extra_args', + [(None, {}), + (['read', 'write'], + {'audience': ['client_id1, client_id2']}) + ]) + def test_client_credentials_grant(self, scope, extra_args): token_endpoint = f'{self.PROVIDER_BASEURL}/token' provider_metadata = self.PROVIDER_METADATA.copy( token_endpoint=token_endpoint) @@ -264,7 +278,7 @@ def test_client_credentials_grant(self, scope): responses.add(responses.POST, token_endpoint, json=client_credentials_grant_response) assert client_credentials_grant_response == facade.client_credentials_grant( - scope=scope, audience=['client_id1, client_id2']).to_dict() + scope=scope, **extra_args).to_dict() def test_post_logout_redirect_uris(self): post_logout_redirect_uris = ['https://client.example.com/logout'] @@ -274,32 +288,3 @@ def test_post_logout_redirect_uris(self): client_metadata=client_metadata), REDIRECT_URI) assert facade.post_logout_redirect_uris == post_logout_redirect_uris - - -class TestClientAuthentication(object): - CLIENT_ID = 'client1' - CLIENT_SECRET = 'secret1' - - @property - def basic_auth(self): - credentials = '{}:{}'.format(self.CLIENT_ID, self.CLIENT_SECRET) - return 'Basic {}'.format(base64.urlsafe_b64encode(credentials.encode('utf-8')).decode('utf-8')) - - @pytest.fixture(autouse=True) - def setup(self): - self.client_auth = _ClientAuthentication(self.CLIENT_ID, self.CLIENT_SECRET) - - def test_client_secret_basic(self): - request = {} - headers = self.client_auth('client_secret_basic', request) - assert headers == {'Authorization': self.basic_auth} - assert request == {} - - def test_client_secret_post(self): - request = {} - headers = self.client_auth('client_secret_post', request) - assert headers is None - assert request == {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET} - - def test_defaults_to_client_secret_basic(self): - assert self.client_auth('invalid_client_auth_method', {}) == self.client_auth('client_secret_basic', {})