From 9966b81619001ed3b9ef369f6036871edc355379 Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:18:54 +0800 Subject: [PATCH] feat(frontend): implement OAuth authentication (#13151) Co-authored-by: August --- Cargo.lock | 4 + proto/user.proto | 2 + src/frontend/src/handler/alter_user.rs | 21 ++++- src/frontend/src/handler/create_user.rs | 20 ++++- src/frontend/src/session.rs | 2 + src/frontend/src/user/user_authentication.rs | 33 ++++++++ src/sqlparser/src/ast/statement.rs | 12 ++- src/sqlparser/src/keywords.rs | 1 + src/sqlparser/src/parser.rs | 2 +- src/storage/src/hummock/sstable_store.rs | 4 +- src/utils/pgwire/Cargo.toml | 4 + src/utils/pgwire/src/pg_protocol.rs | 10 +-- src/utils/pgwire/src/pg_server.rs | 84 +++++++++++++++++++- 13 files changed, 179 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2f780d819e49..1e34470c2300 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7324,12 +7324,16 @@ dependencies = [ "bytes", "futures", "itertools 0.12.0", + "jsonwebtoken 9.2.0", "madsim-tokio", "openssl", "panic-message", "parking_lot 0.12.1", + "reqwest", "risingwave_common", "risingwave_sqlparser", + "serde", + "serde_json", "tempfile", "thiserror", "thiserror-ext", diff --git a/proto/user.proto b/proto/user.proto index b132df55dcc1..383e78cb57b2 100644 --- a/proto/user.proto +++ b/proto/user.proto @@ -14,9 +14,11 @@ message AuthInfo { PLAINTEXT = 1; SHA256 = 2; MD5 = 3; + OAUTH = 4; } EncryptionType encryption_type = 1; bytes encrypted_value = 2; + map metadata = 3; } // User defines a user in the system. diff --git a/src/frontend/src/handler/alter_user.rs b/src/frontend/src/handler/alter_user.rs index 810b71c6dcec..431a217a20cf 100644 --- a/src/frontend/src/handler/alter_user.rs +++ b/src/frontend/src/handler/alter_user.rs @@ -20,10 +20,12 @@ use risingwave_sqlparser::ast::{AlterUserStatement, ObjectName, UserOption, User use super::RwPgResponse; use crate::binder::Binder; use crate::catalog::CatalogError; -use crate::error::ErrorCode::{InternalError, PermissionDenied}; +use crate::error::ErrorCode::{self, InternalError, PermissionDenied}; use crate::error::Result; use crate::handler::HandlerArgs; -use crate::user::user_authentication::encrypted_password; +use crate::user::user_authentication::{ + build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY, +}; use crate::user::user_catalog::UserCatalog; fn alter_prost_user_info( @@ -111,6 +113,16 @@ fn alter_prost_user_info( } update_fields.push(UpdateField::AuthInfo); } + UserOption::OAuth(options) => { + let auth_info = build_oauth_info(options).ok_or_else(|| { + ErrorCode::InvalidParameterValue(format!( + "{} and {} must be provided", + OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY + )) + })?; + user_info.auth_info = Some(auth_info); + update_fields.push(UpdateField::AuthInfo) + } } } Ok((user_info, update_fields)) @@ -181,6 +193,8 @@ pub async fn handle_alter_user( #[cfg(test)] mod tests { + use std::collections::HashMap; + use risingwave_pb::user::auth_info::EncryptionType; use risingwave_pb::user::AuthInfo; @@ -219,7 +233,8 @@ mod tests { user_info.auth_info, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, - encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec() + encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec(), + metadata: HashMap::new(), }) ); } diff --git a/src/frontend/src/handler/create_user.rs b/src/frontend/src/handler/create_user.rs index 434927a21827..bfdc33e6db80 100644 --- a/src/frontend/src/handler/create_user.rs +++ b/src/frontend/src/handler/create_user.rs @@ -20,10 +20,12 @@ use risingwave_sqlparser::ast::{CreateUserStatement, UserOption, UserOptions}; use super::RwPgResponse; use crate::binder::Binder; use crate::catalog::{CatalogError, DatabaseId}; -use crate::error::ErrorCode::PermissionDenied; +use crate::error::ErrorCode::{self, PermissionDenied}; use crate::error::Result; use crate::handler::HandlerArgs; -use crate::user::user_authentication::encrypted_password; +use crate::user::user_authentication::{ + build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY, +}; use crate::user::user_catalog::UserCatalog; fn make_prost_user_info( @@ -91,6 +93,15 @@ fn make_prost_user_info( user_info.auth_info = encrypted_password(&user_info.name, &password.0); } } + UserOption::OAuth(options) => { + let auth_info = build_oauth_info(options).ok_or_else(|| { + ErrorCode::InvalidParameterValue(format!( + "{} and {} must be provided", + OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY + )) + })?; + user_info.auth_info = Some(auth_info); + } } } @@ -130,6 +141,8 @@ pub async fn handle_create_user( #[cfg(test)] mod tests { + use std::collections::HashMap; + use risingwave_common::catalog::DEFAULT_DATABASE_NAME; use risingwave_pb::user::auth_info::EncryptionType; use risingwave_pb::user::AuthInfo; @@ -157,7 +170,8 @@ mod tests { user_info.auth_info, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, - encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec() + encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec(), + metadata: HashMap::new(), }) ); frontend diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 67a5da01e121..30d1b02df7c0 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -976,6 +976,8 @@ impl SessionManager for SessionManagerImpl { ), salt, } + } else if auth_info.encryption_type == EncryptionType::Oauth as i32 { + UserAuthenticator::OAuth(auth_info.metadata.clone()) } else { return Err(Box::new(Error::new( ErrorKind::Unsupported, diff --git a/src/frontend/src/user/user_authentication.rs b/src/frontend/src/user/user_authentication.rs index d033e4c79811..b0cefc1faedc 100644 --- a/src/frontend/src/user/user_authentication.rs +++ b/src/frontend/src/user/user_authentication.rs @@ -12,10 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use risingwave_pb::user::auth_info::EncryptionType; use risingwave_pb::user::AuthInfo; +use risingwave_sqlparser::ast::SqlOption; use sha2::{Digest, Sha256}; +use crate::WithOptions; + // SHA-256 is not supported in PostgreSQL protocol. We need to implement SCRAM-SHA-256 instead // if necessary. const SHA256_ENCRYPTED_PREFIX: &str = "SHA-256:"; @@ -24,6 +29,27 @@ const MD5_ENCRYPTED_PREFIX: &str = "md5"; const VALID_SHA256_ENCRYPTED_LEN: usize = SHA256_ENCRYPTED_PREFIX.len() + 64; const VALID_MD5_ENCRYPTED_LEN: usize = MD5_ENCRYPTED_PREFIX.len() + 32; +pub const OAUTH_JWKS_URL_KEY: &str = "jwks_url"; +pub const OAUTH_ISSUER_KEY: &str = "issuer"; + +/// Build `AuthInfo` for `OAuth`. +#[inline(always)] +pub fn build_oauth_info(options: &Vec) -> Option { + let metadata: HashMap = WithOptions::try_from(options.as_slice()) + .ok()? + .into_inner() + .into_iter() + .collect(); + if !metadata.contains_key(OAUTH_JWKS_URL_KEY) || !metadata.contains_key(OAUTH_ISSUER_KEY) { + return None; + } + Some(AuthInfo { + encryption_type: EncryptionType::Oauth as i32, + encrypted_value: Vec::new(), + metadata, + }) +} + /// Try to extract the encryption password from given password. The password is always stored /// encrypted in the system catalogs. The ENCRYPTED keyword has no effect, but is accepted for /// backwards compatibility. The method of encryption is by default SHA-256-encrypted. If the @@ -53,11 +79,13 @@ pub fn encrypted_password(name: &str, password: &str) -> Option { Some(AuthInfo { encryption_type: EncryptionType::Sha256 as i32, encrypted_value: password.trim_start_matches(SHA256_ENCRYPTED_PREFIX).into(), + metadata: HashMap::new(), }) } else if valid_md5_password(password) { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: password.trim_start_matches(MD5_ENCRYPTED_PREFIX).into(), + metadata: HashMap::new(), }) } else { Some(encrypt_default(name, password)) @@ -70,6 +98,7 @@ fn encrypt_default(name: &str, password: &str) -> AuthInfo { AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(name, password), + metadata: HashMap::new(), } } @@ -81,6 +110,7 @@ pub fn encrypted_raw_password(info: &AuthInfo) -> String { EncryptionType::Plaintext => "", EncryptionType::Sha256 => SHA256_ENCRYPTED_PREFIX, EncryptionType::Md5 => MD5_ENCRYPTED_PREFIX, + EncryptionType::Oauth => "", }; format!("{}{}", prefix, encrypted_pwd) } @@ -156,15 +186,18 @@ mod tests { Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(user_name, password), + metadata: HashMap::new(), }), None, Some(AuthInfo { encryption_type: EncryptionType::Md5 as i32, encrypted_value: md5_hash(user_name, password), + metadata: HashMap::new(), }), Some(AuthInfo { encryption_type: EncryptionType::Sha256 as i32, encrypted_value: sha256_hash(user_name, password), + metadata: HashMap::new(), }), ]; let output_passwords = input_passwords diff --git a/src/sqlparser/src/ast/statement.rs b/src/sqlparser/src/ast/statement.rs index 85ce47a7bc7b..1b73edc1150d 100644 --- a/src/sqlparser/src/ast/statement.rs +++ b/src/sqlparser/src/ast/statement.rs @@ -755,6 +755,7 @@ pub enum UserOption { NoLogin, EncryptedPassword(AstString), Password(Option), + OAuth(Vec), } impl fmt::Display for UserOption { @@ -771,6 +772,9 @@ impl fmt::Display for UserOption { UserOption::EncryptedPassword(p) => write!(f, "ENCRYPTED PASSWORD {}", p), UserOption::Password(None) => write!(f, "PASSWORD NULL"), UserOption::Password(Some(p)) => write!(f, "PASSWORD {}", p), + UserOption::OAuth(options) => { + write!(f, "({})", display_comma_separated(options.as_slice())) + } } } } @@ -858,10 +862,14 @@ impl ParseTo for UserOptions { UserOption::EncryptedPassword(AstString::parse_to(parser)?), ) } + Keyword::OAUTH => { + let options = parser.parse_options()?; + (&mut builder.password, UserOption::OAuth(options)) + } _ => { parser.expected( "SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN \ - | NOLOGIN | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL", + | NOLOGIN | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL | OAUTH", token, )?; unreachable!() @@ -871,7 +879,7 @@ impl ParseTo for UserOptions { } else { parser.expected( "SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN | NOLOGIN \ - | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL", + | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL | OAUTH", token, )? } diff --git a/src/sqlparser/src/keywords.rs b/src/sqlparser/src/keywords.rs index a82c1c1c04c6..a3cc9013a21e 100644 --- a/src/sqlparser/src/keywords.rs +++ b/src/sqlparser/src/keywords.rs @@ -344,6 +344,7 @@ define_keywords!( NULLIF, NULLS, NUMERIC, + OAUTH, OBJECT, OCCURRENCES_REGEX, OCTET_LENGTH, diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index c15237199330..22f035002414 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -2379,7 +2379,7 @@ impl Parser { // | CREATEDB | NOCREATEDB // | CREATEUSER | NOCREATEUSER // | LOGIN | NOLOGIN - // | [ ENCRYPTED ] PASSWORD 'password' | PASSWORD NULL + // | [ ENCRYPTED ] PASSWORD 'password' | PASSWORD NULL | OAUTH fn parse_create_user(&mut self) -> Result { Ok(Statement::CreateUser(CreateUserStatement::parse_to(self)?)) } diff --git a/src/storage/src/hummock/sstable_store.rs b/src/storage/src/hummock/sstable_store.rs index c603b7d8f503..f0cacf863fcc 100644 --- a/src/storage/src/hummock/sstable_store.rs +++ b/src/storage/src/hummock/sstable_store.rs @@ -1020,9 +1020,9 @@ impl SstableWriter for StreamingUploadWriter { } async fn finish(mut self, meta: SstableMeta) -> HummockResult { - let meta_data = Bytes::from(meta.encode_to_bytes()); + let metadata = Bytes::from(meta.encode_to_bytes()); - self.object_uploader.write_bytes(meta_data).await?; + self.object_uploader.write_bytes(metadata).await?; let join_handle = tokio::spawn(async move { let uploader_memory_usage = self.object_uploader.get_memory_usage(); let _tracker = self.tracker.map(|mut t| { diff --git a/src/utils/pgwire/Cargo.toml b/src/utils/pgwire/Cargo.toml index 0e5b4e98faef..47840b0cf498 100644 --- a/src/utils/pgwire/Cargo.toml +++ b/src/utils/pgwire/Cargo.toml @@ -21,11 +21,15 @@ byteorder = "1.5" bytes = "1" futures = { version = "0.3", default-features = false, features = ["alloc"] } itertools = "0.12" +jsonwebtoken = "9" openssl = "0.10.60" panic-message = "0.3" parking_lot = "0.12" +reqwest = { version = "0.11" } risingwave_common = { workspace = true } risingwave_sqlparser = { workspace = true } +serde = { version = "1", features = ["derive"] } +serde_json = "1" thiserror = "1" thiserror-ext = { workspace = true } tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "macros"] } diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 83da9f5dc058..18411b1a0235 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -387,7 +387,7 @@ where match msg { FeMessage::Ssl => self.process_ssl_msg().await?, FeMessage::Startup(msg) => self.process_startup_msg(msg)?, - FeMessage::Password(msg) => self.process_password_msg(msg)?, + FeMessage::Password(msg) => self.process_password_msg(msg).await?, FeMessage::Query(query_msg) => self.process_query_msg(query_msg.get_sql()).await?, FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?, FeMessage::Terminate => self.process_terminate(), @@ -508,7 +508,7 @@ where })?; self.ready_for_query()?; } - UserAuthenticator::ClearText(_) => { + UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => { self.stream .write_no_flush(&BeMessage::AuthenticationCleartextPassword)?; } @@ -523,11 +523,9 @@ where Ok(()) } - fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> { + async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> { let authenticator = self.session.as_ref().unwrap().user_authenticator(); - if !authenticator.authenticate(&msg.password) { - return Err(PsqlError::PasswordError); - } + authenticator.authenticate(&msg.password).await?; self.stream.write_no_flush(&BeMessage::AuthenticationOk)?; self.stream .write_parameter_status_msg_no_flush(&ParameterStatus::default())?; diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index e545c8a2d724..7f6dd41368d4 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -12,20 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::future::Future; use std::io; use std::result::Result; +use std::str::FromStr; use std::sync::Arc; use std::time::Instant; use bytes::Bytes; +use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation}; use parking_lot::Mutex; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; +use serde::Deserialize; use thiserror_ext::AsReport; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::error::PsqlResult; +use crate::error::{PsqlError, PsqlResult}; use crate::net::{AddressRef, Listener}; use crate::pg_field_descriptor::PgFieldDescriptor; use crate::pg_message::TransactionStatus; @@ -155,17 +159,91 @@ pub enum UserAuthenticator { encrypted_password: Vec, salt: [u8; 4], }, + OAuth(HashMap), +} + +/// A JWK Set is a JSON object that represents a set of JWKs. +/// The JSON object MUST have a "keys" member, with its value being an array of JWKs. +/// See for more details. +#[derive(Debug, Deserialize)] +struct Jwks { + keys: Vec, +} + +/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key. +/// See for more details. +#[derive(Debug, Deserialize)] +struct Jwk { + kid: String, // Key ID + alg: String, // Algorithm + n: String, // Modulus + e: String, // Exponent +} + +async fn validate_jwt( + jwt: &str, + jwks_url: &str, + issuer: &str, + metadata: &HashMap, +) -> Result { + let header = decode_header(jwt)?; + let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?; + + // 1. Retrieve the kid from the header to find the right JWK in the JWK Set. + let kid = header.kid.ok_or("kid not found in jwt header")?; + let jwk = jwks + .keys + .into_iter() + .find(|k| k.kid == kid) + .ok_or("kid not found in jwks")?; + + // 2. Check if the algorithms are matched. + if Algorithm::from_str(&jwk.alg)? != header.alg { + return Err("alg in jwt header does not match with alg in jwk".into()); + } + + // 3. Decode the JWT and validate the claims. + let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?; + let mut validation = Validation::new(header.alg); + validation.set_issuer(&[issuer]); + validation.set_required_spec_claims(&["exp", "iss"]); + let token_data = decode::>(jwt, &decoding_key, &validation)?; + + // 4. Check if the metadata in the token matches. + if !metadata.iter().all( + |(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v), + ) { + return Err("metadata in jwt does not match with metadata declared with user".into()); + } + Ok(true) } impl UserAuthenticator { - pub fn authenticate(&self, password: &[u8]) -> bool { - match self { + pub async fn authenticate(&self, password: &[u8]) -> PsqlResult<()> { + let success = match self { UserAuthenticator::None => true, UserAuthenticator::ClearText(text) => password == text, UserAuthenticator::Md5WithSalt { encrypted_password, .. } => encrypted_password == password, + UserAuthenticator::OAuth(metadata) => { + let mut metadata = metadata.clone(); + let jwks_url = metadata.remove("jwks_url").unwrap(); + let issuer = metadata.remove("issuer").unwrap(); + validate_jwt( + &String::from_utf8_lossy(password), + &jwks_url, + &issuer, + &metadata, + ) + .await + .map_err(PsqlError::StartupError)? + } + }; + if !success { + return Err(PsqlError::PasswordError); } + Ok(()) } }