diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index e2cb65e8..c0aacc78 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -240,4 +240,8 @@ def user_claims_key(self): def exempt_methods(self): return {"OPTIONS"} + @property + def json_encoder(self): + return current_app.json_encoder + config = _Config() diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index f258f5a9..ab63422b 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -379,6 +379,7 @@ def _create_refresh_token(self, identity, expires_delta=None): expires_delta=expires_delta, csrf=config.csrf_protect, identity_claim_key=config.identity_claim_key, + json_encoder=config.json_encoder ) return refresh_token @@ -395,7 +396,8 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None): user_claims=self._user_claims_callback(identity), csrf=config.csrf_protect, identity_claim_key=config.identity_claim_key, - user_claims_key=config.user_claims_key + user_claims_key=config.user_claims_key, + json_encoder=config.json_encoder ) return access_token diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 0750bdb9..4aebb9d8 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -13,7 +13,8 @@ def _create_csrf_token(): return str(uuid.uuid4()) -def _encode_jwt(additional_token_data, expires_delta, secret, algorithm): +def _encode_jwt(additional_token_data, expires_delta, secret, algorithm, + json_encoder=None): uid = str(uuid.uuid4()) now = datetime.datetime.utcnow() token_data = { @@ -31,7 +32,8 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm): def encode_access_token(identity, secret, algorithm, expires_delta, fresh, - user_claims, csrf, identity_claim_key, user_claims_key): + user_claims, csrf, identity_claim_key, user_claims_key, + json_encoder=None): """ Creates a new encoded (utf-8) access token. @@ -70,11 +72,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh, if csrf: token_data['csrf'] = _create_csrf_token() - return _encode_jwt(token_data, expires_delta, secret, algorithm) + return _encode_jwt(token_data, expires_delta, secret, algorithm, + json_encoder=json_encoder) def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, - identity_claim_key): + identity_claim_key, json_encoder=None): """ Creates a new encoded (utf-8) refresh token. @@ -95,7 +98,8 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, } if csrf: token_data['csrf'] = _create_csrf_token() - return _encode_jwt(token_data, expires_delta, secret, algorithm) + return _encode_jwt(token_data, expires_delta, secret, algorithm, + json_encoder=json_encoder) def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, diff --git a/tests/test_config.py b/tests/test_config.py index b3149022..c69bfbc6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,6 +3,7 @@ import pytest from datetime import timedelta from flask import Flask +from flask.json import JSONEncoder from flask_jwt_extended import JWTManager from flask_jwt_extended.config import config @@ -56,6 +57,8 @@ def test_default_configs(app): assert config.identity_claim_key == 'identity' assert config.user_claims_key == 'user_claims' + assert config.json_encoder is app.json_encoder + def test_override_configs(app): app.config['JWT_TOKEN_LOCATION'] = ['cookies'] @@ -91,6 +94,11 @@ def test_override_configs(app): app.config['JWT_IDENTITY_CLAIM'] = 'foo' app.config['JWT_USER_CLAIMS'] = 'bar' + class CustomJSONEncoder(JSONEncoder): + pass + + app.json_encoder = CustomJSONEncoder + with app.test_request_context(): assert config.token_location == ['cookies'] assert config.jwt_in_cookies is True @@ -131,6 +139,8 @@ def test_override_configs(app): assert config.identity_claim_key == 'foo' assert config.user_claims_key == 'bar' + assert config.json_encoder is CustomJSONEncoder + def test_tokens_never_expire(app): app.config['JWT_ACCESS_TOKEN_EXPIRES'] = False diff --git a/tests/utils.py b/tests/utils.py index c0ba1b32..f211915f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,7 +8,8 @@ def encode_token(app, token_data): token = jwt.encode( token_data, config.decode_key, - algorithm=config.algorithm + algorithm=config.algorithm, + json_encoder=config.json_encoder ) return token.decode('utf-8')