Skip to content

Commit

Permalink
Merge pull request #66 from psafont/sub-decode
Browse files Browse the repository at this point in the history
Allow changing subject claim
  • Loading branch information
vimalloc authored Jul 13, 2017
2 parents 17c3254 + f8d83f2 commit 64ec456
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 48 deletions.
3 changes: 3 additions & 0 deletions docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ General Options:
such as ``RS*`` or ``ES*``. PEM format expected.
``JWT_PRIVATE_KEY`` The private key needed for asymmetric based signing algorithms,
such as ``RS*`` or ``ES*``. PEM format expected.
``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used as source of identity.
For interoperativity, the JWT RFC recommends using ``'sub'``.
Defaults to ``'identity'``.
================================= =========================================


Expand Down
4 changes: 4 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def cookie_max_age(self):
# seconds a long ways in the future
return None if self.session_cookie else 2147483647 # 2^31

@property
def identity_claim(self):
return current_app.config['JWT_IDENTITY_CLAIM']

config = _Config()


8 changes: 6 additions & 2 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def _set_default_configuration_options(app):
app.config.setdefault('JWT_BLACKLIST_ENABLED', False)
app.config.setdefault('JWT_BLACKLIST_TOKEN_CHECKS', ['access', 'refresh'])

app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')

def user_claims_loader(self, callback):
"""
This sets the callback method for adding custom user claims to a JWT.
Expand Down Expand Up @@ -319,7 +321,8 @@ def create_refresh_token(self, identity, expires_delta=None):
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=expires_delta,
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)
return refresh_token

Expand Down Expand Up @@ -352,7 +355,8 @@ def create_access_token(self, identity, fresh=False, expires_delta=None):
expires_delta=expires_delta,
fresh=fresh,
user_claims=self._user_claims_callback(identity),
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)
return access_token

17 changes: 10 additions & 7 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm):


def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
user_claims, csrf):
user_claims, csrf, identity_claim):
"""
Creates a new encoded (utf-8) access token.
Expand All @@ -40,11 +40,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
be json serializable
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim: Which claim should be used to store the identity in
:return: Encoded access token
"""
# Create the jwt
token_data = {
'identity': identity,
identity_claim: identity,
'fresh': fresh,
'type': 'access',
'user_claims': user_claims,
Expand All @@ -54,7 +55,7 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
return _encode_jwt(token_data, expires_delta, secret, algorithm)


def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, identity_claim):
"""
Creates a new encoded (utf-8) refresh token.
Expand All @@ -65,18 +66,19 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
(datetime.timedelta)
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim: Which claim should be used to store the identity in
:return: Encoded refresh token
"""
token_data = {
'identity': identity,
identity_claim: identity,
'type': 'refresh',
}
if csrf:
token_data['csrf'] = _create_csrf_token()
return _encode_jwt(token_data, expires_delta, secret, algorithm)


def decode_jwt(encoded_token, secret, algorithm, csrf):
def decode_jwt(encoded_token, secret, algorithm, csrf, identity_claim):
"""
Decodes an encoded JWT
Expand All @@ -85,6 +87,7 @@ def decode_jwt(encoded_token, secret, algorithm, csrf):
:param algorithm: Algorithm used to encode the JWT
:param csrf: If this token is expected to have a CSRF double submit
value present (boolean)
:param identity_claim: expected claim that is used to identify the subject
:return: Dictionary containing contents of the JWT
"""
# This call verifies the ext, iat, and nbf claims
Expand All @@ -93,8 +96,8 @@ def decode_jwt(encoded_token, secret, algorithm, csrf):
# Make sure that any custom claims we expect in the token are present
if 'jti' not in data:
raise JWTDecodeError("Missing claim: jti")
if 'identity' not in data:
raise JWTDecodeError("Missing claim: identity")
if identity_claim not in data:
raise JWTDecodeError("Missing claim: {}".format(identity_claim))
if 'type' not in data or data['type'] not in ('refresh', 'access'):
raise JWTDecodeError("Missing or invalid claim: type")
if data['type'] == 'access':
Expand Down
13 changes: 10 additions & 3 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_jwt_identity():
Returns the identity of the JWT in this context. If no JWT is present,
None is returned.
"""
return get_raw_jwt().get('identity', None)
return get_raw_jwt().get(config.identity_claim, None)


def get_jwt_claims():
Expand Down Expand Up @@ -63,7 +63,8 @@ def decode_token(encoded_token):
encoded_token=encoded_token,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)


Expand Down Expand Up @@ -106,7 +107,13 @@ def token_in_blacklist(*args, **kwargs):


def get_csrf_token(encoded_token):
token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True)
token = decode_jwt(
encoded_token,
config.decode_key,
config.algorithm,
csrf=True,
identity_claim=config.identity_claim
)
return token['csrf']


Expand Down
11 changes: 9 additions & 2 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,13 @@ def _decode_jwt_from_headers():
raise InvalidHeaderError(msg)
token = parts[1]

return decode_jwt(token, config.decode_key, config.algorithm, csrf=False)
return decode_jwt(
encoded_token=token,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=False,
identity_claim=config.identity_claim
)


def _decode_jwt_from_cookies(request_type):
Expand All @@ -163,7 +169,8 @@ def _decode_jwt_from_cookies(request_type):
encoded_token=encoded_token,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)

# Verify csrf double submit tokens match if required
Expand Down
6 changes: 6 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def test_default_configs(self):
self.assertEqual(config.decode_key, self.app.secret_key)
self.assertEqual(config.cookie_max_age, None)

self.assertEqual(config.identity_claim, 'identity')

def test_override_configs(self):
self.app.config['JWT_TOKEN_LOCATION'] = ['cookies']
self.app.config['JWT_HEADER_NAME'] = 'TestHeader'
Expand Down Expand Up @@ -86,6 +88,8 @@ def test_override_configs(self):

self.app.secret_key = 'banana'

self.app.config['JWT_IDENTITY_CLAIM'] = 'foo'

with self.app.test_request_context():
self.assertEqual(config.token_location, ['cookies'])
self.assertEqual(config.jwt_in_cookies, True)
Expand Down Expand Up @@ -122,6 +126,8 @@ def test_override_configs(self):

self.assertEqual(config.cookie_max_age, 2147483647)

self.assertEqual(config.identity_claim, 'foo')

def test_invalid_config_options(self):
with self.app.test_request_context():
self.app.config['JWT_TOKEN_LOCATION'] = 'banana'
Expand Down
Loading

0 comments on commit 64ec456

Please sign in to comment.