Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(frontend): implement OAuth authentication #13151

Merged
merged 17 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions proto/user.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ message AuthInfo {
PLAINTEXT = 1;
SHA256 = 2;
MD5 = 3;
OAUTH = 4;
}
EncryptionType encryption_type = 1;
bytes encrypted_value = 2;
map<string, string> metadata = 3;
}

// User defines a user in the system.
Expand Down
21 changes: 18 additions & 3 deletions src/frontend/src/handler/alter_user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ use risingwave_sqlparser::ast::{AlterUserStatement, ObjectName, UserOption, User
use super::RwPgResponse;
use crate::binder::Binder;
use crate::catalog::CatalogError;
use crate::error::ErrorCode::{InternalError, PermissionDenied};
use crate::error::ErrorCode::{self, InternalError, PermissionDenied};
use crate::error::Result;
use crate::handler::HandlerArgs;
use crate::user::user_authentication::encrypted_password;
use crate::user::user_authentication::{
build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY,
};
use crate::user::user_catalog::UserCatalog;

fn alter_prost_user_info(
Expand Down Expand Up @@ -111,6 +113,16 @@ fn alter_prost_user_info(
}
update_fields.push(UpdateField::AuthInfo);
}
UserOption::OAuth(options) => {
let auth_info = build_oauth_info(options).ok_or_else(|| {
ErrorCode::InvalidParameterValue(format!(
"{} and {} must be provided",
OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY
))
})?;
user_info.auth_info = Some(auth_info);
update_fields.push(UpdateField::AuthInfo)
}
}
}
Ok((user_info, update_fields))
Expand Down Expand Up @@ -181,6 +193,8 @@ pub async fn handle_alter_user(

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use risingwave_pb::user::auth_info::EncryptionType;
use risingwave_pb::user::AuthInfo;

Expand Down Expand Up @@ -219,7 +233,8 @@ mod tests {
user_info.auth_info,
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec()
encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec(),
metadata: HashMap::new(),
})
);
}
Expand Down
20 changes: 17 additions & 3 deletions src/frontend/src/handler/create_user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ use risingwave_sqlparser::ast::{CreateUserStatement, UserOption, UserOptions};
use super::RwPgResponse;
use crate::binder::Binder;
use crate::catalog::{CatalogError, DatabaseId};
use crate::error::ErrorCode::PermissionDenied;
use crate::error::ErrorCode::{self, PermissionDenied};
use crate::error::Result;
use crate::handler::HandlerArgs;
use crate::user::user_authentication::encrypted_password;
use crate::user::user_authentication::{
build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY,
};
use crate::user::user_catalog::UserCatalog;

fn make_prost_user_info(
Expand Down Expand Up @@ -91,6 +93,15 @@ fn make_prost_user_info(
user_info.auth_info = encrypted_password(&user_info.name, &password.0);
}
}
UserOption::OAuth(options) => {
let auth_info = build_oauth_info(options).ok_or_else(|| {
ErrorCode::InvalidParameterValue(format!(
"{} and {} must be provided",
OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY
))
})?;
user_info.auth_info = Some(auth_info);
}
}
}

