diff --git a/Cargo.lock b/Cargo.lock index f95e94d..077a41c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "adler32" version = "1.2.0" @@ -131,6 +137,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "1.3.2" @@ -246,10 +258,12 @@ dependencies = [ "rouille", "serde", "serde_json", + "serde_urlencoded", "sha2", "thiserror", "time", "tracing", + "ureq", ] [[package]] @@ -462,6 +476,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "flate2" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -725,6 +749,15 @@ dependencies = [ "unicase", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "multipart" version = "0.18.0" @@ -1031,6 +1064,21 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rouille" version = "3.6.2" @@ -1068,6 +1116,38 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.23.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" + +[[package]] +name = "rustls-webpki" +version = "0.102.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84678086bd54edf2b415183ed7a94d0efb049f1b646a33e22a36f3794be6ae56" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "ryu" version = "1.0.18" @@ -1118,6 +1198,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sha1_smol" version = "1.0.1" @@ -1147,6 +1239,12 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "strsim" version = "0.11.1" @@ -1385,6 +1483,30 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b74fc6b57825be3373f7054754755f03ac3a8f5d70015ccad699ba2029956f4a" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "url", + "webpki-roots", +] + [[package]] name = "url" version = "2.5.2" @@ -1469,6 +1591,15 @@ version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +[[package]] +name = "webpki-roots" +version = "0.26.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bd24728e5af82c6c4ec1b66ac4844bdf8156257fccda846ec58b42cd0cdbe6a" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "winapi-util" version = "0.1.9" @@ -1598,3 +1729,9 @@ dependencies = [ "quote", "syn 2.0.77", ] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" diff --git a/Cargo.toml b/Cargo.toml index 6147be7..55cb0d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,10 +28,12 @@ rand = { version = "0.8.5", features = ["std_rng"] } rouille = "3.6.2" serde = { version = "1.0.145", features = ["derive"] } serde_json = "1.0.86" +serde_urlencoded = "0.7.1" sha2 = "0.10.6" thiserror = "1.0.37" time = "0.3.15" tracing = "0.1.37" +ureq = { version = "2.10.1", features = ["json"] } [package.metadata.deb] maintainer-scripts = "debian/" diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..db3a87b --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,243 @@ +use std::collections::HashSet; +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::RwLock; + +use base64::engine::general_purpose::URL_SAFE; +use base64::Engine; +use log::info; +use rand::distributions::Alphanumeric; +use rand::distributions::DistString; +use rouille::input; +use rouille::Request; +use rouille::Response; +use serde::Deserialize; +use serde::Serialize; + +const SESSION_COOKIE_NAME: &str = "session-id"; + +#[derive(Serialize)] +struct GoogleAuthParams<'a> { + response_type: &'a str, + client_id: &'a str, + scope: &'a str, + redirect_uri: &'a str, + state: &'a str, + nonce: &'a str, +} + +#[derive(Deserialize)] +#[allow(dead_code)] +struct CodeResponse { + // A token that can be sent to a Google API. + access_token: String, + // The remaining lifetime of the access token in seconds. + expires_in: i64, + // A JWT that contains identity information about the user that is digitally + // signed by Google. + id_token: String, + // The scopes of access granted by the access_token expressed as a list of + // space-delimited, case-sensitive strings. + scope: String, + // Identifies the type of token returned. At this time, this field always + // has the value Bearer. + token_type: String, + // This field is only present if the access_type parameter was set to offline + // in the authentication request. For details, see Refresh tokens. + refresh_token: Option, +} + +#[derive(Deserialize)] +#[allow(dead_code)] +struct Claims { + // The audience that this ID token is intended for. It must be one of the OAuth + // 2.0 client IDs of your application. + aud: String, + // Expiration time on or after which the ID token must not be accepted. Represented + // in Unix time (integer seconds). + exp: i64, + // The time the ID token was issued. Represented in Unix time (integer seconds). + iat: i64, + // The Issuer Identifier for the Issuer of the response. Always + // https://accounts.google.com or accounts.google.com for Google ID tokens. + iss: String, + // An identifier for the user, unique among all Google accounts and never reused. A + // Google account can have multiple email addresses at different points in time, but + // the sub value is never changed. Use sub within your application as the + // unique-identifier key for the user. Maximum length of 255 case-sensitive ASCII characters. + sub: String, + // Access token hash. Provides validation that the access token is tied to the identity + // token. If the ID token is issued with an access_token value in the server flow, this + // claim is always included. This claim can be used as an alternate mechanism to protect + // against cross-site request forgery attacks, but if you follow Step 1 and Step 3 it is + // not necessary to verify the access token. + at_hash: Option, + // The user's email address. Provided only if you included the email scope in your request. + // The value of this claim may not be unique to this account and could change over time, + // therefore you should not use this value as the primary identifier to link to your user + // record. You also can't rely on the domain of the email claim to identify users of + // Google Workspace or Cloud organizations; use the hd claim instead. + email: Option, + // True if the user's e-mail address has been verified; otherwise false. + email_verified: Option, + // The value of the nonce supplied by your app in the authentication request. You should + // enforce protection against replay attacks by ensuring it is presented only once. + nonce: Option, +} + +fn self_uri(req: &Request) -> String { + if let Some(host) = req.header("Host") { + let prefix = if req.is_secure() { + "https://" + } else { + "http://" + }; + format!("{prefix}{host}") + } else { + "".into() + } +} + +pub struct Authorizer { + nonces: Arc>>, + session_ids: Arc>>, + auth_tokens: HashSet, + allowed_emails: HashSet, + oidc_client_id: String, + oidc_client_secret: String, +} + +impl Authorizer { + pub fn new( + auth_tokens: HashSet, + allowed_emails: HashSet, + oidc_client_id: String, + oidc_client_secret: String, + ) -> Self { + Self { + nonces: Arc::new(Mutex::new(HashSet::new())), + session_ids: Arc::new(RwLock::new(HashSet::new())), + allowed_emails, + auth_tokens, + oidc_client_id, + oidc_client_secret, + } + } +} + +impl Authorizer { + fn is_authorized(&self, req: &Request) -> bool { + if let Some(auth_header) = req.header("Authorization") { + if self.auth_tokens.contains(auth_header) { + return true; + } + } + + if let Some((_, val)) = input::cookies(req).find(|&(n, _)| n == SESSION_COOKIE_NAME) { + // session_ids last for the lifetime of the program for simplicity. + if self.session_ids.read().unwrap().contains(val) { + return true; + } + } + + return false; + } + + fn process_code(&self, req: &Request) -> Response { + let state = match req.get_param("state") { + Some(s) => s, + None => return Response::text("missing state").with_status_code(400), + }; + let nonces = self.nonces.lock().unwrap(); + if !nonces.contains(&state) { + return Response::text("unknown state").with_status_code(400); + } + let code = match req.get_param("code") { + Some(c) => c, + None => return Response::text("missing code").with_status_code(400), + }; + let redirect_uri = self_uri(&req) + "/code"; + let resp = ureq::post("https://oauth2.googleapis.com/token") + .send_form(&[ + // The authorization code that is returned from the initial request. + ("code", &code), + // The client ID that you obtain from the API Console Credentials page, as + // described in Obtain OAuth 2.0 credentials. + ("client_id", &self.oidc_client_id), + // The client secret that you obtain from the API Console Credentials page, + // as described in Obtain OAuth 2.0 credentials. + ("client_secret", &self.oidc_client_secret), + // An authorized redirect URI for the given client_id specified in the API + // Console Credentials page, as described in Set a redirect URI. + ("redirect_uri", &redirect_uri), + // This field must contain a value of authorization_code, as defined in + // the OAuth 2.0 specification. + ("grant_type", "authorization_code"), + ]) + .unwrap(); + let parsed_resp: CodeResponse = resp.into_json().unwrap(); + let jsonclaims = URL_SAFE + .decode(&parsed_resp.id_token.split(".").skip(1).next().unwrap()) + .unwrap(); + let claims: Claims = serde_json::from_slice(&jsonclaims).unwrap(); + + // Check nonces + let nonce = claims.nonce.unwrap_or_default(); + let mut nonces = self.nonces.lock().unwrap(); + if !nonces.contains(&nonce) { + return Response::text("reused nonce").with_status_code(400); + } + nonces.remove(&nonce); + + let email = claims.email.unwrap_or_default(); + + // Make sure user is allowed + if !self.allowed_emails.contains(&email) { + info!("denied {email}"); + return Response::text("not authorized").with_status_code(401); + } + info!("authenticated {email}"); + + // Create session and add to headers + let session_id = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + let session_cookie = format!("{SESSION_COOKIE_NAME}={session_id}"); + self.session_ids.write().unwrap().insert(session_id); + + // Now back to where we wanted to go. + Response::redirect_302(self_uri(req)).with_unique_header("Set-Cookie", session_cookie) + } + + pub fn ensure_authorized(&self, req: &Request, next: N) -> Response + where + N: FnOnce(&Request) -> Response, + { + // See if we're handling an earlier auth message. + if req.url() == "/code" { + return self.process_code(req); + } + + if self.is_authorized(&req) { + // No need to do any more auth, call our normal function. + return next(req); + } + + let redirect_uri = self_uri(&req) + "/code"; + // Construct a message for OIDC. + // We omit state because CSRF attacks don't seem like a meaningful problem + // for this specific application. + let nonce = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + let params = GoogleAuthParams { + response_type: "code", + client_id: &self.oidc_client_id, + scope: "openid email", + redirect_uri: &redirect_uri, + state: &nonce, + nonce: &nonce, + }; + let encoded = serde_urlencoded::to_string(params).unwrap(); + self.nonces.lock().unwrap().insert(nonce); + + let redirect = format!("https://accounts.google.com/o/oauth2/v2/auth?{encoded}"); + Response::redirect_302(redirect) + } +} diff --git a/src/bin/cecvol.rs b/src/bin/cecvol.rs index f61cece..2350342 100644 --- a/src/bin/cecvol.rs +++ b/src/bin/cecvol.rs @@ -1,4 +1,5 @@ use cecvol::action; +use cecvol::auth; use cecvol::cec; use cecvol::lgip; use cecvol::tv; @@ -14,6 +15,8 @@ use rouille::router; use rouille::Request; use rouille::Response; use serde_json::json; +use std::collections::HashSet; +use std::fs; use std::sync::Arc; use std::sync::Mutex; @@ -212,6 +215,22 @@ struct Args { /// Server MAC address for WoL, in xx:xx:xx:xx:xx:xx form. #[arg(long, env = "SERVER_MAC_ADDR")] server_mac_addr: String, + + /// File with newline-separated tokens acceptable for Authorization header + #[arg(long, env = "AUTH_TOKEN_FILE")] + auth_token_file: Option, + + /// Permitted emails for login + #[arg(long, env = "ALLOWED_EMAILS")] + allowed_emails: Vec, + + /// Client id for OIDC login + #[arg(long, env = "OIDC_CLIENT_ID")] + oidc_client_id: Option, + + /// Client secret for OIDC login + #[arg(long, env = "OIDC_CLIENT_SECRET")] + oidc_client_secret: Option, } #[derive(Clone)] @@ -267,6 +286,37 @@ fn main() -> Result<(), Box> { for (i, s) in args.server_mac_addr.split(":").enumerate() { server_mac_addr[i] = u8::from_str_radix(s, 16)?; } + + let mut auth_tokens = HashSet::new(); + if let Some(filepath) = args.auth_token_file { + let contents = fs::read_to_string(filepath)?; + for line in contents.lines() { + if line.trim() != "" && !line.starts_with("#") { + auth_tokens.insert(line.trim().to_string()); + } + } + } + let mut allowed_emails = HashSet::new(); + for e in args.allowed_emails { + allowed_emails.insert(e); + } + + let authorizer = match (args.oidc_client_id, args.oidc_client_secret) { + (Some(oidc_client_id), Some(oidc_client_secret)) => { + info!("enforcing login"); + Some(auth::Authorizer::new( + auth_tokens, + allowed_emails, + oidc_client_id, + oidc_client_secret, + )) + } + _ => { + info!("not enforcing login"); + None + } + }; + let app_state = AppState { cec: conn, server_mac_addr, @@ -275,12 +325,18 @@ fn main() -> Result<(), Box> { info!("Starting server..."); rouille::start_server(&args.http_addr, move |request| { - let resp = router!(request, - (GET) (/) => {index()}, - (GET) (/varz) => {varz()}, - (POST) (/fulfillment) => {fulfillment(app_state.clone(), request)}, - _ => rouille::Response::empty_404() - ); + let route = |req: &Request| { + router!(req, + (GET) (/) => {index()}, + (GET) (/varz) => {varz()}, + (POST) (/fulfillment) => {fulfillment(app_state.clone(), req)}, + _ => rouille::Response::empty_404() + ) + }; + let resp = match &authorizer { + Some(a) => a.ensure_authorized(request, route), + None => route(request), + }; info!( "{request} {status}", request = request.url(), diff --git a/src/lib.rs b/src/lib.rs index d990f8f..bc0eea6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod action; +pub mod auth; pub mod cec; pub mod lgip; pub mod tv;