Skip to content

Commit

Permalink
Allow passing in a single string location to the locations kwarg (#402)
Browse files Browse the repository at this point in the history
Allow locations kwarg for jwt_required() to be a string
  • Loading branch information
vimalloc authored Mar 9, 2021
1 parent 228822d commit 6f66f0f
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 20 deletions.
4 changes: 2 additions & 2 deletions docs/add_custom_data_claims.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Storing Data in Access Tokens
=============================
Storing Additional Data in JWTs
===============================
You may want to store additional information in the access token which you could
later access in the protected views. This can be done using the ``additional_claims``
argument with the :func:`~flask_jwt_extended.create_access_token` or
Expand Down
2 changes: 1 addition & 1 deletion examples/additional_data_in_access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def login():
# In a protected view, get the claims you added to the jwt with the
# get_jwt() method
@app.route("/protected", methods=["GET"])
@jwt_required
@jwt_required()
def protected():
claims = get_jwt()
return jsonify(foo=claims["foo"])
Expand Down
2 changes: 1 addition & 1 deletion examples/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def login():


@app.route("/protected", methods=["GET"])
@jwt_required
@jwt_required()
def protected():
return jsonify(hello="world")

Expand Down
34 changes: 18 additions & 16 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
If ``True``, require a refresh JWT to be verified.
:param locations:
A list of locations to look for the JWT in this request, for example:
``['headers', 'cookies']``. Defaluts to ``None`` which indicates that JWTs
will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION``
configuration option.
A location or list of locations to look for the JWT in this request, for
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None``
which indicates that JWTs will be looked for in the locations defined by the
``JWT_TOKEN_LOCATION`` configuration option.
"""
if request.method in config.exempt_methods:
return
Expand Down Expand Up @@ -103,10 +103,10 @@ def jwt_required(optional=False, fresh=False, refresh=False, locations=None):
requires an access JWT to access this endpoint. Defaults to ``False``.
:param locations:
A list of locations to look for the JWT in this request, for example:
``['headers', 'cookies']``. Defaluts to ``None`` which indicates that JWTs
will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION``
configuration option.
A location or list of locations to look for the JWT in this request, for
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None``
which indicates that JWTs will be looked for in the locations defined by the
``JWT_TOKEN_LOCATION`` configuration option.
"""

def wrapper(fn):
Expand Down Expand Up @@ -227,26 +227,28 @@ def _decode_jwt_from_json(refresh):


def _decode_jwt_from_request(locations, fresh, refresh=False):
# All the places we can get a JWT from in this request
get_encoded_token_functions = []
# Figure out what locations to look for the JWT in this request
if isinstance(locations, str):
locations = [locations]

# Get locations in the order specified by the decorator or JWT_TOKEN_LOCATION
# configuration.
if not locations:
locations = config.token_location

# Add the functions in the order specified by locations.
# Get the decode functions in the order specified by locations.
get_encoded_token_functions = []
for location in locations:
if location == "cookies":
get_encoded_token_functions.append(
lambda: _decode_jwt_from_cookies(refresh)
)
if location == "query_string":
elif location == "query_string":
get_encoded_token_functions.append(_decode_jwt_from_query_string)
if location == "headers":
elif location == "headers":
get_encoded_token_functions.append(_decode_jwt_from_headers)
if location == "json":
elif location == "json":
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh))
else:
raise RuntimeError(f"'{location}' is not a valid location")

# Try to find the token from one of these locations. It only needs to exist
# in one place to be valid (not every location).
Expand Down
32 changes: 32 additions & 0 deletions tests/test_view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,38 @@ def test_jwt_optional(app, delta_func):
assert response.get_json() == {"msg": "Token has expired"}


def test_override_jwt_location(app):
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]

@app.route("/protected_other")
@jwt_required(locations="headers")
def protected_other():
return jsonify(foo="bar")

@app.route("/protected_invalid")
@jwt_required(locations="INVALID_LOCATION")
def protected_invalid():
return jsonify(foo="bar")

test_client = app.test_client()
with app.test_request_context():
access_token = create_access_token("username")

url = "/protected_other"
response = test_client.get(url, headers=make_headers(access_token))
assert response.get_json() == {"foo": "bar"}
assert response.status_code == 200

url = "/protected"
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 401
assert response.get_json() == {"msg": 'Missing cookie "access_token_cookie"'}

url = "/protected_invalid"
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 500


def test_invalid_jwt(app):
url = "/protected"
jwtM = get_jwt_manager(app)
Expand Down

0 comments on commit 6f66f0f

Please sign in to comment.