Skip to content

Commit

Permalink
Remove hardcoded assumption that the only JWT type supported other th…
Browse files Browse the repository at this point in the history
…an "refresh" is "access" (#401)

Cloudflare Teams JWT auth for example, sets a token with a value of "app".
  • Loading branch information
sambonner authored Mar 9, 2021
1 parent 715f9d5 commit 228822d
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 37 deletions.
10 changes: 6 additions & 4 deletions flask_jwt_extended/internal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ def user_lookup(*args, **kwargs):
return jwt_manager._user_lookup_callback(*args, **kwargs)


def verify_token_type(decoded_token, expected_type):
if decoded_token["type"] != expected_type:
raise WrongTokenError("Only {} tokens are allowed".format(expected_type))
def verify_token_type(decoded_token, refresh):
if not refresh and decoded_token["type"] == "refresh":
raise WrongTokenError("Only non-refresh tokens are allowed")
elif refresh and decoded_token["type"] != "refresh":
raise WrongTokenError("Only refresh tokens are allowed")


def verify_token_not_blocklisted(jwt_header, jwt_data, request_type):
def verify_token_not_blocklisted(jwt_header, jwt_data):
jwt_manager = get_jwt_manager()
if jwt_manager._token_in_blocklist_callback(jwt_header, jwt_data):
raise RevokedTokenError(jwt_header, jwt_data)
Expand Down
3 changes: 0 additions & 3 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ def _decode_jwt(
if "type" not in decoded_token:
decoded_token["type"] = "access"

if decoded_token["type"] not in ("access", "refresh"):
raise JWTDecodeError("Invalid token type: {}".format(decoded_token["type"]))

if "fresh" not in decoded_token:
decoded_token["fresh"] = False

Expand Down
41 changes: 20 additions & 21 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
Defaults to ``False``.
:param refresh:
If ``True``, require a refresh JWT to be verified. If ``False`` require an access
JWT to be verified. Defaults to ``False``.
If ``True``, require a refresh JWT to be verified.
:param locations:
A list of locations to look for the JWT in this request, for example:
Expand All @@ -61,9 +60,11 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=

try:
if refresh:
jwt_data, jwt_header = _decode_jwt_from_request("refresh", locations, fresh)
jwt_data, jwt_header = _decode_jwt_from_request(
locations, fresh, refresh=True
)
else:
jwt_data, jwt_header = _decode_jwt_from_request("access", locations, fresh)
jwt_data, jwt_header = _decode_jwt_from_request(locations, fresh)
except (NoAuthorizationError, InvalidHeaderError):
if not optional:
raise
Expand Down Expand Up @@ -170,15 +171,15 @@ def _decode_jwt_from_headers():
return encoded_token, None


def _decode_jwt_from_cookies(token_type):
if token_type == "access":
cookie_key = config.access_cookie_name
csrf_header_key = config.access_csrf_header_name
csrf_field_key = config.access_csrf_field_name
else:
def _decode_jwt_from_cookies(refresh):
if refresh:
cookie_key = config.refresh_cookie_name
csrf_header_key = config.refresh_csrf_header_name
csrf_field_key = config.refresh_csrf_field_name
else:
cookie_key = config.access_cookie_name
csrf_header_key = config.access_csrf_header_name
csrf_field_key = config.access_csrf_field_name

encoded_token = request.cookies.get(cookie_key)
if not encoded_token:
Expand All @@ -205,15 +206,15 @@ def _decode_jwt_from_query_string():
return encoded_token, None


def _decode_jwt_from_json(token_type):
def _decode_jwt_from_json(refresh):
content_type = request.content_type or ""
if not content_type.startswith("application/json"):
raise NoAuthorizationError("Invalid content-type. Must be application/json.")

if token_type == "access":
token_key = config.json_key
else:
if refresh:
token_key = config.refresh_json_key
else:
token_key = config.json_key

try:
encoded_token = request.json.get(token_key, None)
Expand All @@ -225,7 +226,7 @@ def _decode_jwt_from_json(token_type):
return encoded_token, None


def _decode_jwt_from_request(token_type, locations, fresh):
def _decode_jwt_from_request(locations, fresh, refresh=False):
# All the places we can get a JWT from in this request
get_encoded_token_functions = []

Expand All @@ -238,16 +239,14 @@ def _decode_jwt_from_request(token_type, locations, fresh):
for location in locations:
if location == "cookies":
get_encoded_token_functions.append(
lambda: _decode_jwt_from_cookies(token_type)
lambda: _decode_jwt_from_cookies(refresh)
)
if location == "query_string":
get_encoded_token_functions.append(_decode_jwt_from_query_string)
if location == "headers":
get_encoded_token_functions.append(_decode_jwt_from_headers)
if location == "json":
get_encoded_token_functions.append(
lambda: _decode_jwt_from_json(token_type)
)
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh))

# Try to find the token from one of these locations. It only needs to exist
# in one place to be valid (not every location).
Expand Down Expand Up @@ -277,10 +276,10 @@ def _decode_jwt_from_request(token_type, locations, fresh):
raise NoAuthorizationError(errors[0])

# Additional verifications provided by this extension
verify_token_type(decoded_token, expected_type=token_type)
verify_token_type(decoded_token, refresh)
if fresh:
_verify_token_is_fresh(jwt_header, decoded_token)
verify_token_not_blocklisted(jwt_header, decoded_token, token_type)
verify_token_not_blocklisted(jwt_header, decoded_token)
custom_verification_for_token(jwt_header, decoded_token)

return decoded_token, jwt_header
12 changes: 6 additions & 6 deletions tests/test_decode_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def test_default_decode_token_values(app, default_access_token):
assert decoded["fresh"] is False


def test_bad_token_type(app, default_access_token):
default_access_token["type"] = "banana"
bad_type_token = encode_token(app, default_access_token)
def test_supports_decoding_other_token_types(app, default_access_token):
default_access_token["type"] = "app"
other_token = encode_token(app, default_access_token)

with pytest.raises(JWTDecodeError):
with app.test_request_context():
decode_token(bad_type_token)
with app.test_request_context():
decoded = decode_token(other_token)
assert decoded["type"] == "app"


def test_encode_decode_audience(app):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_jwt_required(app):
# Test refresh token access to jwt_required
response = test_client.get(url, headers=make_headers(refresh_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Only access tokens are allowed"}
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}


def test_fresh_jwt_required(app):
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_fresh_jwt_required(app):

response = test_client.get(url, headers=make_headers(refresh_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Only access tokens are allowed"}
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}

# Test with custom response
@jwtM.needs_fresh_token_loader
Expand Down Expand Up @@ -176,7 +176,7 @@ def test_jwt_optional(app, delta_func):

response = test_client.get(url, headers=make_headers(refresh_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Only access tokens are allowed"}
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}

response = test_client.get(url, headers=None)
assert response.status_code == 200
Expand Down

0 comments on commit 228822d

Please sign in to comment.