Skip to content

Commit

Permalink
Fix: mark scan as failed on fetch result errors
Browse files Browse the repository at this point in the history
Updated error handling in `get_results` to mark scans as failed if errors
occur during result fetch.

Updated OpenAPI spec to reflect `206` response for partial results in case
of errors.

Added `LambdaScannerBuilder` for testing error scenarios in fetch results.
  • Loading branch information
nichtsfrei committed Nov 4, 2024
1 parent fa73486 commit 011d61f
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 33 deletions.
1 change: 0 additions & 1 deletion rust/doc/openapi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ paths:
description: "Schema of a list of results response"
get results 0-3:
$ref: "#/components/examples/scan_results"

"400":
description: "Bad range format"
"404":
Expand Down
91 changes: 65 additions & 26 deletions rust/src/openvasd/controller/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::{fmt::Display, marker::PhantomData, sync::Arc};

use super::{context::Context, ClientIdentifier};

use http::StatusCode;
use hyper::{Method, Request};
use scannerlib::models::scanner::{ScanDeleter, ScanResultFetcher, ScanStarter, ScanStopper};
use scannerlib::models::{scanner::*, Action, Phase, Scan, ScanAction};
Expand Down Expand Up @@ -192,7 +193,7 @@ where
let kp = KnownPaths::from_path(req.uri().path(), &ctx.mode);
// on head requests we just return an empty response, except for /scans
if req.method() == Method::HEAD && kp != KnownPaths::Scans(None) {
return Ok(ctx.response.empty(hyper::StatusCode::OK));
return Ok(ctx.response.empty(StatusCode::OK));
}
let cid: Option<ClientHash> = {
match &*cid {
Expand Down Expand Up @@ -259,27 +260,25 @@ where
"process call",
);
match (req.method(), kp) {
(&Method::HEAD, Scans(None)) => {
Ok(ctx.response.empty(hyper::StatusCode::NO_CONTENT))
}
(&Method::HEAD, Scans(None)) => Ok(ctx.response.empty(StatusCode::NO_CONTENT)),
(&Method::GET, Health(HealthOpts::Alive))
| (&Method::GET, Health(HealthOpts::Started)) => {
Ok(ctx.response.empty(hyper::StatusCode::OK))
Ok(ctx.response.empty(StatusCode::OK))
}
(&Method::GET, Health(HealthOpts::Ready)) => {
let oids = ctx.scheduler.oids().await?;
if oids.count() == 0 {
Ok(ctx.response.empty(hyper::StatusCode::SERVICE_UNAVAILABLE))
Ok(ctx.response.empty(StatusCode::SERVICE_UNAVAILABLE))
} else {
Ok(ctx.response.empty(hyper::StatusCode::OK))
Ok(ctx.response.empty(StatusCode::OK))
}
}
(&Method::GET, Notus(None)) => match &ctx.notus {
Some(notus) => match notus.get_available_os().await {
Ok(result) => Ok(ctx.response.ok(&result)),
Err(err) => Ok(ctx.response.internal_server_error(&err)),
},
None => Ok(ctx.response.empty(hyper::StatusCode::SERVICE_UNAVAILABLE)),
None => Ok(ctx.response.empty(StatusCode::SERVICE_UNAVAILABLE)),
},

(&Method::POST, Notus(Some(os))) => {
Expand All @@ -297,7 +296,7 @@ where
_ => Ok(ctx.response.internal_server_error(&err)),
},
},
None => Ok(ctx.response.empty(hyper::StatusCode::SERVICE_UNAVAILABLE)),
None => Ok(ctx.response.empty(StatusCode::SERVICE_UNAVAILABLE)),
},
Err(resp) => Ok(resp),
}
Expand Down Expand Up @@ -429,7 +428,7 @@ where
};

