diff --git a/detectors/avoid-panic-error/src/lib.rs b/detectors/avoid-panic-error/src/lib.rs index 58f09097..72895519 100644 --- a/detectors/avoid-panic-error/src/lib.rs +++ b/detectors/avoid-panic-error/src/lib.rs @@ -3,22 +3,21 @@ extern crate rustc_ast; extern crate rustc_span; -use clippy_utils::sym; -use clippy_wrappers::span_lint_and_help; -use if_chain::if_chain; +use clippy_utils::diagnostics::span_lint_and_help; use rustc_ast::{ - ptr::P, - token::{LitKind, TokenKind}, - tokenstream::{TokenStream, TokenTree}, - AttrArgs, AttrKind, Expr, ExprKind, Item, MacCall, StmtKind, + tokenstream::TokenTree, + visit::{walk_expr, Visitor}, + AssocItemKind, AttrArgs, AttrKind, Block, Expr, ExprKind, FnRetTy, Item, ItemKind, MacCall, + ModKind, TyKind, }; use rustc_lint::{EarlyContext, EarlyLintPass}; use rustc_span::{sym, Span}; -const LINT_MESSAGE: &str = "The panic! macro is used to stop execution when a condition is not met. Even when this does not break the execution of the contract, it is recommended to use Result instead of panic! because it will stop the execution of the caller contract"; +const LINT_MESSAGE: &str = "The panic! macro is used in a function that returns Result. \ + Consider using the ? operator or return Err() instead."; dylint_linting::impl_pre_expansion_lint! { - /// ### What it does + /// ### What it does /// The panic! macro is used to stop execution when a condition is not met. /// This is useful for testing and prototyping, but should be avoided in production code /// @@ -62,7 +61,8 @@ dylint_linting::impl_pre_expansion_lint! { AvoidPanicError::default(), { name: "Avoid panic! macro", - long_message: "The use of the panic! macro to stop execution when a condition is not met is useful for testing and prototyping but should be avoided in production code. Using Result as the return type for functions that can fail is the idiomatic way to handle errors in Rust. ", + long_message: "Using panic! in functions that return Result defeats the purpose of error handling. \ + Consider propagating the error using ? or return Err() instead.", severity: "Enhancement", help: "https://coinfabrik.github.io/scout-soroban/docs/detectors/avoid-panic-error", vulnerability_class: "Validations and error handling", @@ -75,123 +75,120 @@ pub struct AvoidPanicError { } impl EarlyLintPass for AvoidPanicError { - fn check_item(&mut self, _cx: &EarlyContext, item: &Item) { - match (is_test_item(item), self.in_test_span) { - (true, None) => self.in_test_span = Some(item.span), - (true, Some(test_span)) => { - if !test_span.contains(item.span) { - self.in_test_span = Some(item.span); + fn check_item(&mut self, cx: &EarlyContext, item: &Item) { + if is_test_item(item) { + self.in_test_span = Some(item.span); + return; + } + + if let Some(test_span) = self.in_test_span { + if !test_span.contains(item.span) { + self.in_test_span = None; + } else { + return; + } + } + + match &item.kind { + ItemKind::Impl(impl_item) => { + for assoc_item in &impl_item.items { + if let AssocItemKind::Fn(fn_item) = &assoc_item.kind { + self.check_function( + cx, + &fn_item.sig.decl.output, + fn_item.body.as_ref().unwrap(), + ); + } } } - (false, None) => {} - (false, Some(test_span)) => { - if !test_span.contains(item.span) { - self.in_test_span = None; + ItemKind::Fn(fn_item) => { + self.check_function(cx, &fn_item.sig.decl.output, fn_item.body.as_ref().unwrap()); + } + ItemKind::Mod(_, ModKind::Loaded(items, _, _)) => { + for item in items { + self.check_item(cx, item); } } - }; - } - - fn check_stmt(&mut self, cx: &EarlyContext<'_>, stmt: &rustc_ast::Stmt) { - if_chain! { - if !self.in_test_item(); - if let StmtKind::MacCall(mac) = &stmt.kind; - then { - check_macro_call(cx, stmt.span, &mac.mac) + ItemKind::Trait(trait_item) => { + for item in &trait_item.items { + if let AssocItemKind::Fn(fn_item) = &item.kind { + self.check_function( + cx, + &fn_item.sig.decl.output, + fn_item.body.as_ref().unwrap(), + ); + } + } } + _ => {} } } +} - fn check_expr(&mut self, cx: &EarlyContext, expr: &Expr) { - if_chain! { - if !self.in_test_item(); - if let ExprKind::MacCall(mac) = &expr.kind; - then { - check_macro_call(cx, expr.span, mac) - } +impl AvoidPanicError { + fn check_function(&self, cx: &EarlyContext, output: &FnRetTy, body: &Block) { + if is_result_type(output) { + let mut visitor = PanicVisitor { cx }; + visitor.visit_block(body); } } } -fn check_macro_call(cx: &EarlyContext, span: Span, mac: &P) { - if_chain! { - if mac.path == sym!(panic); - if let [TokenTree::Token(token, _)] = mac - .args - .tokens - .clone() - .trees() - .collect::>() - .as_slice(); - if let TokenKind::Literal(lit) = token.kind; - if lit.kind == LitKind::Str; - then { - span_lint_and_help( - cx, - AVOID_PANIC_ERROR, - span, - LINT_MESSAGE, - None, - format!("You could use instead an Error enum and then 'return Err(Error::{})'", capitalize_err_msg(lit.symbol.as_str()).replace(' ', "")) - ) +struct PanicVisitor<'a, 'tcx> { + cx: &'a EarlyContext<'tcx>, +} + +impl<'a, 'tcx> Visitor<'tcx> for PanicVisitor<'a, 'tcx> { + fn visit_expr(&mut self, expr: &'tcx Expr) { + if let ExprKind::MacCall(mac) = &expr.kind { + check_macro_call(self.cx, expr.span, mac); } + walk_expr(self, expr); + } +} + +fn check_macro_call(cx: &EarlyContext, span: Span, mac: &MacCall) { + if mac.path == sym::panic { + let suggestion = "Consider using '?' to propagate errors or 'return Err()' to return early with an error"; + span_lint_and_help(cx, AVOID_PANIC_ERROR, span, LINT_MESSAGE, None, suggestion); } } fn is_test_item(item: &Item) -> bool { item.attrs.iter().any(|attr| { - // Find #[cfg(all(test, feature = "e2e-tests"))] - if_chain!( - if let AttrKind::Normal(normal) = &attr.kind; - if let AttrArgs::Delimited(delim_args) = &normal.item.args; - if is_test_token_present(&delim_args.tokens); - then { - return true; - } - ); - - // Find unit or integration tests - if attr.has_name(sym::test) { - return true; - } - - if_chain! { - if attr.has_name(sym::cfg); - if let Some(items) = attr.meta_item_list(); - if let [item] = items.as_slice(); - if let Some(feature_item) = item.meta_item(); - if feature_item.has_name(sym::test); - then { - return true; - } - } - - false + attr.has_name(sym::test) + || (attr.has_name(sym::cfg) + && attr.meta_item_list().map_or(false, |list| { + list.iter().any(|item| item.has_name(sym::test)) + })) + || matches!( + &attr.kind, + AttrKind::Normal(normal) if is_test_token_present(&normal.item.args) + ) }) } -impl AvoidPanicError { - fn in_test_item(&self) -> bool { - self.in_test_span.is_some() +fn is_test_token_present(args: &AttrArgs) -> bool { + if let AttrArgs::Delimited(delim_args) = args { + delim_args.tokens.trees().any( + |tree| matches!(tree, TokenTree::Token(token, _) if token.is_ident_named(sym::test)), + ) + } else { + false } } -fn capitalize_err_msg(s: &str) -> String { - s.split_whitespace() - .map(|word| { - let mut chars = word.chars(); - match chars.next() { - None => String::new(), - Some(f) => f.to_uppercase().collect::() + chars.as_str(), +fn is_result_type(output: &FnRetTy) -> bool { + match output { + FnRetTy::Default(_) => false, + FnRetTy::Ty(ty) => { + if let TyKind::Path(None, path) = &ty.kind { + path.segments + .last() + .map_or(false, |seg| seg.ident.name == sym::Result) + } else { + false } - }) - .collect::>() - .join(" ") -} - -fn is_test_token_present(token_stream: &TokenStream) -> bool { - token_stream.trees().any(|tree| match tree { - TokenTree::Token(token, _) => token.is_ident_named(sym::test), - TokenTree::Delimited(_, _, _, token_stream) => is_test_token_present(token_stream), - }) + } + } } diff --git a/test-cases/avoid-panic-error/avoid-panic-error-1/remediated-example/src/lib.rs b/test-cases/avoid-panic-error/avoid-panic-error-1/remediated-example/src/lib.rs index 4a50c7d8..a0a0cda9 100644 --- a/test-cases/avoid-panic-error/avoid-panic-error-1/remediated-example/src/lib.rs +++ b/test-cases/avoid-panic-error/avoid-panic-error-1/remediated-example/src/lib.rs @@ -9,18 +9,18 @@ pub struct AvoidPanicError; #[contracterror] #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] #[repr(u32)] -pub enum Error { +pub enum PanicError { OverflowError = 1, } #[contractimpl] impl AvoidPanicError { - pub fn add(env: Env, value: u32) -> Result { + pub fn add(env: Env, value: u32) -> Result { let storage = env.storage().instance(); let mut count: u32 = storage.get(&COUNTER).unwrap_or(0); match count.checked_add(value) { Some(value) => count = value, - None => return Err(Error::OverflowError), + None => return Err(PanicError::OverflowError), } storage.set(&COUNTER, &count); storage.extend_ttl(100, 100); @@ -32,7 +32,7 @@ impl AvoidPanicError { mod tests { use soroban_sdk::Env; - use crate::{AvoidPanicError, AvoidPanicErrorClient, Error}; + use crate::{AvoidPanicError, AvoidPanicErrorClient, PanicError}; #[test] fn add() { @@ -64,6 +64,6 @@ mod tests { let overflow = client.try_add(&1); // Then - assert_eq!(overflow, Err(Ok(Error::OverflowError))); + assert_eq!(overflow, Err(Ok(PanicError::OverflowError))); } } diff --git a/test-cases/avoid-panic-error/avoid-panic-error-1/vulnerable-example/src/lib.rs b/test-cases/avoid-panic-error/avoid-panic-error-1/vulnerable-example/src/lib.rs index 7119ed1d..0aa1758f 100644 --- a/test-cases/avoid-panic-error/avoid-panic-error-1/vulnerable-example/src/lib.rs +++ b/test-cases/avoid-panic-error/avoid-panic-error-1/vulnerable-example/src/lib.rs @@ -1,14 +1,21 @@ #![no_std] -use soroban_sdk::{contract, contractimpl, symbol_short, Env, Symbol}; + +use soroban_sdk::{contract, contracterror, contractimpl, symbol_short, Env, Symbol}; const COUNTER: Symbol = symbol_short!("COUNTER"); #[contract] pub struct AvoidPanicError; +#[contracterror] +#[derive(Copy, Clone)] +pub enum PanicError { + Overflow = 1, +} + #[contractimpl] impl AvoidPanicError { - pub fn add(env: Env, value: u32) -> u32 { + pub fn add(env: Env, value: u32) -> Result { let storage = env.storage().instance(); let mut count: u32 = storage.get(&COUNTER).unwrap_or(0); match count.checked_add(value) { @@ -17,7 +24,7 @@ impl AvoidPanicError { } storage.set(&COUNTER, &count); storage.extend_ttl(100, 100); - count + Ok(count) } } diff --git a/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/Cargo.toml b/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/Cargo.toml new file mode 100644 index 00000000..e64d4f1f --- /dev/null +++ b/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/Cargo.toml @@ -0,0 +1,13 @@ +[package] +edition = "2021" +name = "avoid-panic-error-remediated-2" +version = "0.1.0" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +soroban-sdk = { workspace = true } + +[features] +testutils = ["soroban-sdk/testutils"] diff --git a/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/src/lib.rs b/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/src/lib.rs new file mode 100644 index 00000000..6093d1f5 --- /dev/null +++ b/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/src/lib.rs @@ -0,0 +1,63 @@ +#![no_std] +use soroban_sdk::{contract, contractimpl, symbol_short, Env, Symbol}; + +const COUNTER: Symbol = symbol_short!("COUNTER"); + +#[contract] +pub struct AvoidPanicError; + +#[contractimpl] +impl AvoidPanicError { + pub fn add(env: Env, value: u32) -> u32 { + let storage = env.storage().instance(); + let mut count: u32 = storage.get(&COUNTER).unwrap_or(0); + match count.checked_add(value) { + Some(value) => count = value, + None => panic!("Overflow error"), + } + storage.set(&COUNTER, &count); + storage.extend_ttl(100, 100); + count + } +} + +#[cfg(test)] +mod tests { + use soroban_sdk::Env; + + use crate::{AvoidPanicError, AvoidPanicErrorClient}; + + #[test] + fn add() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, AvoidPanicError); + let client = AvoidPanicErrorClient::new(&env, &contract_id); + + // When + let first_increment = client.try_add(&1); + let second_increment = client.try_add(&2); + let third_increment = client.try_add(&3); + + // Then + assert_eq!(first_increment, Ok(Ok(1))); + assert_eq!(second_increment, Ok(Ok(3))); + assert_eq!(third_increment, Ok(Ok(6))); + } + + #[test] + #[should_panic(expected = "Overflow error")] + fn overflow() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, AvoidPanicError); + let client = AvoidPanicErrorClient::new(&env, &contract_id); + + // When + client.add(&u32::MAX); + client.add(&1); + + // Then + // panic + } +} diff --git a/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/Cargo.toml b/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/Cargo.toml new file mode 100644 index 00000000..c690ea8e --- /dev/null +++ b/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/Cargo.toml @@ -0,0 +1,13 @@ +[package] +edition = "2021" +name = "avoid-panic-error-vulnerable-2" +version = "0.1.0" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +soroban-sdk = { workspace = true } + +[features] +testutils = ["soroban-sdk/testutils"] diff --git a/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/src/lib.rs b/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/src/lib.rs new file mode 100644 index 00000000..0aa1758f --- /dev/null +++ b/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/src/lib.rs @@ -0,0 +1,70 @@ +#![no_std] + +use soroban_sdk::{contract, contracterror, contractimpl, symbol_short, Env, Symbol}; + +const COUNTER: Symbol = symbol_short!("COUNTER"); + +#[contract] +pub struct AvoidPanicError; + +#[contracterror] +#[derive(Copy, Clone)] +pub enum PanicError { + Overflow = 1, +} + +#[contractimpl] +impl AvoidPanicError { + pub fn add(env: Env, value: u32) -> Result { + let storage = env.storage().instance(); + let mut count: u32 = storage.get(&COUNTER).unwrap_or(0); + match count.checked_add(value) { + Some(value) => count = value, + None => panic!("Overflow error"), + } + storage.set(&COUNTER, &count); + storage.extend_ttl(100, 100); + Ok(count) + } +} + +#[cfg(test)] +mod tests { + use soroban_sdk::Env; + + use crate::{AvoidPanicError, AvoidPanicErrorClient}; + + #[test] + fn add() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, AvoidPanicError); + let client = AvoidPanicErrorClient::new(&env, &contract_id); + + // When + let first_increment = client.add(&1); + let second_increment = client.add(&2); + let third_increment = client.add(&3); + + // Then + assert_eq!(first_increment, 1); + assert_eq!(second_increment, 3); + assert_eq!(third_increment, 6); + } + + #[test] + #[should_panic(expected = "Overflow error")] + fn overflow() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, AvoidPanicError); + let client = AvoidPanicErrorClient::new(&env, &contract_id); + + // When + client.add(&u32::MAX); + client.add(&1); + + // Then + // panic + } +}