Skip to content

Commit

Permalink
Merge pull request #67 from psafont/identity_fix
Browse files Browse the repository at this point in the history
Fix regressions introduced in 3.1.0
  • Loading branch information
vimalloc authored Jul 13, 2017
2 parents 81a4363 + f150fe0 commit 4701f93
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 35 deletions.
6 changes: 3 additions & 3 deletions examples/database_blacklist/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def login():
refresh_token = create_refresh_token(identity=username)

# Store the tokens in our store with a status of not currently revoked.
add_token_to_database(access_token)
add_token_to_database(refresh_token)
add_token_to_database(access_token, app.config['JWT_IDENTITY_CLAIM'])
add_token_to_database(refresh_token, app.config['JWT_IDENTITY_CLAIM'])

ret = {
'access_token': access_token,
Expand All @@ -72,7 +72,7 @@ def refresh():
# Do the same thing that we did in the login endpoint here
current_user = get_jwt_identity()
access_token = create_access_token(identity=current_user)
add_token_to_database(access_token)
add_token_to_database(access_token, app.config['JWT_IDENTITY_CLAIM'])
return jsonify({'access_token': access_token}), 201

# Provide a way for a user to look at their tokens
Expand Down
5 changes: 3 additions & 2 deletions examples/database_blacklist/blacklist_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ def _epoch_utc_to_datetime(epoch_utc):
return datetime.fromtimestamp(epoch_utc)


def add_token_to_database(encoded_token):
def add_token_to_database(encoded_token, identity_claim):
"""
Adds a new token to the database. It is not revoked when it is added.
:param identity_claim:
"""
decoded_token = decode_token(encoded_token)
jti = decoded_token['jti']
token_type = decoded_token['type']
user_identity = decoded_token['identity']
user_identity = decoded_token[identity_claim]
expires = _epoch_utc_to_datetime(decoded_token['exp'])
revoked = False

Expand Down
8 changes: 4 additions & 4 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def jwt_required(fn):
def wrapper(*args, **kwargs):
jwt_data = _decode_jwt_from_request(request_type='access')
ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
_load_user(jwt_data[config.identity_claim])
return fn(*args, **kwargs)
return wrapper

Expand All @@ -53,7 +53,7 @@ def wrapper(*args, **kwargs):
try:
jwt_data = _decode_jwt_from_request(request_type='access')
ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
_load_user(jwt_data[config.identity_claim])
except NoAuthorizationError:
pass
return fn(*args, **kwargs)
Expand All @@ -77,7 +77,7 @@ def wrapper(*args, **kwargs):
raise FreshTokenRequired('Fresh token required')

ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
_load_user(jwt_data[config.identity_claim])
return fn(*args, **kwargs)
return wrapper

Expand All @@ -92,7 +92,7 @@ def jwt_refresh_token_required(fn):
def wrapper(*args, **kwargs):
jwt_data = _decode_jwt_from_request(request_type='refresh')
ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
_load_user(jwt_data[config.identity_claim])
return fn(*args, **kwargs)
return wrapper

Expand Down
1 change: 1 addition & 0 deletions tests/test_blacklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def setUp(self):
self.app = Flask(__name__)
self.app.secret_key = 'super=secret'
self.app.config['JWT_BLACKLIST_ENABLED'] = True
self.app.config['JWT_IDENTITY_CLAIM'] = 'sub'
self.jwt_manager = JWTManager(self.app)
self.client = self.app.test_client()
self.blacklist = set()
Expand Down
54 changes: 30 additions & 24 deletions tests/test_jwt_encode_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,25 @@ def test_encode_access_token(self):
algorithm = 'HS256'
token_expire_delta = timedelta(minutes=5)
user_claims = {'foo': 'bar'}
identity_claim = 'identity'

# Check with a fresh token
with self.app.test_request_context():
identity = 'user1'
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
fresh=True, user_claims=user_claims, csrf=False,
identity_claim='identity')
identity_claim=identity_claim)
data = jwt.decode(token, secret, algorithms=[algorithm])
self.assertIn('exp', data)
self.assertIn('iat', data)
self.assertIn('nbf', data)
self.assertIn('jti', data)
self.assertIn('identity', data)
self.assertIn(identity_claim, data)
self.assertIn('fresh', data)
self.assertIn('type', data)
self.assertIn('user_claims', data)
self.assertNotIn('csrf', data)
self.assertEqual(data['identity'], identity)
self.assertEqual(data[identity_claim], identity)
self.assertEqual(data['fresh'], True)
self.assertEqual(data['type'], 'access')
self.assertEqual(data['user_claims'], user_claims)
Expand All @@ -61,18 +62,18 @@ def test_encode_access_token(self):
identity = 12345 # identity can be anything json serializable
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
fresh=False, user_claims=user_claims, csrf=True,
identity_claim='identity')
identity_claim=identity_claim)
data = jwt.decode(token, secret, algorithms=[algorithm])
self.assertIn('exp', data)
self.assertIn('iat', data)
self.assertIn('nbf', data)
self.assertIn('jti', data)
self.assertIn('identity', data)
self.assertIn(identity_claim, data)
self.assertIn('fresh', data)
self.assertIn('type', data)
self.assertIn('user_claims', data)
self.assertIn('csrf', data)
self.assertEqual(data['identity'], identity)
self.assertEqual(data[identity_claim], identity)
self.assertEqual(data['fresh'], False)
self.assertEqual(data['type'], 'access')
self.assertEqual(data['user_claims'], user_claims)
Expand All @@ -86,16 +87,17 @@ def test_encode_invalid_access_token(self):
# Check with non-serializable json
with self.app.test_request_context():
user_claims = datetime
identity_claim = 'identity'
with self.assertRaises(Exception):
encode_access_token('user1', 'secret', 'HS256',
timedelta(hours=1), True, user_claims,
csrf=True, identity_claim='identity')
csrf=True, identity_claim=identity_claim)