Expand Down Expand Up @@ -130,6 +141,8 @@ pub async fn handle_create_user(

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use risingwave_common::catalog::DEFAULT_DATABASE_NAME;
use risingwave_pb::user::auth_info::EncryptionType;
use risingwave_pb::user::AuthInfo;
Expand Down Expand Up @@ -157,7 +170,8 @@ mod tests {
user_info.auth_info,
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec()
encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec(),
metadata: HashMap::new(),
})
);
frontend
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,8 @@ impl SessionManager for SessionManagerImpl {
),
salt,
}
} else if auth_info.encryption_type == EncryptionType::Oauth as i32 {
UserAuthenticator::OAuth(auth_info.metadata.clone())
} else {
return Err(Box::new(Error::new(
ErrorKind::Unsupported,
Expand Down
33 changes: 33 additions & 0 deletions src/frontend/src/user/user_authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;

use risingwave_pb::user::auth_info::EncryptionType;
use risingwave_pb::user::AuthInfo;
use risingwave_sqlparser::ast::SqlOption;
use sha2::{Digest, Sha256};

use crate::WithOptions;

// SHA-256 is not supported in PostgreSQL protocol. We need to implement SCRAM-SHA-256 instead
// if necessary.
const SHA256_ENCRYPTED_PREFIX: &str = "SHA-256:";
Expand All @@ -24,6 +29,27 @@ const MD5_ENCRYPTED_PREFIX: &str = "md5";
const VALID_SHA256_ENCRYPTED_LEN: usize = SHA256_ENCRYPTED_PREFIX.len() + 64;
const VALID_MD5_ENCRYPTED_LEN: usize = MD5_ENCRYPTED_PREFIX.len() + 32;

pub const OAUTH_JWKS_URL_KEY: &str = "jwks_url";
pub const OAUTH_ISSUER_KEY: &str = "issuer";

/// Build `AuthInfo` for `OAuth`.
#[inline(always)]
pub fn build_oauth_info(options: &Vec<SqlOption>) -> Option<AuthInfo> {
let metadata: HashMap<String, String> = WithOptions::try_from(options.as_slice())
.ok()?
.into_inner()
.into_iter()
.collect();
if !metadata.contains_key(OAUTH_JWKS_URL_KEY) || !metadata.contains_key(OAUTH_ISSUER_KEY) {
return None;
}
Some(AuthInfo {
encryption_type: EncryptionType::Oauth as i32,
encrypted_value: Vec::new(),
metadata,
})
}

/// Try to extract the encryption password from given password. The password is always stored
/// encrypted in the system catalogs. The ENCRYPTED keyword has no effect, but is accepted for
/// backwards compatibility. The method of encryption is by default SHA-256-encrypted. If the
Expand Down Expand Up @@ -53,11 +79,13 @@ pub fn encrypted_password(name: &str, password: &str) -> Option<AuthInfo> {
Some(AuthInfo {
encryption_type: EncryptionType::Sha256 as i32,
encrypted_value: password.trim_start_matches(SHA256_ENCRYPTED_PREFIX).into(),
metadata: HashMap::new(),
})
} else if valid_md5_password(password) {
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: password.trim_start_matches(MD5_ENCRYPTED_PREFIX).into(),
metadata: HashMap::new(),
})
} else {
Some(encrypt_default(name, password))
Expand All @@ -70,6 +98,7 @@ fn encrypt_default(name: &str, password: &str) -> AuthInfo {
AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(name, password),
metadata: HashMap::new(),
}
}

Expand All @@ -81,6 +110,7 @@ pub fn encrypted_raw_password(info: &AuthInfo) -> String {
EncryptionType::Plaintext => "",
EncryptionType::Sha256 => SHA256_ENCRYPTED_PREFIX,
EncryptionType::Md5 => MD5_ENCRYPTED_PREFIX,
EncryptionType::Oauth => "",
};
format!("{}{}", prefix, encrypted_pwd)
}
Expand Down Expand Up @@ -156,15 +186,18 @@ mod tests {
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(user_name, password),
metadata: HashMap::new(),
}),
None,
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(user_name, password),
metadata: HashMap::new(),
}),
Some(AuthInfo {
encryption_type: EncryptionType::Sha256 as i32,
encrypted_value: sha256_hash(user_name, password),
metadata: HashMap::new(),
}),
];
let output_passwords = input_passwords
Expand Down
12 changes: 10 additions & 2 deletions src/sqlparser/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,7 @@ pub enum UserOption {
NoLogin,
EncryptedPassword(AstString),
Password(Option<AstString>),
OAuth(Vec<SqlOption>),
}

