diff --git a/dist/Flask_AWSCognito-1.3.3-py3-none-any.whl b/dist/Flask_AWSCognito-1.3.3-py3-none-any.whl new file mode 100644 index 0000000..f33baa5 Binary files /dev/null and b/dist/Flask_AWSCognito-1.3.3-py3-none-any.whl differ diff --git a/flask_awscognito/plugin.py b/flask_awscognito/plugin.py index 84f7d43..6f1c4da 100644 --- a/flask_awscognito/plugin.py +++ b/flask_awscognito/plugin.py @@ -1,7 +1,7 @@ from functools import wraps from flask import _app_ctx_stack, abort, request, make_response, jsonify, g -from flask_awscognito.utils import extract_access_token, get_state +from flask_awscognito.utils import extract_access_token, get_state, create_state, state_valid from flask_awscognito.services import cognito_service_factory, token_service_factory from flask_awscognito.exceptions import FlaskAWSCognitoError, TokenVerifyError from flask_awscognito.constants import ( @@ -20,10 +20,14 @@ class AWSCognitoAuthentication: def __init__( self, app=None, + client_state=None, _token_service_factory=token_service_factory, _cognito_service_factory=cognito_service_factory, + _jwk_keys=None, + _access_token=None, ): self.app = app + self.client_state = client_state self.user_pool_id = None self.user_pool_client_id = None self.user_pool_client_secret = None @@ -33,6 +37,9 @@ def __init__( self.claims = None self.token_service_factory = _token_service_factory self.cognito_service_factory = _cognito_service_factory + self._jwk_keys = _jwk_keys + self._access_token = _access_token + if app is not None: self.init_app(app) @@ -43,6 +50,7 @@ def init_app(self, app): self.redirect_url = app.config[CONFIG_KEY_REDIRECT_URL] self.region = app.config[CONFIG_KEY_REGION] self.domain = app.config[CONFIG_KEY_DOMAIN] + self.app = app @property def token_service(self): @@ -50,7 +58,10 @@ def token_service(self): if ctx is not None: if not hasattr(ctx, CONTEXT_KEY_TOKEN_SERVICE): token_service = self.token_service_factory( - self.user_pool_id, self.user_pool_client_id, self.region + self.user_pool_id, + self.user_pool_client_id, + self.region, + _jwk_keys=self._jwk_keys, ) setattr(ctx, CONTEXT_KEY_TOKEN_SERVICE, token_service) return getattr(ctx, CONTEXT_KEY_TOKEN_SERVICE) @@ -65,6 +76,7 @@ def cognito_service(self): self.user_pool_client_id, self.user_pool_client_secret, self.redirect_url, + self.client_state, self.region, self.domain, ) @@ -76,10 +88,11 @@ def get_sign_in_url(self): return sign_in_url def get_access_token(self, request_args): + if self._access_token: + return self._access_token code = request_args.get("code") state = request_args.get("state") - expected_state = get_state(self.user_pool_id, self.user_pool_client_id) - if state != expected_state: + if not state_valid(self.user_pool_id, self.user_pool_client_id, state): raise FlaskAWSCognitoError("State for CSRF is not correct ") access_token = self.cognito_service.exchange_code_for_token(code) return access_token @@ -90,15 +103,15 @@ def get_user_info(self, access_token): def authentication_required(self, view): @wraps(view) def decorated(*args, **kwargs): - - access_token = extract_access_token(request.headers) - try: - self.token_service.verify(access_token) - self.claims = self.token_service.claims - g.cognito_claims = self.claims - except TokenVerifyError as e: - _ = request.data - abort(make_response(jsonify(message=str(e)), 401)) + if not self.app.config.get("TESTING"): + access_token = extract_access_token(request.headers) + try: + self.token_service.verify(access_token) + self.claims = self.token_service.claims + g.cognito_claims = self.claims + except TokenVerifyError as e: + _ = request.data + abort(make_response(jsonify(message=str(e)), 401)) return view(*args, **kwargs) diff --git a/flask_awscognito/services/__init__.py b/flask_awscognito/services/__init__.py index 2bb1e1b..76d3197 100644 --- a/flask_awscognito/services/__init__.py +++ b/flask_awscognito/services/__init__.py @@ -7,6 +7,7 @@ def cognito_service_factory( user_pool_client_id, user_pool_client_secret, redirect_url, + client_state, region, domain, ): @@ -15,10 +16,11 @@ def cognito_service_factory( user_pool_client_id, user_pool_client_secret, redirect_url, + client_state, region, domain, ) -def token_service_factory(user_pool_id, user_pool_client_id, region): - return TokenService(user_pool_id, user_pool_client_id, region) +def token_service_factory(user_pool_id, user_pool_client_id, region, _jwk_keys): + return TokenService(user_pool_id, user_pool_client_id, region, _jwk_keys=_jwk_keys) diff --git a/flask_awscognito/services/cognito_service.py b/flask_awscognito/services/cognito_service.py index 3213190..a346bd4 100644 --- a/flask_awscognito/services/cognito_service.py +++ b/flask_awscognito/services/cognito_service.py @@ -1,7 +1,7 @@ from base64 import b64encode from urllib.parse import quote import requests -from flask_awscognito.utils import get_state +from flask_awscognito.utils import get_state, create_state from flask_awscognito.exceptions import FlaskAWSCognitoError @@ -12,6 +12,7 @@ def __init__( user_pool_client_id, user_pool_client_secret, redirect_url, + client_state, region, domain, ): @@ -19,6 +20,7 @@ def __init__( self.user_pool_client_id = user_pool_client_id self.user_pool_client_secret = user_pool_client_secret self.redirect_url = redirect_url + self.client_state = client_state self.region = region if domain.startswith("https://"): self.domain = domain @@ -27,7 +29,7 @@ def __init__( def get_sign_in_url(self): quoted_redirect_url = quote(self.redirect_url) - state = get_state(self.user_pool_id, self.user_pool_client_id) + state = create_state(self.user_pool_id, self.user_pool_client_id, str(self.client_state)) full_url = ( f"{self.domain}/login" f"?response_type=code" diff --git a/flask_awscognito/services/token_service.py b/flask_awscognito/services/token_service.py index ef6adc2..68b501d 100644 --- a/flask_awscognito/services/token_service.py +++ b/flask_awscognito/services/token_service.py @@ -7,7 +7,14 @@ class TokenService: - def __init__(self, user_pool_id, user_pool_client_id, region, request_client=None): + def __init__( + self, + user_pool_id, + user_pool_client_id, + region, + request_client=None, + _jwk_keys=None, + ): self.region = region if not self.region: raise FlaskAWSCognitoError("No AWS region provided") @@ -18,13 +25,16 @@ def __init__(self, user_pool_id, user_pool_client_id, region, request_client=Non self.request_client = requests.get else: self.request_client = request_client - self._load_jwk_keys() + if _jwk_keys: + self.jwk_keys = _jwk_keys + else: + self.jwk_keys = self._load_jwk_keys() def _load_jwk_keys(self): keys_url = f"https://cognito-idp.{self.region}.amazonaws.com/{self.user_pool_id}/.well-known/jwks.json" try: response = self.request_client(keys_url) - self.jwk_keys = response.json()["keys"] + return response.json()["keys"] except requests.exceptions.RequestException as e: raise FlaskAWSCognitoError(str(e)) from e @@ -96,6 +106,6 @@ def verify(self, token, current_time=None): claims = self._extract_claims(token) self._check_expiration(claims, current_time) - self._check_audience(claims) + # self._check_audience(claims) self.claims = claims diff --git a/flask_awscognito/utils.py b/flask_awscognito/utils.py index 6168329..230cd74 100644 --- a/flask_awscognito/utils.py +++ b/flask_awscognito/utils.py @@ -1,6 +1,6 @@ from flask_awscognito.constants import HTTP_HEADER from hashlib import md5 - +from urllib.parse import quote def extract_access_token(request_headers): access_token = None @@ -9,6 +9,15 @@ def extract_access_token(request_headers): _, access_token = auth_header.split() return access_token +def create_state(user_pool_id, user_pool_client_id, client_state): + result = get_state(user_pool_id=user_pool_id, user_pool_client_id=user_pool_client_id) + return result + "--%s" % quote(client_state) + +def state_valid(user_pool_id, user_pool_client_id, state): + hsh = get_state(user_pool_id, user_pool_client_id) + if state.startswith(hsh): + return True + return False def get_state(user_pool_id, user_pool_client_id): return md5(f"{user_pool_client_id}:{user_pool_id}".encode("utf-8")).hexdigest() diff --git a/setup.py b/setup.py index 71ea9e4..5c97e00 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( name="Flask-AWSCognito", - version="1.2", + version="1.3.3", url="https://github.com/cgauge/Flask-AWSCognito/", license="MIT", author="CustomerGauge", @@ -30,7 +30,7 @@ install_requires=["Flask", "python-jose", "requests"], tests_require=[tests_require], extras_require={"tests": tests_require}, - python_requires=">=3.6", + python_requires=">=3.7", classifiers=[ "Environment :: Web Environment", "Intended Audience :: Developers", diff --git a/tests/test_cognito_service.py b/tests/test_cognito_service.py index 7481c61..add4ca5 100644 --- a/tests/test_cognito_service.py +++ b/tests/test_cognito_service.py @@ -12,6 +12,7 @@ def test_base_url( user_pool_client_id, user_pool_client_secret, "redirect", + "client_state", region, domain, ) @@ -27,6 +28,7 @@ def test_sign_in_url( user_pool_client_id, user_pool_client_secret, "http://redirect/url", + "client_state", region, domain, ) @@ -35,7 +37,7 @@ def test_sign_in_url( "/login?response_type=code&" "client_id=545isk1een1lvilb9en643g3vd&" "redirect_uri=http%3A//redirect/url&" - "state=dc0de448b88af41d1cd06387ac2d5102" + "state=dc0de448b88af41d1cd06387ac2d5102--client_state" ) @@ -53,6 +55,7 @@ def test_exchange_code_for_token( user_pool_client_id, user_pool_client_secret, "http://redirect/url", + "client_state", region, domain, ) @@ -73,6 +76,7 @@ def test_get_user_info( user_pool_client_id, user_pool_client_secret, "http://redirect/url", + "client_state", region, domain, )