Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Rossil2012 committed Mar 1, 2024
1 parent bde884f commit b5bc317
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 35 deletions.
2 changes: 1 addition & 1 deletion proto/user.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ message AuthInfo {
}
EncryptionType encryption_type = 1;
bytes encrypted_value = 2;
map<string, string> meta_data = 3;
map<string, string> metadata = 3;
}

// User defines a user in the system.
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/handler/alter_user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ mod tests {
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec(),
meta_data: HashMap::new(),
metadata: HashMap::new(),
})
);
}
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/handler/create_user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ mod tests {
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec(),
meta_data: HashMap::new(),
metadata: HashMap::new(),
})
);
frontend
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ impl SessionManager for SessionManagerImpl {
salt,
}
} else if auth_info.encryption_type == EncryptionType::Oauth as i32 {
UserAuthenticator::OAuth(auth_info.meta_data.clone())
UserAuthenticator::OAuth(auth_info.metadata.clone())
} else {
return Err(Box::new(Error::new(
ErrorKind::Unsupported,
Expand Down
18 changes: 9 additions & 9 deletions src/frontend/src/user/user_authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ pub const OAUTH_ISSUER_KEY: &str = "issuer";
/// Build `AuthInfo` for `OAuth`.
#[inline(always)]
pub fn build_oauth_info(options: &Vec<SqlOption>) -> Option<AuthInfo> {
let meta_data: HashMap<String, String> = WithOptions::try_from(options.as_slice())
let metadata: HashMap<String, String> = WithOptions::try_from(options.as_slice())
.ok()?
.into_inner()
.into_iter()
.collect();
if !meta_data.contains_key(OAUTH_JWKS_URL_KEY) || !meta_data.contains_key(OAUTH_ISSUER_KEY) {
if !metadata.contains_key(OAUTH_JWKS_URL_KEY) || !metadata.contains_key(OAUTH_ISSUER_KEY) {
return None;
}
Some(AuthInfo {
encryption_type: EncryptionType::Oauth as i32,
encrypted_value: Vec::new(),
meta_data,
metadata,
})
}

Expand Down Expand Up @@ -79,13 +79,13 @@ pub fn encrypted_password(name: &str, password: &str) -> Option<AuthInfo> {
Some(AuthInfo {
encryption_type: EncryptionType::Sha256 as i32,
encrypted_value: password.trim_start_matches(SHA256_ENCRYPTED_PREFIX).into(),
meta_data: HashMap::new(),
metadata: HashMap::new(),
})
} else if valid_md5_password(password) {
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: password.trim_start_matches(MD5_ENCRYPTED_PREFIX).into(),
meta_data: HashMap::new(),
metadata: HashMap::new(),
})
} else {
Some(encrypt_default(name, password))
Expand All @@ -98,7 +98,7 @@ fn encrypt_default(name: &str, password: &str) -> AuthInfo {
AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(name, password),
meta_data: HashMap::new(),
metadata: HashMap::new(),
}
}

Expand Down Expand Up @@ -186,18 +186,18 @@ mod tests {
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(user_name, password),
meta_data: HashMap::new(),
metadata: HashMap::new(),
}),
None,
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(user_name, password),
meta_data: HashMap::new(),
metadata: HashMap::new(),
}),
Some(AuthInfo {
encryption_type: EncryptionType::Sha256 as i32,
encrypted_value: sha256_hash(user_name, password),
meta_data: HashMap::new(),
metadata: HashMap::new(),
}),
];
let output_passwords = input_passwords
Expand Down
4 changes: 2 additions & 2 deletions src/storage/src/hummock/sstable_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1020,9 +1020,9 @@ impl SstableWriter for StreamingUploadWriter {
}

