diff --git a/e2e_test/batch/catalog/pg_settings.slt.part b/e2e_test/batch/catalog/pg_settings.slt.part index eeec9713382d2..9ff41b8dbea45 100644 --- a/e2e_test/batch/catalog/pg_settings.slt.part +++ b/e2e_test/batch/catalog/pg_settings.slt.part @@ -13,7 +13,6 @@ postmaster barrier_interval_ms postmaster checkpoint_frequency postmaster enable_tracing postmaster max_concurrent_creating_streaming_jobs -postmaster oauth_jwks_url postmaster pause_on_next_bootstrap user application_name user background_ddl diff --git a/proto/meta.proto b/proto/meta.proto index 4cb08f872f6a2..1db290af7b308 100644 --- a/proto/meta.proto +++ b/proto/meta.proto @@ -557,7 +557,6 @@ message SystemParams { optional bool pause_on_next_bootstrap = 13; optional string wasm_storage_url = 14 [deprecated = true]; optional bool enable_tracing = 15; - optional string oauth_jwks_url = 16; } message GetSystemParamsRequest {} diff --git a/proto/user.proto b/proto/user.proto index dd04dd558a6a3..0ebb1cb30649b 100644 --- a/proto/user.proto +++ b/proto/user.proto @@ -18,6 +18,7 @@ message AuthInfo { } EncryptionType encryption_type = 1; bytes encrypted_value = 2; + map meta_data = 3; } // User defines a user in the system. diff --git a/src/common/src/system_param/mod.rs b/src/common/src/system_param/mod.rs index 998450a7f79dd..19c36baf09c68 100644 --- a/src/common/src/system_param/mod.rs +++ b/src/common/src/system_param/mod.rs @@ -87,7 +87,6 @@ macro_rules! for_all_params { { max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", }, { pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", }, { enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", }, - { oauth_jwks_url, String, Some("".to_string()), true, "Url to get JSON Web Key Set(JWKS) for oauth authentication.", }, } }; } @@ -376,7 +375,6 @@ macro_rules! impl_system_params_for_test { ret.state_store = Some("hummock+memory".to_string()); ret.backup_storage_url = Some("memory".into()); ret.backup_storage_directory = Some("backup".into()); - ret.oauth_jwks_url = Some("https://auth-static.confluent.io/jwks".into()); ret } }; @@ -442,7 +440,6 @@ mod tests { (MAX_CONCURRENT_CREATING_STREAMING_JOBS_KEY, "1"), (PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"), (ENABLE_TRACING_KEY, "true"), - (OAUTH_JWKS_URL_KEY, "a"), ("a_deprecated_param", "foo"), ]; diff --git a/src/common/src/system_param/reader.rs b/src/common/src/system_param/reader.rs index 7b0e0d4667e08..3374e72120238 100644 --- a/src/common/src/system_param/reader.rs +++ b/src/common/src/system_param/reader.rs @@ -160,11 +160,4 @@ where .enable_tracing .unwrap_or_else(default::enable_tracing) } - - fn oauth_jwks_url(&self) -> &str { - self.inner() - .oauth_jwks_url - .as_ref() - .unwrap_or(&default::OAUTH_JWKS_URL) - } } diff --git a/src/config/docs.md b/src/config/docs.md index 07736636b9e37..63e8f7ce1278d 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -146,7 +146,6 @@ This page is automatically generated by `./risedev generate-example-config` | data_directory | Remote directory for storing data and metadata objects. | | | enable_tracing | Whether to enable distributed tracing. | false | | max_concurrent_creating_streaming_jobs | Max number of concurrent creating streaming jobs. | 1 | -| oauth_jwks_url | Url to get JSON Web Key Set(JWKS) for oauth authentication. | "" | | parallel_compact_size_mb | | 512 | | pause_on_next_bootstrap | Whether to pause all data sources on next bootstrap. | false | | sstable_size_mb | Target size of the Sstable. | 256 | diff --git a/src/config/example.toml b/src/config/example.toml index 9d83b7d5502f5..a1d5fadb52c53 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -196,4 +196,3 @@ bloom_false_positive = 0.001 max_concurrent_creating_streaming_jobs = 1 pause_on_next_bootstrap = false enable_tracing = false -oauth_jwks_url = "" diff --git a/src/frontend/src/handler/alter_user.rs b/src/frontend/src/handler/alter_user.rs index 062b9c401b5a1..05dffaa57ef0d 100644 --- a/src/frontend/src/handler/alter_user.rs +++ b/src/frontend/src/handler/alter_user.rs @@ -20,10 +20,12 @@ use risingwave_sqlparser::ast::{AlterUserStatement, ObjectName, UserOption, User use super::RwPgResponse; use crate::binder::Binder; use crate::catalog::CatalogError; -use crate::error::ErrorCode::{InternalError, PermissionDenied}; +use crate::error::ErrorCode::{self, InternalError, PermissionDenied}; use crate::error::Result; use crate::handler::HandlerArgs; -use crate::user::user_authentication::{build_oauth_info, encrypted_password}; +use crate::user::user_authentication::{ + build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY, +}; use crate::user::user_catalog::UserCatalog; fn alter_prost_user_info( @@ -111,8 +113,14 @@ fn alter_prost_user_info( } update_fields.push(UpdateField::AuthInfo); } - UserOption::OAuth => { - user_info.auth_info = build_oauth_info(); + UserOption::OAuth(options) => { + let auth_info = build_oauth_info(options).ok_or_else(|| { + ErrorCode::InvalidParameterValue(format!( + "{} and {} must be provided", + OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY + )) + })?; + user_info.auth_info = Some(auth_info); update_fields.push(UpdateField::AuthInfo) } } @@ -185,6 +193,8 @@ pub async fn handle_alter_user( #[cfg(test)] mod tests { + use std::collections::HashMap; + use risingwave_pb::user::auth_info::EncryptionType; use risingwave_pb::user::AuthInfo; @@ -223,7 +233,8 @@ mod tests { user_info.auth_info, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, - encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec() + encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec(), + meta_data: HashMap::new(), }) ); } diff --git a/src/frontend/src/handler/create_user.rs b/src/frontend/src/handler/create_user.rs index 429a35754d1aa..6022693c5cc36 100644 --- a/src/frontend/src/handler/create_user.rs +++ b/src/frontend/src/handler/create_user.rs @@ -20,10 +20,12 @@ use risingwave_sqlparser::ast::{CreateUserStatement, UserOption, UserOptions}; use super::RwPgResponse; use crate::binder::Binder; use crate::catalog::{CatalogError, DatabaseId}; -use crate::error::ErrorCode::PermissionDenied; +use crate::error::ErrorCode::{self, PermissionDenied}; use crate::error::Result; use crate::handler::HandlerArgs; -use crate::user::user_authentication::{build_oauth_info, encrypted_password}; +use crate::user::user_authentication::{ + build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY, +}; use crate::user::user_catalog::UserCatalog; fn make_prost_user_info( @@ -91,7 +93,15 @@ fn make_prost_user_info( user_info.auth_info = encrypted_password(&user_info.name, &password.0); } } - UserOption::OAuth => user_info.auth_info = build_oauth_info(), + UserOption::OAuth(options) => { + let auth_info = build_oauth_info(options).ok_or_else(|| { + ErrorCode::InvalidParameterValue(format!( + "{} and {} must be provided", + OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY + )) + })?; + user_info.auth_info = Some(auth_info); + } } } @@ -131,6 +141,8 @@ pub async fn handle_create_user( #[cfg(test)] mod tests { + use std::collections::HashMap; + use risingwave_common::catalog::DEFAULT_DATABASE_NAME; use risingwave_pb::user::auth_info::EncryptionType; use risingwave_pb::user::AuthInfo; @@ -158,7 +170,8 @@ mod tests { user_info.auth_info, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, - encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec() + encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec(), + meta_data: HashMap::new(), }) ); frontend diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 17e4f7ef09e0c..3ffd8dc7a6f6a 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -46,7 +46,6 @@ use risingwave_common::session_config::{ConfigMap, ConfigReporter, VisibilityMod use risingwave_common::system_param::local_manager::{ LocalSystemParamsManager, LocalSystemParamsManagerRef, }; -use risingwave_common::system_param::reader::SystemParamsRead; use risingwave_common::telemetry::manager::TelemetryManager; use risingwave_common::telemetry::telemetry_env_enabled; use risingwave_common::types::DataType; @@ -978,20 +977,7 @@ impl SessionManager for SessionManagerImpl { salt, } } else if auth_info.encryption_type == EncryptionType::Oauth as i32 { - let oauth_jwks_url = self - .env - .system_params_manager - .get_params() - .load() - .oauth_jwks_url() - .to_string(); - if oauth_jwks_url.is_empty() { - return Err(Box::new(Error::new( - ErrorKind::PermissionDenied, - "OAuth JWKS URL is not set", - ))); - } - UserAuthenticator::OAuth(oauth_jwks_url) + UserAuthenticator::OAuth(auth_info.meta_data.clone()) } else { return Err(Box::new(Error::new( ErrorKind::Unsupported, diff --git a/src/frontend/src/user/user_authentication.rs b/src/frontend/src/user/user_authentication.rs index d558fb03ee3b6..c1f3e570878c5 100644 --- a/src/frontend/src/user/user_authentication.rs +++ b/src/frontend/src/user/user_authentication.rs @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use risingwave_pb::user::auth_info::EncryptionType; use risingwave_pb::user::AuthInfo; +use risingwave_sqlparser::ast::SqlOption; use sha2::{Digest, Sha256}; // SHA-256 is not supported in PostgreSQL protocol. We need to implement SCRAM-SHA-256 instead @@ -24,12 +27,23 @@ const MD5_ENCRYPTED_PREFIX: &str = "md5"; const VALID_SHA256_ENCRYPTED_LEN: usize = SHA256_ENCRYPTED_PREFIX.len() + 64; const VALID_MD5_ENCRYPTED_LEN: usize = MD5_ENCRYPTED_PREFIX.len() + 32; +pub const OAUTH_JWKS_URL_KEY: &str = "jwks_url"; +pub const OAUTH_ISSUER_KEY: &str = "issuer"; + /// Build `AuthInfo` for `OAuth`. #[inline(always)] -pub fn build_oauth_info() -> Option { +pub fn build_oauth_info(options: &Vec) -> Option { + let meta_data: HashMap = options + .iter() + .map(|opt| (opt.name.real_value(), opt.value.to_string())) + .collect(); + if !meta_data.contains_key(OAUTH_JWKS_URL_KEY) || !meta_data.contains_key(OAUTH_ISSUER_KEY) { + return None; + } Some(AuthInfo { encryption_type: EncryptionType::Oauth as i32, encrypted_value: Vec::new(), + meta_data, }) } @@ -62,11 +76,13 @@ pub fn encrypted_password(name: &str, password: &str) -> Option { Some(AuthInfo { encryption_type: EncryptionType::Sha256 as i32, encrypted_value: password.trim_start_matches(SHA256_ENCRYPTED_PREFIX).into(), + meta_data: HashMap::new(), }) } else if valid_md5_password(password) { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: password.trim_start_matches(MD5_ENCRYPTED_PREFIX).into(), + meta_data: HashMap::new(), }) } else { Some(encrypt_default(name, password)) @@ -79,6 +95,7 @@ fn encrypt_default(name: &str, password: &str) -> AuthInfo { AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(name, password), + meta_data: HashMap::new(), } } @@ -166,15 +183,18 @@ mod tests { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(user_name, password), + meta_data: HashMap::new(), }), None, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(user_name, password), + meta_data: HashMap::new(), }), Some(AuthInfo { encryption_type: EncryptionType::Sha256 as i32, encrypted_value: sha256_hash(user_name, password), + meta_data: HashMap::new(), }), ]; let output_passwords = input_passwords diff --git a/src/sqlparser/src/ast/statement.rs b/src/sqlparser/src/ast/statement.rs index d5def93c81fef..6edc0702a425d 100644 --- a/src/sqlparser/src/ast/statement.rs +++ b/src/sqlparser/src/ast/statement.rs @@ -741,7 +741,7 @@ pub enum UserOption { NoLogin, EncryptedPassword(AstString), Password(Option), - OAuth, + OAuth(Vec), } impl fmt::Display for UserOption { @@ -758,7 +758,9 @@ impl fmt::Display for UserOption { UserOption::EncryptedPassword(p) => write!(f, "ENCRYPTED PASSWORD {}", p), UserOption::Password(None) => write!(f, "PASSWORD NULL"), UserOption::Password(Some(p)) => write!(f, "PASSWORD {}", p), - UserOption::OAuth => write!(f, "OAUTH"), + UserOption::OAuth(options) => { + write!(f, "({})", display_comma_separated(options.as_slice())) + } } } } @@ -846,7 +848,10 @@ impl ParseTo for UserOptions { UserOption::EncryptedPassword(AstString::parse_to(parser)?), ) } - Keyword::OAUTH => (&mut builder.password, UserOption::OAuth), + Keyword::OAUTH => { + let options = parser.parse_options()?; + (&mut builder.password, UserOption::OAuth(options)) + } _ => { parser.expected( "SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN \ diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index eb9c31442b85c..5fef18a61bff9 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -159,7 +159,7 @@ pub enum UserAuthenticator { encrypted_password: Vec, salt: [u8; 4], }, - OAuth(String), + OAuth(HashMap), } #[derive(Debug, Deserialize)] @@ -181,7 +181,11 @@ async fn fetch_jwks(url: &str) -> Result { Ok(resp) } -async fn validate_jwt(jwt: &str, jwks_url: &str) -> Result { +async fn validate_jwt( + jwt: &str, + jwks_url: &str, + meta_data: &HashMap, +) -> Result { let header = decode_header(jwt)?; let jwks = fetch_jwks(jwks_url).await?; @@ -194,8 +198,11 @@ async fn validate_jwt(jwt: &str, jwks_url: &str) -> Result { let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?; let validation = Validation::new(Algorithm::from_str(&jwk.alg)?); + let token_data = decode::>(jwt, &decoding_key, &validation)?; - Ok(decode::>(jwt, &decoding_key, &validation).is_ok()) + Ok(meta_data + .iter() + .all(|(k, v)| token_data.claims.get(k) == Some(v))) } impl UserAuthenticator { @@ -206,8 +213,10 @@ impl UserAuthenticator { UserAuthenticator::Md5WithSalt { encrypted_password, .. } => encrypted_password == password, - UserAuthenticator::OAuth(oauth_jwks_url) => { - validate_jwt(&String::from_utf8_lossy(password), oauth_jwks_url) + UserAuthenticator::OAuth(meta_data) => { + let mut meta_data = meta_data.clone(); + let jwks_url = meta_data.remove("jwks_url").unwrap(); + validate_jwt(&String::from_utf8_lossy(password), &jwks_url, &meta_data) .await .map_err(PsqlError::StartupError)? }