user_claims = {'foo': timedelta(hours=4)}
with self.assertRaises(Exception):
encode_access_token('user1', 'secret', 'HS256',
timedelta(hours=1), True, user_claims,
csrf=True, identity_claim='identity')
csrf=True, identity_claim=identity_claim)

def test_encode_refresh_token(self):
secret = 'super-totally-secret-key'
Expand Down Expand Up @@ -212,25 +214,27 @@ def test_decode_jwt(self):

def test_decode_invalid_jwt(self):
with self.app.test_request_context():
identity_claim = 'identity'
# Verify underlying pyjwt expires verification works
with self.assertRaises(jwt.ExpiredSignatureError):
token_data = {
'exp': datetime.utcnow() - timedelta(minutes=5),
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Missing jti
with self.assertRaises(JWTDecodeError):

token_data = {
'exp': datetime.utcnow() + timedelta(minutes=5),
'identity': 'banana',
identity_claim: 'banana',
'type': 'refresh'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Missing identity
with self.assertRaises(JWTDecodeError):
Expand All @@ -241,83 +245,85 @@ def test_decode_invalid_jwt(self):
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Non-matching identity claim
with self.assertRaises(JWTDecodeError):
token_data = {
'exp': datetime.utcnow() + timedelta(minutes=5),
'identity': 'banana',
identity_claim: 'banana',
'type': 'refresh'
}
other_identity_claim = 'sub'
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
self.assertNotEqual(identity_claim, other_identity_claim)
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='sub')
csrf=False, identity_claim=other_identity_claim)

# Missing type
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Missing fresh in access token
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
'type': 'access',
'user_claims': {}
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Missing user claims in access token
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
'type': 'access',
'fresh': True
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Bad token type
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
'type': 'banana',
'fresh': True,
'user_claims': 'banana'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')
csrf=False, identity_claim=identity_claim)

# Missing csrf in csrf enabled token
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
'type': 'access',
'fresh': True,
'user_claims': 'banana'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True,
identity_claim='identity')
identity_claim=identity_claim)

def test_create_jwt_with_object(self):
# Complex object to test building a JWT from. Normally if you are using
Expand Down
9 changes: 7 additions & 2 deletions tests/test_protected_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def setUp(self):
self.app.config['JWT_ALGORITHM'] = 'HS256'
self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=1)
self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(seconds=1)
self.app.config['JWT_IDENTITY_CLAIM'] = 'sub'
self.jwt_manager = JWTManager(self.app)
self.client = self.app.test_client()

Expand Down Expand Up @@ -454,6 +455,9 @@ def claims():
claims_keys = [claim for claim in jwt]
return jsonify(claims_keys), 200

# Grab custom identity claim
identity_claim = self.app.config['JWT_IDENTITY_CLAIM']

# Login
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
Expand All @@ -466,7 +470,7 @@ def claims():
self.assertIn('iat', data)
self.assertIn('nbf', data)
self.assertIn('jti', data)
self.assertIn('identity', data)
self.assertIn(identity_claim, data)
self.assertIn('fresh', data)
self.assertIn('type', data)
self.assertIn('user_claims', data)
Expand Down Expand Up @@ -836,12 +840,13 @@ def test_access_endpoints_with_cookie_missing_csrf_field(self):

def test_access_endpoints_with_cookie_csrf_claim_not_string(self):
now = datetime.utcnow()
identity_claim = self.app.config['JWT_IDENTITY_CLAIM']
token_data = {
'exp': now + timedelta(minutes=5),
'iat': now,
'nbf': now,
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'type': 'refresh',
'csrf': 404
}
Expand Down

0 comments on commit 4701f93

Please sign in to comment.