diff --git a/detectors/unsafe-expect/src/lib.rs b/detectors/unsafe-expect/src/lib.rs index 4f3cb283..fba6c168 100644 --- a/detectors/unsafe-expect/src/lib.rs +++ b/detectors/unsafe-expect/src/lib.rs @@ -11,12 +11,13 @@ use rustc_hir::{ def::Res, def_id::LocalDefId, intravisit::{walk_expr, FnKind, Visitor}, - BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, PathSegment, QPath, UnOp, + BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, LangItem, MatchSource, PathSegment, QPath, + UnOp, }; use rustc_lint::{LateContext, LateLintPass}; use rustc_span::{sym, Span, Symbol}; use std::{collections::HashSet, hash::Hash}; -use utils::fn_returns; +use utils::{fn_returns, ConstantAnalyzer}; const LINT_MESSAGE: &str = "Unsafe usage of `expect`"; const PANIC_INDUCING_FUNCTIONS: [&str; 2] = ["panic", "bail"]; @@ -147,13 +148,14 @@ impl ConditionalChecker { /// Main unsafe-expect visitor struct UnsafeExpectVisitor<'a, 'tcx> { cx: &'a LateContext<'tcx>, + constant_analyzer: ConstantAnalyzer<'a, 'tcx>, conditional_checker: HashSet, checked_exprs: HashSet, linted_spans: HashSet, } -impl UnsafeExpectVisitor<'_, '_> { - fn is_panic_inducing_call(&self, func: &Expr<'_>) -> bool { +impl<'a, 'tcx> UnsafeExpectVisitor<'a, 'tcx> { + fn is_panic_inducing_call(&self, func: &Expr<'tcx>) -> bool { if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind { return PANIC_INDUCING_FUNCTIONS.iter().any(|&func| { path.segments @@ -164,7 +166,7 @@ impl UnsafeExpectVisitor<'_, '_> { false } - fn get_expect_info(&self, receiver: &Expr<'_>) -> Option { + fn get_expect_info(&self, receiver: &Expr<'tcx>) -> Option { if_chain! { if let ExprKind::Path(QPath::Resolved(_, path)) = &receiver.kind; if let Res::Local(hir_id) = path.res; @@ -204,109 +206,59 @@ impl UnsafeExpectVisitor<'_, '_> { }); } - fn is_literal_or_composed_of_literals(&self, expr: &Expr<'_>) -> bool { - let mut stack = vec![expr]; - - while let Some(current_expr) = stack.pop() { - match current_expr.kind { - ExprKind::Lit(_) => continue, // A literal is fine, continue processing. - ExprKind::Tup(elements) | ExprKind::Array(elements) => { - stack.extend(elements); - } - ExprKind::Struct(_, fields, _) => { - for field in fields { - stack.push(field.expr); - } - } - ExprKind::Repeat(element, _) => { - stack.push(element); - } - _ => return false, // If any element is not a literal or a compound of literals, return false. + fn is_method_call_unsafe(&self, path_segment: &PathSegment, receiver: &Expr<'tcx>) -> bool { + if path_segment.ident.name == sym::expect { + if self.constant_analyzer.is_constant(receiver) { + return false; } - } - - true // If the stack is emptied without finding a non-literal, all elements are literals. - } - fn is_method_call_unsafe( - &self, - path_segment: &PathSegment, - receiver: &Expr, - args: &[Expr], - ) -> bool { - if path_segment.ident.name == sym::expect { return self .get_expect_info(receiver) .map_or(true, |id| !self.checked_exprs.contains(&id)); } - - args.iter().any(|arg| self.contains_unsafe_method_call(arg)) - || self.contains_unsafe_method_call(receiver) + false } - fn contains_unsafe_method_call(&self, expr: &Expr) -> bool { - match &expr.kind { - ExprKind::MethodCall(path_segment, receiver, args, _) => { - self.is_method_call_unsafe(path_segment, receiver, args) + fn check_expr_for_unsafe_expect(&mut self, expr: &Expr<'tcx>) { + if let ExprKind::MethodCall(path_segment, receiver, _, _) = &expr.kind { + if self.is_method_call_unsafe(path_segment, receiver) + && !self.linted_spans.contains(&expr.span) + { + span_lint_and_help( + self.cx, + UNSAFE_EXPECT, + expr.span, + LINT_MESSAGE, + None, + "Please, use a custom error instead of `expect`", + ); + self.linted_spans.insert(expr.span); } - _ => false, } } - fn check_expr_for_unsafe_expect(&mut self, expr: &Expr) { - match &expr.kind { - ExprKind::MethodCall(path_segment, receiver, args, _) => { - if self.is_method_call_unsafe(path_segment, receiver, args) - && !self.linted_spans.contains(&expr.span) - { - span_lint_and_help( - self.cx, - UNSAFE_EXPECT, - expr.span, - LINT_MESSAGE, - None, - "Please, use a custom error instead of `expect`", - ); - self.linted_spans.insert(expr.span); - } - } - ExprKind::Call(func, args) => { - if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind { - let is_some_or_ok = path - .segments - .iter() - .any(|segment| matches!(segment.ident.name, sym::Some | sym::Ok)); - let all_literals = args - .iter() - .all(|arg| self.is_literal_or_composed_of_literals(arg)); - if is_some_or_ok && all_literals { - self.checked_exprs.insert(expr.hir_id); - return; - } - } - // Check arguments for unsafe expect - for arg in args.iter() { - self.check_expr_for_unsafe_expect(arg); - } - } - ExprKind::Tup(exprs) | ExprKind::Array(exprs) => { - for expr in exprs.iter() { - self.check_expr_for_unsafe_expect(expr); - } + fn check_for_try(&mut self, expr: &Expr<'tcx>) { + if_chain! { + // Check for match expressions desugared from try + if let ExprKind::Match(expr, _, MatchSource::TryDesugar(_)) = &expr.kind; + if let ExprKind::Call(func, args) = &expr.kind; + // Check for the try trait branch lang item + if let ExprKind::Path(QPath::LangItem(LangItem::TryTraitBranch, _)) = &func.kind; + if let ExprKind::Path(QPath::Resolved(_, path)) = &args[0].kind; + // Get the HirId of the expression that is being checked + if let Res::Local(hir_id) = path.res; + then { + self.checked_exprs.insert(hir_id); } - _ => {} } } } impl<'a, 'tcx> Visitor<'tcx> for UnsafeExpectVisitor<'a, 'tcx> { - fn visit_local(&mut self, local: &'tcx rustc_hir::LetStmt<'tcx>) -> Self::Result { - if let Some(init) = local.init { - self.check_expr_for_unsafe_expect(init); - } - } - fn visit_expr(&mut self, expr: &'tcx Expr<'_>) { + // Check for try desugaring '?' + self.check_for_try(expr); + // If we are inside an `if` or `if let` expression, we analyze its body if !self.conditional_checker.is_empty() { match &expr.kind { @@ -357,8 +309,15 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeExpect { return; } + let mut constant_analyzer = ConstantAnalyzer { + cx, + constants: HashSet::new(), + }; + constant_analyzer.visit_body(body); + let mut visitor = UnsafeExpectVisitor { cx, + constant_analyzer, checked_exprs: HashSet::new(), conditional_checker: HashSet::new(), linted_spans: HashSet::new(),