Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add injection of custom state (useful for internal redirects) #8

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
8 changes: 5 additions & 3 deletions flask_awscognito/plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import wraps

from flask import _app_ctx_stack, abort, request, make_response, jsonify, g
from flask_awscognito.utils import extract_access_token, get_state
from flask_awscognito.utils import extract_access_token, get_state, create_state, state_valid
from flask_awscognito.services import cognito_service_factory, token_service_factory
from flask_awscognito.exceptions import FlaskAWSCognitoError, TokenVerifyError
from flask_awscognito.constants import (
Expand All @@ -20,10 +20,12 @@ class AWSCognitoAuthentication:
def __init__(
self,
app=None,
client_state=None,
_token_service_factory=token_service_factory,
_cognito_service_factory=cognito_service_factory,
):
self.app = app
self.client_state = client_state
self.user_pool_id = None
self.user_pool_client_id = None
self.user_pool_client_secret = None
Expand Down Expand Up @@ -65,6 +67,7 @@ def cognito_service(self):
self.user_pool_client_id,
self.user_pool_client_secret,
self.redirect_url,
self.client_state,
self.region,
self.domain,
)
Expand All @@ -78,8 +81,7 @@ def get_sign_in_url(self):
def get_access_token(self, request_args):
code = request_args.get("code")
state = request_args.get("state")
expected_state = get_state(self.user_pool_id, self.user_pool_client_id)
if state != expected_state:
if not state_valid(self.user_pool_id, self.user_pool_client_id, state):
raise FlaskAWSCognitoError("State for CSRF is not correct ")
access_token = self.cognito_service.exchange_code_for_token(code)
return access_token
Expand Down
2 changes: 2 additions & 0 deletions flask_awscognito/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def cognito_service_factory(
user_pool_client_id,
user_pool_client_secret,
redirect_url,
client_state,
region,
domain,
):
Expand All @@ -15,6 +16,7 @@ def cognito_service_factory(
user_pool_client_id,
user_pool_client_secret,
redirect_url,
client_state,
region,
domain,
)
Expand Down
6 changes: 4 additions & 2 deletions flask_awscognito/services/cognito_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from base64 import b64encode
from urllib.parse import quote
import requests
from flask_awscognito.utils import get_state
from flask_awscognito.utils import get_state, create_state
from flask_awscognito.exceptions import FlaskAWSCognitoError


Expand All @@ -12,13 +12,15 @@ def __init__(
user_pool_client_id,
user_pool_client_secret,
redirect_url,
client_state,
region,
domain,
):
self.user_pool_id = user_pool_id
self.user_pool_client_id = user_pool_client_id
self.user_pool_client_secret = user_pool_client_secret
self.redirect_url = redirect_url
self.client_state = client_state
self.region = region
if domain.startswith("https://"):
self.domain = domain
Expand All @@ -27,7 +29,7 @@ def __init__(

def get_sign_in_url(self):
quoted_redirect_url = quote(self.redirect_url)
state = get_state(self.user_pool_id, self.user_pool_client_id)
state = create_state(self.user_pool_id, self.user_pool_client_id, quote(str(self.client_state)))
m4g005 marked this conversation as resolved.
Show resolved Hide resolved
full_url = (
f"{self.domain}/login"
f"?response_type=code"
Expand Down
10 changes: 9 additions & 1 deletion flask_awscognito/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from flask_awscognito.constants import HTTP_HEADER
from hashlib import md5


def extract_access_token(request_headers):
access_token = None
auth_header = request_headers.get(HTTP_HEADER)
if auth_header and " " in auth_header:
_, access_token = auth_header.split()
return access_token

def create_state(user_pool_id, user_pool_client_id, client_state):
result = get_state(user_pool_id=user_pool_id, user_pool_client_id=user_pool_client_id)
return result + "--%s" % client_state

def state_valid(user_pool_id, user_pool_client_id, state):
hsh = get_state(user_pool_id, user_pool_client_id)
if state.startswith(hsh):
return True
return False

def get_state(user_pool_id, user_pool_client_id):
return md5(f"{user_pool_client_id}:{user_pool_id}".encode("utf-8")).hexdigest()
6 changes: 5 additions & 1 deletion tests/test_cognito_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_base_url(
user_pool_client_id,
user_pool_client_secret,
"redirect",
"client_state",
region,
domain,
)
Expand All @@ -27,6 +28,7 @@ def test_sign_in_url(
user_pool_client_id,
user_pool_client_secret,
"http://redirect/url",
"client_state",
region,
domain,
)
Expand All @@ -35,7 +37,7 @@ def test_sign_in_url(
"/login?response_type=code&"
"client_id=545isk1een1lvilb9en643g3vd&"
"redirect_uri=http%3A//redirect/url&"
"state=dc0de448b88af41d1cd06387ac2d5102"
"state=dc0de448b88af41d1cd06387ac2d5102--client_state"
)


Expand All @@ -53,6 +55,7 @@ def test_exchange_code_for_token(
user_pool_client_id,
user_pool_client_secret,
"http://redirect/url",
"client_state",
region,
domain,
)
Expand All @@ -73,6 +76,7 @@ def test_get_user_info(
user_pool_client_id,
user_pool_client_secret,
"http://redirect/url",
"client_state",
region,
domain,
)
Expand Down