From 0ae804c0fe4f4d9d92b891457f6e325b02178cb6 Mon Sep 17 00:00:00 2001 From: Jose Garcia Crosta Date: Mon, 9 Sep 2024 11:43:54 -0300 Subject: [PATCH] Improve detectors --- detectors/unsafe-unwrap/src/lib.rs | 103 ++++++++++++++--------------- 1 file changed, 51 insertions(+), 52 deletions(-) diff --git a/detectors/unsafe-unwrap/src/lib.rs b/detectors/unsafe-unwrap/src/lib.rs index 63b225ec..f3e77c3a 100644 --- a/detectors/unsafe-unwrap/src/lib.rs +++ b/detectors/unsafe-unwrap/src/lib.rs @@ -156,6 +156,7 @@ struct UnsafeUnwrapVisitor<'a, 'tcx> { cx: &'a LateContext<'tcx>, conditional_checker: HashSet, checked_exprs: HashSet, + linted_spans: HashSet, returns_result_or_option: bool, } @@ -282,42 +283,59 @@ impl UnsafeUnwrapVisitor<'_, '_> { _ => false, } } -} -impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> { - fn visit_local(&mut self, local: &'tcx LetStmt<'tcx>) -> Self::Result { - if let Some(init) = local.init { - match init.kind { - ExprKind::MethodCall(path_segment, receiver, args, _) => { - if self.is_method_call_unsafe(path_segment, receiver, args) { - let unwrap_type = self.determine_unwrap_type(receiver); - let help_message = self.get_help_message(unwrap_type); - span_lint_and_help( - self.cx, - UNSAFE_UNWRAP, - local.span, - LINT_MESSAGE, - None, - help_message, - ); - } + fn check_expr_for_unsafe_unwrap(&mut self, expr: &Expr) { + match &expr.kind { + ExprKind::MethodCall(path_segment, receiver, args, _) => { + if self.is_method_call_unsafe(path_segment, receiver, args) + && !self.linted_spans.contains(&expr.span) + { + let unwrap_type = self.determine_unwrap_type(receiver); + let help_message = self.get_help_message(unwrap_type); + span_lint_and_help( + self.cx, + UNSAFE_UNWRAP, + expr.span, + LINT_MESSAGE, + None, + help_message, + ); + self.linted_spans.insert(expr.span); } - ExprKind::Call(func, args) => { - if let ExprKind::Path(QPath::Resolved(_, path)) = func.kind { - let is_some_or_ok = path - .segments - .iter() - .any(|segment| matches!(segment.ident.name, sym::Some | sym::Ok)); - let all_literals = args - .iter() - .all(|arg| self.is_literal_or_composed_of_literals(arg)); - if is_some_or_ok && all_literals { - self.checked_exprs.insert(local.pat.hir_id); - } + } + ExprKind::Call(func, args) => { + if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind { + let is_some_or_ok = path + .segments + .iter() + .any(|segment| matches!(segment.ident.name, sym::Some | sym::Ok)); + let all_literals = args + .iter() + .all(|arg| self.is_literal_or_composed_of_literals(arg)); + if is_some_or_ok && all_literals { + self.checked_exprs.insert(expr.hir_id); + return; } } - _ => {} + // Check arguments for unsafe expect + for arg in args.iter() { + self.check_expr_for_unsafe_unwrap(arg); + } + } + ExprKind::Tup(exprs) | ExprKind::Array(exprs) => { + for expr in exprs.iter() { + self.check_expr_for_unsafe_unwrap(expr); + } } + _ => {} + } + } +} + +impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> { + fn visit_local(&mut self, local: &'tcx LetStmt<'tcx>) -> Self::Result { + if let Some(init) = local.init { + self.check_expr_for_unsafe_unwrap(init); } } @@ -349,27 +367,7 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> { } // If we find an unsafe `unwrap`, we raise a warning - if_chain! { - if let ExprKind::MethodCall(path_segment, receiver, _, _) = &expr.kind; - if path_segment.ident.name == sym::unwrap; - then { - let receiver_hir_id = self.get_unwrap_info(receiver); - // If the receiver is `None` or `Err`, then we assume that the `unwrap` is unsafe - let is_checked_safe = receiver_hir_id.map_or(false, |id| self.checked_exprs.contains(&id)); - if !is_checked_safe { - let unwrap_type = self.determine_unwrap_type(receiver); - let help_message = self.get_help_message(unwrap_type); - span_lint_and_help( - self.cx, - UNSAFE_UNWRAP, - expr.span, - LINT_MESSAGE, - None, - help_message, - ); - } - } - } + self.check_expr_for_unsafe_unwrap(expr); walk_expr(self, expr); } @@ -394,6 +392,7 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap { cx, checked_exprs: HashSet::new(), conditional_checker: HashSet::new(), + linted_spans: HashSet::new(), returns_result_or_option: fn_returns(fn_decl, sym::Result) || fn_returns(fn_decl, sym::Option), };