Skip to content

Commit

Permalink
feat(frontend): implement OAuth authentication (#13151)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rossil2012 committed Mar 5, 2024
1 parent 5e6e2ca commit be74d82
Show file tree
Hide file tree
Showing 13 changed files with 179 additions and 20 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions proto/user.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ message AuthInfo {
PLAINTEXT = 1;
SHA256 = 2;
MD5 = 3;
OAUTH = 4;
}
EncryptionType encryption_type = 1;
bytes encrypted_value = 2;
map<string, string> metadata = 3;
}

// User defines a user in the system.
Expand Down
21 changes: 18 additions & 3 deletions src/frontend/src/handler/alter_user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use pgwire::pg_response::{PgResponse, StatementType};
use risingwave_common::error::ErrorCode::{InternalError, PermissionDenied};
use risingwave_common::error::ErrorCode::{self, InternalError, PermissionDenied};
use risingwave_common::error::Result;
use risingwave_pb::user::update_user_request::UpdateField;
use risingwave_pb::user::UserInfo;
Expand All @@ -23,7 +23,9 @@ use super::RwPgResponse;
use crate::binder::Binder;
use crate::catalog::CatalogError;
use crate::handler::HandlerArgs;
use crate::user::user_authentication::encrypted_password;
use crate::user::user_authentication::{
build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY,
};
use crate::user::user_catalog::UserCatalog;

fn alter_prost_user_info(
Expand Down Expand Up @@ -111,6 +113,16 @@ fn alter_prost_user_info(
}
update_fields.push(UpdateField::AuthInfo);
}
UserOption::OAuth(options) => {
let auth_info = build_oauth_info(options).ok_or_else(|| {
ErrorCode::InvalidParameterValue(format!(
"{} and {} must be provided",
OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY
))
})?;
user_info.auth_info = Some(auth_info);
update_fields.push(UpdateField::AuthInfo)
}
}
}
Ok((user_info, update_fields))
Expand Down Expand Up @@ -181,6 +193,8 @@ pub async fn handle_alter_user(

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use risingwave_pb::user::auth_info::EncryptionType;
use risingwave_pb::user::AuthInfo;

Expand Down Expand Up @@ -219,7 +233,8 @@ mod tests {
user_info.auth_info,
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec()
encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec(),
metadata: HashMap::new(),
})
);
}
Expand Down
20 changes: 17 additions & 3 deletions src/frontend/src/handler/create_user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use pgwire::pg_response::{PgResponse, StatementType};
use risingwave_common::error::ErrorCode::PermissionDenied;
use risingwave_common::error::ErrorCode::{self, PermissionDenied};
use risingwave_common::error::Result;
use risingwave_pb::user::grant_privilege::{Action, ActionWithGrantOption, Object};
use risingwave_pb::user::{GrantPrivilege, UserInfo};
Expand All @@ -23,7 +23,9 @@ use super::RwPgResponse;
use crate::binder::Binder;
use crate::catalog::{CatalogError, DatabaseId};
use crate::handler::HandlerArgs;
use crate::user::user_authentication::encrypted_password;
use crate::user::user_authentication::{
build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY,
};
use crate::user::user_catalog::UserCatalog;

fn make_prost_user_info(
Expand Down Expand Up @@ -91,6 +93,15 @@ fn make_prost_user_info(
user_info.auth_info = encrypted_password(&user_info.name, &password.0);
}
}
UserOption::OAuth(options) => {
let auth_info = build_oauth_info(options).ok_or_else(|| {
ErrorCode::InvalidParameterValue(format!(
"{} and {} must be provided",
OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY
))
})?;
user_info.auth_info = Some(auth_info);
}
}
}

