Skip to content

Commit

Permalink
Merge pull request #3 from primait/PLATFORM-371/user-story/improvements
Browse files Browse the repository at this point in the history
[PLATFORM-371]: [Jwks Client] Improvements
  • Loading branch information
cottinisimone authored Mar 14, 2022
2 parents d85eda6 + ad68be4 commit 8071a14
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 40 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "jwks_client_rs"
version = "0.1.1"
edition = "2018"
version = "0.2.0"
edition = "2021"
authors = ["Mite Ristovski <[email protected]>", "Simone Cottini <[email protected]>"]
license = "MIT"
description = "JWKS-sync client implementation for Auth0"
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM rust:1.55.0
FROM rust:1.59.0

WORKDIR /code

Expand Down
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ use jwks_client_rs::JwksClient;

// here you must join your `BASE_AUTH0_URL` env var with `.well-known/jwks.json` or whatever is the jwks url
let url: reqwest::Url = todo!();
let source: WebSource = WebSource::new(url); // You can define a different source too using `JwksSource` trait
let timeout: std::time::Duration = todo!();
// You can define a different source too using `JwksSource` trait
let source: WebSource = WebSource::builder()
.with_timeout(timeout)
.with_connect_timeout(timeout)
.build(url);
let client: JwksClient = JwksClient::new(source);

