From 069e2ee6207ad80193a6398620b687992c829ce6 Mon Sep 17 00:00:00 2001 From: Roberto Prevato Date: Mon, 20 Mar 2023 23:41:22 +0100 Subject: [PATCH] Improve JWKS automatic rotation --- .flake8 | 1 + CHANGELOG.md | 6 +++ guardpost/__about__.py | 2 +- guardpost/authentication.py | 11 +++-- guardpost/authorization.py | 2 - guardpost/jwks/__init__.py | 3 ++ guardpost/jwks/caching.py | 41 +++++++++++++++- guardpost/jwts/__init__.py | 29 ++++++----- tests/test_authentication.py | 1 - tests/test_common.py | 1 - tests/test_jwks.py | 94 ++++++++++++++++++++++++++++++++++++ tests/test_jwts.py | 56 ++++++++++++++++++--- 12 files changed, 216 insertions(+), 31 deletions(-) diff --git a/.flake8 b/.flake8 index a8bdce6..172483a 100644 --- a/.flake8 +++ b/.flake8 @@ -6,3 +6,4 @@ max-complexity = 18 select = B,C,E,F,W,T4,B9 per-file-ignores = guardpost/__init__.py:F401 + tests/test_jwks.py:E501 diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ce20f9..779d576 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.0.1] - 2023-03-20 :sun_with_face: +- Improves the automatic rotation of `JWKS`: when validating `JWTs`, `JWKS` are + refreshed automatically if an unknown `kid` is encountered, and `JWKS` were + last fetched more than `refresh_time` seconds ago (by default 120 seconds). +- Corrects an inconsistency in how `claims` are read in the `User` class. + ## [1.0.0] - 2023-01-07 :star: - Adds built-in support for dependency injection, using the new `ContainerProtocol` in `rodi` v2. diff --git a/guardpost/__about__.py b/guardpost/__about__.py index 5becc17..5c4105c 100644 --- a/guardpost/__about__.py +++ b/guardpost/__about__.py @@ -1 +1 @@ -__version__ = "1.0.0" +__version__ = "1.0.1" diff --git a/guardpost/authentication.py b/guardpost/authentication.py index da15d04..466a67a 100644 --- a/guardpost/authentication.py +++ b/guardpost/authentication.py @@ -26,11 +26,14 @@ def __init__( @property def sub(self) -> Optional[str]: - return self["sub"] + return self.get("sub") def is_authenticated(self) -> bool: return bool(self.authentication_mode) + def get(self, key: str): + return self.claims.get(key) + def __getitem__(self, item): return self.claims[item] @@ -44,15 +47,15 @@ def has_claim_value(self, name: str, value: str) -> bool: class User(Identity): @property def id(self) -> Optional[str]: - return self["id"] or self.sub + return self.get("id") or self.sub @property def name(self) -> Optional[str]: - return self["name"] + return self.get("name") @property def email(self) -> Optional[str]: - return self["email"] + return self.get("email") class AuthenticationHandler(ABC): diff --git a/guardpost/authorization.py b/guardpost/authorization.py index 4f84bfd..383bb41 100644 --- a/guardpost/authorization.py +++ b/guardpost/authorization.py @@ -86,7 +86,6 @@ def _get_message(forced_failure, failed_requirements): class AuthorizationContext: - __slots__ = ("identity", "requirements", "_succeeded", "_failed_forced") def __init__(self, identity: Identity, requirements: Sequence[Requirement]): @@ -222,7 +221,6 @@ async def _handle_with_policy(self, policy: Policy, identity: Identity, scope: A with AuthorizationContext( identity, list(self._get_requirements(policy, scope)) ) as context: - for requirement in context.requirements: if _is_async_handler(type(requirement)): # type: ignore await requirement.handle(context) diff --git a/guardpost/jwks/__init__.py b/guardpost/jwks/__init__.py index b2316c0..1b1af91 100644 --- a/guardpost/jwks/__init__.py +++ b/guardpost/jwks/__init__.py @@ -70,6 +70,9 @@ def from_dict(cls, value) -> "JWK": class JWKS: keys: List[JWK] + def update(self, new_set: "JWKS"): + self.keys = list({key.kid: key for key in self.keys + new_set.keys}.values()) + @classmethod def from_dict(cls, value) -> "JWKS": if "keys" not in value: diff --git a/guardpost/jwks/caching.py b/guardpost/jwks/caching.py index 0698a48..f0e769f 100644 --- a/guardpost/jwks/caching.py +++ b/guardpost/jwks/caching.py @@ -1,7 +1,7 @@ import time from typing import Optional -from . import JWKS, KeysProvider +from . import JWK, JWKS, KeysProvider class CachingKeysProvider(KeysProvider): @@ -9,11 +9,15 @@ class CachingKeysProvider(KeysProvider): Kind of KeysProvider that can cache the result of another KeysProvider. """ - def __init__(self, keys_provider: KeysProvider, cache_time: float) -> None: + def __init__( + self, keys_provider: KeysProvider, cache_time: float, refresh_time: float = 120 + ) -> None: """ Creates a new instance of CachingKeysProvider bound to a given KeysProvider, and caching its result up to an optional amount of seconds described by cache_time. Expiration is disabled if `cache_time` <= 0. + JWKS are refreshed anyway if an unknown `kid` is encountered and the set was + fetched more than `refresh_time` seconds ago. """ super().__init__() @@ -22,6 +26,7 @@ def __init__(self, keys_provider: KeysProvider, cache_time: float) -> None: self._keys: Optional[JWKS] = None self._cache_time = cache_time + self._refresh_time = refresh_time self._last_fetch_time: float = 0 self._keys_provider = keys_provider @@ -34,6 +39,14 @@ async def _fetch_keys(self) -> JWKS: self._last_fetch_time = time.time() return self._keys + async def _refresh_keys(self) -> JWKS: + new_set = await self._fetch_keys() + if self._keys is None: # pragma: no cover + self._keys = new_set + else: + self._keys.update(new_set) + return self._keys + async def get_keys(self) -> JWKS: if self._keys is not None: if self._cache_time > 0 and ( @@ -43,3 +56,27 @@ async def get_keys(self) -> JWKS: else: return self._keys return await self._fetch_keys() + + async def get_key(self, kid: str) -> Optional[JWK]: + """ + Tries to get a JWK by kid. If the JWK is not found and the last time the keys + were fetched is older than `refresh_time` (default 120 seconds), it fetches + again the JWKS from the source. + """ + jwks = await self.get_keys() + + for jwk in jwks.keys.copy(): + if jwk.kid is not None and jwk.kid == kid: + return jwk + + if ( + self._refresh_time > 0 + and time.time() - self._last_fetch_time >= self._refresh_time + ): + jwks = await self._refresh_keys() + + for jwk in jwks.keys.copy(): + if jwk.kid is not None and jwk.kid == kid: + return jwk + + return None diff --git a/guardpost/jwts/__init__.py b/guardpost/jwts/__init__.py index 170e7c8..28f7e85 100644 --- a/guardpost/jwts/__init__.py +++ b/guardpost/jwts/__init__.py @@ -44,7 +44,8 @@ def __init__( require_kid: bool = True, keys_provider: Optional[KeysProvider] = None, keys_url: Optional[str] = None, - cache_time: float = 10800 + cache_time: float = 10800, + refresh_time: float = 120, ) -> None: """ Creates a new instance of JWTValidator. This class only supports validating @@ -72,9 +73,15 @@ def __init__( keys_url : Optional[str], optional If provided, keys are obtained from the given URL through HTTP GET. This parameter is ignored if `keys_provider` is given. - cache_time : float, optional - If >= 0, JWKS are cached in memory and stored for the given amount in - seconds. By default 10800 (3 hours). + cache_time : float + JWKS are cached in memory and stored for the given amount in seconds. + By default 10800 (3 hours). Regardless of this parameter, JWKS are refreshed + automatically if an unknown kid is met and JWKS were last fetched more than + `refresh_time` earlier (in seconds). + refresh_time : float + JWKS are refreshed automatically if an unknown `kid` is encountered, and + JWKS were last fetched more than `refresh_time` seconds ago (by default + 120 seconds) """ if keys_provider: pass @@ -89,13 +96,12 @@ def __init__( "`authority`, or `keys_provider`." ) - if cache_time: - keys_provider = CachingKeysProvider(keys_provider, cache_time) + keys_provider = CachingKeysProvider(keys_provider, cache_time, refresh_time) self._valid_issuers = list(valid_issuers) self._valid_audiences = list(valid_audiences) self._algorithms = list(algorithms) - self._keys_provider: KeysProvider = keys_provider + self._keys_provider = keys_provider self.require_kid = require_kid self.logger = get_logger() @@ -103,12 +109,11 @@ async def get_jwks(self) -> JWKS: return await self._keys_provider.get_keys() async def get_jwk(self, kid: str) -> JWK: - jwks = await self.get_jwks() + key = await self._keys_provider.get_key(kid) - for jwk in jwks.keys: - if jwk.kid is not None and jwk.kid == kid: - return jwk - raise InvalidAccessToken("kid not recognized") + if key is None: + raise InvalidAccessToken("kid not recognized") + return key def _validate_jwt_by_key( self, access_token: str, jwk: JWK diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 3803192..3ba08d1 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -112,7 +112,6 @@ async def authenticate(self, context: Request): @pytest.mark.asyncio async def test_strategy_throws_for_missing_context(): - strategy = AuthenticationStrategy() with raises(ValueError): diff --git a/tests/test_common.py b/tests/test_common.py index ad55273..6ad42fb 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -147,7 +147,6 @@ def test_authorization_strategy_set_default_fluent(): def test_unauthorized_error_supports_error_and_description(): - error = UnauthorizedError( None, [], diff --git a/tests/test_jwks.py b/tests/test_jwks.py index 31cb963..d116746 100644 --- a/tests/test_jwks.py +++ b/tests/test_jwks.py @@ -28,3 +28,97 @@ def test_jwk_raises_for_unsupported_type(): with pytest.raises(ValueError): JWK.from_dict({"kty": "RSA"}) + + +def test_jwks_update(): + jwks_1 = JWKS.from_dict( + { + "keys": [ + { + "kty": "RSA", + "kid": "0", + "n": "xzO7x0gEMbktuu5RLUqiABJNqt4kdm_5ucsKgSdHUdUcbkG28dLAikoFTki9awmyapSbO84zlKMaH24obOe44hd32sdeMOdQp0TxpxE95HfYVFuAWdfCM4Bz_x32Sq51e7x6oZd09vODFFbwTlMJ27LPAEuI5G6UVQKxhIB_wA2FOPkbHeDncB7jYv9kLidvpNgp5PC-aKHKv9ay6gi7M-wUQEpeQMjpyDFN2p_q12BWSUbsRwOjhYtCuSmmBNh07MizzVIQjpmZU5f6qmZHw--iJSBD52wsI87itYbBwRcDN5ffColkFpA8va0hDlShI2qVmwtQ3LUpZVivKuJOSw==", + "e": "AQAB", + }, + { + "kty": "RSA", + "kid": "1", + "n": "3a-KHqLSxXba1e-qa2cWaV6VNd3LsNptZsbd1eZj402lehEbHm8ZdjHlZNwirPeqhvHYbCGRKfqLV2jE1UacfkCmcP8u7klENFbl01IyA8-MiVfmRB6BWlaBNS0NCDIGJ1GY7aPfEOJgGc5L4laIAD6iSVTfUwNtkLVAHXx5OQjJIVIxk6Vkji1n2JvpEO9337Kp96-AqfpIFWyCLg56uGJfK6XdlDYZvPm17xorcLGUB9MBsOID7PbdqeVnmaKW9aFNZj1OaDTZAsNqnxGkmsp3wkds8Th3raIbYvotQEGm1BCdEbqj3hu05bIEZuQbWuNTIseYCKFw7GJXawEKzQ==", + "e": "AQAB", + }, + ] + } + ) + + jwks_2 = JWKS.from_dict( + { + "keys": [ + { + "kty": "RSA", + "kid": "2", + "n": "nk4LTnUzUBqmQTdMmNaHRU6FHHHXfW7TwOoVnCSu36PKyFovRGs5Qiec1VBmF4PZCXlkAwmpBPf4iBbWr3xXU4lE8d3OBuqnf-qFWbOCkyNFp_kyqHu7SlGHJhYilfRzKqDGJ5FqIafBpXID_FsxTqNi-mf98G_jm_QoF5ifMAPUf0eVTCjzs9fcawnKDbeaAED3SbYJt-EVjdcOJalilXJWPNdpGx8ouF1Zn77NDEbj6_1BBk22AZI1yQzDy8c08HlEK1NQgToJyQ-CLP6deHYiHrxMSZe83WbkCvxr1PLMFZlUTWh2AcgbiR9zJARu7nk6PWTbBhreuXRL5meGMQ==", + "e": "AQAB", + }, + { + "kty": "RSA", + "kid": "3", + "n": "v_6KlxHChgEdhvV5t6cDi2h-u2y355dxkwIp1YM4YINXKNStSnFUTkRIPXAY9H15kn6CuWCSWXl7jRwCPm5UOBnC9TjKJTuTK_IVJrTqd1dFkxOEsesKKBPsc0nBjtYMc0c_74K0OxJphy6I4d0M6gXWVOx1avOMEU7LQHE18WtfSYXtBk_Q51foM8StqFARCKAdyRZRXwhtS71lPrHNLhU2aayKBKpWL-r-q4KZGwDLtw0z3bHR5Z_bIJVGushkYLN_DaJvkvypb1y7Lq6ozMovLA5xHgYhv6VCUGWOAJWo9PZXjtwjrO8gXME-msBmB7iO-ltV0FM3O9wTqsJJxw==", + "e": "AQAB", + }, + ] + } + ) + + jwks_1.update(jwks_2) + + assert len(jwks_1.keys) == 4 + + assert [key.kid for key in jwks_1.keys] == "0 1 2 3".split() + + +def test_jwks_update_override(): + jwks_1 = JWKS.from_dict( + { + "keys": [ + { + "kty": "RSA", + "kid": "0", + "n": "xzO7x0gEMbktuu5RLUqiABJNqt4kdm_5ucsKgSdHUdUcbkG28dLAikoFTki9awmyapSbO84zlKMaH24obOe44hd32sdeMOdQp0TxpxE95HfYVFuAWdfCM4Bz_x32Sq51e7x6oZd09vODFFbwTlMJ27LPAEuI5G6UVQKxhIB_wA2FOPkbHeDncB7jYv9kLidvpNgp5PC-aKHKv9ay6gi7M-wUQEpeQMjpyDFN2p_q12BWSUbsRwOjhYtCuSmmBNh07MizzVIQjpmZU5f6qmZHw--iJSBD52wsI87itYbBwRcDN5ffColkFpA8va0hDlShI2qVmwtQ3LUpZVivKuJOSw==", + "e": "AQAB", + }, + { + "kty": "RSA", + "kid": "1", + "n": "3a-KHqLSxXba1e-qa2cWaV6VNd3LsNptZsbd1eZj402lehEbHm8ZdjHlZNwirPeqhvHYbCGRKfqLV2jE1UacfkCmcP8u7klENFbl01IyA8-MiVfmRB6BWlaBNS0NCDIGJ1GY7aPfEOJgGc5L4laIAD6iSVTfUwNtkLVAHXx5OQjJIVIxk6Vkji1n2JvpEO9337Kp96-AqfpIFWyCLg56uGJfK6XdlDYZvPm17xorcLGUB9MBsOID7PbdqeVnmaKW9aFNZj1OaDTZAsNqnxGkmsp3wkds8Th3raIbYvotQEGm1BCdEbqj3hu05bIEZuQbWuNTIseYCKFw7GJXawEKzQ==", + "e": "AQAB", + }, + ] + } + ) + + jwks_2 = JWKS.from_dict( + { + "keys": [ + { + "kty": "RSA", + "kid": "0", + "n": "nk4LTnUzUBqmQTdMmNaHRU6FHHHXfW7TwOoVnCSu36PKyFovRGs5Qiec1VBmF4PZCXlkAwmpBPf4iBbWr3xXU4lE8d3OBuqnf-qFWbOCkyNFp_kyqHu7SlGHJhYilfRzKqDGJ5FqIafBpXID_FsxTqNi-mf98G_jm_QoF5ifMAPUf0eVTCjzs9fcawnKDbeaAED3SbYJt-EVjdcOJalilXJWPNdpGx8ouF1Zn77NDEbj6_1BBk22AZI1yQzDy8c08HlEK1NQgToJyQ-CLP6deHYiHrxMSZe83WbkCvxr1PLMFZlUTWh2AcgbiR9zJARu7nk6PWTbBhreuXRL5meGMQ==", + "e": "AQAB", + }, + { + "kty": "RSA", + "kid": "3", + "n": "v_6KlxHChgEdhvV5t6cDi2h-u2y355dxkwIp1YM4YINXKNStSnFUTkRIPXAY9H15kn6CuWCSWXl7jRwCPm5UOBnC9TjKJTuTK_IVJrTqd1dFkxOEsesKKBPsc0nBjtYMc0c_74K0OxJphy6I4d0M6gXWVOx1avOMEU7LQHE18WtfSYXtBk_Q51foM8StqFARCKAdyRZRXwhtS71lPrHNLhU2aayKBKpWL-r-q4KZGwDLtw0z3bHR5Z_bIJVGushkYLN_DaJvkvypb1y7Lq6ozMovLA5xHgYhv6VCUGWOAJWo9PZXjtwjrO8gXME-msBmB7iO-ltV0FM3O9wTqsJJxw==", + "e": "AQAB", + }, + ] + } + ) + + jwks_1.update(jwks_2) + + assert len(jwks_1.keys) == 3 + + key_0 = next((key for key in jwks_1.keys if key.kid == "0"), None) + assert key_0 is not None + assert key_0.n == jwks_2.keys[0].n diff --git a/tests/test_jwts.py b/tests/test_jwts.py index 09bb292..f303495 100644 --- a/tests/test_jwts.py +++ b/tests/test_jwts.py @@ -1,10 +1,10 @@ import time -from typing import Any, Dict +from typing import Any, Dict, Iterable import jwt import pytest -from guardpost.jwks import InMemoryKeysProvider, KeysProvider +from guardpost.jwks import JWKS, InMemoryKeysProvider, KeysProvider from guardpost.jwks.caching import CachingKeysProvider from guardpost.jwks.openid import AuthorityKeysProvider from guardpost.jwks.urls import URLKeysProvider @@ -19,6 +19,14 @@ def default_keys_provider() -> KeysProvider: return InMemoryKeysProvider(get_test_jwks()) +class MockedKeysProvider(KeysProvider): + def __init__(self, mocked: Iterable[JWKS]) -> None: + self.mocked = iter(mocked) + + async def get_keys(self) -> JWKS: + return next(self.mocked) + + def get_access_token( kid: str, payload: Dict[str, Any], include_headers: bool = True, fake_kid: str = "" ): @@ -35,14 +43,20 @@ def get_access_token( ) -async def _valid_tokens_scenario(validator: JWTValidator, include_headers: bool = True): - for i in range(5): - payload = {"aud": "a", "iss": "b"} - valid_token = get_access_token(str(i), payload, include_headers=include_headers) +async def _valid_token_scenario( + kid: str, validator: JWTValidator, include_headers: bool = True +): + payload = {"aud": "a", "iss": "b"} + valid_token = get_access_token(kid, payload, include_headers=include_headers) + + value = await validator.validate_jwt(valid_token) + + assert value == payload - value = await validator.validate_jwt(valid_token) - assert value == payload +async def _valid_tokens_scenario(validator: JWTValidator, include_headers: bool = True): + for i in range(5): + await _valid_token_scenario(str(i), validator, include_headers) def test_jwt_validator_raises_for_missing_key_source(): @@ -71,6 +85,32 @@ async def test_jwt_validator_cache_expiration(default_keys_provider): await _valid_tokens_scenario(validator) +@pytest.mark.asyncio +async def test_jwt_validator_fetches_tokens_again_for_unknown_kid(): + keys = get_test_jwks() + # configure a key provider that returns the given JWKS in sequence + keys_provider = MockedKeysProvider([JWKS(keys.keys[0:2]), JWKS(keys.keys[2:])]) + validator = JWTValidator( + valid_audiences=["a"], + valid_issuers=["b"], + keys_provider=keys_provider, + cache_time=10, + refresh_time=0.2, + ) + await _valid_token_scenario("0", validator) + await _valid_token_scenario("1", validator) + + # this must fail because tokens were just fetched, and kid "2" is not present + with pytest.raises(InvalidAccessToken): + await _valid_token_scenario("2", validator) + + time.sleep(0.3) + # now the JWTValidator should fetch automatically the new keys + await _valid_token_scenario("2", validator) + await _valid_token_scenario("3", validator) + await _valid_token_scenario("4", validator) + + @pytest.mark.asyncio async def test_jwt_validator_blocks_forged_access_tokens(default_keys_provider): validator = JWTValidator(