Expand Down Expand Up @@ -130,6 +141,8 @@ pub async fn handle_create_user(

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use risingwave_common::catalog::DEFAULT_DATABASE_NAME;
use risingwave_pb::user::auth_info::EncryptionType;
use risingwave_pb::user::AuthInfo;
Expand Down Expand Up @@ -157,7 +170,8 @@ mod tests {
user_info.auth_info,
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec()
encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec(),
metadata: HashMap::new(),
})
);
frontend
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,8 @@ impl SessionManager for SessionManagerImpl {
),
salt,
}
} else if auth_info.encryption_type == EncryptionType::Oauth as i32 {
UserAuthenticator::OAuth(auth_info.metadata.clone())
} else {
return Err(Box::new(Error::new(
ErrorKind::Unsupported,
Expand Down
33 changes: 33 additions & 0 deletions src/frontend/src/user/user_authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;

use risingwave_pb::user::auth_info::EncryptionType;
use risingwave_pb::user::AuthInfo;
use risingwave_sqlparser::ast::SqlOption;
use sha2::{Digest, Sha256};

use crate::WithOptions;

// SHA-256 is not supported in PostgreSQL protocol. We need to implement SCRAM-SHA-256 instead
// if necessary.
const SHA256_ENCRYPTED_PREFIX: &str = "SHA-256:";
Expand All @@ -24,6 +29,27 @@ const MD5_ENCRYPTED_PREFIX: &str = "md5";
const VALID_SHA256_ENCRYPTED_LEN: usize = SHA256_ENCRYPTED_PREFIX.len() + 64;
const VALID_MD5_ENCRYPTED_LEN: usize = MD5_ENCRYPTED_PREFIX.len() + 32;

pub const OAUTH_JWKS_URL_KEY: &str = "jwks_url";
pub const OAUTH_ISSUER_KEY: &str = "issuer";

/// Build `AuthInfo` for `OAuth`.
#[inline(always)]
pub fn build_oauth_info(options: &Vec<SqlOption>) -> Option<AuthInfo> {
let metadata: HashMap<String, String> = WithOptions::try_from(options.as_slice())
.ok()?
.into_inner()
.into_iter()
.collect();
if !metadata.contains_key(OAUTH_JWKS_URL_KEY) || !metadata.contains_key(OAUTH_ISSUER_KEY) {
return None;
}
Some(AuthInfo {
encryption_type: EncryptionType::Oauth as i32,
encrypted_value: Vec::new(),
metadata,
})
}

/// Try to extract the encryption password from given password. The password is always stored
/// encrypted in the system catalogs. The ENCRYPTED keyword has no effect, but is accepted for
/// backwards compatibility. The method of encryption is by default SHA-256-encrypted. If the
Expand Down Expand Up @@ -53,11 +79,13 @@ pub fn encrypted_password(name: &str, password: &str) -> Option<AuthInfo> {
Some(AuthInfo {
encryption_type: EncryptionType::Sha256 as i32,
encrypted_value: password.trim_start_matches(SHA256_ENCRYPTED_PREFIX).into(),
metadata: HashMap::new(),
})
} else if valid_md5_password(password) {
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: password.trim_start_matches(MD5_ENCRYPTED_PREFIX).into(),
metadata: HashMap::new(),
})
} else {
Some(encrypt_default(name, password))
Expand All @@ -70,6 +98,7 @@ fn encrypt_default(name: &str, password: &str) -> AuthInfo {
AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(name, password),
metadata: HashMap::new(),
}
}

Expand All @@ -81,6 +110,7 @@ pub fn encrypted_raw_password(info: &AuthInfo) -> String {
EncryptionType::Plaintext => "",
EncryptionType::Sha256 => SHA256_ENCRYPTED_PREFIX,
EncryptionType::Md5 => MD5_ENCRYPTED_PREFIX,
EncryptionType::Oauth => "",
};
format!("{}{}", prefix, encrypted_pwd)
}
Expand Down Expand Up @@ -156,15 +186,18 @@ mod tests {
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(user_name, password),
metadata: HashMap::new(),
}),
None,
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(user_name, password),
metadata: HashMap::new(),
}),
Some(AuthInfo {
encryption_type: EncryptionType::Sha256 as i32,
encrypted_value: sha256_hash(user_name, password),
metadata: HashMap::new(),
}),
];
let output_passwords = input_passwords
Expand Down
12 changes: 10 additions & 2 deletions src/sqlparser/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ pub enum UserOption {
NoLogin,
EncryptedPassword(AstString),
Password(Option<AstString>),
OAuth(Vec<SqlOption>),
}

