-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from primait/PLATFORM-378/user-story/cache-refa…
…ctor [PLATFORM-378]: [Jwks Client] Cache refactor
- Loading branch information
Showing
8 changed files
with
252 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[package] | ||
name = "jwks_client_rs" | ||
version = "0.2.0" | ||
version = "0.3.0" | ||
edition = "2021" | ||
authors = ["Mite Ristovski <[email protected]>", "Simone Cottini <[email protected]>"] | ||
license = "MIT" | ||
|
@@ -11,20 +11,21 @@ readme = "README.md" | |
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
tokio = { version = "1", features = ["sync", "macros"] } | ||
anyhow = "1.0.44" | ||
async-trait = "0.1.51" | ||
jsonwebtoken = "8.0.0-beta.2" | ||
moka = { version = "0.6.0", features = ["future"] } | ||
chrono = "0.4" | ||
reqwest = { version = "0.11", features = ["json"] } | ||
serde = { version = "1.0", features = ["derive"]} | ||
serde_json = "1.0" | ||
thiserror = "1.0.29" | ||
url = "2.2" | ||
|
||
[dev-dependencies] | ||
tokio = { version = "1.11.0", features = ["macros"]} | ||
mockall = "0.10" | ||
httpmock = "0.6" | ||
rand = "0.8" | ||
|
||
[[example]] | ||
name = "get_jwks" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
use std::marker::PhantomData; | ||
use std::time::Duration; | ||
|
||
use crate::source::JwksSource; | ||
use crate::JwksClient; | ||
|
||
pub struct JwksClientBuilder<T> { | ||
ttl_opt: Option<Duration>, | ||
t: PhantomData<*const T>, | ||
// New PR to add this? | ||
// cache_size: Option<usize>, | ||
} | ||
|
||
impl<T: JwksSource + Send + Sync + 'static> JwksClientBuilder<T> { | ||
pub(crate) fn new() -> Self { | ||
Self { | ||
ttl_opt: None, | ||
t: PhantomData::default(), | ||
} | ||
} | ||
|
||
pub fn time_to_live(&self, ttl: Duration) -> Self { | ||
Self { | ||
ttl_opt: Some(ttl), | ||
t: PhantomData::default(), | ||
} | ||
} | ||
|
||
#[must_use] | ||
pub fn build(self, source: T) -> JwksClient<T> { | ||
JwksClient::new(source, self.ttl_opt) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
use std::future::Future; | ||
use std::sync::atomic::{AtomicBool, Ordering}; | ||
use std::sync::Arc; | ||
use std::time::Duration as StdDuration; | ||
|
||
use chrono::{Duration, Utc}; | ||
use tokio::sync::RwLock; | ||
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard}; | ||
|
||
use crate::error::Error; | ||
use crate::keyset::JsonWebKeySet; | ||
use crate::JsonWebKey; | ||
|
||
#[derive(Clone)] | ||
pub struct Cache { | ||
inner: Arc<RwLock<Entry>>, | ||
time_to_live: Duration, | ||
refreshed: Arc<AtomicBool>, | ||
} | ||
|
||
impl Cache { | ||
pub fn new(time_to_live: StdDuration) -> Self { | ||
let ttl: Duration = Duration::from_std(time_to_live) | ||
.expect("Failed to convert from `std::time::Duration` to `chrono::Duration`"); | ||
let json_web_key_set: JsonWebKeySet = JsonWebKeySet::empty(); | ||
|
||
Self { | ||
inner: Arc::new(RwLock::new(Entry::new(json_web_key_set, &ttl))), | ||
time_to_live: ttl, | ||
refreshed: Arc::new(AtomicBool::new(false)), | ||
} | ||
} | ||
|
||
pub async fn get_or_refresh<F>(&self, key: &str, future: F) -> Result<JsonWebKey, Error> | ||
where | ||
F: Future<Output = Result<JsonWebKeySet, Error>> + Send + 'static, | ||
{ | ||
let read: RwLockReadGuard<Entry> = self.inner.read().await; | ||
let is_entry_expired: bool = (*read).is_expired(); | ||
let get_key_result: Result<JsonWebKey, Error> = (*read).set.get_key(key).cloned(); | ||
// Drop RwLock read guard prematurely to be able to write in the lock | ||
drop(read); | ||
|
||
match get_key_result { | ||
// Key not found. Maybe a refresh is needed | ||
Err(_) => self.try_refresh(future).await.and_then(|v| v.take_key(key)), | ||
// Specified key exist but a refresh is needed | ||
Ok(json_web_key) if is_entry_expired => self | ||
.try_refresh(future) | ||
.await | ||
.and_then(|v| v.take_key(key)) | ||
.or(Ok(json_web_key)), | ||
// Specified key exist and is still valid. Return this one | ||
Ok(key) => Ok(key), | ||
} | ||
} | ||
|
||
async fn try_refresh<F>(&self, future: F) -> Result<JsonWebKeySet, Error> | ||
where | ||
F: Future<Output = Result<JsonWebKeySet, Error>> + Send + 'static, | ||
{ | ||
self.refreshed.store(false, Ordering::SeqCst); | ||
let mut guard: RwLockWriteGuard<Entry> = self.inner.write().await; | ||
|
||
if !self.refreshed.load(Ordering::SeqCst) { | ||
let set: JsonWebKeySet = future.await?; | ||
*guard = Entry::new(set.clone(), &self.time_to_live); | ||
self.refreshed.store(true, Ordering::SeqCst); | ||
Ok(set) | ||
} else { | ||
Ok((*guard).set.clone()) | ||
} | ||
// we drop the write guard here so "refresh=true" for the other threads/tasks | ||
} | ||
} | ||
|
||
struct Entry { | ||
set: JsonWebKeySet, | ||
expire_time_millis: i64, | ||
} | ||
|
||
impl Entry { | ||
fn new(set: JsonWebKeySet, expiration: &Duration) -> Self { | ||
Self { | ||
set, | ||
expire_time_millis: Utc::now().timestamp_millis() + expiration.num_milliseconds(), | ||
} | ||
} | ||
|
||
fn is_expired(&self) -> bool { | ||
Utc::now().timestamp_millis() > self.expire_time_millis | ||
} | ||
} |
Oops, something went wrong.