impl fmt::Display for UserOption {
Expand All @@ -771,6 +772,9 @@ impl fmt::Display for UserOption {
UserOption::EncryptedPassword(p) => write!(f, "ENCRYPTED PASSWORD {}", p),
UserOption::Password(None) => write!(f, "PASSWORD NULL"),
UserOption::Password(Some(p)) => write!(f, "PASSWORD {}", p),
UserOption::OAuth(options) => {
write!(f, "({})", display_comma_separated(options.as_slice()))
}
}
}
}
Expand Down Expand Up @@ -858,10 +862,14 @@ impl ParseTo for UserOptions {
UserOption::EncryptedPassword(AstString::parse_to(parser)?),
)
}
Keyword::OAUTH => {
let options = parser.parse_options()?;
(&mut builder.password, UserOption::OAuth(options))
}
_ => {
parser.expected(
"SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN \
| NOLOGIN | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL",
| NOLOGIN | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL | OAUTH",
token,
)?;
unreachable!()
Expand All @@ -871,7 +879,7 @@ impl ParseTo for UserOptions {
} else {
parser.expected(
"SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN | NOLOGIN \
| CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL",
| CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL | OAUTH",
token,
)?
}
Expand Down
1 change: 1 addition & 0 deletions src/sqlparser/src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ define_keywords!(
NULLIF,
NULLS,
NUMERIC,
OAUTH,
OBJECT,
OCCURRENCES_REGEX,
OCTET_LENGTH,
Expand Down
2 changes: 1 addition & 1 deletion src/sqlparser/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2379,7 +2379,7 @@ impl Parser {
// | CREATEDB | NOCREATEDB
// | CREATEUSER | NOCREATEUSER
// | LOGIN | NOLOGIN
// | [ ENCRYPTED ] PASSWORD 'password' | PASSWORD NULL
// | [ ENCRYPTED ] PASSWORD 'password' | PASSWORD NULL | OAUTH
fn parse_create_user(&mut self) -> Result<Statement, ParserError> {
Ok(Statement::CreateUser(CreateUserStatement::parse_to(self)?))
}
Expand Down
4 changes: 2 additions & 2 deletions src/storage/src/hummock/sstable_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1020,9 +1020,9 @@ impl SstableWriter for StreamingUploadWriter {
}

async fn finish(mut self, meta: SstableMeta) -> HummockResult<UploadJoinHandle> {
let meta_data = Bytes::from(meta.encode_to_bytes());
let metadata = Bytes::from(meta.encode_to_bytes());

self.object_uploader.write_bytes(meta_data).await?;
self.object_uploader.write_bytes(metadata).await?;
let join_handle = tokio::spawn(async move {
let uploader_memory_usage = self.object_uploader.get_memory_usage();
let _tracker = self.tracker.map(|mut t| {
Expand Down
4 changes: 4 additions & 0 deletions src/utils/pgwire/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ byteorder = "1.5"
bytes = "1"
futures = { version = "0.3", default-features = false, features = ["alloc"] }
itertools = "0.12"
jsonwebtoken = "9"
openssl = "0.10.60"
panic-message = "0.3"
parking_lot = "0.12"
reqwest = { version = "0.11" }
risingwave_common = { workspace = true }
risingwave_sqlparser = { workspace = true }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"
thiserror-ext = { workspace = true }
tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "macros"] }
Expand Down
10 changes: 4 additions & 6 deletions src/utils/pgwire/src/pg_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ where
match msg {
FeMessage::Ssl => self.process_ssl_msg().await?,
FeMessage::Startup(msg) => self.process_startup_msg(msg)?,
FeMessage::Password(msg) => self.process_password_msg(msg)?,
FeMessage::Password(msg) => self.process_password_msg(msg).await?,
FeMessage::Query(query_msg) => self.process_query_msg(query_msg.get_sql()).await?,
FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
FeMessage::Terminate => self.process_terminate(),
Expand Down Expand Up @@ -508,7 +508,7 @@ where
})?;
self.ready_for_query()?;
}
UserAuthenticator::ClearText(_) => {
UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => {
self.stream
.write_no_flush(&BeMessage::AuthenticationCleartextPassword)?;
}
Expand All @@ -523,11 +523,9 @@ where
Ok(())
}

fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
let authenticator = self.session.as_ref().unwrap().user_authenticator();
if !authenticator.authenticate(&msg.password) {
return Err(PsqlError::PasswordError);
}
authenticator.authenticate(&msg.password).await?;
self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
self.stream
.write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
Expand Down
Loading
Loading