diff --git a/detectors/unsafe-unwrap/src/lib.rs b/detectors/unsafe-unwrap/src/lib.rs index f3e77c3a..66916981 100644 --- a/detectors/unsafe-unwrap/src/lib.rs +++ b/detectors/unsafe-unwrap/src/lib.rs @@ -11,12 +11,12 @@ use rustc_hir::{ def::Res, def_id::LocalDefId, intravisit::{walk_expr, FnKind, Visitor}, - BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, LetStmt, PathSegment, QPath, UnOp, + BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, PathSegment, QPath, UnOp, }; use rustc_lint::{LateContext, LateLintPass}; use rustc_span::{sym, Span, Symbol}; use std::{collections::HashSet, hash::Hash}; -use utils::{fn_returns, get_node_type_opt, match_type_to_str}; +use utils::{fn_returns, get_node_type_opt, match_type_to_str, ConstantAnalyzer}; const LINT_MESSAGE: &str = "Unsafe usage of `unwrap`"; const PANIC_INDUCING_FUNCTIONS: [&str; 2] = ["panic", "bail"]; @@ -154,13 +154,14 @@ impl ConditionalChecker { /// Main unsafe-unwrap visitor struct UnsafeUnwrapVisitor<'a, 'tcx> { cx: &'a LateContext<'tcx>, + constant_analyzer: ConstantAnalyzer<'a, 'tcx>, conditional_checker: HashSet, checked_exprs: HashSet, linted_spans: HashSet, returns_result_or_option: bool, } -impl UnsafeUnwrapVisitor<'_, '_> { +impl<'a, 'tcx> UnsafeUnwrapVisitor<'a, 'tcx> { fn get_help_message(&self, unwrap_type: UnwrapType) -> &'static str { match (self.returns_result_or_option, unwrap_type) { (true, UnwrapType::Option) => "Consider using `ok_or` to convert Option to Result", @@ -174,7 +175,7 @@ impl UnsafeUnwrapVisitor<'_, '_> { } } - fn determine_unwrap_type(&self, receiver: &Expr<'_>) -> UnwrapType { + fn determine_unwrap_type(&self, receiver: &Expr<'tcx>) -> UnwrapType { let type_opt = get_node_type_opt(self.cx, &receiver.hir_id); if let Some(type_) = type_opt { if match_type_to_str(self.cx, type_, "Result") { @@ -184,7 +185,7 @@ impl UnsafeUnwrapVisitor<'_, '_> { UnwrapType::Option } - fn is_panic_inducing_call(&self, func: &Expr<'_>) -> bool { + fn is_panic_inducing_call(&self, func: &Expr<'tcx>) -> bool { if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind { return PANIC_INDUCING_FUNCTIONS.iter().any(|&func| { path.segments @@ -195,7 +196,7 @@ impl UnsafeUnwrapVisitor<'_, '_> { false } - fn get_unwrap_info(&self, receiver: &Expr<'_>) -> Option { + fn get_unwrap_info(&self, receiver: &Expr<'tcx>) -> Option { if_chain! { if let ExprKind::Path(QPath::Resolved(_, path)) = &receiver.kind; if let Res::Local(hir_id) = path.res; @@ -235,110 +236,43 @@ impl UnsafeUnwrapVisitor<'_, '_> { }); } - fn is_literal_or_composed_of_literals(&self, expr: &Expr<'_>) -> bool { - let mut stack = vec![expr]; - - while let Some(current_expr) = stack.pop() { - match current_expr.kind { - ExprKind::Lit(_) => continue, // A literal is fine, continue processing. - ExprKind::Tup(elements) | ExprKind::Array(elements) => { - stack.extend(elements); - } - ExprKind::Struct(_, fields, _) => { - for field in fields { - stack.push(field.expr); - } - } - ExprKind::Repeat(element, _) => { - stack.push(element); - } - _ => return false, // If any element is not a literal or a compound of literals, return false. + fn is_method_call_unsafe(&mut self, path_segment: &PathSegment, receiver: &Expr<'tcx>) -> bool { + if path_segment.ident.name == sym::unwrap { + if self.constant_analyzer.is_constant(receiver) { + return false; } - } - true // If the stack is emptied without finding a non-literal, all elements are literals. - } - - fn is_method_call_unsafe( - &self, - path_segment: &PathSegment, - receiver: &Expr, - args: &[Expr], - ) -> bool { - if path_segment.ident.name == sym::unwrap { return self .get_unwrap_info(receiver) .map_or(true, |id| !self.checked_exprs.contains(&id)); } - args.iter().any(|arg| self.contains_unsafe_method_call(arg)) - || self.contains_unsafe_method_call(receiver) - } - - fn contains_unsafe_method_call(&self, expr: &Expr) -> bool { - match &expr.kind { - ExprKind::MethodCall(path_segment, receiver, args, _) => { - self.is_method_call_unsafe(path_segment, receiver, args) - } - _ => false, - } + false } - fn check_expr_for_unsafe_unwrap(&mut self, expr: &Expr) { - match &expr.kind { - ExprKind::MethodCall(path_segment, receiver, args, _) => { - if self.is_method_call_unsafe(path_segment, receiver, args) - && !self.linted_spans.contains(&expr.span) - { - let unwrap_type = self.determine_unwrap_type(receiver); - let help_message = self.get_help_message(unwrap_type); - span_lint_and_help( - self.cx, - UNSAFE_UNWRAP, - expr.span, - LINT_MESSAGE, - None, - help_message, - ); - self.linted_spans.insert(expr.span); - } - } - ExprKind::Call(func, args) => { - if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind { - let is_some_or_ok = path - .segments - .iter() - .any(|segment| matches!(segment.ident.name, sym::Some | sym::Ok)); - let all_literals = args - .iter() - .all(|arg| self.is_literal_or_composed_of_literals(arg)); - if is_some_or_ok && all_literals { - self.checked_exprs.insert(expr.hir_id); - return; - } - } - // Check arguments for unsafe expect - for arg in args.iter() { - self.check_expr_for_unsafe_unwrap(arg); - } - } - ExprKind::Tup(exprs) | ExprKind::Array(exprs) => { - for expr in exprs.iter() { - self.check_expr_for_unsafe_unwrap(expr); - } + fn check_expr_for_unsafe_unwrap(&mut self, expr: &Expr<'tcx>) { + if let ExprKind::MethodCall(path_segment, receiver, _, _) = &expr.kind { + if self.is_method_call_unsafe(path_segment, receiver) + && !self.linted_spans.contains(&expr.span) + { + let unwrap_type = self.determine_unwrap_type(receiver); + let help_message = self.get_help_message(unwrap_type); + + span_lint_and_help( + self.cx, + UNSAFE_UNWRAP, + expr.span, + LINT_MESSAGE, + None, + help_message, + ); + self.linted_spans.insert(expr.span); } - _ => {} } } } impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> { - fn visit_local(&mut self, local: &'tcx LetStmt<'tcx>) -> Self::Result { - if let Some(init) = local.init { - self.check_expr_for_unsafe_unwrap(init); - } - } - fn visit_expr(&mut self, expr: &'tcx Expr<'_>) { // If we are inside an `if` or `if let` expression, we analyze its body if !self.conditional_checker.is_empty() { @@ -388,8 +322,15 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap { return; } + let mut constant_analyzer = ConstantAnalyzer { + cx, + constants: HashSet::new(), + }; + constant_analyzer.visit_body(body); + let mut visitor = UnsafeUnwrapVisitor { cx, + constant_analyzer, checked_exprs: HashSet::new(), conditional_checker: HashSet::new(), linted_spans: HashSet::new(), diff --git a/utils/src/constant_analyzer/mod.rs b/utils/src/constant_analyzer/mod.rs index a9e9457e..7b92b111 100644 --- a/utils/src/constant_analyzer/mod.rs +++ b/utils/src/constant_analyzer/mod.rs @@ -22,13 +22,23 @@ impl<'a, 'tcx> ConstantAnalyzer<'a, 'tcx> { fn is_qpath_constant(&self, path: &QPath) -> bool { if let QPath::Resolved(_, path) = path { match path.res { - Res::Def(def_kind, _) => matches!( - def_kind, - DefKind::AnonConst - | DefKind::AssocConst - | DefKind::Const - | DefKind::InlineConst - ), + Res::Def(def_kind, def_id) => { + matches!( + def_kind, + DefKind::AnonConst + | DefKind::AssocConst + | DefKind::Const + | DefKind::InlineConst + ) || { + // Allow both Some and Ok variant constructors + if let DefKind::Ctor(..) = def_kind { + let def_path = self.cx.tcx.def_path_str(def_id); + def_path.ends_with("::Some") || def_path.ends_with("::Ok") + } else { + false + } + } + } Res::Local(hir_id) => self.constants.contains(&hir_id), _ => false, } @@ -61,6 +71,9 @@ impl<'a, 'tcx> ConstantAnalyzer<'a, 'tcx> { ExprKind::Struct(_, expr_fields, _) => expr_fields .iter() .all(|field_expr| self.is_expr_constant(field_expr.expr)), + ExprKind::Call(func, args) => { + self.is_expr_constant(func) && args.iter().all(|arg| self.is_expr_constant(arg)) + } _ => false, } }