Skip to content

Commit

Permalink
Merge pull request #4 from primait/PLATFORM-378/user-story/cache-refa…
Browse files Browse the repository at this point in the history
…ctor

[PLATFORM-378]: [Jwks Client]  Cache refactor
  • Loading branch information
cottinisimone authored Apr 1, 2022
2 parents 8071a14 + 5f333ad commit bd13777
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 27 deletions.
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "jwks_client_rs"
version = "0.2.0"
version = "0.3.0"
edition = "2021"
authors = ["Mite Ristovski <[email protected]>", "Simone Cottini <[email protected]>"]
license = "MIT"
Expand All @@ -11,20 +11,21 @@ readme = "README.md"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tokio = { version = "1", features = ["sync", "macros"] }
anyhow = "1.0.44"
async-trait = "0.1.51"
jsonwebtoken = "8.0.0-beta.2"
moka = { version = "0.6.0", features = ["future"] }
chrono = "0.4"
reqwest = { version = "0.11", features = ["json"] }
serde = { version = "1.0", features = ["derive"]}
serde_json = "1.0"
thiserror = "1.0.29"
url = "2.2"

[dev-dependencies]
tokio = { version = "1.11.0", features = ["macros"]}
mockall = "0.10"
httpmock = "0.6"
rand = "0.8"

[[example]]
name = "get_jwks"
Expand Down
2 changes: 1 addition & 1 deletion Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ args = [

[tasks.test]
command = "cargo"
args = ["test"]
args = ["test", "${@}"]

[tasks.clippy]
command = "cargo"
Expand Down
7 changes: 6 additions & 1 deletion examples/get_jwks.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::str::FromStr;
use std::time::Duration;

use reqwest::Url;

Expand All @@ -21,7 +22,11 @@ async fn main() {
.build(url)
.expect("Failed to build WebSource");

let client: JwksClient<WebSource> = JwksClient::new(source);
let time_to_live: Duration = Duration::from_secs(60);

let client: JwksClient<WebSource> = JwksClient::builder()
.time_to_live(time_to_live)
.build(source);

// The kid "unknown" cannot be a JWKS valid KID. This must not be found here
let result: Result<JsonWebKey, JwksClientError> = client.get("unknown".to_string()).await;
Expand Down
33 changes: 33 additions & 0 deletions src/builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use std::marker::PhantomData;
use std::time::Duration;

use crate::source::JwksSource;
use crate::JwksClient;

pub struct JwksClientBuilder<T> {
ttl_opt: Option<Duration>,
t: PhantomData<*const T>,
// New PR to add this?
// cache_size: Option<usize>,
}

impl<T: JwksSource + Send + Sync + 'static> JwksClientBuilder<T> {
pub(crate) fn new() -> Self {
Self {
ttl_opt: None,
t: PhantomData::default(),
}
}

pub fn time_to_live(&self, ttl: Duration) -> Self {
Self {
ttl_opt: Some(ttl),
t: PhantomData::default(),
}
}

#[must_use]
pub fn build(self, source: T) -> JwksClient<T> {
JwksClient::new(source, self.ttl_opt)
}
}
93 changes: 93 additions & 0 deletions src/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration as StdDuration;

use chrono::{Duration, Utc};
use tokio::sync::RwLock;
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};

use crate::error::Error;
use crate::keyset::JsonWebKeySet;
use crate::JsonWebKey;

#[derive(Clone)]
pub struct Cache {
inner: Arc<RwLock<Entry>>,
time_to_live: Duration,
refreshed: Arc<AtomicBool>,
}

impl Cache {
pub fn new(time_to_live: StdDuration) -> Self {
let ttl: Duration = Duration::from_std(time_to_live)
.expect("Failed to convert from `std::time::Duration` to `chrono::Duration`");
let json_web_key_set: JsonWebKeySet = JsonWebKeySet::empty();

Self {
inner: Arc::new(RwLock::new(Entry::new(json_web_key_set, &ttl))),
time_to_live: ttl,
refreshed: Arc::new(AtomicBool::new(false)),
}
}

pub async fn get_or_refresh<F>(&self, key: &str, future: F) -> Result<JsonWebKey, Error>
where
F: Future<Output = Result<JsonWebKeySet, Error>> + Send + 'static,
{
let read: RwLockReadGuard<Entry> = self.inner.read().await;
let is_entry_expired: bool = (*read).is_expired();
let get_key_result: Result<JsonWebKey, Error> = (*read).set.get_key(key).cloned();
// Drop RwLock read guard prematurely to be able to write in the lock
drop(read);

match get_key_result {
// Key not found. Maybe a refresh is needed
Err(_) => self.try_refresh(future).await.and_then(|v| v.take_key(key)),
// Specified key exist but a refresh is needed
Ok(json_web_key) if is_entry_expired => self
.try_refresh(future)
.await
.and_then(|v| v.take_key(key))
.or(Ok(json_web_key)),
// Specified key exist and is still valid. Return this one
Ok(key) => Ok(key),
}
}

async fn try_refresh<F>(&self, future: F) -> Result<JsonWebKeySet, Error>
where
F: Future<Output = Result<JsonWebKeySet, Error>> + Send + 'static,
{
self.refreshed.store(false, Ordering::SeqCst);
let mut guard: RwLockWriteGuard<Entry> = self.inner.write().await;

if !self.refreshed.load(Ordering::SeqCst) {
let set: JsonWebKeySet = future.await?;
*guard = Entry::new(set.clone(), &self.time_to_live);
self.refreshed.store(true, Ordering::SeqCst);
Ok(set)
} else {
Ok((*guard).set.clone())
}
// we drop the write guard here so "refresh=true" for the other threads/tasks
}
}

struct Entry {
set: JsonWebKeySet,
expire_time_millis: i64,
}

impl Entry {
fn new(set: JsonWebKeySet, expiration: &Duration) -> Self {
Self {
set,
expire_time_millis: Utc::now().timestamp_millis() + expiration.num_milliseconds(),
}
}

fn is_expired(&self) -> bool {
Utc::now().timestamp_millis() > self.expire_time_millis
}
}
Loading

0 comments on commit bd13777

Please sign in to comment.