// Store your client in your application context or whatever
Expand Down
5 changes: 4 additions & 1 deletion examples/get_jwks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ async fn main() {
let url: Url = Url::from_str(url_string.as_str()).unwrap();
let url: Url = url.join(".well-known/jwks.json").unwrap();

let source: WebSource = WebSource::new(url);
let source: WebSource = WebSource::builder()
.build(url)
.expect("Failed to build WebSource");

let client: JwksClient<WebSource> = JwksClient::new(source);

// The kid "unknown" cannot be a JWKS valid KID. This must not be found here
Expand Down
48 changes: 25 additions & 23 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use jsonwebtoken::{Algorithm, DecodingKey, Validation};
use moka::future::{Cache, CacheBuilder};
use serde::de::DeserializeOwned;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;

use jsonwebtoken::{Algorithm, DecodingKey, Validation};
use moka::future::{Cache, CacheBuilder};
use serde::de::DeserializeOwned;

use crate::error::{Error, JwksClientError};
use crate::keyset::JsonWebKey;
use crate::source::JwksSource;

const DEFAULT_CACHE_SIZE: usize = 100;
//TODO: we can also use auth0 response cache-control headers instead of this const
const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(60);
//TODO: we can also use auth0 response cache-control headers instead of this const (1 day)
const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(86400);

pub struct JwksClient<T: JwksSource> {
source: Arc<T>,
Expand Down Expand Up @@ -102,22 +103,23 @@ impl<T: JwksSource + Send + Sync + 'static> JwksClient<T> {

#[cfg(test)]
mod test {
use crate::error::Error;
use crate::source::WebSource;
use crate::{JwksClient, JwksClientError};
use httpmock::prelude::*;
use jsonwebtoken::{Algorithm, EncodingKey, Header};
use serde_json::{json, Value};
use url::Url;

use crate::error::Error;
use crate::source::WebSource;
use crate::{JwksClient, JwksClientError};

const MODULUS: &str = "qjNzuylUQpyU9qX3_bMGpiRUO1G_xKbB0fyqQy0naETviHIqPS2D3lGcfK9XIFLZOq1O7K2KRXEE5nSDTf-S9qc0nPRkS38CXK4DBKPTBXtjufLK3e9lN9dh8Ehazx8xNmdCc6aocVKKlamOJv7Qr_UgmoFllq7W-UQ0YK2qfN8WgqxOQUPrss-40RWslCAKpjZmMOpIpRXQLGmR-GGZUdQZXnTUhnhRyDz5VcXHH--o1PkH_F0rlabMxgNFfsCIWKWbGy8G89bNrvoeVKq15QPCeaGBV13f2Do6XHGt0l2M3eYz85wyz1pISvjQuR4PrtJr6VsuHz3Puh_KgY8GqQ";
const EXPONENT: &str = "AQAB";

#[tokio::test]
async fn get_key() {
let server = MockServer::start();
let path = "/keys";
let kid = "go14h7EBWUvPRncjniI_2";
let path: &str = "/keys";
let kid: &str = "go14h7EBWUvPRncjniI_2";

let mock = server.mock(|when, then| {
when.method(GET).path(path);
Expand All @@ -127,9 +129,9 @@ mod test {
.json_body(jwks_endpoint_response(kid));
});

let url = Url::parse(&server.url(path)).unwrap();
let source = WebSource::new(url);
let client = JwksClient::new(source);
let url: Url = Url::parse(&server.url(path)).unwrap();
let source: WebSource = WebSource::builder().build(url).unwrap();
let client: JwksClient<WebSource> = JwksClient::new(source);

assert!(client.get(kid.to_string()).await.is_ok());
mock.assert();
Expand All @@ -138,18 +140,18 @@ mod test {
#[tokio::test]
async fn get_key_fails_to_fetch_keys() {
let server = MockServer::start();
let path = "/keys";
let kid = "go14h7EBWUvPRncjniI_2";
let path: &str = "/keys";
let kid: &str = "go14h7EBWUvPRncjniI_2";

let mock = server.mock(|when, then| {
when.method(GET).path(path);

then.status(400).body("Error");
});

let url = Url::parse(&server.url(path)).unwrap();
let source = WebSource::new(url);
let client = JwksClient::new(source);
let url: Url = Url::parse(&server.url(path)).unwrap();
let source: WebSource = WebSource::builder().build(url).unwrap();
let client: JwksClient<WebSource> = JwksClient::new(source);

let result = client.get(kid.to_string()).await;
assert!(result.is_err());
Expand All @@ -168,8 +170,8 @@ mod test {
#[tokio::test]
async fn get_key_key_not_found() {
let server = MockServer::start();
let path = "/keys";
let kid = "other_kid";
let path: &str = "/keys";
let kid: &str = "other_kid";

let mock = server.mock(|when, then| {
when.method(GET).path(path);
Expand All @@ -179,9 +181,9 @@ mod test {
.json_body(jwks_endpoint_response("go14h7EBWUvPRncjniI_2"));
});

let url = Url::parse(&server.url(path)).unwrap();
let source = WebSource::new(url);
let client = JwksClient::new(source);
let url: Url = Url::parse(&server.url(path)).unwrap();
let source: WebSource = WebSource::builder().build(url).unwrap();
let client: JwksClient<WebSource> = JwksClient::new(source);

let result = client.get(kid.to_string()).await;
assert!(result.is_err());
Expand Down
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pub use client::JwksClient;
pub use error::JwksClientError;
pub use keyset::JsonWebKey;

mod client;
mod error;
mod keyset;
pub mod source;

pub use client::JwksClient;
pub use error::JwksClientError;
pub use keyset::JsonWebKey;
61 changes: 53 additions & 8 deletions src/source.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use std::time::Duration;

use async_trait::async_trait;
use reqwest::Url;
use reqwest::{Request, Url};

use crate::error::Error;
use crate::keyset::JsonWebKeySet;

const CONNECT_TIMEOUT: Duration = Duration::from_secs(20);
const TIMEOUT: Duration = Duration::from_secs(10);

#[cfg_attr(test, mockall::automock)]
#[async_trait]
pub trait JwksSource {
Expand All @@ -16,19 +21,16 @@ pub struct WebSource {
}

impl WebSource {
pub fn new(url: Url) -> Self {
Self {
client: reqwest::Client::default(),
url,
}
pub fn builder() -> WebSourceBuilder {
WebSourceBuilder::new()
}
}

#[async_trait]
impl JwksSource for WebSource {
async fn fetch_keys(&self) -> Result<JsonWebKeySet, Error> {
let request = self.client.get(self.url.clone()).build()?;
let keys = self
let request: Request = self.client.get(self.url.clone()).build()?;
let keys: JsonWebKeySet = self
.client
.execute(request)
.await?
Expand All @@ -39,3 +41,46 @@ impl JwksSource for WebSource {
Ok(keys)
}
}

pub struct WebSourceBuilder {
client_builder: reqwest::ClientBuilder,
timeout_opt: Option<Duration>,
connect_timeout_opt: Option<Duration>,
}

impl WebSourceBuilder {
fn new() -> Self {
Self {
client_builder: reqwest::ClientBuilder::default(),
timeout_opt: None,
connect_timeout_opt: None,
}
}

pub fn with_timeout(self, timeout: Duration) -> Self {
Self {
timeout_opt: Some(timeout),
..self
}
}

pub fn with_connect_timeout(self, connect_timeout: Duration) -> Self {
Self {
connect_timeout_opt: Some(connect_timeout),
..self
}
}

pub fn build(self, url: Url) -> Result<WebSource, reqwest::Error> {
let timeout: Duration = self.timeout_opt.unwrap_or(TIMEOUT);
let connect_timeout: Duration = self.connect_timeout_opt.unwrap_or(CONNECT_TIMEOUT);
Ok(WebSource {
url,
client: self
.client_builder
.timeout(timeout)
.connect_timeout(connect_timeout)
.build()?,
})
}
}

0 comments on commit 8071a14

Please sign in to comment.