match ctx.scheduler.get_results(&id, begin, end).await {
Ok(results) => Ok(ctx.response.ok_byte_stream(results).await),
Ok(results) => Ok(ctx.response.byte_stream(StatusCode::OK, results).await),
Err(crate::storage::Error::NotFound) => {
Ok(ctx.response.not_found("scans/results", &id))
}
Expand Down Expand Up @@ -470,6 +469,7 @@ where
pub mod client {
use std::sync::Arc;

use http::StatusCode;
use http_body_util::{BodyExt, Empty, Full};
use hyper::{
body::Bytes, header::HeaderValue, service::HttpService, HeaderMap, Method, Request,
Expand All @@ -482,6 +482,7 @@ pub mod client {
};
use serde::Deserialize;

use crate::storage::inmemory;
use crate::{
controller::{ClientIdentifier, Context},
storage::{file::Storage, NVTStorer, UserNASLStorageForKBandVT},
Expand All @@ -506,11 +507,7 @@ pub mod client {
>,
FSPluginLoader,
)>,
Arc<
UserNASLStorageForKBandVT<
crate::storage::inmemory::Storage<crate::crypt::ChaCha20Crypt>,
>,
>,
Arc<UserNASLStorageForKBandVT<inmemory::Storage<crate::crypt::ChaCha20Crypt>>>,
> {
use crate::file::tests::{example_feeds, nasl_root};
let storage = crate::storage::inmemory::Storage::default();
Expand Down Expand Up @@ -552,6 +549,24 @@ pub mod client {
Client::authenticated(scanner, storage)
}

pub async fn fails_to_fetch_results() -> Client<
scannerlib::scanner::fake::LambdaScanner,
Arc<UserNASLStorageForKBandVT<inmemory::Storage<crate::crypt::ChaCha20Crypt>>>,
> {
use crate::file::tests::example_feeds;
let storage = crate::storage::inmemory::Storage::default();
let storage = Arc::new(UserNASLStorageForKBandVT::new(storage));
storage
.synchronize_feeds(example_feeds().await)
.await
.unwrap();

let scanner = scannerlib::scanner::fake::LambdaScannerBuilder::new()
.with_fetch_results(|_| Err(scanner::Error::Unexpected("no results".to_string())))
.build();
Client::authenticated(scanner, storage)
}

pub async fn file_based_example_feed(
prefix: &str,
) -> Client<
Expand Down Expand Up @@ -648,7 +663,7 @@ pub mod client {
let result = self
.request_empty(Method::GET, KnownPaths::ScanStatus(id.to_string()))
.await;
self.parsed(result).await
self.parsed(result, StatusCode::OK).await
}

pub async fn header(&self) -> TypeResult<HeaderMap<HeaderValue>> {
Expand All @@ -662,14 +677,18 @@ pub mod client {
let result = self
.request_empty(Method::GET, KnownPaths::Scans(Some(id.to_string())))
.await;
self.parsed(result).await
self.parsed(result, StatusCode::OK).await
}

pub async fn scan_results(&self, id: &str) -> TypeResult<Vec<models::Result>> {
pub async fn scan_results(
&self,
id: &str,
status: StatusCode,
) -> TypeResult<Vec<models::Result>> {
let result = self
.request_empty(Method::GET, KnownPaths::ScanResults(id.to_string(), None))
.await;
self.parsed(result).await
self.parsed(result, status).await
}
pub async fn scan_delete(&self, id: &str) -> TypeResult<()> {
let result = self
Expand Down Expand Up @@ -705,7 +724,7 @@ pub mod client {
let result = self
.request_empty(Method::GET, KnownPaths::Scans(None))
.await;
self.parsed(result).await
self.parsed(result, StatusCode::OK).await
}

// TODO: deal with that static stuff that prevents deserializiation based on Bytes
Expand All @@ -731,12 +750,12 @@ pub mod client {
let result = self
.request_json(Method::POST, KnownPaths::Scans(None), scan)
.await;
self.parsed(result).await
self.parsed(result, StatusCode::CREATED).await
}

pub async fn vts(&self) -> TypeResult<Vec<String>> {
let result = self.request_empty(Method::GET, KnownPaths::Vts(None)).await;
self.parsed(result).await
self.parsed(result, StatusCode::OK).await
}

/// Starts a scan and wait until is finished and returns it status and results
Expand Down Expand Up @@ -770,14 +789,19 @@ pub mod client {
}
}

pub async fn parsed<'a, T>(&self, result: HttpResult) -> TypeResult<T>
pub async fn parsed<'a, T>(
&self,
result: HttpResult,
expected_status: StatusCode,
) -> TypeResult<T>
where
T: for<'de> Deserialize<'de>,
{
let resp = result?;
if resp.status() != 200 && resp.status() != 201 {
if resp.status() != expected_status {
return Err(scanner::Error::Unexpected(format!(
"Expected 200 for a body response but got {}",
"Expected {} for a body response but got {}",
expected_status,
resp.status()
)));
}
Expand All @@ -792,6 +816,7 @@ pub mod client {

#[cfg(test)]
pub(super) mod tests {
use http::StatusCode;
use scannerlib::models::{Scan, VT};

#[tokio::test]
Expand Down Expand Up @@ -820,8 +845,22 @@ pub(super) mod tests {
assert!(vts.len() > 2);
let (id, status) = client.scan_finish(&scan).await.unwrap();
assert_eq!(status.status, scannerlib::models::Phase::Succeeded);
let results = client.scan_results(&id).await.unwrap();
let results = client.scan_results(&id, StatusCode::OK).await.unwrap();
assert_eq!(3, results.len());
client.scan_delete(&id).await.unwrap();
}

#[tokio::test]
#[tracing_test::traced_test]
async fn status_of_internal_error_should_be_reflects() {
let client = super::client::fails_to_fetch_results().await;

let mut scan: Scan = Scan::default();
scan.target.hosts.push("localhost".to_string());
let (id, status) = client.scan_finish(&scan).await.unwrap();
assert_eq!(status.status, scannerlib::models::Phase::Failed);
let results = client.scan_results(&id, StatusCode::OK).await.unwrap();
assert_eq!(0, results.len());
client.scan_delete(&id).await.unwrap();
}
}
22 changes: 17 additions & 5 deletions rust/src/openvasd/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ impl Response {
}

#[inline]
fn ok_json_response(&self, body: BodyKind) -> Result {
fn json_response(&self, status: hyper::StatusCode, body: BodyKind) -> Result {
match self
.default_response_builder()
.header("Content-Type", "application/json")
.status(hyper::StatusCode::OK)
.status(status)
.body(body)
{
Ok(resp) => resp,
Expand All @@ -188,8 +188,9 @@ impl Response {
}
}
}

#[inline]
pub async fn ok_byte_stream<T>(&self, mut value: T) -> Result
pub async fn byte_stream<T>(&self, status: hyper::StatusCode, mut value: T) -> Result
where
T: Iterator<Item = Vec<u8>> + Send + 'static,
{
Expand Down Expand Up @@ -227,7 +228,15 @@ impl Response {
tracing::debug!("end send values");
drop(tx);
});
self.ok_json_response(BodyKind::BinaryStream(rx))
self.json_response(status, BodyKind::BinaryStream(rx))
}

#[inline]
pub async fn ok_byte_stream<T>(&self, value: T) -> Result
where
T: Iterator<Item = Vec<u8>> + Send + 'static,
{
self.byte_stream(hyper::StatusCode::OK, value).await
}

#[inline]
Expand Down Expand Up @@ -281,7 +290,10 @@ impl Response {
}

pub fn ok_static(&self, value: &[u8]) -> Result {
self.ok_json_response(BodyKind::Binary(value.to_vec().into()))
self.json_response(
hyper::StatusCode::OK,
BodyKind::Binary(value.to_vec().into()),
)
}

pub fn created<T>(&self, value: &T) -> Result
Expand Down
7 changes: 6 additions & 1 deletion rust/src/openvasd/scheduling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,12 @@ where
};
}
Err(e) => {
tracing::warn!(%scan_id, %e, "unable to fetch results");
// TODO: set scan to failed and inform entry to return 500 instead of 200
// Also may remove from running
tracing::warn!(%scan_id, %e, "unable to fetch results, setting scan to failed");
let mut status = self.db.get_status(&scan_id).await?;
status.status = Phase::Failed;
self.db.update_status(&scan_id, status).await?;
}
};
}
Expand Down
Loading

0 comments on commit 011d61f

Please sign in to comment.