diff --git a/detectors/unsafe-unwrap/Cargo.toml b/detectors/unsafe-unwrap/Cargo.toml index cb8e056c..4812ff59 100644 --- a/detectors/unsafe-unwrap/Cargo.toml +++ b/detectors/unsafe-unwrap/Cargo.toml @@ -11,6 +11,7 @@ clippy_utils = { workspace = true } clippy_wrappers = { workspace = true } dylint_linting = { workspace = true } if_chain = { workspace = true } +utils = { workspace = true } [package.metadata.rust-analyzer] rustc_private = true diff --git a/detectors/unsafe-unwrap/src/lib.rs b/detectors/unsafe-unwrap/src/lib.rs index 6f2026d4..becba5ec 100644 --- a/detectors/unsafe-unwrap/src/lib.rs +++ b/detectors/unsafe-unwrap/src/lib.rs @@ -1,23 +1,22 @@ #![feature(rustc_private)] #![allow(clippy::enum_variant_names)] -extern crate rustc_ast; extern crate rustc_hir; extern crate rustc_span; -use std::{collections::HashSet, hash::Hash}; - -use clippy_utils::higher; +use clippy_utils::higher::IfOrIfLet; use clippy_wrappers::span_lint_and_help; use if_chain::if_chain; use rustc_hir::{ def::Res, def_id::LocalDefId, intravisit::{walk_expr, FnKind, Visitor}, - BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, PathSegment, QPath, UnOp, + BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, LetStmt, PathSegment, QPath, UnOp, }; use rustc_lint::{LateContext, LateLintPass}; use rustc_span::{sym, Span, Symbol}; +use std::{collections::HashSet, hash::Hash}; +use utils::{get_node_type_opt, match_type_to_str, returns_result}; const LINT_MESSAGE: &str = "Unsafe usage of `unwrap`"; const PANIC_INDUCING_FUNCTIONS: [&str; 2] = ["panic", "bail"]; @@ -65,6 +64,13 @@ dylint_linting::declare_late_lint! { } } +// Enum to represent the type being unwrapped +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +enum UnwrapType { + Option, + Result, +} + /// Represents the type of check performed on method call expressions to determine their safety or behavior. #[derive(Clone, Copy, Hash, Eq, PartialEq)] enum CheckType { @@ -132,7 +138,7 @@ impl ConditionalChecker { fn from_expression(condition: &Expr<'_>) -> HashSet { match condition.kind { // Single `not` expressions are supported - ExprKind::Unary(op, condition) => Self::handle_condition(condition, op == UnOp::Not), + ExprKind::Unary(UnOp::Not, condition) => Self::handle_condition(condition, true), // Multiple `or` expressions are supported ExprKind::Binary(op, left_condition, right_condition) if op.node == BinOpKind::Or => { let mut result = Self::from_expression(left_condition); @@ -150,9 +156,33 @@ struct UnsafeUnwrapVisitor<'a, 'tcx> { cx: &'a LateContext<'tcx>, conditional_checker: HashSet, checked_exprs: HashSet, + returns_result: bool, } impl UnsafeUnwrapVisitor<'_, '_> { + fn get_help_message(&self, unwrap_type: UnwrapType) -> &'static str { + match (self.returns_result, unwrap_type) { + (true, UnwrapType::Option) => "Consider using `ok_or` to convert Option to Result", + (true, UnwrapType::Result) => "Consider using the `?` operator for error propagation", + (false, UnwrapType::Option) => { + "Consider pattern matching or using `if let` instead of `unwrap`" + } + (false, UnwrapType::Result) => { + "Consider handling the error case explicitly or using `if let` instead of `unwrap`" + } + } + } + + fn determine_unwrap_type(&self, receiver: &Expr<'_>) -> UnwrapType { + let type_opt = get_node_type_opt(self.cx, &receiver.hir_id); + if let Some(type_) = type_opt { + if match_type_to_str(self.cx, type_, "Result") { + return UnwrapType::Result; + } + } + UnwrapType::Option + } + fn is_panic_inducing_call(&self, func: &Expr<'_>) -> bool { if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind { return PANIC_INDUCING_FUNCTIONS.iter().any(|&func| { @@ -175,21 +205,23 @@ impl UnsafeUnwrapVisitor<'_, '_> { None } - fn set_conditional_checker(&mut self, conditional_checkers: &HashSet) { - for checker in conditional_checkers { - self.conditional_checker.insert(*checker); - if checker.check_type.is_safe_to_unwrap() { - self.checked_exprs.insert(checker.checked_expr_hir_id); - } - } - } - - fn reset_conditional_checker(&mut self, conditional_checkers: HashSet) { + fn update_conditional_checker( + &mut self, + conditional_checkers: &HashSet, + set: bool, + ) { for checker in conditional_checkers { - if checker.check_type.is_safe_to_unwrap() { - self.checked_exprs.remove(&checker.checked_expr_hir_id); + if set { + self.conditional_checker.insert(*checker); + if checker.check_type.is_safe_to_unwrap() { + self.checked_exprs.insert(checker.checked_expr_hir_id); + } + } else { + if checker.check_type.is_safe_to_unwrap() { + self.checked_exprs.remove(&checker.checked_expr_hir_id); + } + self.conditional_checker.remove(checker); } - self.conditional_checker.remove(&checker); } } @@ -253,18 +285,20 @@ impl UnsafeUnwrapVisitor<'_, '_> { } impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> { - fn visit_local(&mut self, local: &'tcx rustc_hir::LetStmt<'tcx>) -> Self::Result { + fn visit_local(&mut self, local: &'tcx LetStmt<'tcx>) -> Self::Result { if let Some(init) = local.init { match init.kind { ExprKind::MethodCall(path_segment, receiver, args, _) => { if self.is_method_call_unsafe(path_segment, receiver, args) { + let unwrap_type = self.determine_unwrap_type(receiver); + let help_message = self.get_help_message(unwrap_type); span_lint_and_help( self.cx, UNSAFE_UNWRAP, local.span, LINT_MESSAGE, None, - "Please, use a custom error instead of `unwrap`", + help_message, ); } } @@ -300,17 +334,17 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> { } // Find `if` or `if let` expressions - if let Some(higher::IfOrIfLet { + if let Some(IfOrIfLet { cond, then: if_expr, r#else: _, - }) = higher::IfOrIfLet::hir(expr) + }) = IfOrIfLet::hir(expr) { // If we are interested in the condition (if it is a CheckType) we traverse the body. let conditional_checker = ConditionalChecker::from_expression(cond); - self.set_conditional_checker(&conditional_checker); + self.update_conditional_checker(&conditional_checker, true); walk_expr(self, if_expr); - self.reset_conditional_checker(conditional_checker); + self.update_conditional_checker(&conditional_checker, false); return; } @@ -320,16 +354,18 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> { if path_segment.ident.name == sym::unwrap; then { let receiver_hir_id = self.get_unwrap_info(receiver); - // If the receiver is `None`, then we asume that the `unwrap` is unsafe + // If the receiver is `None` or `Err`, then we assume that the `unwrap` is unsafe let is_checked_safe = receiver_hir_id.map_or(false, |id| self.checked_exprs.contains(&id)); if !is_checked_safe { + let unwrap_type = self.determine_unwrap_type(receiver); + let help_message = self.get_help_message(unwrap_type); span_lint_and_help( self.cx, UNSAFE_UNWRAP, expr.span, LINT_MESSAGE, None, - "Please, use a custom error instead of `unwrap`", + help_message, ); } } @@ -344,7 +380,7 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap { &mut self, cx: &LateContext<'tcx>, _: FnKind<'tcx>, - _: &'tcx FnDecl<'tcx>, + fn_decl: &'tcx FnDecl<'tcx>, body: &'tcx Body<'tcx>, span: Span, _: LocalDefId, @@ -358,6 +394,7 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap { cx, checked_exprs: HashSet::new(), conditional_checker: HashSet::new(), + returns_result: returns_result(fn_decl), }; walk_expr(&mut visitor, body.value);