async fn finish(mut self, meta: SstableMeta) -> HummockResult<UploadJoinHandle> {
let meta_data = Bytes::from(meta.encode_to_bytes());
let metadata = Bytes::from(meta.encode_to_bytes());

self.object_uploader.write_bytes(meta_data).await?;
self.object_uploader.write_bytes(metadata).await?;
let join_handle = tokio::spawn(async move {
let uploader_memory_usage = self.object_uploader.get_memory_usage();
let _tracker = self.tracker.map(|mut t| {
Expand Down
51 changes: 31 additions & 20 deletions src/utils/pgwire/src/pg_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ use std::collections::HashMap;
use std::future::Future;
use std::io;
use std::result::Result;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Instant;

use bytes::Bytes;
use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use parking_lot::Mutex;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::Statement;
Expand Down Expand Up @@ -161,50 +162,60 @@ pub enum UserAuthenticator {
OAuth(HashMap<String, String>),
}

/// A JWK Set is a JSON object that represents a set of JWKs.
/// The JSON object MUST have a "keys" member, with its value being an array of JWKs.
/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-5> for more details.
#[derive(Debug, Deserialize)]
struct Jwks {
keys: Vec<Jwk>,
}

#[allow(dead_code)]
/// A JSON Web Key (JWK) is a JSON object that represents a cryptographic key.
/// See <https://www.rfc-editor.org/rfc/rfc7517.html#section-4> for more details.
#[derive(Debug, Deserialize)]
struct Jwk {
kid: String,
alg: String,
n: String,
e: String,
}

async fn fetch_jwks(url: &str) -> Result<Jwks, reqwest::Error> {
let resp: Jwks = reqwest::get(url).await?.json().await?;
Ok(resp)
kid: String, // Key ID
alg: String, // Algorithm
n: String, // Modulus
e: String, // Exponent
}

async fn validate_jwt(
jwt: &str,
jwks_url: &str,
issuer: &str,
meta_data: &HashMap<String, String>,
metadata: &HashMap<String, String>,
) -> Result<bool, BoxedError> {
let header = decode_header(jwt)?;
let jwks = fetch_jwks(jwks_url).await?;
let jwks: Jwks = reqwest::get(jwks_url).await?.json().await?;

// 1. Retrieve the kid from the header to find the right JWK in the JWK Set.
let kid = header.kid.ok_or("kid not found in jwt header")?;
let jwk = jwks
.keys
.into_iter()
.find(|k| k.kid == kid)
.ok_or("kid not found in jwks")?;

// 2. Check if the algorithms are matched.
if Algorithm::from_str(&jwk.alg)? != header.alg {
return Err("alg in jwt header does not match with alg in jwk".into());
}

// 3. Decode the JWT and validate the claims.
let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
let mut validation = Validation::new(header.alg);
validation.set_issuer(&[issuer]);
validation.set_required_spec_claims(&["exp", "iss"]);
let token_data = decode::<HashMap<String, serde_json::Value>>(jwt, &decoding_key, &validation)?;

Ok(meta_data.iter().all(
// 4. Check if the metadata in the token matches.
if !metadata.iter().all(
|(k, v)| matches!(token_data.claims.get(k), Some(serde_json::Value::String(s)) if s == v),
))
) {
return Err("metadata in jwt does not match with metadata declared with user".into());
}
Ok(true)
}

impl UserAuthenticator {
Expand All @@ -215,15 +226,15 @@ impl UserAuthenticator {
UserAuthenticator::Md5WithSalt {
encrypted_password, ..
} => encrypted_password == password,
UserAuthenticator::OAuth(meta_data) => {
let mut meta_data = meta_data.clone();
let jwks_url = meta_data.remove("jwks_url").unwrap();
let issuer = meta_data.remove("issuer").unwrap();
UserAuthenticator::OAuth(metadata) => {
let mut metadata = metadata.clone();
let jwks_url = metadata.remove("jwks_url").unwrap();
let issuer = metadata.remove("issuer").unwrap();
validate_jwt(
&String::from_utf8_lossy(password),
&jwks_url,
&issuer,
&meta_data,
&metadata,
)
.await
.map_err(PsqlError::StartupError)?
Expand Down

0 comments on commit b5bc317

Please sign in to comment.