diff --git a/Cargo.lock b/Cargo.lock index ff0ee47..a837110 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -342,6 +342,51 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core", + "bitflags 1.3.2", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.29", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.72" @@ -445,6 +490,11 @@ name = "cc" version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" +dependencies = [ + "jobserver", + "libc", + "once_cell", +] [[package]] name = "cfg-if" @@ -535,6 +585,15 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +[[package]] +name = "common-error" +version = "0.1.0" +dependencies = [ + "snafu", + "strum 0.25.0", + "tonic", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -624,6 +683,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "cron" version = "0.12.1" @@ -938,16 +1006,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "eyre" -version = "0.6.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" -dependencies = [ - "indenter", - "once_cell", -] - [[package]] name = "fastrand" version = "1.9.0" @@ -963,6 +1021,16 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +[[package]] +name = "flate2" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "flume" version = "0.11.0" @@ -1252,13 +1320,19 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap", + "indexmap 2.2.6", "slab", "tokio", "tokio-util", "tracing", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.14.5" @@ -1275,7 +1349,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown", + "hashbrown 0.14.5", ] [[package]] @@ -1485,6 +1559,18 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.29", + "pin-project-lite", + "tokio", + "tokio-io-timeout", +] + [[package]] name = "hyper-timeout" version = "0.5.1" @@ -1597,10 +1683,14 @@ dependencies = [ ] [[package]] -name = "indenter" -version = "0.3.3" +name = "indexmap" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] [[package]] name = "indexmap" @@ -1609,7 +1699,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.5", ] [[package]] @@ -1689,6 +1779,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jobserver" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.69" @@ -1791,6 +1890,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" +[[package]] +name = "macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "markup5ever" version = "0.11.0" @@ -1814,6 +1922,12 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "md-5" version = "0.10.6" @@ -2042,7 +2156,7 @@ dependencies = [ "http-body-util", "hyper 1.3.1", "hyper-rustls", - "hyper-timeout", + "hyper-timeout 0.5.1", "hyper-util", "jsonwebtoken", "once_cell", @@ -2506,6 +2620,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" +dependencies = [ + "bytes", +] + [[package]] name = "psl-types" version = "2.0.11" @@ -2879,6 +3002,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + [[package]] name = "ryu" version = "1.0.18" @@ -2965,7 +3094,7 @@ dependencies = [ "sea-query-binder", "serde", "sqlx", - "strum", + "strum 0.26.2", "thiserror", "tracing", "url", @@ -3216,7 +3345,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap", + "indexmap 2.2.6", "itoa", "ryu", "serde", @@ -3437,7 +3566,7 @@ dependencies = [ "futures-util", "hashlink", "hex", - "indexmap", + "indexmap 2.2.6", "log", "memchr", "once_cell", @@ -3664,12 +3793,34 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros", +] + [[package]] name = "strum" version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.66", +] + [[package]] name = "subtle" version = "2.5.0" @@ -3961,6 +4112,16 @@ dependencies = [ "uuid", ] +[[package]] +name = "tokio-io-timeout" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-macros" version = "2.3.0" @@ -4017,6 +4178,38 @@ dependencies = [ "tokio", ] +[[package]] +name = "tonic" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.21.7", + "bytes", + "flate2", + "h2", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.29", + "hyper-timeout 0.4.1", + "percent-encoding", + "pin-project", + "prost", + "rustls-pemfile 2.1.2", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", + "zstd", +] + [[package]] name = "tower" version = "0.4.13" @@ -4025,8 +4218,11 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", + "indexmap 1.9.3", "pin-project", "pin-project-lite", + "rand", + "slab", "tokio", "tokio-util", "tower-layer", @@ -4464,10 +4660,11 @@ dependencies = [ "chrono", "chrono-tz", "clap", - "eyre", + "common-error", "fs-err", "hmac", "lazy_static", + "macros", "migration", "octocrab", "regex", @@ -4479,9 +4676,9 @@ dependencies = [ "serde_variant", "serde_yaml", "sha2", + "snafu", "teloxide", "tera", - "thiserror", "time", "tokio", "tokio-cron-scheduler", @@ -4749,3 +4946,32 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zstd" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "6.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.10+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 017981c..94c8a6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,10 +6,9 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [workspace] -members = ["./migration", "."] +members = ["./migration", ".", "macros", "common-error"] [dependencies] -thiserror = "1" clap = { version = "4.5.4", features = ["derive"] } serde = { version = "1", features = ["derive"] } @@ -18,7 +17,6 @@ serde_yaml = "0.9" serde_variant = "0.1" lazy_static = "1.4" -eyre = "0.6.12" fs-err = "2.11" tera = "1.19.1" @@ -58,3 +56,6 @@ hmac = "0.12.1" base64 = "0.22.1" chrono-tz = "0.9.0" octocrab = "0.38.0" +snafu = "0.8.3" +macros = { path = "macros" } +common-error = { path = "common-error" } diff --git a/common-error/Cargo.toml b/common-error/Cargo.toml new file mode 100644 index 0000000..9c5c406 --- /dev/null +++ b/common-error/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "common-error" +version = "0.1.0" +edition = "2021" + +[dependencies] +strum = { version = "0.25", features = ["derive"] } +tonic = { version = "0.11", features = ["tls", "gzip", "zstd"] } +snafu = "0.8.3" diff --git a/common-error/src/ext.rs b/common-error/src/ext.rs new file mode 100644 index 0000000..d11556d --- /dev/null +++ b/common-error/src/ext.rs @@ -0,0 +1,40 @@ +use std::sync::Arc; + +pub trait StackError: std::error::Error { + fn debug_fmt(&self, layer: usize, buf: &mut Vec); + + fn next(&self) -> Option<&dyn StackError>; + + fn last(&self) -> &dyn StackError + where + Self: Sized, + { + let Some(mut result) = self.next() else { + return self; + }; + while let Some(err) = result.next() { + result = err; + } + result + } +} + +impl StackError for Arc { + fn debug_fmt(&self, layer: usize, buf: &mut Vec) { + self.as_ref().debug_fmt(layer, buf) + } + + fn next(&self) -> Option<&dyn StackError> { + self.as_ref().next() + } +} + +impl StackError for Box { + fn debug_fmt(&self, layer: usize, buf: &mut Vec) { + self.as_ref().debug_fmt(layer, buf) + } + + fn next(&self) -> Option<&dyn StackError> { + self.as_ref().next() + } +} diff --git a/common-error/src/lib.rs b/common-error/src/lib.rs new file mode 100644 index 0000000..310a957 --- /dev/null +++ b/common-error/src/lib.rs @@ -0,0 +1 @@ +pub mod ext; diff --git a/macros/Cargo.toml b/macros/Cargo.toml new file mode 100644 index 0000000..39a7bc5 --- /dev/null +++ b/macros/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "macros" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0.85" +quote = "1.0.36" +syn = { version = "2.0", features = [ + "derive", + "parsing", + "printing", + "clone-impls", + "proc-macro", + "extra-traits", + "full", +] } diff --git a/macros/src/lib.rs b/macros/src/lib.rs new file mode 100644 index 0000000..f32f1a0 --- /dev/null +++ b/macros/src/lib.rs @@ -0,0 +1,7 @@ +mod stack_trace_debug; +use proc_macro::TokenStream; + +#[proc_macro_attribute] +pub fn stack_trace_debug(args: TokenStream, input: TokenStream) -> TokenStream { + stack_trace_debug::stack_trace_style_impl(args.into(), input.into()).into() +} diff --git a/macros/src/stack_trace_debug.rs b/macros/src/stack_trace_debug.rs new file mode 100644 index 0000000..d5a5922 --- /dev/null +++ b/macros/src/stack_trace_debug.rs @@ -0,0 +1,248 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{quote, quote_spanned}; +use syn::{parenthesized, spanned::Spanned, Attribute, Ident, ItemEnum, Variant}; + +pub fn stack_trace_style_impl(args: TokenStream, input: TokenStream) -> TokenStream { + let input_cloned: TokenStream = input.clone(); + let error_enum_definition: ItemEnum = syn::parse2(input_cloned).unwrap(); + let enum_name = error_enum_definition.ident; + + let mut variants = vec![]; + + for error_variant in error_enum_definition.variants { + let variant = ErrorVariant::from_enum_variant(error_variant); + variants.push(variant); + } + + let debug_fmt_fn = build_debug_fmt_impl(enum_name.clone(), variants.clone()); + let next_fn = build_next_impl(enum_name.clone(), variants); + let debug_impl = build_debug_impl(enum_name.clone()); + + quote! { + #args + #input + + impl ::common_error::ext::StackError for #enum_name { + #debug_fmt_fn + #next_fn + } + + #debug_impl + } +} + +fn build_debug_fmt_impl(enum_name: Ident, variants: Vec) -> TokenStream { + let match_arms = variants + .iter() + .map(|v| v.to_debug_match_arm()) + .collect::>(); + + quote! { + fn debug_fmt(&self, layer: usize, buf: &mut Vec) { + use #enum_name::*; + match self { + #(#match_arms)* + } + } + } +} + +fn build_next_impl(enum_name: Ident, variants: Vec) -> TokenStream { + let match_arms = variants + .iter() + .map(|v| v.to_next_match_arm()) + .collect::>(); + + quote! { + fn next(&self) -> Option<&dyn ::common_error::ext::StackError> { + use #enum_name::*; + match self { + #(#match_arms)* + } + } + } +} + +/// Implement [std::fmt::Debug] via `debug_fmt` +fn build_debug_impl(enum_name: Ident) -> TokenStream { + quote! { + impl std::fmt::Debug for #enum_name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use ::common_error::ext::StackError; + let mut buf = vec![]; + self.debug_fmt(0, &mut buf); + write!(f, "{}", buf.join("\n")) + } + } + } +} + +#[derive(Clone, Debug)] +struct ErrorVariant { + name: Ident, + fields: Vec, + has_location: bool, + has_source: bool, + has_external_cause: bool, + display: TokenStream, + span: Span, + cfg_attr: Option, +} + +impl ErrorVariant { + /// Construct self from [Variant] + fn from_enum_variant(variant: Variant) -> Self { + let span = variant.span(); + let mut has_location = false; + let mut has_source = false; + let mut has_external_cause = false; + + for field in &variant.fields { + if let Some(ident) = &field.ident { + if ident == "location" { + has_location = true; + } else if ident == "source" { + has_source = true; + } else if ident == "error" { + has_external_cause = true; + } + } + } + + let mut display = None; + let mut cfg_attr = None; + for attr in variant.attrs { + if attr.path().is_ident("snafu") { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("display") { + let content; + parenthesized!(content in meta.input); + let display_ts: TokenStream = content.parse()?; + display = Some(display_ts); + Ok(()) + } else { + Err(meta.error("unrecognized repr")) + } + }) + .expect("Each error should contains a display attribute"); + } + + if attr.path().is_ident("cfg") { + cfg_attr = Some(attr); + } + } + + let field_ident = variant + .fields + .iter() + .map(|f| f.ident.clone().unwrap_or_else(|| Ident::new("_", f.span()))) + .collect(); + + Self { + name: variant.ident, + fields: field_ident, + has_location, + has_source, + has_external_cause, + display: display.unwrap(), + span, + cfg_attr, + } + } + + /// Convert self into an match arm that will be used in [build_debug_impl]. + /// + /// The generated match arm will be like: + /// ```rust, ignore + /// ErrorKindWithSource { source, .. } => { + /// debug_fmt(source, layer + 1, buf); + /// }, + /// ErrorKindWithoutSource { .. } => { + /// buf.push(format!("{layer}: {}, at {}", format!(#display), location))); + /// } + /// ``` + /// + /// The generated code assumes fn `debug_fmt`, var `layer`, var `buf` are in scope. + fn to_debug_match_arm(&self) -> TokenStream { + let name = &self.name; + let fields = &self.fields; + let display = &self.display; + let cfg = if let Some(cfg) = &self.cfg_attr { + quote_spanned!(cfg.span() => #cfg) + } else { + quote! {} + }; + + match (self.has_location, self.has_source, self.has_external_cause) { + (true, true, _) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),*, } => { + buf.push(format!("{layer}: {}, at {}", format!(#display), location)); + source.debug_fmt(layer + 1, buf); + }, + }, + (true, false, true) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + buf.push(format!("{layer}: {}, at {}", format!(#display), location)); + buf.push(format!("{}: {:?}", layer + 1, error)); + }, + }, + (true, false, false) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + buf.push(format!("{layer}: {}, at {}", format!(#display), location)); + }, + }, + (false, true, _) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + buf.push(format!("{layer}: {}", format!(#display))); + source.debug_fmt(layer + 1, buf); + }, + }, + (false, false, true) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + buf.push(format!("{layer}: {}", format!(#display))); + buf.push(format!("{}: {:?}", layer + 1, error)); + }, + }, + (false, false, false) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + buf.push(format!("{layer}: {}", format!(#display))); + }, + }, + } + } + + /// Convert self into an match arm that will be used in [build_next_impl]. + /// + /// The generated match arm will be like: + /// ```rust, ignore + /// ErrorKindWithSource { source, .. } => { + /// Some(source) + /// }, + /// ErrorKindWithoutSource { .. } => { + /// None + /// } + /// ``` + fn to_next_match_arm(&self) -> TokenStream { + let name = &self.name; + let fields = &self.fields; + let cfg = if let Some(cfg) = &self.cfg_attr { + quote_spanned!(cfg.span() => #cfg) + } else { + quote! {} + }; + + if self.has_source { + quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + Some(source) + }, + } + } else { + quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } =>{ + None + } + } + } + } +} diff --git a/src/app.rs b/src/app.rs index c4ca52d..9c8b594 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,17 +1,18 @@ use lazy_static::lazy_static; use migration::MigratorTrait; use sea_orm::DatabaseConnection; +use snafu::ResultExt; use std::{collections::HashMap, sync::Arc, time::Duration}; -use tokio::{task::JoinSet, time}; +use tokio::{runtime::Runtime, task::JoinSet, time}; use tokio_cron_scheduler::{Job, JobScheduler}; -use tracing::{info, warn}; +use tracing::{error, info, warn}; use crate::{ config::Config, db, environment::Environment, - error::Result, - grab::{self, Grab}, + error::{CronSchedulerErrSnafu, DbErrSnafu, Result}, + grab::{self, Grab, VulnInfo}, models::_entities::vuln_informations::{self, Model}, push::{ self, @@ -55,9 +56,17 @@ impl WatchVulnApp { move |uuid, mut lock| { let self_clone = self_arc.clone(); Box::pin(async move { - let res = self_clone.crawling_task(false).await; + let res: Vec = self_clone + .crawling_task(false) + .await + .into_iter() + .map(|x| x.into()) + .collect(); info!("crawling over all count is: {}", res.len()); - self_clone.push(res).await; + let rt1 = Runtime::new().unwrap(); + rt1.block_on(async move { + self_clone.push(res).await; + }); let next_tick = lock.next_tick_for_job(uuid).await; if let Ok(Some(tick)) = next_tick { info!( @@ -67,7 +76,8 @@ impl WatchVulnApp { } }) }, - )?; + ) + .context(CronSchedulerErrSnafu)?; Ok(job) } @@ -80,11 +90,11 @@ impl WatchVulnApp { info!("init finished, local database has {} vulns", local_count); self.push_init_msg(local_count).await?; - let sched = JobScheduler::new().await?; + let sched = JobScheduler::new().await.context(CronSchedulerErrSnafu)?; let job = self.crawling_job()?; - sched.add(job).await?; - sched.start().await?; + sched.add(job).await.context(CronSchedulerErrSnafu)?; + sched.start().await.context(CronSchedulerErrSnafu)?; loop { time::sleep(Duration::from_secs(60)).await; } @@ -96,48 +106,52 @@ impl WatchVulnApp { for v in self.grabs.as_ref().values() { let grab = v.to_owned(); if is_init { - set.spawn(async move { grab.get_update(INIT_PAGE_LIMIT).await }); + // set.spawn(async move { grab.get_update(INIT_PAGE_LIMIT).await }); + set.spawn(async move { + grab.get_update(INIT_PAGE_LIMIT) + .await + .expect("crawling error") + }); } else { - set.spawn(async move { grab.get_update(PAGE_LIMIT).await }); + set.spawn( + async move { grab.get_update(PAGE_LIMIT).await.expect("crawling error") }, + ); } } let mut new_vulns = Vec::new(); while let Some(set_res) = set.join_next().await { match set_res { - Ok(grabs_res) => match grabs_res { - Ok(res) => { - for v in res { - let create_res = - vuln_informations::Model::creat_or_update(&self.app_context.db, v) - .await; - match create_res { - Ok(m) => { - info!("found new vuln:{}", m.key); - new_vulns.push(m) - } - Err(e) => { - warn!("db model error:{}", e); - } + Ok(grabs_res) => { + for v in grabs_res { + let create_res = + vuln_informations::Model::creat_or_update(&self.app_context.db, v) + .await; + match create_res { + Ok(m) => { + info!("found new vuln:{}", m.key); + new_vulns.push(m) + } + Err(e) => { + warn!("db model error:{:?}", e); } } } - Err(err) => warn!("grab crawling error:{}", err), - }, - Err(e) => warn!("join set error:{}", e), + } + Err(e) => warn!("join set error:{:?}", e), } } new_vulns } - async fn push(&self, vulns: Vec) { + async fn push(&self, vulns: Vec) { for mut vuln in vulns.into_iter() { if vuln.is_valuable { if vuln.pushed { - info!("{} has been pushed, skipped", vuln.key); + info!("{} has been pushed, skipped", vuln.unique_key); continue; } - let key = vuln.key.clone(); + let key = vuln.unique_key.clone(); let title = vuln.title.clone(); if !vuln.cve.is_empty() && self.app_context.config.github_search { let links = search_github_poc(&vuln.cve).await; @@ -150,15 +164,15 @@ impl WatchVulnApp { ) .await { - warn!("update vuln {} github_search error: {}", &vuln.cve, err); + warn!("update vuln {} github_search error: {:?}", &vuln.cve, err); } - vuln.github_search = Some(links); + vuln.github_search = links; } } - let msg = match reader_vulninfo(vuln.into()) { + let msg = match reader_vulninfo(vuln) { Ok(msg) => msg, Err(err) => { - warn!("reader vulninfo {} error {}", key, err); + warn!("reader vulninfo {} error {:?}", key, err); continue; } }; @@ -169,7 +183,7 @@ impl WatchVulnApp { vuln_informations::Model::update_pushed_by_key(&self.app_context.db, key) .await { - warn!("update vuln {} pushed error: {}", msg, err); + warn!("update vuln {} pushed error: {:?}", msg, err); } } } @@ -205,28 +219,28 @@ impl WatchVulnApp { let bot_clone = bot.clone(); let message = msg.clone(); let title = title.clone(); - set.spawn(async move { bot_clone.push_markdown(title, message).await }); + set.spawn(async move { + if let Err(e) = bot_clone.push_markdown(title, message).await { + error!("push to bot error: {:?}", e); + warn!("push to bot error: {:?}", e); + return Err(format!("push to bot error:{}", e)); + } + Ok(()) + }); } let mut is_push = true; while let Some(set_res) = set.join_next().await { - match set_res { - Ok(res) => { - if let Err(e) = res { - is_push = false; - warn!("push message error:{}", e); - } - } - Err(err) => { - is_push = false; - warn!("push join set error:{}", err) - } + if set_res.is_err() { + is_push = false; } } is_push } pub async fn run_migration(&self) -> Result<()> { - migration::Migrator::up(&self.app_context.db, None).await?; + migration::Migrator::up(&self.app_context.db, None) + .await + .context(DbErrSnafu)?; Ok(()) } } @@ -241,7 +255,7 @@ pub struct AppContext { pub async fn create_context(environment: &Environment) -> Result { let config = environment.load()?; - let db = db::connect(&config.database).await?; + let db = db::connect(&config.database).await.context(DbErrSnafu)?; let bot_manager = push::init(config.clone()); Ok(AppContext { environment: environment.clone(), diff --git a/src/config.rs b/src/config.rs index 668c52f..a63b7cd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,7 +2,7 @@ use std::path::{Path, PathBuf}; use crate::{ environment::Environment, - error::{Error, Result}, + error::{ConfigErrSnafu, IoSnafu, Result, SerdeYamlErrSnafu}, logger, utils::render_string, }; @@ -10,6 +10,7 @@ use fs_err as fs; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use serde_json::json; +use snafu::{OptionExt, ResultExt}; lazy_static! { static ref DEFAULT_FOLDER: PathBuf = PathBuf::from("config"); @@ -39,12 +40,13 @@ impl Config { let selected_path = files .iter() .find(|p| p.exists()) - .ok_or_else(|| Error::Message("no configuration file found".to_string()))?; + .with_context(|| ConfigErrSnafu { + msg: format!("config file {env}.local.yaml or {env}.yaml not found"), + })?; - let content = fs::read_to_string(selected_path)?; + let content = fs::read_to_string(selected_path).with_context(|_| IoSnafu)?; let rendered = render_string(&content, &json!({}))?; - serde_yaml::from_str(&rendered) - .map_err(|err| Error::YAMLFile(err, selected_path.to_string_lossy().to_string())) + serde_yaml::from_str(&rendered).with_context(|_| SerdeYamlErrSnafu) } } diff --git a/src/error.rs b/src/error.rs index 1d8480a..dd6187c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,114 +1,249 @@ use std::{num::ParseIntError, time::SystemTimeError}; use hmac::digest::crypto_common; +use macros::stack_trace_debug; use migration::sea_orm; +use scraper::error::SelectorErrorKind; +use snafu::{Location, Snafu}; + +pub type Result = std::result::Result; + +#[derive(Snafu)] +#[snafu(visibility(pub))] +#[stack_trace_debug] +pub enum AppError { + #[snafu(display("IO error"))] + Io { + #[snafu(source)] + error: std::io::Error, + #[snafu(implicit)] + location: Location, + }, -pub type Result = std::result::Result; + #[snafu(display("db error"))] + DbErr { + #[snafu(source)] + error: sea_orm::DbErr, + #[snafu(implicit)] + location: Location, + }, -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("{inner}\n{backtrace}")] - WithBacktrace { - inner: Box, - backtrace: Box, + #[snafu(display("db table {} filter by {} not found", table, filter))] + DbNotFoundErr { + table: String, + filter: String, + #[snafu(implicit)] + location: Location, }, - // Model - #[error(transparent)] - Model(#[from] crate::models::ModelError), + #[snafu(display("db table {} filter by {} altread exists", table, filter))] + DbAlreadyExists { + table: String, + filter: String, + #[snafu(implicit)] + location: Location, + }, - #[error("{0}")] - Message(String), + #[snafu(display("cron scheduler error"))] + CronSchedulerErr { + #[snafu(source)] + error: tokio_cron_scheduler::JobSchedulerError, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - HmacError(#[from] crypto_common::InvalidLength), + #[snafu(display("request url {url} error"))] + HttpClientErr { + url: String, + #[snafu(source)] + error: reqwest::Error, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - CronScheduler(#[from] tokio_cron_scheduler::JobSchedulerError), + #[snafu(display("regex new {re} error"))] + RegexErr { + re: String, + #[snafu(source)] + error: regex::Error, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - Reqwest(#[from] reqwest::Error), + #[snafu(display("regex captures error: {msg}"))] + RegexCapturesErr { + msg: String, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - Regex(#[from] regex::Error), + #[snafu(display("parse {num} to int error"))] + ParseIntErr { + num: String, + #[snafu(source)] + error: ParseIntError, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - ParseInt(#[from] ParseIntError), + #[snafu(display("parse {url} error"))] + ParseUrlErr { + url: String, + #[snafu(source)] + error: url::ParseError, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - ParseUrl(#[from] url::ParseError), + #[snafu(display("task join error"))] + TaskJoinErr { + #[snafu(source)] + error: tokio::task::JoinError, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("tera error"))] + TeraErr { + #[snafu(source)] + error: tera::Error, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("crypto error"))] + CryptoError { + #[snafu(source)] + error: crypto_common::InvalidLength, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("json error"))] + JsonErr { + #[snafu(source)] + error: serde_json::Error, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - DB(#[from] sea_orm::DbErr), + #[snafu(display("teloxide request error"))] + TeloxideErr { + #[snafu(source)] + error: teloxide::RequestError, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - Join(#[from] tokio::task::JoinError), + #[snafu(display("octocrab search {search} error"))] + OctocrabErr { + search: String, + #[snafu(source)] + error: octocrab::Error, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - Tera(#[from] tera::Error), + #[snafu(display("chrono parse {date} error"))] + ChronoParseErr { + date: String, + #[snafu(source)] + error: chrono::ParseError, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - JSON(serde_json::Error), + #[snafu(display("system time error"))] + SystemTimeErr { + #[snafu(source)] + error: SystemTimeError, + #[snafu(implicit)] + location: Location, + }, - #[error("cannot parse `{1}`: {0}")] - YAMLFile(#[source] serde_yaml::Error, String), + #[snafu(display("Illegal config msg: {msg}"))] + ConfigErr { + msg: String, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - YAML(#[from] serde_yaml::Error), + #[snafu(display("serde yaml error"))] + SerdeYamlErr { + #[snafu(source)] + error: serde_yaml::Error, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - TELOXIDE(#[from] teloxide::RequestError), + #[snafu(display("timestamp {timestamp} to datetime error"))] + DateTimeFromTimestampErr { + timestamp: i64, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - EnvVar(#[from] std::env::VarError), + #[snafu(display("scraper selector error"))] + SelectorError { + #[snafu(source)] + error: SelectorErrorKind<'static>, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - IO(#[from] std::io::Error), + #[snafu(display("html selector nth {} not found", nth))] + SelectNthErr { + nth: usize, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - DateParse(#[from] chrono::ParseError), + #[snafu(display("html element attr {} not found", attr))] + ElementAttrErr { + attr: String, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - SystemTime(#[from] SystemTimeError), + #[snafu(display("ding push markdown message response errorcode {errorcode}"))] + DingPushErr { + errorcode: i64, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - Octocrab(#[from] octocrab::Error), + #[snafu(display("lark push markdown message response code {code}"))] + LarkPushErr { + code: i64, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - Any(#[from] Box), + // detail.code != 200 || !detail.success || detail.data.is_empty(), + #[snafu(display("oscs {mps} detail {code} invalid"))] + InvalidOscsDetail { + mps: String, + code: i64, + #[snafu(implicit)] + location: Location, + }, - #[error(transparent)] - Anyhow(#[from] eyre::Report), -} + #[snafu(display("oscs list total invalid"))] + InvalidOscsListTotal { + #[snafu(implicit)] + location: Location, + }, -impl Error { - pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { - Self::Any(Box::new(err)) - } - - pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self { - Self::Message(err.to_string()) - } - - pub fn string(s: &str) -> Self { - Self::Message(s.to_string()) - } - - pub fn bt(self) -> Self { - let backtrace = std::backtrace::Backtrace::capture(); - match backtrace.status() { - std::backtrace::BacktraceStatus::Disabled - | std::backtrace::BacktraceStatus::Unsupported => self, - _ => Self::WithBacktrace { - inner: Box::new(self), - backtrace: Box::new(backtrace), - }, - } - } -} + #[snafu(display("seebug parse html error"))] + ParseSeeBugHtmlErr { + #[snafu(implicit)] + location: Location, + }, -impl From for Error { - fn from(val: serde_json::Error) -> Self { - Self::JSON(val).bt() - } + #[snafu(display("Invalid seebug page num: {}", num))] + InvalidSeebugPageNum { + num: usize, + #[snafu(implicit)] + location: Location, + }, } diff --git a/src/grab/anti.rs b/src/grab/anti.rs index 311bf37..c07ba85 100644 --- a/src/grab/anti.rs +++ b/src/grab/anti.rs @@ -1,12 +1,14 @@ use crate::{ - error::{Error, Result}, + error::{HttpClientErrSnafu, RegexCapturesErrSnafu, RegexErrSnafu, Result}, utils::data_str_format, }; use async_trait::async_trait; + use regex::Regex; use reqwest::header; use serde::{Deserialize, Serialize}; use serde_json::json; +use snafu::{OptionExt, ResultExt}; use tracing::{info, warn}; use crate::utils::http_client::Help; @@ -41,7 +43,7 @@ impl Grab for AntiCrawler { let unique_key = match cve { Ok(unique_key) => unique_key, Err(e) => { - warn!("AntiCrawler get update not found cve error:{}", e); + warn!("AntiCrawler get update not found cve error:{:?}", e); continue; } }; @@ -65,6 +67,7 @@ impl Grab for AntiCrawler { reasons: vec![], github_search: vec![], is_valuable: true, + pushed: false, }; res.push(vuln); } @@ -116,12 +119,15 @@ impl AntiCrawler { "time_range":[] } }); - let anti_response: AntiResponse = self - .help - .post_json(ANTI_LIST_URL, ¶ms) - .await? - .json() - .await?; + + let post_json_res = self.help.post_json(ANTI_LIST_URL, ¶ms).await?; + let anti_response: AntiResponse = + post_json_res + .json() + .await + .with_context(|_| HttpClientErrSnafu { + url: ANTI_LIST_URL.to_string(), + })?; Ok(anti_response) } @@ -161,12 +167,15 @@ impl AntiCrawler { } fn get_cve(&self, title: &str) -> Result { - let res = Regex::new(ANTI_CVEID_REGEXP)?.captures(title); - if let Some(cve) = res { - Ok(cve[0].to_string()) - } else { - Err(Error::Message("cve regex match not found".to_owned())) - } + let res = Regex::new(ANTI_CVEID_REGEXP) + .with_context(|_| RegexErrSnafu { + re: ANTI_CVEID_REGEXP, + })? + .captures(title) + .with_context(|| RegexCapturesErrSnafu { + msg: format!("captures title:{} cve not found", title), + })?; + Ok(res[0].to_string()) } } diff --git a/src/grab/avd.rs b/src/grab/avd.rs index f250980..e835b8b 100644 --- a/src/grab/avd.rs +++ b/src/grab/avd.rs @@ -1,12 +1,15 @@ use async_trait::async_trait; -use eyre::eyre; use regex::Regex; use reqwest::{header, Url}; use scraper::{Html, Selector}; +use snafu::{ensure, OptionExt, ResultExt}; use tracing::{debug, info, warn}; use crate::{ - error::{Error, Result}, + error::{ + ParseIntErrSnafu, ParseUrlErrSnafu, RegexCapturesErrSnafu, RegexErrSnafu, Result, + SelectNthErrSnafu, SelectorSnafu, + }, grab::{Severity, VulnInfo}, utils::http_client::Help, }; @@ -59,17 +62,23 @@ impl AVDCrawler { pub async fn get_page_count(&self) -> Result { let content = self.help.get_html_content(&self.link).await?; - let cap = Regex::new(PAGE_REGEXP)?.captures(&content); - if let Some(res) = cap { - if res.len() == 2 { - let total = res[1].parse::()?; - Ok(total) - } else { - Err(Error::Message("page regex match error".to_owned())) + let captures = Regex::new(PAGE_REGEXP) + .with_context(|_| RegexErrSnafu { re: PAGE_REGEXP })? + .captures(&content) + .with_context(|| RegexCapturesErrSnafu { + msg: "captures page content page info not found".to_string(), + })?; + ensure!( + captures.len() == 2, + RegexCapturesErrSnafu { + msg: "captures page content page len not eq 2".to_string() } - } else { - Err(Error::Message("page regex match not found".to_owned())) - } + ); + captures[1] + .parse::() + .with_context(|_| ParseIntErrSnafu { + num: captures[1].to_string(), + }) } pub async fn parse_page(&self, page: i32) -> Result> { @@ -80,16 +89,20 @@ impl AVDCrawler { for detail in detail_links { let data = self.parse_detail_page(detail.as_ref()).await; match data { - Ok(data) => res.push(data), - Err(err) => warn!("crawing detail {} error {}", detail, err), + Ok(data) => { + if data.cve.is_empty() && data.disclosure.is_empty() { + continue; + } + res.push(data); + } + Err(err) => warn!("crawing detail {} error {:?}", detail, err), } } Ok(res) } fn get_detail_links(&self, document: Html) -> Result> { - let src_url_selector = - Selector::parse("tbody tr td a").map_err(|err| eyre!("parse html error {}", err))?; + let src_url_selector = Selector::parse("tbody tr td a").context(SelectorSnafu)?; let detail_links: Vec = document .select(&src_url_selector) @@ -119,10 +132,6 @@ impl AVDCrawler { tags.push(utilization); } - if cve_id.is_empty() && disclosure.is_empty() { - return Err(eyre!("invalid vuln data in {}", href).into()); - } - let severity = self.get_severity(&document)?; let title = self.get_title(&document)?; @@ -148,12 +157,13 @@ impl AVDCrawler { reasons: vec![], github_search: vec![], is_valuable, + pushed: false, }; Ok(data) } fn get_avd_id(&self, detail_url: &str) -> Result { - let url = Url::parse(detail_url)?; + let url = Url::parse(detail_url).with_context(|_| ParseUrlErrSnafu { url: detail_url })?; let avd_id = url .query_pairs() .filter(|(key, _)| key == "id") @@ -164,8 +174,7 @@ impl AVDCrawler { } fn get_references(&self, document: &Html) -> Result> { - let reference_selector = Selector::parse("td[nowrap='nowrap'] a") - .map_err(|err| eyre!("avd get references selector parse error {}", err))?; + let reference_selector = Selector::parse("td[nowrap='nowrap'] a").context(SelectorSnafu)?; let references = document .select(&reference_selector) .filter_map(|el| el.attr("href")) @@ -175,12 +184,11 @@ impl AVDCrawler { } fn get_solutions(&self, document: &Html) -> Result { - let solutions_selector = Selector::parse(".text-detail") - .map_err(|err| eyre!("avd get solutions selector parse error {}", err))?; + let solutions_selector = Selector::parse(".text-detail").context(SelectorSnafu)?; let solutions = document .select(&solutions_selector) .nth(1) - .ok_or_else(|| Error::Message("avd solutions value not found".to_string()))? + .with_context(|| SelectNthErrSnafu { nth: 1_usize })? .text() .map(|el| el.trim()) .collect::>() @@ -189,8 +197,7 @@ impl AVDCrawler { } fn get_description(&self, document: &Html) -> Result { - let description_selector = Selector::parse(".text-detail div") - .map_err(|err| eyre!("avd get description selector parse error {}", err))?; + let description_selector = Selector::parse(".text-detail div").context(SelectorSnafu)?; let description = document .select(&description_selector) .map(|e| e.text().collect::()) @@ -201,11 +208,11 @@ impl AVDCrawler { fn get_title(&self, document: &Html) -> Result { let title_selector = Selector::parse("h5[class='header__title'] .header__title__text") - .map_err(|err| eyre!("avd get title selector parse error {}", err))?; + .context(SelectorSnafu)?; let title = document .select(&title_selector) .nth(0) - .ok_or_else(|| eyre!("avd title value not found"))? + .with_context(|| SelectNthErrSnafu { nth: 0_usize })? .inner_html() .trim() .to_string(); @@ -213,12 +220,12 @@ impl AVDCrawler { } fn get_severity(&self, document: &Html) -> Result { - let level_selector = Selector::parse("h5[class='header__title'] .badge") - .map_err(|err| eyre!("avd get severity selector parse error {}", err))?; + let level_selector = + Selector::parse("h5[class='header__title'] .badge").context(SelectorSnafu)?; let level = document .select(&level_selector) .nth(0) - .ok_or_else(|| eyre!("avd level value not found"))? + .with_context(|| SelectNthErrSnafu { nth: 0_usize })? .inner_html() .trim() .to_string(); @@ -233,12 +240,11 @@ impl AVDCrawler { } fn get_mertric_value(&self, document: &Html, index: usize) -> Result { - let value_selector = - Selector::parse(".metric-value").map_err(|err| eyre!("parse html error {}", err))?; + let value_selector = Selector::parse(".metric-value").context(SelectorSnafu)?; let metric_value = document .select(&value_selector) .nth(index) - .ok_or_else(|| eyre!("avd metric value not found"))? + .with_context(|| SelectNthErrSnafu { nth: index })? .inner_html() .trim() .to_string(); @@ -247,7 +253,9 @@ impl AVDCrawler { fn get_cve_id(&self, document: &Html) -> Result { let mut cve_id = self.get_mertric_value(document, 0)?; - if !Regex::new(CVEID_REGEXP)?.is_match(&cve_id) { + let regex = + Regex::new(CVEID_REGEXP).with_context(|_| RegexErrSnafu { re: CVEID_REGEXP })?; + if !regex.is_match(&cve_id) { cve_id = "".to_string(); } Ok(cve_id) diff --git a/src/grab/kev.rs b/src/grab/kev.rs index e9db829..f022856 100644 --- a/src/grab/kev.rs +++ b/src/grab/kev.rs @@ -2,10 +2,11 @@ use async_trait::async_trait; use chrono::{DateTime, FixedOffset}; use reqwest::header::{self}; use serde::{Deserialize, Serialize}; +use snafu::ResultExt; use tracing::info; use super::{Grab, VulnInfo}; -use crate::error::Result; +use crate::error::{HttpClientErrSnafu, Result}; use crate::{grab::Severity, utils::http_client::Help}; const KEV_URL: &str = @@ -23,7 +24,11 @@ pub struct KevCrawler { #[async_trait] impl Grab for KevCrawler { async fn get_update(&self, page_limit: i32) -> Result> { - let kev_list_resp: KevResp = self.help.get_json(KEV_URL).await?.json().await?; + let get_json_res = self.help.get_json(KEV_URL).await?; + let kev_list_resp: KevResp = get_json_res + .json() + .await + .with_context(|_| HttpClientErrSnafu { url: KEV_URL })?; let all_count = kev_list_resp.vulnerabilities.len(); let item_limit = if page_limit as usize * KEV_PAGE_SIZE > all_count { all_count @@ -58,6 +63,7 @@ impl Grab for KevCrawler { github_search: vec![], reasons: vec![], is_valuable, + pushed: false, }; res.push(vuln_info) } @@ -116,7 +122,13 @@ mod tests { #[tokio::test] async fn test_get_key_res() -> Result<()> { let kev = KevCrawler::new(); - let kev_list_resp: KevResp = kev.help.get_json(KEV_URL).await?.json().await?; + let kev_list_resp: KevResp = kev + .help + .get_json(KEV_URL) + .await? + .json() + .await + .with_context(|_| HttpClientErrSnafu { url: KEV_URL })?; let mut vulnerabilities = kev_list_resp.vulnerabilities; vulnerabilities.sort_by(|a, b| b.date_added.cmp(&a.date_added)); println!("{:?}", vulnerabilities); diff --git a/src/grab/mod.rs b/src/grab/mod.rs index 8916051..4a76619 100644 --- a/src/grab/mod.rs +++ b/src/grab/mod.rs @@ -32,11 +32,12 @@ pub struct VulnInfo { pub reasons: Vec, pub github_search: Vec, pub is_valuable: bool, + pub pushed: bool, } impl From for VulnInfo { fn from(v: Model) -> Self { - let severtiy = match v.severtiy.as_str() { + let severity = match v.severtiy.as_str() { "Low" => Severity::Low, "Medium" => Severity::Medium, "High" => Severity::High, @@ -56,7 +57,7 @@ impl From for VulnInfo { unique_key: v.key, title: v.title, description: v.description, - severity: severtiy, + severity, cve: v.cve, disclosure: v.disclosure, references, @@ -66,6 +67,7 @@ impl From for VulnInfo { reasons, github_search, is_valuable: v.is_valuable, + pushed: v.pushed, } } } diff --git a/src/grab/oscs.rs b/src/grab/oscs.rs index 33a8833..b05ef1a 100644 --- a/src/grab/oscs.rs +++ b/src/grab/oscs.rs @@ -2,10 +2,11 @@ use async_trait::async_trait; use chrono::{DateTime, FixedOffset}; use reqwest::header::{self}; use serde::{Deserialize, Serialize}; +use snafu::{ensure, ResultExt}; use tracing::{error, info}; use crate::{ - error::{Error, Result}, + error::{HttpClientErrSnafu, InvalidOscsDetailSnafu, InvalidOscsListTotalSnafu, Result}, grab::{Severity, VulnInfo}, utils::{http_client::Help, timestamp_to_date}, }; @@ -64,12 +65,11 @@ impl OscCrawler { "page": page, "per_page": per_page, }); - let oscs_list_resp: OscsListResp = self - .help - .post_json(OSCS_LIST_URL, ¶ms) - .await? + let post_json_res = self.help.post_json(OSCS_LIST_URL, ¶ms).await?; + let oscs_list_resp: OscsListResp = post_json_res .json() - .await?; + .await + .with_context(|_| HttpClientErrSnafu { url: OSCS_LIST_URL })?; Ok(oscs_list_resp) } @@ -78,9 +78,8 @@ impl OscCrawler { .get_list_resp(OSCS_PAGE_DEFAULT, OSCS_PER_PAGE_DEFAULT) .await?; let total = oscs_list_resp.data.total; - if total <= 0 { - return Err(Error::Message("oscs get total error".to_owned())); - } + ensure!(total > 0, InvalidOscsListTotalSnafu,); + let page_count = total / OSCS_PAGE_SIZE; if page_count == 0 { return Ok(1); @@ -119,9 +118,13 @@ impl OscCrawler { pub async fn parse_detail(&self, mps: &str) -> Result { let detail = self.get_detail_resp(mps).await?; - if detail.code != 200 || !detail.success || detail.data.is_empty() { - return Err(Error::Message(format!("oscs get: {} detail error", mps))); - }; + ensure!( + detail.code == 200 && detail.success && !detail.data.is_empty(), + InvalidOscsDetailSnafu { + mps, + code: detail.code + } + ); let data = detail.data[0].clone(); let severity = self.get_severity(&data.level); let disclosure = timestamp_to_date(data.publish_time)?; @@ -148,6 +151,7 @@ impl OscCrawler { reasons: vec![], github_search: vec![], is_valuable, + pushed: false, }; Ok(data) } @@ -174,12 +178,14 @@ impl OscCrawler { let params = serde_json::json!({ "vuln_no": mps, }); - let detail: OscsDetailResp = self - .help - .post_json(OSCS_DETAIL_URL, ¶ms) - .await? - .json() - .await?; + let post_json_res = self.help.post_json(OSCS_DETAIL_URL, ¶ms).await?; + let detail: OscsDetailResp = + post_json_res + .json() + .await + .with_context(|_| HttpClientErrSnafu { + url: OSCS_DETAIL_URL, + })?; Ok(detail) } diff --git a/src/grab/seebug.rs b/src/grab/seebug.rs index 44bc957..3778088 100644 --- a/src/grab/seebug.rs +++ b/src/grab/seebug.rs @@ -1,10 +1,13 @@ use async_trait::async_trait; -use eyre::eyre; use reqwest::header::{self}; use scraper::{ElementRef, Html, Selector}; +use snafu::{ensure, OptionExt, ResultExt}; use tracing::{info, warn}; -use crate::error::{Error, Result}; +use crate::error::{ + ElementAttrErrSnafu, InvalidSeebugPageNumSnafu, ParseIntErrSnafu, ParseSeeBugHtmlErrSnafu, + Result, SelectNthErrSnafu, SelectorSnafu, +}; use crate::grab::{Severity, VulnInfo}; use crate::utils::http_client::Help; @@ -55,71 +58,71 @@ impl SeeBugCrawler { pub async fn get_page_count(&self) -> Result { let document = self.get_document(SEEBUG_LIST_URL).await?; - let selector = Selector::parse("ul.pagination li a") - .map_err(|err| eyre!("parse html error {}", err))?; + let selector = Selector::parse("ul.pagination li a").context(SelectorSnafu)?; let page_nums = document .select(&selector) .map(|el| el.inner_html()) .collect::>(); - if page_nums.len() < 3 { - return Err(Error::Message( - "failed to get seebug pagination node".to_owned(), - )); - } - let total = page_nums[page_nums.len() - 1 - 1].parse::()?; + ensure!( + page_nums.len() >= 3, + InvalidSeebugPageNumSnafu { + num: page_nums.len() + } + ); + let total_str = &page_nums[page_nums.len() - 1 - 1]; + let total = total_str + .parse::() + .with_context(|_| ParseIntErrSnafu { num: total_str })?; Ok(total) } pub async fn parse_page(&self, page: i32) -> Result> { let url = format!("{}?page={}", SEEBUG_LIST_URL, page); let document = self.get_document(&url).await?; - let selector = Selector::parse(".sebug-table tbody tr") - .map_err(|err| eyre!("seebug parse html error {}", err))?; + let selector = Selector::parse(".sebug-table tbody tr").context(SelectorSnafu)?; let tr_elements = document.select(&selector).collect::>(); - if tr_elements.is_empty() { - return Err(Error::Message("failed to get seebug page".into())); - } + ensure!(!tr_elements.is_empty(), ParseSeeBugHtmlErrSnafu); let mut res = Vec::with_capacity(tr_elements.len()); for el in tr_elements { let (href, unique_key) = match self.get_href(el) { Ok((href, unique_key)) => (href, unique_key), Err(e) => { - warn!("seebug get href error {}", e); + warn!("seebug get href error {:?}", e); continue; } }; let disclosure = match self.get_disclosure(el) { Ok(disclosure) => disclosure, Err(e) => { - warn!("seebug get disclosure error {}", e); + warn!("seebug get disclosure error {:?}", e); continue; } }; let severity_title = match self.get_severity_title(el) { Ok(severity_title) => severity_title, Err(e) => { - warn!("seebug get severity title error {}", e); + warn!("seebug get severity title error {:?}", e); continue; } }; let title = match self.get_title(el) { Ok(title) => title, Err(e) => { - warn!("seebug get title error {}", e); + warn!("seebug get title error {:?}", e); continue; } }; let cve_id = match self.get_cve_id(el) { Ok(cve_id) => cve_id, Err(e) => { - warn!("seebug get cve_id error {}", e); + warn!("seebug get cve_id error {:?}", e); "".to_string() } }; let tag = match self.get_tag(el) { Ok(tag) => tag, Err(e) => { - warn!("seebug get tag error {}", e); + warn!("seebug get tag error {:?}", e); continue; } }; @@ -145,6 +148,7 @@ impl SeeBugCrawler { reasons: vec![], github_search: vec![], is_valuable, + pushed: false, }; res.push(data); } @@ -167,15 +171,17 @@ impl SeeBugCrawler { } fn get_href(&self, el: ElementRef) -> Result<(String, String)> { - let selector = Selector::parse("td a").map_err(|err| eyre!("parse html error {}", err))?; + let selector = Selector::parse("td a").context(SelectorSnafu)?; let a_element = el .select(&selector) .nth(0) - .ok_or_else(|| eyre!("value not found"))?; + .with_context(|| SelectNthErrSnafu { nth: 0_usize })?; let href = a_element .value() .attr("href") - .ok_or_else(|| eyre!("href not found"))? + .with_context(|| ElementAttrErrSnafu { + attr: "href".to_string(), + })? .trim(); let href = format!("https://www.seebug.org{}", href); let binding = a_element.inner_html(); @@ -184,73 +190,75 @@ impl SeeBugCrawler { } fn get_disclosure(&self, el: ElementRef) -> Result { - let selector = Selector::parse("td").map_err(|err| eyre!("parse html error {}", err))?; + let selector = Selector::parse("td").context(SelectorSnafu)?; let disclosure = el .select(&selector) .nth(1) - .ok_or_else(|| eyre!("value not found"))? + .with_context(|| SelectNthErrSnafu { nth: 1_usize })? .inner_html(); Ok(disclosure) } fn get_severity_title(&self, el: ElementRef) -> Result { - let selector = - Selector::parse("td div").map_err(|err| eyre!("parse html error {}", err))?; + let selector = Selector::parse("td div").context(SelectorSnafu)?; let td_element = el .select(&selector) .nth(0) - .ok_or_else(|| eyre!("severity_title div not found"))?; + .with_context(|| SelectNthErrSnafu { nth: 0_usize })?; let severity_title = td_element .value() .attr("data-original-title") - .ok_or_else(|| eyre!("href not found"))? + .with_context(|| ElementAttrErrSnafu { + attr: "data-original-title".to_string(), + })? .trim(); Ok(severity_title.to_owned()) } fn get_title(&self, el: ElementRef) -> Result { - let selector = Selector::parse("td a[class='vul-title']") - .map_err(|err| eyre!("parse html error {}", err))?; + let selector = Selector::parse("td a[class='vul-title']").context(SelectorSnafu)?; let title = el .select(&selector) .nth(0) - .ok_or_else(|| eyre!("title not found"))? + .with_context(|| SelectNthErrSnafu { nth: 0_usize })? .inner_html(); Ok(title) } fn get_cve_id(&self, el: ElementRef) -> Result { - let selector = Selector::parse("td i[class='fa fa-id-card ']") - .map_err(|err| eyre!("parse html error {}", err))?; + let selector = Selector::parse("td i[class='fa fa-id-card ']").context(SelectorSnafu)?; let cve_ids = el .select(&selector) .nth(0) - .ok_or_else(|| eyre!("cve id element not found"))? + .with_context(|| SelectNthErrSnafu { nth: 0_usize })? .value() .attr("data-original-title") - .ok_or_else(|| eyre!("data-original-title not found"))? + .with_context(|| ElementAttrErrSnafu { + attr: "data-original-title".to_string(), + })? .trim(); if cve_ids.contains('、') { return Ok(cve_ids .split('、') .nth(0) - .ok_or_else(|| eyre!("cve_ids split not found cve id"))? + .with_context(|| SelectNthErrSnafu { nth: 0_usize })? .to_owned()); } Ok(cve_ids.to_string()) } fn get_tag(&self, el: ElementRef) -> Result { - let selector = Selector::parse("td .fa-file-text-o") - .map_err(|err| eyre!("parse html error {}", err))?; + let selector = Selector::parse("td .fa-file-text-o").context(SelectorSnafu)?; let tag = el .select(&selector) .nth(0) - .ok_or_else(|| eyre!("tag element not found"))? + .with_context(|| SelectNthErrSnafu { nth: 0_usize })? .value() .attr("data-original-title") - .ok_or_else(|| eyre!("tag data-original-title not found"))? + .with_context(|| ElementAttrErrSnafu { + attr: "data-original-title".to_string(), + })? .trim(); Ok(tag.to_string()) } @@ -258,46 +266,34 @@ impl SeeBugCrawler { #[cfg(test)] mod tests { + use crate::error::IoSnafu; + use super::*; use std::fs; - #[tokio::test] - async fn test_seebug_get_cve() -> Result<()> { + fn get_first_element_cve(path: &str) -> Result { let seebug = SeeBugCrawler::new(); - // read fixtures/seebug.html - let html = fs::read_to_string("fixtures/seebug.html")?; + let html = fs::read_to_string(path).context(IoSnafu)?; let document = Html::parse_document(&html); - let selector = Selector::parse(".sebug-table tbody tr") - .map_err(|err| eyre!("seebug parse html error {}", err))?; + let selector = Selector::parse(".sebug-table tbody tr").context(SelectorSnafu)?; let tr_elements = document.select(&selector).collect::>(); - if tr_elements.is_empty() { - return Err(Error::Message("failed to get seebug page".into())); - } - let first = tr_elements - .first() - .ok_or_else(|| Error::Message("failed to get seebug page first element".to_string()))? - .to_owned(); + + ensure!(!tr_elements.is_empty(), ParseSeeBugHtmlErrSnafu); + let first = tr_elements[0].to_owned(); let cve_id = seebug.get_cve_id(first)?; + Ok(cve_id) + } + + #[test] + fn test_seebug_get_cve() -> Result<()> { + let cve_id = get_first_element_cve("fixtures/seebug.html")?; assert_eq!(cve_id, "CVE-2024-23692"); Ok(()) } - #[tokio::test] - async fn test_many_cve_seebug_get_cve() -> Result<()> { - let seebug = SeeBugCrawler::new(); - // read fixtures/seebug.html - let html = fs::read_to_string("fixtures/seebug_many_cve.html")?; - let document = Html::parse_document(&html); - let selector = Selector::parse(".sebug-table tbody tr") - .map_err(|err| eyre!("seebug parse html error {}", err))?; - let tr_elements = document.select(&selector).collect::>(); - if tr_elements.is_empty() { - return Err(Error::Message("failed to get seebug page".into())); - } - let first = tr_elements - .first() - .ok_or_else(|| Error::Message("failed to get seebug page first element".to_string()))? - .to_owned(); - let cve_id = seebug.get_cve_id(first)?; + #[test] + fn test_many_cve_seebug_get_cve() -> Result<()> { + let cve_id = get_first_element_cve("fixtures/seebug_many_cve.html")?; + assert_eq!(cve_id, "CVE-2023-50445"); Ok(()) } diff --git a/src/grab/threatbook.rs b/src/grab/threatbook.rs index 106483a..154b3d8 100644 --- a/src/grab/threatbook.rs +++ b/src/grab/threatbook.rs @@ -1,10 +1,11 @@ use async_trait::async_trait; use reqwest::header; use serde::{Deserialize, Serialize}; +use snafu::ResultExt; use tracing::info; use crate::{ - error::Result, + error::{HttpClientErrSnafu, Result}, grab::Severity, utils::{check_over_two_week, http_client::Help}, }; @@ -26,8 +27,12 @@ pub struct ThreadBookCrawler { impl Grab for ThreadBookCrawler { async fn get_update(&self, _page_limit: i32) -> Result> { let crawler = ThreadBookCrawler::new(); - let home_page_resp: ThreadBookHomePage = - crawler.help.get_json(HOME_PAGE_URL).await?.json().await?; + let get_json_res = crawler.help.get_json(HOME_PAGE_URL).await?; + let home_page_resp: ThreadBookHomePage = get_json_res + .json() + .await + .with_context(|_| HttpClientErrSnafu { url: HOME_PAGE_URL })?; + let mut res = Vec::with_capacity(home_page_resp.data.high_risk.len()); for v in home_page_resp.data.high_risk { let mut is_valuable = false; @@ -67,6 +72,7 @@ impl Grab for ThreadBookCrawler { reasons: Vec::new(), github_search: vec![], is_valuable, + pushed: false, }; res.push(vuln); } @@ -148,8 +154,13 @@ mod tests { #[tokio::test] async fn test_get_threat_book_homepage() -> Result<()> { let crawler = ThreadBookCrawler::new(); - let res: ThreadBookHomePage = crawler.help.get_json(HOME_PAGE_URL).await?.json().await?; - info!("{:?}", res); + let get_json_res = crawler.help.get_json(HOME_PAGE_URL).await?; + let home_page_resp: ThreadBookHomePage = get_json_res + .json() + .await + .with_context(|_| HttpClientErrSnafu { url: HOME_PAGE_URL })?; + + info!("{:?}", home_page_resp); Ok(()) } } diff --git a/src/grab/ti.rs b/src/grab/ti.rs index 390c736..a349e49 100644 --- a/src/grab/ti.rs +++ b/src/grab/ti.rs @@ -1,10 +1,11 @@ use async_trait::async_trait; use reqwest::header::{self}; use serde::{Deserialize, Serialize}; +use snafu::ResultExt; use tracing::info; use super::{Grab, Severity, VulnInfo}; -use crate::error::Result; +use crate::error::{HttpClientErrSnafu, Result}; use crate::utils::http_client::Help; const ONE_URL: &str = "https://ti.qianxin.com/alpha-api/v2/vuln/one-day"; @@ -51,7 +52,12 @@ impl TiCrawler { pub async fn get_ti_one_day_resp(&self) -> Result { let params = serde_json::json!({}); - let resp: TiOneDayResp = self.help.post_json(ONE_URL, ¶ms).await?.json().await?; + let post_json = self.help.post_json(ONE_URL, ¶ms).await?; + + let resp: TiOneDayResp = post_json + .json() + .await + .with_context(|_| HttpClientErrSnafu { url: ONE_URL })?; Ok(resp) } @@ -77,6 +83,7 @@ impl TiCrawler { reasons: vec![], github_search: vec![], is_valuable, + pushed: false, }; if vuln_infos .iter() diff --git a/src/models/error.rs b/src/models/error.rs deleted file mode 100644 index 167da21..0000000 --- a/src/models/error.rs +++ /dev/null @@ -1,17 +0,0 @@ -#[derive(thiserror::Error, Debug)] -pub enum ModelError { - #[error("Entity {} already exists", key)] - EntityAlreadyExists { key: String }, - - #[error("Entity not found")] - EntityNotFound, - - #[error("Entity update not found by key: {}", key)] - EntityUpdateNotFound { key: String }, - - #[error(transparent)] - DbErr(#[from] sea_orm::DbErr), - - #[error(transparent)] - Any(#[from] Box), -} diff --git a/src/models/mod.rs b/src/models/mod.rs index e82d7e8..82ed940 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,7 +1,3 @@ pub mod _entities; -pub mod error; -pub mod vuln_informations; - -pub use error::*; -pub type ModelResult = std::result::Result; +pub mod vuln_informations; diff --git a/src/models/vuln_informations.rs b/src/models/vuln_informations.rs index 8566d36..de5128f 100644 --- a/src/models/vuln_informations.rs +++ b/src/models/vuln_informations.rs @@ -2,36 +2,49 @@ use sea_orm::{ ActiveModelTrait, ActiveValue, ColumnTrait, DatabaseConnection, EntityTrait, PaginatorTrait, QueryFilter, TransactionTrait, }; +use snafu::{ensure, OptionExt, ResultExt}; use tracing::info; -use crate::{grab::VulnInfo, models::ModelError}; +use crate::{ + error::{DbAlreadyExistsSnafu, DbErrSnafu, DbNotFoundErrSnafu, Result}, + grab::VulnInfo, +}; -use super::{ModelResult, _entities::vuln_informations}; +use super::_entities::vuln_informations; const REASON_NEW_CREATED: &str = "漏洞创建"; const REASON_TAG_UPDATED: &str = "标签更新"; const REASON_SEVERITY_UPDATE: &str = "等级更新"; impl super::_entities::vuln_informations::Model { - pub async fn find_by_id(db: &DatabaseConnection, key: &str) -> ModelResult { + pub async fn find_by_id(db: &DatabaseConnection, key: &str) -> Result { let vuln = vuln_informations::Entity::find() .filter(vuln_informations::Column::Key.eq(key)) .one(db) - .await?; - vuln.ok_or_else(|| ModelError::EntityNotFound) + .await + .context(DbErrSnafu)?; + let res = vuln.with_context(|| DbNotFoundErrSnafu { + table: "vuln_informations".to_string(), + filter: key.to_string(), + })?; + Ok(res) } - pub async fn query_count(db: &DatabaseConnection) -> ModelResult { - let count = vuln_informations::Entity::find().count(db).await?; + pub async fn query_count(db: &DatabaseConnection) -> Result { + let count = vuln_informations::Entity::find() + .count(db) + .await + .context(DbErrSnafu)?; Ok(count) } - pub async fn creat_or_update(db: &DatabaseConnection, mut vuln: VulnInfo) -> ModelResult { - let txn = db.begin().await?; + pub async fn creat_or_update(db: &DatabaseConnection, mut vuln: VulnInfo) -> Result { + let txn = db.begin().await.context(DbErrSnafu)?; let v = vuln_informations::Entity::find() .filter(vuln_informations::Column::Key.eq(vuln.unique_key.clone())) .one(&txn) - .await?; + .await + .context(DbErrSnafu)?; if let Some(v) = v { let mut vuln_model: vuln_informations::ActiveModel = v.into(); let mut as_new_vuln = false; @@ -78,26 +91,30 @@ impl super::_entities::vuln_informations::Model { as_new_vuln = true } } - if as_new_vuln { - vuln_model.title = ActiveValue::set(vuln.title); - vuln_model.description = ActiveValue::set(vuln.description); - vuln_model.severtiy = ActiveValue::set(vuln.severity.to_string()); - vuln_model.disclosure = ActiveValue::set(vuln.disclosure); - vuln_model.solutions = ActiveValue::set(vuln.solutions); - vuln_model.references = ActiveValue::set(Some(vuln.references)); - vuln_model.tags = ActiveValue::set(Some(vuln.tags)); - vuln_model.from = ActiveValue::set(vuln.from); - vuln_model.reasons = ActiveValue::set(Some(vuln.reasons)); - vuln_model.is_valuable = ActiveValue::set(vuln.is_valuable); - // if tags or severtiy update should set pushed false, repush - vuln_model.pushed = ActiveValue::set(false); - let m = vuln_model.update(&txn).await?; - txn.commit().await?; - return Ok(m); - } - return Err(ModelError::EntityAlreadyExists { - key: vuln.unique_key, - }); + + ensure!( + as_new_vuln, + DbAlreadyExistsSnafu { + table: "vuln_informations".to_string(), + filter: vuln.unique_key.clone() + } + ); + + vuln_model.title = ActiveValue::set(vuln.title); + vuln_model.description = ActiveValue::set(vuln.description); + vuln_model.severtiy = ActiveValue::set(vuln.severity.to_string()); + vuln_model.disclosure = ActiveValue::set(vuln.disclosure); + vuln_model.solutions = ActiveValue::set(vuln.solutions); + vuln_model.references = ActiveValue::set(Some(vuln.references)); + vuln_model.tags = ActiveValue::set(Some(vuln.tags)); + vuln_model.from = ActiveValue::set(vuln.from); + vuln_model.reasons = ActiveValue::set(Some(vuln.reasons)); + vuln_model.is_valuable = ActiveValue::set(vuln.is_valuable); + // if tags or severtiy update should set pushed false, repush + vuln_model.pushed = ActiveValue::set(false); + let m = vuln_model.update(&txn).await.context(DbErrSnafu)?; + txn.commit().await.context(DbErrSnafu)?; + return Ok(m); } vuln.reasons.push(REASON_NEW_CREATED.to_owned()); let v = vuln_informations::ActiveModel { @@ -117,8 +134,9 @@ impl super::_entities::vuln_informations::Model { ..Default::default() } .insert(&txn) - .await?; - txn.commit().await?; + .await + .context(DbErrSnafu)?; + txn.commit().await.context(DbErrSnafu)?; Ok(v) } @@ -126,54 +144,57 @@ impl super::_entities::vuln_informations::Model { db: &DatabaseConnection, key: &str, links: Vec, - ) -> ModelResult<()> { - let txn = db.begin().await?; + ) -> Result<()> { + let txn = db.begin().await.context(DbErrSnafu)?; let v = vuln_informations::Entity::find() .filter(vuln_informations::Column::Key.eq(key)) .one(&txn) - .await?; - if let Some(v) = v { - let mut v: vuln_informations::ActiveModel = v.into(); - v.github_search = ActiveValue::set(Some(links)); - v.update(&txn).await?; - txn.commit().await?; - Ok(()) - } else { - Err(ModelError::EntityUpdateNotFound { - key: key.to_string(), - }) - } + .await + .context(DbErrSnafu)?; + let res = v.with_context(|| DbNotFoundErrSnafu { + table: "vuln_informations".to_string(), + filter: key.to_string(), + })?; + let mut v: vuln_informations::ActiveModel = res.into(); + v.github_search = ActiveValue::set(Some(links)); + v.update(&txn).await.context(DbErrSnafu)?; + txn.commit().await.context(DbErrSnafu)?; + Ok(()) } - pub async fn update_pushed_by_key(db: &DatabaseConnection, key: String) -> ModelResult<()> { - let txn = db.begin().await?; + pub async fn update_pushed_by_key(db: &DatabaseConnection, key: String) -> Result<()> { + let txn = db.begin().await.context(DbErrSnafu)?; let v = vuln_informations::Entity::find() .filter(vuln_informations::Column::Key.eq(key.clone())) .one(&txn) - .await?; - if let Some(v) = v { - let mut v: vuln_informations::ActiveModel = v.into(); - v.pushed = ActiveValue::set(true); - v.update(&txn).await?; - txn.commit().await?; - Ok(()) - } else { - Err(ModelError::EntityUpdateNotFound { key }) - } + .await + .context(DbErrSnafu)?; + + let res = v.with_context(|| DbNotFoundErrSnafu { + table: "vuln_informations".to_string(), + filter: key, + })?; + let mut v: vuln_informations::ActiveModel = res.into(); + v.pushed = ActiveValue::set(true); + v.update(&txn).await.context(DbErrSnafu)?; + txn.commit().await.context(DbErrSnafu)?; + Ok(()) } - pub async fn create(db: &DatabaseConnection, vuln: VulnInfo) -> ModelResult { - let txn = db.begin().await?; - if vuln_informations::Entity::find() + pub async fn create(db: &DatabaseConnection, vuln: VulnInfo) -> Result { + let txn = db.begin().await.context(DbErrSnafu)?; + let res = vuln_informations::Entity::find() .filter(vuln_informations::Column::Key.eq(vuln.unique_key.clone())) .one(&txn) - .await? - .is_some() - { - return Err(ModelError::EntityAlreadyExists { - key: vuln.unique_key, - }); - } + .await + .context(DbErrSnafu)?; + ensure!( + res.is_some(), + DbAlreadyExistsSnafu { + table: "vuln_informations".to_string(), + filter: vuln.unique_key.clone(), + } + ); let v = vuln_informations::ActiveModel { key: ActiveValue::set(vuln.unique_key), title: ActiveValue::set(vuln.title), @@ -189,8 +210,9 @@ impl super::_entities::vuln_informations::Model { ..Default::default() } .insert(&txn) - .await?; - txn.commit().await?; + .await + .context(DbErrSnafu)?; + txn.commit().await.context(DbErrSnafu)?; Ok(v) } } diff --git a/src/push/dingding.rs b/src/push/dingding.rs index 4fa3fb5..df2e163 100644 --- a/src/push/dingding.rs +++ b/src/push/dingding.rs @@ -1,14 +1,14 @@ use std::time::SystemTime; use crate::{ - error::{Error, Result}, + error::{DingPushErrSnafu, HttpClientErrSnafu, Result, SystemTimeErrSnafu}, utils::{calc_hmac_sha256, http_client::Help}, }; use async_trait::async_trait; use base64::prelude::*; use reqwest::header; use serde::{Deserialize, Serialize}; -use tracing::warn; +use snafu::{ensure, ResultExt}; use super::MessageBot; @@ -37,25 +37,26 @@ impl MessageBot for DingDing { let sign = self.generate_sign()?; - let res: DingResponse = help + let send_res = help .http_client .post(DING_API_URL) .query(&sign) .json(&message) .send() - .await? + .await + .with_context(|_| HttpClientErrSnafu { url: DING_API_URL })?; + + let res: DingResponse = send_res .json() - .await?; + .await + .with_context(|_| HttpClientErrSnafu { url: DING_API_URL })?; - if res.errcode != 0 { - warn!( - "ding push markdown message error, err msg is {}", - res.errmsg - ); - return Err(Error::Message( - "ding push markdown message response errorcode not eq 0".to_owned(), - )); - } + ensure!( + res.errcode == 0, + DingPushErrSnafu { + errorcode: res.errcode + } + ); Ok(()) } } @@ -75,7 +76,8 @@ impl DingDing { } pub fn generate_sign(&self) -> Result { let timestamp = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH)? + .duration_since(SystemTime::UNIX_EPOCH) + .context(SystemTimeErrSnafu)? .as_millis(); let timestamp_and_secret = &format!("{}\n{}", timestamp, self.secret_token); let hmac_sha256 = calc_hmac_sha256( diff --git a/src/push/lark.rs b/src/push/lark.rs index 11245d6..de22670 100644 --- a/src/push/lark.rs +++ b/src/push/lark.rs @@ -5,9 +5,9 @@ use reqwest::header; use serde::{Deserialize, Serialize}; use serde_json::json; use sha2::Sha256; -use tracing::warn; +use snafu::{ensure, ResultExt}; -use crate::error::{Error, Result}; +use crate::error::{CryptoSnafu, HttpClientErrSnafu, LarkPushErrSnafu, Result}; use crate::utils::http_client::Help; use super::MessageBot; @@ -27,20 +27,20 @@ impl MessageBot for Lark { let help = self.get_help(); let message = self.generate_lark_card(title, msg)?; let url = format!("{}/{}", LARK_HOOK_URL, self.access_token); - let res: LarkResponse = help + let url_clone = url.clone(); + let send_res = help .http_client - .post(url) + .post(&url) .json(&message) .send() - .await? + .await + .with_context(|_| HttpClientErrSnafu { url: url_clone })?; + let res: LarkResponse = send_res .json() - .await?; - if res.code != 0 { - warn!("lark push markdown message error, err msg is {}", res.msg); - return Err(Error::Message( - "lark push markdown message response errorcode not eq 0".to_owned(), - )); - } + .await + .with_context(|_| HttpClientErrSnafu { url })?; + ensure!(res.code == 0, LarkPushErrSnafu { code: res.code }); + Ok(()) } } @@ -61,7 +61,8 @@ impl Lark { pub fn generate_sign(&self, timestamp: i64) -> Result { let timestamp_and_secret = format!("{}\n{}", timestamp, self.secret_token); - let hmac: Hmac = Hmac::new_from_slice(timestamp_and_secret.as_bytes())?; + let hmac: Hmac = + Hmac::new_from_slice(timestamp_and_secret.as_bytes()).context(CryptoSnafu)?; let hmac_code = hmac.finalize().into_bytes(); let sign = BASE64_STANDARD.encode(hmac_code); Ok(sign) diff --git a/src/push/msg_template.rs b/src/push/msg_template.rs index 7a8c262..6a32ce9 100644 --- a/src/push/msg_template.rs +++ b/src/push/msg_template.rs @@ -1,7 +1,8 @@ -use crate::error::Result; +use crate::error::{JsonErrSnafu, Result}; use crate::grab::VulnInfo; use crate::utils::render_string; use serde_json::Value; +use snafu::ResultExt; const VULN_INFO_MSG_TEMPLATE: &str = r####" # {{ title }} @@ -43,7 +44,7 @@ pub fn reader_vulninfo(mut vuln: VulnInfo) -> Result { if vuln.references.len() > MAX_REFERENCE_LENGTH { vuln.references = vuln.references[..MAX_REFERENCE_LENGTH].to_vec(); } - let json_value: Value = serde_json::to_value(vuln)?; + let json_value: Value = serde_json::to_value(vuln).context(JsonErrSnafu)?; let markdown = render_string(VULN_INFO_MSG_TEMPLATE, &json_value)?; Ok(markdown) } @@ -116,6 +117,7 @@ mod tests { reasons, github_search: vec![], is_valuable: false, + pushed: false, }; let res = reader_vulninfo(v)?; println!("{}", res); diff --git a/src/push/telegram.rs b/src/push/telegram.rs index 2944219..e170ec2 100644 --- a/src/push/telegram.rs +++ b/src/push/telegram.rs @@ -1,5 +1,6 @@ -use crate::error::Result; +use crate::error::{Result, TeloxideErrSnafu}; use async_trait::async_trait; +use snafu::ResultExt; use teloxide::{prelude::*, types::ParseMode}; use super::{msg_template::escape_markdown, MessageBot}; @@ -18,7 +19,8 @@ impl MessageBot for Telegram { .send_message(self.chat_id, msg) .parse_mode(ParseMode::MarkdownV2) .send() - .await?; + .await + .context(TeloxideErrSnafu)?; Ok(()) } } diff --git a/src/search.rs b/src/search.rs index db90fa2..954ada6 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,7 +1,11 @@ use regex::Regex; +use snafu::ResultExt; use tracing::{info, warn}; -use crate::{error::Result, utils::get_last_year_data}; +use crate::{ + error::{OctocrabErrSnafu, RegexErrSnafu, Result}, + utils::get_last_year_data, +}; pub async fn search_github_poc(cve_id: &str) -> Vec { let mut res = Vec::new(); @@ -9,13 +13,13 @@ pub async fn search_github_poc(cve_id: &str) -> Vec { match nuclei_res { Ok(nuclei) => res.extend(nuclei), Err(e) => { - warn!("search nucli pr error:{}", e); + warn!("search nucli pr error:{:?}", e); } } match repo_res { Ok(repo) => res.extend(repo), Err(e) => { - warn!("search github repo error:{}", e); + warn!("search github repo error:{:?}", e); } } res @@ -29,14 +33,18 @@ pub async fn search_nuclei_pr(cve_id: &str) -> Result> { .per_page(100) .page(1u32) .send() - .await?; - let re = Regex::new(&format!(r"(?i)(?:\b|/|_){}(?:\b|/|_)", cve_id))?; + .await + .with_context(|_| OctocrabErrSnafu { + search: cve_id.to_owned(), + })?; + let re = format!(r"(?i)(?:\b|/|_){}(?:\b|/|_)", cve_id); + let regex = Regex::new(re.as_str()).with_context(|_| RegexErrSnafu { re })?; let links = page .into_iter() .filter(|pull| pull.title.is_some() || pull.body.is_some()) .filter(|pull| { - re.is_match(pull.title.as_ref().unwrap_or(&String::new())) - || re.is_match(pull.body.as_ref().unwrap_or(&String::new())) + regex.is_match(pull.title.as_ref().unwrap_or(&String::new())) + || regex.is_match(pull.body.as_ref().unwrap_or(&String::new())) }) .filter_map(|pull| pull.html_url) .map(|u| u.to_string()) @@ -54,12 +62,16 @@ pub async fn search_github_repo(cve_id: &str) -> Result> { .per_page(100) .page(1u32) .send() - .await?; - let re = Regex::new(&format!(r"(?i)(?:\b|/|_){}(?:\b|/|_)", cve_id))?; + .await + .with_context(|_| OctocrabErrSnafu { + search: cve_id.to_owned(), + })?; + let re = format!(r"(?i)(?:\b|/|_){}(?:\b|/|_)", cve_id); + let regex = Regex::new(re.as_str()).with_context(|_| RegexErrSnafu { re })?; let links = page .into_iter() .filter_map(|r| r.html_url) - .filter(|url| re.captures(url.as_str()).is_some()) + .filter(|url| regex.captures(url.as_str()).is_some()) .map_while(|u| Some(u.to_string())) .collect::>(); Ok(links) diff --git a/src/utils/http_client.rs b/src/utils/http_client.rs index 2a3fab2..3da33f7 100644 --- a/src/utils/http_client.rs +++ b/src/utils/http_client.rs @@ -1,5 +1,6 @@ -use crate::error::Result; +use crate::error::{HttpClientErrSnafu, Result}; use reqwest::header::{self, HeaderMap}; +use snafu::ResultExt; #[derive(Debug, Clone)] pub struct Help { @@ -28,12 +29,27 @@ impl Help { } pub async fn get_json(&self, url: &str) -> Result { - let content = self.http_client.get(url).send().await?; + let content = self + .http_client + .get(url) + .send() + .await + .with_context(|_| HttpClientErrSnafu { url })?; Ok(content) } pub async fn get_html_content(&self, url: &str) -> Result { - let content = self.http_client.get(url).send().await?.text().await?; + let send_res = self + .http_client + .get(url) + .send() + .await + .with_context(|_| HttpClientErrSnafu { url })?; + + let content = send_res + .text() + .await + .with_context(|_| HttpClientErrSnafu { url })?; Ok(content) } @@ -41,7 +57,13 @@ impl Help { where Body: serde::Serialize, { - let content = self.http_client.post(url).json(body).send().await?; + let content = self + .http_client + .post(url) + .json(body) + .send() + .await + .with_context(|_| HttpClientErrSnafu { url })?; Ok(content) } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 5fe0fa6..1168091 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,9 +1,12 @@ pub mod http_client; -use crate::error::{Error, Result}; +use crate::error::{ + ChronoParseErrSnafu, CryptoSnafu, DateTimeFromTimestampErrSnafu, Result, TeraErrSnafu, +}; use chrono::{DateTime, Duration, Local, NaiveDate, Utc}; use hmac::{Hmac, Mac}; use sha2::Sha256; +use snafu::{OptionExt, ResultExt}; use tera::{Context, Tera}; pub fn get_last_year_data() -> String { @@ -13,7 +16,8 @@ pub fn get_last_year_data() -> String { } pub fn check_over_two_week(date: &str) -> Result { - let target_date = NaiveDate::parse_from_str(date, "%Y-%m-%d")?; + let target_date = NaiveDate::parse_from_str(date, "%Y-%m-%d") + .with_context(|_| ChronoParseErrSnafu { date })?; let now = Utc::now().naive_utc().date(); let two_weeks_ago = now - Duration::weeks(2); if target_date >= two_weeks_ago && target_date <= now { @@ -24,26 +28,29 @@ pub fn check_over_two_week(date: &str) -> Result { // data_str_format convernt 20240603 to 2024-06-03 pub fn data_str_format(date: &str) -> Result { - let date = NaiveDate::parse_from_str(date, "%Y%m%d")?; + let date = + NaiveDate::parse_from_str(date, "%Y%m%d").with_context(|_| ChronoParseErrSnafu { date })?; let formatted_date = format!("{}", date.format("%Y-%m-%d")); Ok(formatted_date) } pub fn timestamp_to_date(timestamp: i64) -> Result { let dt = DateTime::from_timestamp_millis(timestamp); - if let Some(dt) = dt { - return Ok(dt.format("%Y-%m-%d").to_string()); - } - Err(Error::Message("convert timestamp to date error".to_owned())) + let res = dt.with_context(|| DateTimeFromTimestampErrSnafu { timestamp })?; + Ok(res.format("%Y-%m-%d").to_string()) } pub fn render_string(tera_template: &str, locals: &serde_json::Value) -> Result { - let text = Tera::one_off(tera_template, &Context::from_serialize(locals)?, false)?; - Ok(text) + Tera::one_off( + tera_template, + &Context::from_serialize(locals).with_context(|_| TeraErrSnafu)?, + false, + ) + .with_context(|_| TeraErrSnafu) } pub fn calc_hmac_sha256(key: &[u8], message: &[u8]) -> Result> { - let mut mac = Hmac::::new_from_slice(key)?; + let mut mac = Hmac::::new_from_slice(key).with_context(|_| CryptoSnafu)?; mac.update(message); Ok(mac.finalize().into_bytes().to_vec()) } @@ -60,7 +67,10 @@ mod tests { #[test] pub fn test_check_over_two_week() -> Result<()> { - let res = check_over_two_week("2024-06-03")?; + let now = Utc::now().naive_utc().date(); + let one_weeks_ago = now - Duration::weeks(1); + let data_str = one_weeks_ago.format("%Y-%m-%d").to_string(); + let res = check_over_two_week(&data_str)?; assert!(!res); let res = check_over_two_week("2024-05-03")?; assert!(res);