diff --git a/proto/user.proto b/proto/user.proto index 0ebb1cb30649b..014a8d0c1b0d3 100644 --- a/proto/user.proto +++ b/proto/user.proto @@ -18,7 +18,7 @@ message AuthInfo { } EncryptionType encryption_type = 1; bytes encrypted_value = 2; - map meta_data = 3; + map metadata = 3; } // User defines a user in the system. diff --git a/src/frontend/src/handler/alter_user.rs b/src/frontend/src/handler/alter_user.rs index 05dffaa57ef0d..431a217a20cf3 100644 --- a/src/frontend/src/handler/alter_user.rs +++ b/src/frontend/src/handler/alter_user.rs @@ -234,7 +234,7 @@ mod tests { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec(), - meta_data: HashMap::new(), + metadata: HashMap::new(), }) ); } diff --git a/src/frontend/src/handler/create_user.rs b/src/frontend/src/handler/create_user.rs index 6022693c5cc36..bfdc33e6db80f 100644 --- a/src/frontend/src/handler/create_user.rs +++ b/src/frontend/src/handler/create_user.rs @@ -171,7 +171,7 @@ mod tests { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec(), - meta_data: HashMap::new(), + metadata: HashMap::new(), }) ); frontend diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 3ffd8dc7a6f6a..30d1b02df7c03 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -977,7 +977,7 @@ impl SessionManager for SessionManagerImpl { salt, } } else if auth_info.encryption_type == EncryptionType::Oauth as i32 { - UserAuthenticator::OAuth(auth_info.meta_data.clone()) + UserAuthenticator::OAuth(auth_info.metadata.clone()) } else { return Err(Box::new(Error::new( ErrorKind::Unsupported, diff --git a/src/frontend/src/user/user_authentication.rs b/src/frontend/src/user/user_authentication.rs index 10dea11c4e13c..b0cefc1faedcb 100644 --- a/src/frontend/src/user/user_authentication.rs +++ b/src/frontend/src/user/user_authentication.rs @@ -35,18 +35,18 @@ pub const OAUTH_ISSUER_KEY: &str = "issuer"; /// Build `AuthInfo` for `OAuth`. #[inline(always)] pub fn build_oauth_info(options: &Vec) -> Option { - let meta_data: HashMap = WithOptions::try_from(options.as_slice()) + let metadata: HashMap = WithOptions::try_from(options.as_slice()) .ok()? .into_inner() .into_iter() .collect(); - if !meta_data.contains_key(OAUTH_JWKS_URL_KEY) || !meta_data.contains_key(OAUTH_ISSUER_KEY) { + if !metadata.contains_key(OAUTH_JWKS_URL_KEY) || !metadata.contains_key(OAUTH_ISSUER_KEY) { return None; } Some(AuthInfo { encryption_type: EncryptionType::Oauth as i32, encrypted_value: Vec::new(), - meta_data, + metadata, }) } @@ -79,13 +79,13 @@ pub fn encrypted_password(name: &str, password: &str) -> Option { Some(AuthInfo { encryption_type: EncryptionType::Sha256 as i32, encrypted_value: password.trim_start_matches(SHA256_ENCRYPTED_PREFIX).into(), - meta_data: HashMap::new(), + metadata: HashMap::new(), }) } else if valid_md5_password(password) { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: password.trim_start_matches(MD5_ENCRYPTED_PREFIX).into(), - meta_data: HashMap::new(), + metadata: HashMap::new(), }) } else { Some(encrypt_default(name, password)) @@ -98,7 +98,7 @@ fn encrypt_default(name: &str, password: &str) -> AuthInfo { AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(name, password), - meta_data: HashMap::new(), + metadata: HashMap::new(), } } @@ -186,18 +186,18 @@ mod tests { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(user_name, password), - meta_data: HashMap::new(), + metadata: HashMap::new(), }), None, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(user_name, password), - meta_data: HashMap::new(), + metadata: HashMap::new(), }), Some(AuthInfo { encryption_type: EncryptionType::Sha256 as i32, encrypted_value: sha256_hash(user_name, password), - meta_data: HashMap::new(), + metadata: HashMap::new(), }), ]; let output_passwords = input_passwords diff --git a/src/storage/src/hummock/sstable_store.rs b/src/storage/src/hummock/sstable_store.rs index c603b7d8f503a..f0cacf863fcc9 100644 --- a/src/storage/src/hummock/sstable_store.rs +++ b/src/storage/src/hummock/sstable_store.rs @@ -1020,9 +1020,9 @@ impl SstableWriter for StreamingUploadWriter { } async fn finish(mut self, meta: SstableMeta) -> HummockResult { - let meta_data = Bytes::from(meta.encode_to_bytes()); + let metadata = Bytes::from(meta.encode_to_bytes()); - self.object_uploader.write_bytes(meta_data).await?; + self.object_uploader.write_bytes(metadata).await?; let join_handle = tokio::spawn(async move { let uploader_memory_usage = self.object_uploader.get_memory_usage(); let _tracker = self.tracker.map(|mut t| { diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index 5f61d6d5ab6e9..7f6dd41368d45 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -16,11 +16,12 @@ use std::collections::HashMap; use std::future::Future; use std::io; use std::result::Result; +use std::str::FromStr; use std::sync::Arc; use std::time::Instant; use bytes::Bytes; -use jsonwebtoken::{decode, decode_header, DecodingKey, Validation}; +use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation}; use parking_lot::Mutex; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; @@ -161,34 +162,34 @@ pub enum UserAuthenticator { OAuth(HashMap), } +/// A JWK Set is a JSON object that represents a set of JWKs. +/// The JSON object MUST have a "keys" member, with its value being an array of JWKs. +/// See for more details. #[derive(Debug, Deserialize)] struct Jwks { keys: Vec, } -#[allow(dead_code)] +/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key. +/// See for more details. #[derive(Debug, Deserialize)] struct Jwk { - kid: String, - alg: String, - n: String, - e: String, -} - -async fn fetch_jwks(url: &str) -> Result { - let resp: Jwks = reqwest::get(url).await?.json().await?; - Ok(resp) + kid: String, // Key ID + alg: String, // Algorithm + n: String, // Modulus + e: String, // Exponent } async fn validate_jwt( jwt: &str, jwks_url: &str, issuer: &str, - meta_data: &HashMap, + metadata: &HashMap, ) -> Result { let header = decode_header(jwt)?; - let jwks = fetch_jwks(jwks_url).await?; + let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?; + // 1. Retrieve the kid from the header to find the right JWK in the JWK Set. let kid = header.kid.ok_or("kid not found in jwt header")?; let jwk = jwks .keys @@ -196,15 +197,25 @@ async fn validate_jwt( .find(|k| k.kid == kid) .ok_or("kid not found in jwks")?; + // 2. Check if the algorithms are matched. + if Algorithm::from_str(&jwk.alg)? != header.alg { + return Err("alg in jwt header does not match with alg in jwk".into()); + } + + // 3. Decode the JWT and validate the claims. let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?; let mut validation = Validation::new(header.alg); validation.set_issuer(&[issuer]); validation.set_required_spec_claims(&["exp", "iss"]); let token_data = decode::>(jwt, &decoding_key, &validation)?; - Ok(meta_data.iter().all( + // 4. Check if the metadata in the token matches. + if !metadata.iter().all( |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v), - )) + ) { + return Err("metadata in jwt does not match with metadata declared with user".into()); + } + Ok(true) } impl UserAuthenticator { @@ -215,15 +226,15 @@ impl UserAuthenticator { UserAuthenticator::Md5WithSalt { encrypted_password, .. } => encrypted_password == password, - UserAuthenticator::OAuth(meta_data) => { - let mut meta_data = meta_data.clone(); - let jwks_url = meta_data.remove("jwks_url").unwrap(); - let issuer = meta_data.remove("issuer").unwrap(); + UserAuthenticator::OAuth(metadata) => { + let mut metadata = metadata.clone(); + let jwks_url = metadata.remove("jwks_url").unwrap(); + let issuer = metadata.remove("issuer").unwrap(); validate_jwt( &String::from_utf8_lossy(password), &jwks_url, &issuer, - &meta_data, + &metadata, ) .await .map_err(PsqlError::StartupError)?