impl fmt::Display for UserOption {
Expand All @@ -731,6 +732,9 @@ impl fmt::Display for UserOption {
UserOption::EncryptedPassword(p) => write!(f, "ENCRYPTED PASSWORD {}", p),
UserOption::Password(None) => write!(f, "PASSWORD NULL"),
UserOption::Password(Some(p)) => write!(f, "PASSWORD {}", p),
UserOption::OAuth(options) => {
write!(f, "({})", display_comma_separated(options.as_slice()))
}
}
}
}
Expand Down Expand Up @@ -818,10 +822,14 @@ impl ParseTo for UserOptions {
UserOption::EncryptedPassword(AstString::parse_to(parser)?),
)
}
Keyword::OAUTH => {
let options = parser.parse_options()?;
(&mut builder.password, UserOption::OAuth(options))
}
_ => {
parser.expected(
"SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN \
| NOLOGIN | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL",
| NOLOGIN | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL | OAUTH",
token,
)?;
unreachable!()
Expand All @@ -831,7 +839,7 @@ impl ParseTo for UserOptions {
} else {
parser.expected(
"SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN | NOLOGIN \
| CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL",
| CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL | OAUTH",
token,
)?
}
Expand Down
1 change: 1 addition & 0 deletions src/sqlparser/src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ define_keywords!(
NULLIF,
NULLS,
NUMERIC,
OAUTH,
OBJECT,
OCCURRENCES_REGEX,
OCTET_LENGTH,
Expand Down
2 changes: 1 addition & 1 deletion src/sqlparser/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2379,7 +2379,7 @@ impl Parser {
// | CREATEDB | NOCREATEDB
// | CREATEUSER | NOCREATEUSER
// | LOGIN | NOLOGIN
// | [ ENCRYPTED ] PASSWORD 'password' | PASSWORD NULL
// | [ ENCRYPTED ] PASSWORD 'password' | PASSWORD NULL | OAUTH
fn parse_create_user(&mut self) -> Result<Statement, ParserError> {
Ok(Statement::CreateUser(CreateUserStatement::parse_to(self)?))
}
Expand Down
4 changes: 2 additions & 2 deletions src/storage/src/hummock/sstable_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1020,9 +1020,9 @@ impl SstableWriter for StreamingUploadWriter {
}

async fn finish(mut self, meta: SstableMeta) -> HummockResult<UploadJoinHandle> {
let meta_data = Bytes::from(meta.encode_to_bytes());
let metadata = Bytes::from(meta.encode_to_bytes());

self.object_uploader.write_bytes(meta_data).await?;
self.object_uploader.write_bytes(metadata).await?;
let join_handle = tokio::spawn(async move {
let uploader_memory_usage = self.object_uploader.get_memory_usage();
let _tracker = self.tracker.map(|mut t| {
Expand Down
4 changes: 4 additions & 0 deletions src/utils/pgwire/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ byteorder = "1.5"
bytes = "1"
futures = { version = "0.3", default-features = false, features = ["alloc"] }
itertools = "0.12"
jsonwebtoken = "9"
openssl = "0.10.60"
panic-message = "0.3"
parking_lot = "0.12"
reqwest = { version = "0.11" }
risingwave_common = { workspace = true }
risingwave_sqlparser = { workspace = true }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"
thiserror-ext = { workspace = true }
tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "macros"] }
Expand Down
10 changes: 4 additions & 6 deletions src/utils/pgwire/src/pg_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ where
match msg {
FeMessage::Ssl => self.process_ssl_msg().await?,
FeMessage::Startup(msg) => self.process_startup_msg(msg)?,
FeMessage::Password(msg) => self.process_password_msg(msg)?,
FeMessage::Password(msg) => self.process_password_msg(msg).await?,
FeMessage::Query(query_msg) => self.process_query_msg(query_msg.get_sql()).await?,
FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
FeMessage::Terminate => self.process_terminate(),
Expand Down Expand Up @@ -508,7 +508,7 @@ where
})?;
self.ready_for_query()?;
}
UserAuthenticator::ClearText(_) => {
UserAuthenticator::ClearText(_) | UserAuthenticator::OAuth(_) => {
self.stream
.write_no_flush(&BeMessage::AuthenticationCleartextPassword)?;
}
Expand All @@ -523,11 +523,9 @@ where
Ok(())
}

fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
async fn process_password_msg(&mut self, msg: FePasswordMessage) -> PsqlResult<()> {
let authenticator = self.session.as_ref().unwrap().user_authenticator();
if !authenticator.authenticate(&msg.password) {
return Err(PsqlError::PasswordError);
}
authenticator.authenticate(&msg.password).await?;
self.stream.write_no_flush(&BeMessage::AuthenticationOk)?;
self.stream
.write_parameter_status_msg_no_flush(&ParameterStatus::default())?;
Expand Down
Loading

0 comments on commit be74d82

Please sign in to comment.