Skip to content

Commit

Permalink
Improve JWKS automatic rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertoPrevato authored Mar 20, 2023
1 parent 4537386 commit 069e2ee
Show file tree
Hide file tree
Showing 12 changed files with 216 additions and 31 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ max-complexity = 18
select = B,C,E,F,W,T4,B9
per-file-ignores =
guardpost/__init__.py:F401
tests/test_jwks.py:E501
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.0.1] - 2023-03-20 :sun_with_face:
- Improves the automatic rotation of `JWKS`: when validating `JWTs`, `JWKS` are
refreshed automatically if an unknown `kid` is encountered, and `JWKS` were
last fetched more than `refresh_time` seconds ago (by default 120 seconds).
- Corrects an inconsistency in how `claims` are read in the `User` class.

## [1.0.0] - 2023-01-07 :star:
- Adds built-in support for dependency injection, using the new `ContainerProtocol`
in `rodi` v2.
Expand Down
2 changes: 1 addition & 1 deletion guardpost/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.0"
__version__ = "1.0.1"
11 changes: 7 additions & 4 deletions guardpost/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def __init__(

@property
def sub(self) -> Optional[str]:
return self["sub"]
return self.get("sub")

def is_authenticated(self) -> bool:
return bool(self.authentication_mode)

def get(self, key: str):
return self.claims.get(key)

def __getitem__(self, item):
return self.claims[item]

Expand All @@ -44,15 +47,15 @@ def has_claim_value(self, name: str, value: str) -> bool:
class User(Identity):
@property
def id(self) -> Optional[str]:
return self["id"] or self.sub
return self.get("id") or self.sub

@property
def name(self) -> Optional[str]:
return self["name"]
return self.get("name")

@property
def email(self) -> Optional[str]:
return self["email"]
return self.get("email")


class AuthenticationHandler(ABC):
Expand Down
2 changes: 0 additions & 2 deletions guardpost/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def _get_message(forced_failure, failed_requirements):


class AuthorizationContext:

__slots__ = ("identity", "requirements", "_succeeded", "_failed_forced")

def __init__(self, identity: Identity, requirements: Sequence[Requirement]):
Expand Down Expand Up @@ -222,7 +221,6 @@ async def _handle_with_policy(self, policy: Policy, identity: Identity, scope: A
with AuthorizationContext(
identity, list(self._get_requirements(policy, scope))
) as context:

for requirement in context.requirements:
if _is_async_handler(type(requirement)): # type: ignore
await requirement.handle(context)
Expand Down
3 changes: 3 additions & 0 deletions guardpost/jwks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def from_dict(cls, value) -> "JWK":
class JWKS:
keys: List[JWK]

def update(self, new_set: "JWKS"):
self.keys = list({key.kid: key for key in self.keys + new_set.keys}.values())

@classmethod
def from_dict(cls, value) -> "JWKS":
if "keys" not in value:
Expand Down
41 changes: 39 additions & 2 deletions guardpost/jwks/caching.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import time
from typing import Optional

from . import JWKS, KeysProvider
from . import JWK, JWKS, KeysProvider


class CachingKeysProvider(KeysProvider):
"""
Kind of KeysProvider that can cache the result of another KeysProvider.
"""

def __init__(self, keys_provider: KeysProvider, cache_time: float) -> None:
def __init__(
self, keys_provider: KeysProvider, cache_time: float, refresh_time: float = 120
) -> None:
"""
Creates a new instance of CachingKeysProvider bound to a given KeysProvider,
and caching its result up to an optional amount of seconds described by
cache_time. Expiration is disabled if `cache_time` <= 0.
JWKS are refreshed anyway if an unknown `kid` is encountered and the set was
fetched more than `refresh_time` seconds ago.
"""
super().__init__()

Expand All @@ -22,6 +26,7 @@ def __init__(self, keys_provider: KeysProvider, cache_time: float) -> None:

self._keys: Optional[JWKS] = None
self._cache_time = cache_time
self._refresh_time = refresh_time
self._last_fetch_time: float = 0
self._keys_provider = keys_provider

Expand All @@ -34,6 +39,14 @@ async def _fetch_keys(self) -> JWKS:
self._last_fetch_time = time.time()
return self._keys

async def _refresh_keys(self) -> JWKS:
new_set = await self._fetch_keys()
if self._keys is None: # pragma: no cover
self._keys = new_set
else:
self._keys.update(new_set)
return self._keys

async def get_keys(self) -> JWKS:
if self._keys is not None:
if self._cache_time > 0 and (
Expand All @@ -43,3 +56,27 @@ async def get_keys(self) -> JWKS:
else:
return self._keys
return await self._fetch_keys()

async def get_key(self, kid: str) -> Optional[JWK]:
"""
Tries to get a JWK by kid. If the JWK is not found and the last time the keys
were fetched is older than `refresh_time` (default 120 seconds), it fetches
again the JWKS from the source.
"""
jwks = await self.get_keys()

for jwk in jwks.keys.copy():
if jwk.kid is not None and jwk.kid == kid:
return jwk

if (
self._refresh_time > 0
and time.time() - self._last_fetch_time >= self._refresh_time
):
jwks = await self._refresh_keys()

for jwk in jwks.keys.copy():
if jwk.kid is not None and jwk.kid == kid:
return jwk

return None
29 changes: 17 additions & 12 deletions guardpost/jwts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __init__(
require_kid: bool = True,
keys_provider: Optional[KeysProvider] = None,
keys_url: Optional[str] = None,
cache_time: float = 10800
cache_time: float = 10800,
refresh_time: float = 120,
) -> None:
"""
Creates a new instance of JWTValidator. This class only supports validating
Expand Down Expand Up @@ -72,9 +73,15 @@ def __init__(
keys_url : Optional[str], optional
If provided, keys are obtained from the given URL through HTTP GET.
This parameter is ignored if `keys_provider` is given.
cache_time : float, optional
If >= 0, JWKS are cached in memory and stored for the given amount in
seconds. By default 10800 (3 hours).
cache_time : float
JWKS are cached in memory and stored for the given amount in seconds.
By default 10800 (3 hours). Regardless of this parameter, JWKS are refreshed
automatically if an unknown kid is met and JWKS were last fetched more than
`refresh_time` earlier (in seconds).
refresh_time : float
JWKS are refreshed automatically if an unknown `kid` is encountered, and
JWKS were last fetched more than `refresh_time` seconds ago (by default
120 seconds)
"""
if keys_provider:
pass
Expand All @@ -89,26 +96,24 @@ def __init__(
"`authority`, or `keys_provider`."
)

if cache_time:
keys_provider = CachingKeysProvider(keys_provider, cache_time)
keys_provider = CachingKeysProvider(keys_provider, cache_time, refresh_time)

self._valid_issuers = list(valid_issuers)
self._valid_audiences = list(valid_audiences)
self._algorithms = list(algorithms)
self._keys_provider: KeysProvider = keys_provider
self._keys_provider = keys_provider
self.require_kid = require_kid
self.logger = get_logger()

async def get_jwks(self) -> JWKS:
return await self._keys_provider.get_keys()

async def get_jwk(self, kid: str) -> JWK:
jwks = await self.get_jwks()
key = await self._keys_provider.get_key(kid)

for jwk in jwks.keys:
if jwk.kid is not None and jwk.kid == kid:
return jwk
raise InvalidAccessToken("kid not recognized")
if key is None:
raise InvalidAccessToken("kid not recognized")
return key

def _validate_jwt_by_key(
self, access_token: str, jwk: JWK
Expand Down
1 change: 0 additions & 1 deletion tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ async def authenticate(self, context: Request):

@pytest.mark.asyncio
async def test_strategy_throws_for_missing_context():

strategy = AuthenticationStrategy()

with raises(ValueError):
Expand Down
1 change: 0 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def test_authorization_strategy_set_default_fluent():


def test_unauthorized_error_supports_error_and_description():

error = UnauthorizedError(
None,
[],
Expand Down
94 changes: 94 additions & 0 deletions tests/test_jwks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,97 @@ def test_jwk_raises_for_unsupported_type():

with pytest.raises(ValueError):
JWK.from_dict({"kty": "RSA"})


def test_jwks_update():
jwks_1 = JWKS.from_dict(
{
"keys": [
{
"kty": "RSA",
"kid": "0",
"n": "xzO7x0gEMbktuu5RLUqiABJNqt4kdm_5ucsKgSdHUdUcbkG28dLAikoFTki9awmyapSbO84zlKMaH24obOe44hd32sdeMOdQp0TxpxE95HfYVFuAWdfCM4Bz_x32Sq51e7x6oZd09vODFFbwTlMJ27LPAEuI5G6UVQKxhIB_wA2FOPkbHeDncB7jYv9kLidvpNgp5PC-aKHKv9ay6gi7M-wUQEpeQMjpyDFN2p_q12BWSUbsRwOjhYtCuSmmBNh07MizzVIQjpmZU5f6qmZHw--iJSBD52wsI87itYbBwRcDN5ffColkFpA8va0hDlShI2qVmwtQ3LUpZVivKuJOSw==",
"e": "AQAB",
},
{
"kty": "RSA",
"kid": "1",
"n": "3a-KHqLSxXba1e-qa2cWaV6VNd3LsNptZsbd1eZj402lehEbHm8ZdjHlZNwirPeqhvHYbCGRKfqLV2jE1UacfkCmcP8u7klENFbl01IyA8-MiVfmRB6BWlaBNS0NCDIGJ1GY7aPfEOJgGc5L4laIAD6iSVTfUwNtkLVAHXx5OQjJIVIxk6Vkji1n2JvpEO9337Kp96-AqfpIFWyCLg56uGJfK6XdlDYZvPm17xorcLGUB9MBsOID7PbdqeVnmaKW9aFNZj1OaDTZAsNqnxGkmsp3wkds8Th3raIbYvotQEGm1BCdEbqj3hu05bIEZuQbWuNTIseYCKFw7GJXawEKzQ==",
"e": "AQAB",
},
]
}
)

jwks_2 = JWKS.from_dict(
{
"keys": [
{
"kty": "RSA",
"kid": "2",
"n": "nk4LTnUzUBqmQTdMmNaHRU6FHHHXfW7TwOoVnCSu36PKyFovRGs5Qiec1VBmF4PZCXlkAwmpBPf4iBbWr3xXU4lE8d3OBuqnf-qFWbOCkyNFp_kyqHu7SlGHJhYilfRzKqDGJ5FqIafBpXID_FsxTqNi-mf98G_jm_QoF5ifMAPUf0eVTCjzs9fcawnKDbeaAED3SbYJt-EVjdcOJalilXJWPNdpGx8ouF1Zn77NDEbj6_1BBk22AZI1yQzDy8c08HlEK1NQgToJyQ-CLP6deHYiHrxMSZe83WbkCvxr1PLMFZlUTWh2AcgbiR9zJARu7nk6PWTbBhreuXRL5meGMQ==",
"e": "AQAB",
},
{
"kty": "RSA",
"kid": "3",
"n": "v_6KlxHChgEdhvV5t6cDi2h-u2y355dxkwIp1YM4YINXKNStSnFUTkRIPXAY9H15kn6CuWCSWXl7jRwCPm5UOBnC9TjKJTuTK_IVJrTqd1dFkxOEsesKKBPsc0nBjtYMc0c_74K0OxJphy6I4d0M6gXWVOx1avOMEU7LQHE18WtfSYXtBk_Q51foM8StqFARCKAdyRZRXwhtS71lPrHNLhU2aayKBKpWL-r-q4KZGwDLtw0z3bHR5Z_bIJVGushkYLN_DaJvkvypb1y7Lq6ozMovLA5xHgYhv6VCUGWOAJWo9PZXjtwjrO8gXME-msBmB7iO-ltV0FM3O9wTqsJJxw==",
"e": "AQAB",
},
]
}
)

jwks_1.update(jwks_2)

assert len(jwks_1.keys) == 4

assert [key.kid for key in jwks_1.keys] == "0 1 2 3".split()


def test_jwks_update_override():
jwks_1 = JWKS.from_dict(
{
"keys": [
{
"kty": "RSA",
"kid": "0",
"n": "xzO7x0gEMbktuu5RLUqiABJNqt4kdm_5ucsKgSdHUdUcbkG28dLAikoFTki9awmyapSbO84zlKMaH24obOe44hd32sdeMOdQp0TxpxE95HfYVFuAWdfCM4Bz_x32Sq51e7x6oZd09vODFFbwTlMJ27LPAEuI5G6UVQKxhIB_wA2FOPkbHeDncB7jYv9kLidvpNgp5PC-aKHKv9ay6gi7M-wUQEpeQMjpyDFN2p_q12BWSUbsRwOjhYtCuSmmBNh07MizzVIQjpmZU5f6qmZHw--iJSBD52wsI87itYbBwRcDN5ffColkFpA8va0hDlShI2qVmwtQ3LUpZVivKuJOSw==",
"e": "AQAB",
},
{
"kty": "RSA",
"kid": "1",
"n": "3a-KHqLSxXba1e-qa2cWaV6VNd3LsNptZsbd1eZj402lehEbHm8ZdjHlZNwirPeqhvHYbCGRKfqLV2jE1UacfkCmcP8u7klENFbl01IyA8-MiVfmRB6BWlaBNS0NCDIGJ1GY7aPfEOJgGc5L4laIAD6iSVTfUwNtkLVAHXx5OQjJIVIxk6Vkji1n2JvpEO9337Kp96-AqfpIFWyCLg56uGJfK6XdlDYZvPm17xorcLGUB9MBsOID7PbdqeVnmaKW9aFNZj1OaDTZAsNqnxGkmsp3wkds8Th3raIbYvotQEGm1BCdEbqj3hu05bIEZuQbWuNTIseYCKFw7GJXawEKzQ==",
"e": "AQAB",
},
]
}
)

jwks_2 = JWKS.from_dict(
{
"keys": [
{
"kty": "RSA",
"kid": "0",
"n": "nk4LTnUzUBqmQTdMmNaHRU6FHHHXfW7TwOoVnCSu36PKyFovRGs5Qiec1VBmF4PZCXlkAwmpBPf4iBbWr3xXU4lE8d3OBuqnf-qFWbOCkyNFp_kyqHu7SlGHJhYilfRzKqDGJ5FqIafBpXID_FsxTqNi-mf98G_jm_QoF5ifMAPUf0eVTCjzs9fcawnKDbeaAED3SbYJt-EVjdcOJalilXJWPNdpGx8ouF1Zn77NDEbj6_1BBk22AZI1yQzDy8c08HlEK1NQgToJyQ-CLP6deHYiHrxMSZe83WbkCvxr1PLMFZlUTWh2AcgbiR9zJARu7nk6PWTbBhreuXRL5meGMQ==",
"e": "AQAB",
},
{
"kty": "RSA",
"kid": "3",
"n": "v_6KlxHChgEdhvV5t6cDi2h-u2y355dxkwIp1YM4YINXKNStSnFUTkRIPXAY9H15kn6CuWCSWXl7jRwCPm5UOBnC9TjKJTuTK_IVJrTqd1dFkxOEsesKKBPsc0nBjtYMc0c_74K0OxJphy6I4d0M6gXWVOx1avOMEU7LQHE18WtfSYXtBk_Q51foM8StqFARCKAdyRZRXwhtS71lPrHNLhU2aayKBKpWL-r-q4KZGwDLtw0z3bHR5Z_bIJVGushkYLN_DaJvkvypb1y7Lq6ozMovLA5xHgYhv6VCUGWOAJWo9PZXjtwjrO8gXME-msBmB7iO-ltV0FM3O9wTqsJJxw==",
"e": "AQAB",
},
]
}
)

jwks_1.update(jwks_2)

assert len(jwks_1.keys) == 3

key_0 = next((key for key in jwks_1.keys if key.kid == "0"), None)
assert key_0 is not None
assert key_0.n == jwks_2.keys[0].n
Loading

0 comments on commit 069e2ee

Please sign in to comment.