diff --git a/rust/.gitignore b/rust/.gitignore index a2ef7f4edf..d4661763be 100644 --- a/rust/.gitignore +++ b/rust/.gitignore @@ -1,2 +1,4 @@ target/ libsrc/ +*.rsa +*.cert diff --git a/rust/examples/openvasd/config.example.toml b/rust/examples/openvasd/config.example.toml index 4a8df6d761..1b5920a52b 100644 --- a/rust/examples/openvasd/config.example.toml +++ b/rust/examples/openvasd/config.example.toml @@ -1,6 +1,8 @@ [feed] # path to the openvas feed. This is required for the /vts endpoint. path = "/var/lib/openvas/plugins" +# disables or enables the signnature check +signature_check = true [feed.check_interval] # how often the feed should be checked for updates diff --git a/rust/examples/tls/Self-Signed mTLS Method/server_certificates.sh b/rust/examples/tls/Self-Signed mTLS Method/server_certificates.sh index 3da90b334c..db6efe2079 100644 --- a/rust/examples/tls/Self-Signed mTLS Method/server_certificates.sh +++ b/rust/examples/tls/Self-Signed mTLS Method/server_certificates.sh @@ -3,57 +3,65 @@ # # SPDX-License-Identifier: GPL-2.0-or-later -set -xe - -openssl req -nodes \ - -x509 \ - -days 3650 \ - -newkey rsa:4096 \ - -keyout ca.key \ - -out ca.cert \ - -sha256 \ - -batch \ - -subj "/CN=ponytown RSA CA" - -openssl req -nodes \ - -newkey rsa:3072 \ - -keyout inter.key \ - -out inter.req \ - -sha256 \ - -batch \ - -subj "/CN=ponytown RSA level 2 intermediate" - -openssl req -nodes \ - -newkey rsa:2048 \ - -keyout end.key \ - -out end.req \ - -sha256 \ - -batch \ - -subj "/CN=testserver.com" - -openssl rsa \ - -in end.key \ - -out server.rsa - -openssl x509 -req \ - -in inter.req \ - -out inter.cert \ - -CA ca.cert \ - -CAkey ca.key \ - -sha256 \ - -days 3650 \ - -set_serial 123 \ - -extensions v3_inter -extfile ../openssl.cnf - -openssl x509 -req \ - -in end.req \ - -out end.cert \ - -CA inter.cert \ - -CAkey inter.key \ - -sha256 \ - -days 2000 \ - -set_serial 456 \ - -extensions v3_end -extfile ../openssl.cnf - -cat end.cert inter.cert ca.cert > server.pem -rm *.key *.cert *.req +set -e +generate_certificates() +{ + out="$1" + name="$2" + printf "generating $out for $name\t" + openssl req -nodes \ + -x509 \ + -days 3650 \ + -newkey rsa:4096 \ + -keyout ca.key \ + -out ca.cert \ + -sha256 \ + -batch \ + -subj "/CN=$name RSA CA" + + openssl req -nodes \ + -newkey rsa:3072 \ + -keyout inter.key \ + -out inter.req \ + -sha256 \ + -batch \ + -subj "/CN=$name RSA level 2 intermediate" + + openssl req -nodes \ + -newkey rsa:2048 \ + -keyout end.key \ + -out end.req \ + -sha256 \ + -batch \ + -subj "/CN=testserver.com" + + openssl rsa \ + -in end.key \ + -out $out.rsa + + openssl x509 -req \ + -in inter.req \ + -out inter.cert \ + -CA ca.cert \ + -CAkey ca.key \ + -sha256 \ + -days 3650 \ + -set_serial 123 \ + -extensions v3_inter -extfile openssl.cnf + + openssl x509 -req \ + -in end.req \ + -out end.cert \ + -CA inter.cert \ + -CAkey inter.key \ + -sha256 \ + -days 2000 \ + -set_serial 456 \ + -extensions v3_end -extfile openssl.cnf + + cat end.cert inter.cert ca.cert > $out.pem + rm *.key *.cert *.req + printf "done\n" +} + +generate_certificates "server" "ponytown" diff --git a/rust/infisto/src/base.rs b/rust/infisto/src/base.rs index fa0980348f..11949a00d6 100644 --- a/rust/infisto/src/base.rs +++ b/rust/infisto/src/base.rs @@ -282,6 +282,18 @@ impl IndexedFileStorer { .map_err(Error::Remove)?; Ok(()) } + + /// Removes base dir and all its content. + /// + /// # Safety + /// Does remove the whole base dir and its content. + /// Do not use carelessly. + pub unsafe fn remove_base(self) -> Result<(), Error> { + fs::remove_dir_all(self.base) + .map_err(|e| e.kind()) + .map_err(Error::Remove) + .map(|_| ()) + } } impl IndexedByteStorage for IndexedFileStorer { diff --git a/rust/openvasd/Cargo.toml b/rust/openvasd/Cargo.toml index 40fdd9ccec..11cec1e0ff 100644 --- a/rust/openvasd/Cargo.toml +++ b/rust/openvasd/Cargo.toml @@ -20,7 +20,7 @@ serde_json = "1.0.96" serde = { version = "1.0.163", features = ["derive"] } uuid = {version = "1", features = ["v4", "fast-rng", "serde"]} hyper-rustls = "0.24.0" -rustls = "0.21.1" +rustls = {version = "0.21.1", features = ["secret_extraction", "dangerous_configuration"]} tokio-rustls = "0.24.0" futures-util = "0.3.28" rustls-pemfile = "1.0.2" diff --git a/rust/openvasd/src/config.rs b/rust/openvasd/src/config.rs index 1547d687db..b36d3685ec 100644 --- a/rust/openvasd/src/config.rs +++ b/rust/openvasd/src/config.rs @@ -188,8 +188,8 @@ impl Config { where P: AsRef + std::fmt::Display, { - let config = std::fs::read_to_string(path).unwrap_or_default(); - toml::from_str(&config).unwrap_or_default() + let config = std::fs::read_to_string(path).unwrap(); + toml::from_str(&config).unwrap() } pub fn load() -> Self { diff --git a/rust/openvasd/src/controller/context.rs b/rust/openvasd/src/controller/context.rs index 50a1e389df..a68e64aa95 100644 --- a/rust/openvasd/src/controller/context.rs +++ b/rust/openvasd/src/controller/context.rs @@ -100,7 +100,8 @@ impl ContextBuilder { if let Some(fp) = self.feed_config.as_ref() { let loader = nasl_interpreter::FSPluginLoader::new(fp.path.clone()); let dispatcher: DefaultDispatcher = DefaultDispatcher::default(); - let version = feed::version(&loader, &dispatcher).unwrap(); + let version = + feed::version(&loader, &dispatcher).unwrap_or_else(|_| String::from("UNDEFINED")); self.response.set_feed_version(&version); } self diff --git a/rust/openvasd/src/controller/entry.rs b/rust/openvasd/src/controller/entry.rs index 4ba9a1faa9..22cac6a641 100644 --- a/rust/openvasd/src/controller/entry.rs +++ b/rust/openvasd/src/controller/entry.rs @@ -6,12 +6,18 @@ //! //! All known paths must be handled in the entrypoint function. -use std::{fmt::Display, sync::Arc}; +use std::{ + fmt::Display, + sync::{Arc, RwLock}, +}; -use super::context::Context; +use super::{context::Context, ClientIdentifier}; use hyper::{Body, Method, Request, Response}; -use crate::scan::{self, Error, ScanDeleter, ScanStarter, ScanStopper}; +use crate::{ + controller::ClientHash, + scan::{self, Error, ScanDeleter, ScanStarter, ScanStopper}, +}; enum HealthOpts { /// Ready @@ -38,6 +44,10 @@ enum KnownPaths { } impl KnownPaths { + pub fn requires_id(&self) -> bool { + !matches!(self, Self::Health(_) | Self::Vts) + } + #[tracing::instrument] /// Parses a path and returns the corresponding `KnownPaths` variant. fn from_path(path: &str) -> Self { @@ -60,13 +70,20 @@ impl KnownPaths { Some("alive") => KnownPaths::Health(HealthOpts::Alive), Some("started") => KnownPaths::Health(HealthOpts::Started), _ => KnownPaths::Unknown, - } + }, _ => { tracing::trace!("Unknown path: {path}"); KnownPaths::Unknown } } } + + fn scan_id(&self) -> Option<&str> { + match self { + Self::Scans(Some(id)) | Self::ScanResults(id, _) | Self::ScanStatus(id) => Some(id), + _ => None, + } + } } impl Display for KnownPaths { @@ -89,12 +106,10 @@ impl Display for KnownPaths { } /// Is used to handle all incomng requests. -/// -/// First it will be checked if a known path is requested and if the method is supported. -/// Than corresponding functions will be called to handle the request. pub async fn entrypoint<'a, S, DB>( req: Request, ctx: Arc>, + cid: Arc>, ) -> Result, Error> where S: ScanStarter @@ -102,40 +117,65 @@ where + ScanDeleter + scan::ScanResultFetcher + std::marker::Send - + 'static - + std::marker::Sync, + + std::marker::Sync + + 'static, DB: crate::storage::Storage + std::marker::Send + 'static + std::marker::Sync, { use KnownPaths::*; // on head requests we just return an empty response without checking the api key - tracing::trace!( - "{} {}:{:?}", - req.method(), - req.uri().path(), - req.uri().query() - ); if req.method() == Method::HEAD { return Ok(ctx.response.empty(hyper::StatusCode::OK)); } let kp = KnownPaths::from_path(req.uri().path()); - if let Some(key) = ctx.api_key.as_ref() { - match req.headers().get("x-api-key") { - Some(v) if v == key => {} - Some(v) => { - tracing::debug!("{} {} invalid key: {:?}", req.method(), kp, v); - return Ok(ctx.response.unauthorized()); - } - _ => { - tracing::debug!("{} {} unauthorized", req.method(), kp); - return Ok(ctx.response.unauthorized()); + let cid: Option = { + let cid = cid.read().unwrap(); + match &*cid { + ClientIdentifier::Unknown => { + if let Some(key) = ctx.api_key.as_ref() { + match req.headers().get("x-api-key") { + Some(v) if v == key => ctx.api_key.as_ref().map(|x| x.into()), + Some(v) => { + tracing::debug!("{} {} invalid key: {:?}", req.method(), kp, v); + None + } + _ => None, + } + } else { + None + } } + ClientIdentifier::Known(cid) => Some(cid.clone()), } + }; + + if kp.requires_id() && cid.is_none() { + tracing::debug!("{} {} unauthorized", req.method(), kp); + return Ok(ctx.response.unauthorized()); } + let cid = cid.unwrap_or_default(); + if let Some(scan_id) = kp.scan_id() { + if !ctx.db.is_client_allowed(scan_id.to_owned(), &cid).await? { + tracing::debug!( + "client {:x?} is not allowed to operate on scan {} ", + &cid.0, + scan_id + ); + // we return 404 instead of 401 to not leak any ids + return Ok(ctx.response.not_found("scans", scan_id)); + } + } + + tracing::debug!( + "{} {}:{:?}", + req.method(), + req.uri().path(), + req.uri().query(), + ); match (req.method(), kp) { - (&Method::GET, Health(HealthOpts::Alive)) | - (&Method::GET, Health(HealthOpts::Started)) => - Ok(ctx.response.empty(hyper::StatusCode::OK)), + (&Method::GET, Health(HealthOpts::Alive)) | (&Method::GET, Health(HealthOpts::Started)) => { + Ok(ctx.response.empty(hyper::StatusCode::OK)) + } (&Method::GET, Health(HealthOpts::Ready)) => { let oids = ctx.db.oids().await?; if oids.count() == 0 { @@ -156,6 +196,7 @@ where let resp = ctx.response.created(&id); scan.scan_id = Some(id.clone()); ctx.db.insert_scan(scan).await?; + ctx.db.add_scan_client_id(id.clone(), cid).await?; tracing::debug!("Scan with ID {} created", &id); Ok(resp) } @@ -202,7 +243,7 @@ where } (&Method::GET, Scans(None)) => { if ctx.enable_get_scans { - match ctx.db.get_scan_ids().await { + match ctx.db.get_scans_of_client_id(&cid).await { Ok(scans) => Ok(ctx.response.ok(&scans)), Err(e) => Ok(ctx.response.internal_server_error(&e)), } @@ -226,7 +267,8 @@ where ctx.scanner.stop_scan(id.clone()).await?; } ctx.db.remove_scan(&id).await?; - ctx.scanner.delete_scan(id).await?; + ctx.scanner.delete_scan(id.clone()).await?; + ctx.db.remove_scan_id(id).await?; Ok(ctx.response.no_content()) } Err(crate::storage::Error::NotFound) => Ok(ctx.response.not_found("scans", &id)), diff --git a/rust/openvasd/src/controller/mod.rs b/rust/openvasd/src/controller/mod.rs index fb9ec3e864..7942a91c0f 100644 --- a/rust/openvasd/src/controller/mod.rs +++ b/rust/openvasd/src/controller/mod.rs @@ -20,40 +20,56 @@ pub(crate) fn quit_on_poison() -> T { /// Combines all traits needed for a scanner. pub trait Scanner: ScanStarter + ScanStopper + ScanDeleter + ScanResultFetcher {} -impl Scanner for T where T: ScanStarter + ScanStopper + ScanDeleter + ScanResultFetcher {} - -macro_rules! make_svc { - ($controller:expr) => {{ - // start background service - use std::sync::Arc; +#[derive(Clone, Default, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct ClientHash([u8; 32]); - tokio::spawn(crate::controller::results::fetch(Arc::clone(&$controller))); - tokio::spawn(crate::controller::feed::fetch(Arc::clone(&$controller))); +impl From for ClientHash +where + T: AsRef<[u8]>, +{ + fn from(value: T) -> Self { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(value); + let hash = hasher.finalize(); + Self(hash.into()) + } +} - use hyper::service::{make_service_fn, service_fn}; - make_service_fn(|_conn| { - let controller = Arc::clone($controller); - async { - Ok::<_, crate::scan::Error>(service_fn(move |req| { - crate::controller::entrypoint(req, Arc::clone(&controller)) - })) - } - }) - }}; +/// Contains information about an authorization model of a connection (e.g. mtls) +#[derive(Default, Debug)] +pub enum ClientIdentifier { + /// When there in no information available + #[default] + Unknown, + /// Contains a hashed number of an identifier + /// + /// openvasd uses the identifier as a key for results. This key is usually calculated by an + /// subject of a known client certificate. Based on that we don't need more information. + Known(ClientHash), +} +/// Is used to transfer the information if there is an identifier present within the connection +pub trait ClientGossiper { + /// Gets the identifier + /// + /// Based on the concurrent nature, the actual information is boxed within a Arc and locked for + /// concurrent read write accesses. + fn client_identifier(&self) -> &std::sync::Arc>; } -pub(crate) use make_svc; +impl Scanner for T where T: ScanStarter + ScanStopper + ScanDeleter + ScanResultFetcher {} #[cfg(test)] mod tests { use super::context::Context; use super::entry::entrypoint; use crate::{ - controller::{ContextBuilder, NoOpScanner}, + controller::{ClientIdentifier, ContextBuilder, NoOpScanner}, storage::file, }; use async_trait::async_trait; use hyper::{Body, Method, Request, Response}; + use infisto::base::IndexedFileStorer; use std::sync::{Arc, RwLock}; #[derive(Debug, Clone)] @@ -142,7 +158,8 @@ mod tests { .method(Method::HEAD) .body(Body::empty()) .unwrap(); - let resp = entrypoint(req, Arc::clone(&controller)).await.unwrap(); + let cid = Arc::new(RwLock::new(ClientIdentifier::Unknown)); + let resp = entrypoint(req, Arc::clone(&controller), cid).await.unwrap(); assert_eq!(resp.headers().get("api-version").unwrap(), "1"); assert_eq!(resp.headers().get("authentication").unwrap(), ""); } @@ -157,7 +174,9 @@ mod tests { .method(Method::GET) .body(Body::empty()) .unwrap(); - entrypoint(req, Arc::clone(&ctx)).await.unwrap() + + let cid = Arc::new(RwLock::new(ClientIdentifier::Known("42".into()))); + entrypoint(req, Arc::clone(&ctx), cid).await.unwrap() } async fn get_scan(id: &str, ctx: Arc>) -> Response @@ -170,7 +189,8 @@ mod tests { .method(Method::GET) .body(Body::empty()) .unwrap(); - entrypoint(req, Arc::clone(&ctx)).await.unwrap() + let cid = Arc::new(RwLock::new(ClientIdentifier::Known("42".into()))); + entrypoint(req, Arc::clone(&ctx), cid).await.unwrap() } async fn post_scan(scan: &models::Scan, ctx: Arc>) -> Response @@ -183,7 +203,8 @@ mod tests { .method(Method::POST) .body(serde_json::to_string(&scan).unwrap().into()) .unwrap(); - entrypoint(req, Arc::clone(&ctx)).await.unwrap() + let cid = Arc::new(RwLock::new(ClientIdentifier::Known("42".into()))); + entrypoint(req, Arc::clone(&ctx), cid).await.unwrap() } async fn start_scan(id: &str, ctx: Arc>) -> Response @@ -199,7 +220,8 @@ mod tests { .method(Method::POST) .body(serde_json::to_string(action).unwrap().into()) .unwrap(); - entrypoint(req, Arc::clone(&ctx)).await.unwrap() + let cid = Arc::new(RwLock::new(ClientIdentifier::Known("42".into()))); + entrypoint(req, Arc::clone(&ctx), cid).await.unwrap() } async fn post_scan_id(scan: &models::Scan, ctx: Arc>) -> String @@ -216,8 +238,10 @@ mod tests { #[tokio::test] async fn add_scan_with_id_fails() { - let mut scan: models::Scan = models::Scan::default(); - scan.scan_id = Some(String::new()); + let scan: models::Scan = models::Scan { + scan_id: Some(String::new()), + ..Default::default() + }; let ctx = Arc::new(Context::default()); let resp = post_scan(&scan, Arc::clone(&ctx)).await; assert_eq!(resp.status(), hyper::http::StatusCode::BAD_REQUEST); @@ -235,7 +259,8 @@ mod tests { .method(Method::DELETE) .body(Body::empty()) .unwrap(); - entrypoint(req, Arc::clone(&controller)).await.unwrap(); + let cid = Arc::new(RwLock::new(ClientIdentifier::Known("42".into()))); + entrypoint(req, Arc::clone(&controller), cid).await.unwrap(); let resp = get_scan(&id, Arc::clone(&controller)).await; assert_eq!(resp.status(), 404); } @@ -267,17 +292,19 @@ mod tests { .method(Method::GET) .body(Body::empty()) .unwrap(); - let resp = entrypoint(req, Arc::clone(&ctx)).await.unwrap(); + let cid = Arc::new(RwLock::new(ClientIdentifier::Known("42".into()))); + let resp = entrypoint(req, Arc::clone(&ctx), cid).await.unwrap(); let resp = hyper::body::to_bytes(resp.into_body()).await.unwrap(); - let resp = serde_json::from_slice::>(&resp).unwrap(); - resp + + serde_json::from_slice::>(&resp).unwrap() } let scan: models::Scan = models::Scan::default(); let scanner = FakeScanner { count: Arc::new(RwLock::new(0)), }; let ns = std::time::Duration::from_nanos(10); - let storage = file::unencrypted("/tmp/aha").unwrap(); + let root = "/tmp/openvasd/fetch_results"; + let storage = file::unencrypted(root).unwrap(); let ctx = ContextBuilder::new() .result_config(ns) .storage(storage) @@ -321,6 +348,12 @@ mod tests { for (i, r) in resp.iter().enumerate() { assert_eq!(r.id, i + 4900); } + unsafe { + IndexedFileStorer::init(root) + .unwrap() + .remove_base() + .unwrap() + }; } #[tokio::test] @@ -332,6 +365,17 @@ mod tests { .build(); let controller = Arc::new(ctx); let resp = post_scan(&scan, Arc::clone(&controller)).await; + + assert_eq!(resp.status(), 201); + + let req = Request::builder() + .uri("/scans") + .method(Method::POST) + .body(serde_json::to_string(&scan).unwrap().into()) + .unwrap(); + let cid = Arc::new(RwLock::new(ClientIdentifier::Unknown)); + let resp = entrypoint(req, Arc::clone(&controller), cid).await.unwrap(); + assert_eq!(resp.status(), 401); let req = Request::builder() .uri("/scans") @@ -339,7 +383,8 @@ mod tests { .method(Method::POST) .body(serde_json::to_string(&scan).unwrap().into()) .unwrap(); - let resp = entrypoint(req, Arc::clone(&controller)).await.unwrap(); + let cid = Arc::new(RwLock::new(ClientIdentifier::Unknown)); + let resp = entrypoint(req, Arc::clone(&controller), cid).await.unwrap(); assert_eq!(resp.status(), 201); } } diff --git a/rust/openvasd/src/main.rs b/rust/openvasd/src/main.rs index ac8cd86c66..ce7eb537a5 100644 --- a/rust/openvasd/src/main.rs +++ b/rust/openvasd/src/main.rs @@ -2,23 +2,104 @@ // // SPDX-License-Identifier: GPL-2.0-or-later -mod config; -mod controller; -mod crypt; -mod feed; -mod request; -mod response; -mod scan; -mod storage; -mod tls; - -pub async fn run<'a, DB>( +use controller::ClientGossiper; +use futures_util::ready; + +pub mod config; +pub mod controller; +pub mod crypt; +pub mod feed; +pub mod request; +pub mod response; +pub mod scan; +pub mod storage; +pub mod tls; + +struct AddrIncomingWrapper(hyper::server::conn::AddrIncoming); + +impl hyper::server::accept::Accept for AddrIncomingWrapper { + type Conn = AddrStreamWrapper; + type Error = std::io::Error; + + fn poll_accept( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>> { + use core::task::Poll; + use std::pin::Pin; + let pin = self.get_mut(); + match ready!(Pin::new(&mut pin.0).poll_accept(cx)) { + Some(Ok(sock)) => std::task::Poll::Ready(Some(Ok(AddrStreamWrapper::new(sock)))), + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} +struct AddrStreamWrapper( + hyper::server::conn::AddrStream, + std::sync::Arc>, +); +impl AddrStreamWrapper { + fn new(sock: hyper::server::conn::AddrStream) -> AddrStreamWrapper { + AddrStreamWrapper( + sock, + std::sync::Arc::new(std::sync::RwLock::new( + controller::ClientIdentifier::Unknown, + )), + ) + } +} + +impl ClientGossiper for AddrStreamWrapper { + fn client_identifier( + &self, + ) -> &std::sync::Arc> { + &self.1 + } +} + +impl tokio::io::AsyncRead for AddrStreamWrapper { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let pin = self.get_mut(); + std::pin::Pin::new(&mut pin.0).poll_read(cx, buf) + } +} + +impl tokio::io::AsyncWrite for AddrStreamWrapper { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let pin = self.get_mut(); + std::pin::Pin::new(&mut pin.0).poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let pin = self.get_mut(); + std::pin::Pin::new(&mut pin.0).poll_flush(cx) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let pin = self.get_mut(); + std::pin::Pin::new(&mut pin.0).poll_shutdown(cx) + } +} + +fn create_context( db: DB, - config: config::Config, -) -> Result<(), Box> -where - DB: crate::storage::Storage + std::marker::Send + 'static + std::marker::Sync, -{ + config: &config::Config, +) -> controller::Context { let scanner = scan::OSPDWrapper::new(config.ospd.socket.clone(), config.ospd.read_timeout); let rc = config.ospd.result_check_interval; let fc = ( @@ -26,30 +107,81 @@ where config.feed.check_interval, config.feed.signature_check, ); - let ctx = controller::ContextBuilder::new() + controller::ContextBuilder::new() .result_config(rc) .feed_config(fc) .scanner(scanner) .api_key(config.endpoints.key.clone()) .enable_get_scans(config.endpoints.enable_get_scans) .storage(db) - .build(); + .build() +} + +async fn serve<'a, S, DB, I>( + ctx: controller::Context, + inc: I, +) -> Result<(), Box> +where + S: scan::ScanStarter + + scan::ScanStopper + + scan::ScanDeleter + + scan::ScanResultFetcher + + std::marker::Send + + std::marker::Sync + + 'static, + I: hyper::server::accept::Accept, + I::Error: Into>, + I::Conn: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + ClientGossiper + 'static, + DB: crate::storage::Storage + std::marker::Send + 'static + std::marker::Sync, +{ let controller = std::sync::Arc::new(ctx); + let make_svc = { + use std::sync::Arc; + + tokio::spawn(crate::controller::results::fetch(Arc::clone(&controller))); + tokio::spawn(crate::controller::feed::fetch(Arc::clone(&controller))); + + use hyper::service::{make_service_fn, service_fn}; + make_service_fn(|conn| { + let controller = Arc::clone(&controller); + let conn = conn as &dyn ClientGossiper; + let cis = Arc::clone(conn.client_identifier()); + async { + Ok::<_, crate::scan::Error>(service_fn(move |req| { + controller::entrypoint(req, Arc::clone(&controller), cis.clone()) + })) + } + }) + }; + let server = hyper::Server::builder(inc).serve(make_svc); + server.await?; + Ok(()) +} + +pub async fn run<'a, S, DB>( + ctx: controller::Context, + config: &config::Config, +) -> Result<(), Box> +where + S: scan::ScanStarter + + scan::ScanStopper + + scan::ScanDeleter + + scan::ScanResultFetcher + + std::marker::Send + + std::marker::Sync + + 'static, + DB: crate::storage::Storage + std::marker::Send + 'static + std::marker::Sync, +{ let addr = config.listener.address; let incoming = hyper::server::conn::AddrIncoming::bind(&addr)?; let addr = incoming.local_addr(); - - if let Some(tlsc) = tls::tls_config(&config)? { - tracing::trace!("TLS enabled"); - let make_svc = crate::controller::make_svc!(&controller); - let server = hyper::Server::builder(tls::TlsAcceptor::new(tlsc, incoming)).serve(make_svc); + if let Some((roots, certs, key)) = tls::tls_config(config)? { tracing::info!("listening on https://{}", addr); - server.await?; + let inc = tls::TlsAcceptor::new(roots, certs, key, incoming); + serve(ctx, inc).await?; } else { - let make_svc = crate::controller::make_svc!(&controller); - let server = hyper::Server::builder(incoming).serve(make_svc); tracing::info!("listening on http://{}", addr); - server.await?; + serve(ctx, AddrIncomingWrapper(incoming)).await?; } Ok(()) } @@ -59,7 +191,8 @@ async fn main() -> Result<(), Box> { let config = config::Config::load(); let filter = tracing_subscriber::EnvFilter::builder() .with_default_directive(tracing::metadata::LevelFilter::INFO.into()) - .parse_lossy(config.log.level.clone()); + .parse_lossy(format!("info,openvasd={}", &config.log.level)); + tracing::debug!("config: {:?}", config); tracing_subscriber::fmt().with_env_filter(filter).init(); if !config.ospd.socket.exists() { tracing::warn!("OSPD socket {} does not exist. Some commands will not work until the socket is created!", config.ospd.socket.display()); @@ -67,23 +200,30 @@ async fn main() -> Result<(), Box> { match config.storage.storage_type { config::StorageType::InMemory => { tracing::info!("using in memory store. No sensitive data will be stored on disk."); - run(storage::inmemory::Storage::default(), config).await + + let ctx = create_context(storage::inmemory::Storage::default(), &config); + run(ctx, &config).await } config::StorageType::FileSystem => { if let Some(key) = &config.storage.fs.key { tracing::info!( "using in file storage. Sensitive data will be encrypted stored on disk." ); - run( + + let ctx = create_context( storage::file::encrypted(&config.storage.fs.path, key)?, - config, - ) - .await + &config, + ); + run(ctx, &config).await } else { tracing::warn!( "using in file storage. Sensitive data will be stored on disk without any encryption." ); - run(storage::file::unencrypted(&config.storage.fs.path)?, config).await + let ctx = create_context( + storage::file::unencrypted(&config.storage.fs.path)?, + &config, + ); + run(ctx, &config).await } } } diff --git a/rust/openvasd/src/storage/file.rs b/rust/openvasd/src/storage/file.rs index 015df5560f..04b65a02af 100644 --- a/rust/openvasd/src/storage/file.rs +++ b/rust/openvasd/src/storage/file.rs @@ -257,6 +257,86 @@ where } } +#[async_trait] +impl ScanIDClientMapper for Storage +where + S: infisto::base::IndexedByteStorage + std::marker::Sync + std::marker::Send + Clone + 'static, +{ + async fn add_scan_client_id( + &self, + scan_id: String, + client_id: ClientHash, + ) -> Result<(), Error> { + let key = "idmap"; + let storage = Arc::clone(&self.storage); + + tokio::task::spawn_blocking(move || { + let idt = infisto::serde::Serialization::serialize((client_id, scan_id))?; + let mut storage = storage.write().unwrap(); + storage.append(key, idt)?; + Ok(()) + }) + .await + .unwrap() + } + async fn remove_scan_id(&self, scan_id: I) -> Result<(), Error> + where + I: AsRef + Send + 'static, + { + let key = "idmap"; + let storage = Arc::clone(&self.storage); + + tokio::task::spawn_blocking(move || { + use infisto::serde::Serialization; + let mut storage = storage.write().unwrap(); + let sid = scan_id.as_ref(); + + let ids: Vec> = + storage.by_range(key, infisto::base::Range::All)?; + let new: Vec> = ids + .into_iter() + .map(|x| x.deserialize()) + .filter_map(|x| x.ok()) + .filter(|(_, x)| x != sid) + .map(infisto::serde::Serialization::serialize) + .filter_map(|x| x.ok()) + .collect(); + + storage.remove(key)?; + storage.append_all(key, &new)?; + Ok(()) + }) + .await + .unwrap() + } + + async fn get_scans_of_client_id(&self, client_id: &ClientHash) -> Result, Error> { + let key = "idmap"; + let storage = Arc::clone(&self.storage); + let client_id = client_id.clone(); + + tokio::task::spawn_blocking(move || { + use infisto::serde::Serialization; + let storage = storage.read().unwrap(); + + let ids: Vec> = + match storage.by_range(key, infisto::base::Range::All) { + Ok(x) => x, + Err(_) => vec![], + }; + let new: Vec = ids + .into_iter() + .map(|x| x.deserialize()) + .filter_map(|x| x.ok()) + .filter(|(x, _)| x == &client_id) + .map(|(_, x)| x) + .collect(); + Ok(new) + }) + .await + .unwrap() + } +} #[async_trait] impl OIDStorer for Storage where @@ -308,6 +388,7 @@ where #[cfg(test)] mod tests { + use infisto::base::IndexedByteStorage; use models::Scan; use super::*; @@ -393,8 +474,10 @@ mod tests { async fn file_storage_test() { let mut scans = Vec::with_capacity(100); for i in 0..100 { - let mut scan = Scan::default(); - scan.scan_id = Some(i.to_string()); + let scan = Scan { + scan_id: Some(i.to_string()), + ..Default::default() + }; scans.push(scan); } @@ -419,8 +502,10 @@ mod tests { .await .unwrap(); - let mut status = models::Status::default(); - status.status = models::Phase::Running; + let status = models::Status { + status: models::Phase::Running, + ..Default::default() + }; let results = vec![models::Result::default()]; storage @@ -444,4 +529,50 @@ mod tests { let ids = storage.get_scan_ids().await.unwrap(); assert_eq!(0, ids.len()); } + + #[tokio::test] + async fn id_mapper() { + let storage = + infisto::base::CachedIndexFileStorer::init("/tmp/openvasd/file_storage_id_mapper_test") + .unwrap(); + + let storage = crate::storage::file::Storage::new(storage); + storage + .add_scan_client_id("s1".to_owned(), "0".into()) + .await + .unwrap(); + storage + .add_scan_client_id("s2".to_owned(), "0".into()) + .await + .unwrap(); + storage + .add_scan_client_id("s3".to_owned(), "0".into()) + .await + .unwrap(); + storage + .add_scan_client_id("s4".to_owned(), "1".into()) + .await + .unwrap(); + assert_eq!( + storage.get_scans_of_client_id(&"0".into()).await.unwrap(), + vec!["s1", "s2", "s3"] + ); + assert_eq!( + storage.get_scans_of_client_id(&"1".into()).await.unwrap(), + vec!["s4"] + ); + storage.remove_scan_id("s2").await.unwrap(); + assert_eq!( + storage.get_scans_of_client_id(&"0".into()).await.unwrap(), + vec!["s1", "s3"] + ); + assert!(!storage.is_client_allowed("s1", &"1".into()).await.unwrap()); + assert!(storage.is_client_allowed("s4", &"1".into()).await.unwrap()); + + let mut storage = + infisto::base::IndexedFileStorer::init("/tmp/openvasd/file_storage_id_mapper_test") + .unwrap(); + let key = "idmap"; + storage.remove(key).unwrap(); + } } diff --git a/rust/openvasd/src/storage/inmemory.rs b/rust/openvasd/src/storage/inmemory.rs index b206bad154..714355f514 100644 --- a/rust/openvasd/src/storage/inmemory.rs +++ b/rust/openvasd/src/storage/inmemory.rs @@ -18,6 +18,7 @@ pub struct Storage { scans: RwLock>, oids: RwLock>, hash: RwLock, + client_id: RwLock>, crypter: E, } @@ -31,6 +32,7 @@ where scans: RwLock::new(HashMap::new()), oids: RwLock::new(vec![]), hash: RwLock::new(String::new()), + client_id: RwLock::new(vec![]), crypter, } } @@ -63,6 +65,50 @@ impl Default for Storage { } } +#[async_trait] +impl ScanIDClientMapper for Storage +where + E: crate::crypt::Crypt + Send + Sync + 'static, +{ + async fn add_scan_client_id( + &self, + scan_id: String, + client_id: ClientHash, + ) -> Result<(), Error> { + let mut ids = self.client_id.write().await; + ids.push((client_id, scan_id)); + + Ok(()) + } + + async fn remove_scan_id(&self, scan_id: I) -> Result<(), Error> + where + I: AsRef + Send + 'static, + { + let mut ids = self.client_id.write().await; + let ssid = scan_id.as_ref(); + let mut to_remove = vec![]; + for (i, (_, sid)) in ids.iter().enumerate() { + if sid == ssid { + to_remove.push(i); + } + } + for i in to_remove { + ids.remove(i); + } + + Ok(()) + } + + async fn get_scans_of_client_id(&self, client_id: &ClientHash) -> Result, Error> { + let ids = self.client_id.read().await; + Ok(ids + .iter() + .filter(|(cid, _)| cid == client_id) + .map(|(_, s)| s.to_owned()) + .collect()) + } +} #[async_trait] impl ScanStorer for Storage where @@ -214,6 +260,40 @@ mod tests { use super::*; + #[tokio::test] + async fn id_mapper() { + let storage = Storage::default(); + storage + .add_scan_client_id("s1".to_owned(), "0".into()) + .await + .unwrap(); + storage + .add_scan_client_id("s2".to_owned(), "0".into()) + .await + .unwrap(); + storage + .add_scan_client_id("s3".to_owned(), "0".into()) + .await + .unwrap(); + storage + .add_scan_client_id("s4".to_owned(), "1".into()) + .await + .unwrap(); + assert_eq!( + storage.get_scans_of_client_id(&"0".into()).await.unwrap(), + vec!["s1", "s2", "s3"] + ); + assert_eq!( + storage.get_scans_of_client_id(&"1".into()).await.unwrap(), + vec!["s4"] + ); + storage.remove_scan_id("s2").await.unwrap(); + assert_eq!( + storage.get_scans_of_client_id(&"0".into()).await.unwrap(), + vec!["s1", "s3"] + ); + } + #[tokio::test] async fn store_delete_scan() { let storage = Storage::default(); @@ -229,10 +309,12 @@ mod tests { async fn encrypt_decrypt_passwords() { let storage = Storage::default(); let mut scan = Scan::default(); - let mut pw = models::Credential::default(); - pw.credential_type = models::CredentialType::UP { - username: "test".to_string(), - password: "test".to_string(), + let pw = models::Credential { + credential_type: models::CredentialType::UP { + username: "test".to_string(), + password: "test".to_string(), + }, + ..Default::default() }; scan.target.credentials = vec![pw]; diff --git a/rust/openvasd/src/storage/mod.rs b/rust/openvasd/src/storage/mod.rs index ec6602650a..778a623573 100644 --- a/rust/openvasd/src/storage/mod.rs +++ b/rust/openvasd/src/storage/mod.rs @@ -4,7 +4,7 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; -use crate::{crypt, scan::FetchResult}; +use crate::{controller::ClientHash, crypt, scan::FetchResult}; #[derive(Debug)] pub enum Error { @@ -43,6 +43,31 @@ impl From for Error { } } +#[async_trait] +pub trait ScanIDClientMapper { + async fn add_scan_client_id(&self, scan_id: String, client_id: ClientHash) + -> Result<(), Error>; + async fn remove_scan_id(&self, scan_id: I) -> Result<(), Error> + where + I: AsRef + Send + 'static; + + async fn get_scans_of_client_id(&self, client_id: &ClientHash) -> Result, Error>; + + async fn is_client_allowed(&self, scan_id: I, client_id: &ClientHash) -> Result + where + I: AsRef + Send + 'static, + { + let scans = self.get_scans_of_client_id(client_id).await?; + let sid = scan_id.as_ref(); + for id in scans { + if id == sid { + return Ok(true); + } + } + Ok(false) + } +} + #[async_trait] /// A trait for getting the progress of a scan, the scan itself with decrypted credentials and /// encrypted as well as results. @@ -113,7 +138,13 @@ pub trait AppendFetchResult { #[async_trait] /// Combines the traits `ProgressGetter`, `ScanStorer` and `AppendFetchResult`. -pub trait Storage: ProgressGetter + ScanStorer + AppendFetchResult + OIDStorer {} +pub trait Storage: + ProgressGetter + ScanStorer + AppendFetchResult + OIDStorer + ScanIDClientMapper +{ +} #[async_trait] -impl Storage for T where T: ProgressGetter + ScanStorer + AppendFetchResult + OIDStorer {} +impl Storage for T where + T: ProgressGetter + ScanStorer + AppendFetchResult + OIDStorer + ScanIDClientMapper +{ +} diff --git a/rust/openvasd/src/tls.rs b/rust/openvasd/src/tls.rs index 46e0663758..a595599b6c 100644 --- a/rust/openvasd/src/tls.rs +++ b/rust/openvasd/src/tls.rs @@ -29,17 +29,20 @@ use core::task::{Context, Poll}; use futures_util::{ready, Future}; use hyper::server::accept::Accept; use hyper::server::conn::{AddrIncoming, AddrStream}; -use rustls::server::{AllowAnyAuthenticatedClient, NoClientAuth}; +use rustls::server::{AllowAnyAuthenticatedClient, ClientCertVerifier}; use rustls::RootCertStore; use rustls_pemfile::{read_one, Item}; -use std::path::{Path, PathBuf}; + +use std::path::Path; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use std::{fs, io}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::rustls::ServerConfig; +use crate::controller::{ClientGossiper, ClientIdentifier}; + enum State { Handshaking(tokio_rustls::Accept), Streaming(tokio_rustls::server::TlsStream), @@ -53,17 +56,38 @@ enum State { /// On streaming the connection will be read/written. pub struct TlsStream { state: State, + pub client_identifier: Arc>, } impl TlsStream { - fn new(stream: AddrStream, config: Arc) -> TlsStream { + fn new( + stream: AddrStream, + + roots: RootCertStore, + certs: Vec, + key: rustls::PrivateKey, + ) -> TlsStream { + let client_identifier = Arc::new(RwLock::new(ClientIdentifier::default())); + let inner = AllowAnyAuthenticatedClient::new(roots); + let verifier = ClientSnitch::new(inner, client_identifier.clone()).boxed(); + let config: Arc = Arc::new(server_config(verifier, certs, key).unwrap()); + let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); TlsStream { state: State::Handshaking(accept), + client_identifier, } } } +impl ClientGossiper for TlsStream { + fn client_identifier( + &self, + ) -> &std::sync::Arc> { + &self.client_identifier + } +} + impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, @@ -122,13 +146,25 @@ impl AsyncWrite for TlsStream { /// Handles the actual tls connection based on the given address and config. pub struct TlsAcceptor { - config: Arc, + roots: RootCertStore, + certs: Vec, + key: rustls::PrivateKey, incoming: AddrIncoming, } impl TlsAcceptor { - pub fn new(config: Arc, incoming: AddrIncoming) -> TlsAcceptor { - TlsAcceptor { config, incoming } + pub fn new( + roots: RootCertStore, + certs: Vec, + key: rustls::PrivateKey, + incoming: AddrIncoming, + ) -> TlsAcceptor { + TlsAcceptor { + roots, + certs, + key, + incoming, + } } } @@ -142,14 +178,95 @@ impl Accept for TlsAcceptor { ) -> Poll>> { let pin = self.get_mut(); match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { - Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), + Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new( + sock, + pin.roots.clone(), + pin.certs.clone(), + pin.key.clone(), + )))), Some(Err(e)) => Poll::Ready(Some(Err(e))), None => Poll::Ready(None), } } } -/// Creates a rustls ServerConfig based on the given config. +struct ClientSnitch { + inner: AllowAnyAuthenticatedClient, + client_identifier: Arc>, +} + +impl ClientSnitch { + /// Construct a new `AllowAnyAnonymousOrAuthenticatedClient`. + /// + /// `roots` is the list of trust anchors to use for certificate validation. + pub fn new( + inner: AllowAnyAuthenticatedClient, + client_identifier: Arc>, + ) -> Self { + Self { + inner, + client_identifier, + } + } + + /// Update the verifier to validate client certificates against the provided DER format + /// unparsed certificate revocation lists (CRLs). + #[allow(dead_code)] + pub fn with_crls( + self, + crls: impl IntoIterator, + client_identifier: Arc>, + ) -> Result { + // This function is needed to keep it functioning like the original verifier. + Ok(Self { + inner: self.inner.with_crls(crls)?, + client_identifier, + }) + } + + /// Wrap this verifier in an [`Arc`] and coerce it to `dyn ClientCertVerifier` + #[inline(always)] + pub fn boxed(self) -> Arc { + // This function is needed to keep it functioning like the original verifier. + Arc::new(self) + } +} + +impl rustls::server::ClientCertVerifier for ClientSnitch { + fn offer_client_auth(&self) -> bool { + self.inner.offer_client_auth() + } + + fn client_auth_mandatory(&self) -> bool { + false + } + + fn client_auth_root_subjects(&self) -> &[rustls::DistinguishedName] { + self.inner.client_auth_root_subjects() + } + + fn verify_client_cert( + &self, + end_entity: &rustls::Certificate, + intermediates: &[rustls::Certificate], + now: std::time::SystemTime, + ) -> Result { + match self + .inner + .verify_client_cert(end_entity, intermediates, now) + { + Ok(r) => { + let mut ci = self.client_identifier.write().unwrap(); + *ci = ClientIdentifier::Known(end_entity.into()); + Ok(r) + } + Err(_) => todo!(), + } + } +} +/// Data required to create a TlsConfig +type TlsData = (RootCertStore, Vec, rustls::PrivateKey); +/// Creates a root cert store, the certificates and a private key so that a tls configuration can be created. /// /// When the tls certificate cannot be loaded it will return None. /// When client certificates are provided it will return a ServerConfig with @@ -157,60 +274,34 @@ impl Accept for TlsAcceptor { /// client authentication. pub fn tls_config( config: &crate::config::Config, -) -> Result>, Box> { +) -> Result, Box> { if let Some(certs_path) = &config.tls.certs { match load_certs(certs_path) { Ok(certs) => { if let Some(key_path) = &config.tls.key { let key = load_private_key(key_path)?; - let verifier = { - if let Some(client_certs_dir) = &config.tls.client_certs { - let client_certs: Vec = std::fs::read_dir(client_certs_dir)? - .filter_map(|entry| { - let entry = entry.ok()?; - let file_type = entry.file_type().ok()?; - if file_type.is_file() - || file_type.is_symlink() && !file_type.is_dir() - { - Some(entry.path()) - } else { - None - } - }) - .collect(); - if client_certs.is_empty() { - tracing::info!( - "no client certs found, starting without certificate based client auth" - ); - NoClientAuth::boxed() + let client_certs = if let Some(cpath) = &config.tls.client_certs { + let rd = std::fs::read_dir(cpath)?; + rd.filter_map(|entry| { + let entry = entry.ok()?; + let file_type = entry.file_type().ok()?; + if file_type.is_file() || file_type.is_symlink() && !file_type.is_dir() + { + Some(entry.path()) } else { - tracing::info!( - "client certs found, starting with certificate based client auth" - ); - let mut client_auth_roots = RootCertStore::empty(); - for root in client_certs.iter().flat_map(load_certs).flatten() { - client_auth_roots.add(&root)?; - } - AllowAnyAuthenticatedClient::new(client_auth_roots).boxed() + None } - } else { - tracing::info!( - "no client certs found, starting without certificate based client auth" - ); - NoClientAuth::boxed() - } + }) + .collect() + } else { + vec![] }; + let mut roots = RootCertStore::empty(); + for root in client_certs.iter().flat_map(load_certs).flatten() { + roots.add(&root)?; + } - let mut cfg = rustls::ServerConfig::builder() - .with_safe_defaults() - //.with_client_cert_verifier() - .with_client_cert_verifier(verifier) - .with_single_cert(certs, key) - .map_err(|e| error(format!("{}", e)))?; - // Configure ALPN to accept HTTP/2, HTTP/1.1, and HTTP/1.0 in that order. - cfg.alpn_protocols = - vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; - Ok(Some(std::sync::Arc::new(cfg))) + Ok(Some((roots, certs, key))) } else { Err(error("TLS enabled, but private key is missing".to_string()).into()) } @@ -223,6 +314,20 @@ pub fn tls_config( } } +fn server_config( + verifier: Arc, + certs: Vec, + key: rustls::PrivateKey, +) -> Result> { + let mut cfg = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(verifier) + .with_single_cert(certs, key) + .map_err(|e| error(format!("{}", e)))?; + cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + Ok(cfg) +} + fn error(err: String) -> io::Error { io::Error::new(io::ErrorKind::Other, err) }