diff --git a/detectors/unsafe-unwrap/Cargo.toml b/detectors/unsafe-unwrap/Cargo.toml index cb8e056c..4812ff59 100644 --- a/detectors/unsafe-unwrap/Cargo.toml +++ b/detectors/unsafe-unwrap/Cargo.toml @@ -11,6 +11,7 @@ clippy_utils = { workspace = true } clippy_wrappers = { workspace = true } dylint_linting = { workspace = true } if_chain = { workspace = true } +utils = { workspace = true } [package.metadata.rust-analyzer] rustc_private = true diff --git a/detectors/unsafe-unwrap/src/lib.rs b/detectors/unsafe-unwrap/src/lib.rs index 6f2026d4..63b225ec 100644 --- a/detectors/unsafe-unwrap/src/lib.rs +++ b/detectors/unsafe-unwrap/src/lib.rs @@ -1,23 +1,22 @@ #![feature(rustc_private)] #![allow(clippy::enum_variant_names)] -extern crate rustc_ast; extern crate rustc_hir; extern crate rustc_span; -use std::{collections::HashSet, hash::Hash}; - -use clippy_utils::higher; +use clippy_utils::higher::IfOrIfLet; use clippy_wrappers::span_lint_and_help; use if_chain::if_chain; use rustc_hir::{ def::Res, def_id::LocalDefId, intravisit::{walk_expr, FnKind, Visitor}, - BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, PathSegment, QPath, UnOp, + BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, LetStmt, PathSegment, QPath, UnOp, }; use rustc_lint::{LateContext, LateLintPass}; use rustc_span::{sym, Span, Symbol}; +use std::{collections::HashSet, hash::Hash}; +use utils::{fn_returns, get_node_type_opt, match_type_to_str}; const LINT_MESSAGE: &str = "Unsafe usage of `unwrap`"; const PANIC_INDUCING_FUNCTIONS: [&str; 2] = ["panic", "bail"]; @@ -65,6 +64,13 @@ dylint_linting::declare_late_lint! { } } +// Enum to represent the type being unwrapped +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +enum UnwrapType { + Option, + Result, +} + /// Represents the type of check performed on method call expressions to determine their safety or behavior. #[derive(Clone, Copy, Hash, Eq, PartialEq)] enum CheckType { @@ -132,7 +138,7 @@ impl ConditionalChecker { fn from_expression(condition: &Expr<'_>) -> HashSet { match condition.kind { // Single `not` expressions are supported - ExprKind::Unary(op, condition) => Self::handle_condition(condition, op == UnOp::Not), + ExprKind::Unary(UnOp::Not, condition) => Self::handle_condition(condition, true), // Multiple `or` expressions are supported ExprKind::Binary(op, left_condition, right_condition) if op.node == BinOpKind::Or => { let mut result = Self::from_expression(left_condition); @@ -150,9 +156,33 @@ struct UnsafeUnwrapVisitor<'a, 'tcx> { cx: &'a LateContext<'tcx>, conditional_checker: HashSet, checked_exprs: HashSet, + returns_result_or_option: bool, } impl UnsafeUnwrapVisitor<'_, '_> { + fn get_help_message(&self, unwrap_type: UnwrapType) -> &'static str { + match (self.returns_result_or_option, unwrap_type) { + (true, UnwrapType::Option) => "Consider using `ok_or` to convert Option to Result", + (true, UnwrapType::Result) => "Consider using the `?` operator for error propagation", + (false, UnwrapType::Option) => { + "Consider pattern matching or using `if let` instead of `unwrap`" + } + (false, UnwrapType::Result) => { + "Consider handling the error case explicitly or using `if let` instead of `unwrap`" + } + } + } + + fn determine_unwrap_type(&self, receiver: &Expr<'_>) -> UnwrapType { + let type_opt = get_node_type_opt(self.cx, &receiver.hir_id); + if let Some(type_) = type_opt { + if match_type_to_str(self.cx, type_, "Result") { + return UnwrapType::Result; + } + } + UnwrapType::Option + } + fn is_panic_inducing_call(&self, func: &Expr<'_>) -> bool { if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind { return PANIC_INDUCING_FUNCTIONS.iter().any(|&func| { @@ -175,21 +205,23 @@ impl UnsafeUnwrapVisitor<'_, '_> { None } - fn set_conditional_checker(&mut self, conditional_checkers: &HashSet) { - for checker in conditional_checkers { - self.conditional_checker.insert(*checker); - if checker.check_type.is_safe_to_unwrap() { - self.checked_exprs.insert(checker.checked_expr_hir_id); - } - } - } - - fn reset_conditional_checker(&mut self, conditional_checkers: HashSet) { + fn update_conditional_checker( + &mut self, + conditional_checkers: &HashSet, + set: bool, + ) { for checker in conditional_checkers { - if checker.check_type.is_safe_to_unwrap() { - self.checked_exprs.remove(&checker.checked_expr_hir_id); + if set { + self.conditional_checker.insert(*checker); + if checker.check_type.is_safe_to_unwrap() { + self.checked_exprs.insert(checker.checked_expr_hir_id); + } + } else { + if checker.check_type.is_safe_to_unwrap() { + self.checked_exprs.remove(&checker.checked_expr_hir_id); + } + self.conditional_checker.remove(checker); } - self.conditional_checker.remove(&checker); } } @@ -253,18 +285,20 @@ impl UnsafeUnwrapVisitor<'_, '_> { } impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> { - fn visit_local(&mut self, local: &'tcx rustc_hir::LetStmt<'tcx>) -> Self::Result { + fn visit_local(&mut self, local: &'tcx LetStmt<'tcx>) -> Self::Result { if let Some(init) = local.init { match init.kind { ExprKind::MethodCall(path_segment, receiver, args, _) => { if self.is_method_call_unsafe(path_segment, receiver, args) { + let unwrap_type = self.determine_unwrap_type(receiver); + let help_message = self.get_help_message(unwrap_type); span_lint_and_help( self.cx, UNSAFE_UNWRAP, local.span, LINT_MESSAGE, None, - "Please, use a custom error instead of `unwrap`", + help_message, ); } } @@ -300,17 +334,17 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> { } // Find `if` or `if let` expressions - if let Some(higher::IfOrIfLet { + if let Some(IfOrIfLet { cond, then: if_expr, r#else: _, - }) = higher::IfOrIfLet::hir(expr) + }) = IfOrIfLet::hir(expr) { // If we are interested in the condition (if it is a CheckType) we traverse the body. let conditional_checker = ConditionalChecker::from_expression(cond); - self.set_conditional_checker(&conditional_checker); + self.update_conditional_checker(&conditional_checker, true); walk_expr(self, if_expr); - self.reset_conditional_checker(conditional_checker); + self.update_conditional_checker(&conditional_checker, false); return; } @@ -320,16 +354,18 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> { if path_segment.ident.name == sym::unwrap; then { let receiver_hir_id = self.get_unwrap_info(receiver); - // If the receiver is `None`, then we asume that the `unwrap` is unsafe + // If the receiver is `None` or `Err`, then we assume that the `unwrap` is unsafe let is_checked_safe = receiver_hir_id.map_or(false, |id| self.checked_exprs.contains(&id)); if !is_checked_safe { + let unwrap_type = self.determine_unwrap_type(receiver); + let help_message = self.get_help_message(unwrap_type); span_lint_and_help( self.cx, UNSAFE_UNWRAP, expr.span, LINT_MESSAGE, None, - "Please, use a custom error instead of `unwrap`", + help_message, ); } } @@ -344,7 +380,7 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap { &mut self, cx: &LateContext<'tcx>, _: FnKind<'tcx>, - _: &'tcx FnDecl<'tcx>, + fn_decl: &'tcx FnDecl<'tcx>, body: &'tcx Body<'tcx>, span: Span, _: LocalDefId, @@ -358,6 +394,8 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap { cx, checked_exprs: HashSet::new(), conditional_checker: HashSet::new(), + returns_result_or_option: fn_returns(fn_decl, sym::Result) + || fn_returns(fn_decl, sym::Option), }; walk_expr(&mut visitor, body.value); diff --git a/utils/src/soroban_utils/mod.rs b/utils/src/soroban_utils/mod.rs index 6b97c8e6..189ebaee 100644 --- a/utils/src/soroban_utils/mod.rs +++ b/utils/src/soroban_utils/mod.rs @@ -2,11 +2,11 @@ extern crate rustc_lint; extern crate rustc_middle; extern crate rustc_span; -use std::collections::HashSet; - +use crate::type_utils::match_type_to_str; use rustc_lint::LateContext; -use rustc_middle::ty::{Ty, TyKind}; +use rustc_middle::ty::Ty; use rustc_span::def_id::DefId; +use std::collections::HashSet; /// Constants defining the fully qualified names of Soroban types. const SOROBAN_ENV: &str = "soroban_sdk::Env"; @@ -58,28 +58,19 @@ pub fn is_soroban_function( .all(|pattern| checked_functions.contains(pattern)) } -// Private helper function to match soroban types -fn is_soroban_type(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str) -> bool { - match expr_type.kind() { - TyKind::Adt(adt_def, _) => cx.tcx.def_path_str(adt_def.did()).contains(type_str), - TyKind::Ref(_, ty, _) => is_soroban_type(cx, *ty, type_str), - _ => false, - } -} - /// Checks if the provided type is a Soroban environment (`soroban_sdk::Env`). pub fn is_soroban_env(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool { - is_soroban_type(cx, expr_type, SOROBAN_ENV) + match_type_to_str(cx, expr_type, SOROBAN_ENV) } /// Checks if the provided type is a Soroban Address (`soroban_sdk::Address`). pub fn is_soroban_address(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool { - is_soroban_type(cx, expr_type, SOROBAN_ADDRESS) + match_type_to_str(cx, expr_type, SOROBAN_ADDRESS) } /// Checks if the provided type is a Soroban Map (`soroban_sdk::Map`). pub fn is_soroban_map(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool { - is_soroban_type(cx, expr_type, SOROBAN_MAP) + match_type_to_str(cx, expr_type, SOROBAN_MAP) } pub enum SorobanStorageType { @@ -97,14 +88,16 @@ pub fn is_soroban_storage( ) -> bool { match storage_type { SorobanStorageType::Any => { - is_soroban_type(cx, expr_type, SOROBAN_INSTANCE_STORAGE) - || is_soroban_type(cx, expr_type, SOROBAN_TEMPORARY_STORAGE) - || is_soroban_type(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) + match_type_to_str(cx, expr_type, SOROBAN_INSTANCE_STORAGE) + || match_type_to_str(cx, expr_type, SOROBAN_TEMPORARY_STORAGE) + || match_type_to_str(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) + } + SorobanStorageType::Instance => match_type_to_str(cx, expr_type, SOROBAN_INSTANCE_STORAGE), + SorobanStorageType::Temporary => { + match_type_to_str(cx, expr_type, SOROBAN_TEMPORARY_STORAGE) } - SorobanStorageType::Instance => is_soroban_type(cx, expr_type, SOROBAN_INSTANCE_STORAGE), - SorobanStorageType::Temporary => is_soroban_type(cx, expr_type, SOROBAN_TEMPORARY_STORAGE), SorobanStorageType::Persistent => { - is_soroban_type(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) + match_type_to_str(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) } } } diff --git a/utils/src/type_utils.rs b/utils/src/type_utils.rs index 942003b6..ab32ff22 100644 --- a/utils/src/type_utils.rs +++ b/utils/src/type_utils.rs @@ -1,12 +1,35 @@ extern crate rustc_hir; extern crate rustc_lint; extern crate rustc_middle; +extern crate rustc_span; -use rustc_hir::HirId; +use rustc_hir::{FnDecl, FnRetTy, HirId, QPath}; use rustc_lint::LateContext; -use rustc_middle::ty::Ty; +use rustc_middle::ty::{Ty, TyKind}; +use rustc_span::Symbol; /// Get the type of a node, if it exists. pub fn get_node_type_opt<'tcx>(cx: &LateContext<'tcx>, hir_id: &HirId) -> Option> { cx.typeck_results().node_type_opt(*hir_id) } + +/// Match the type of an expression to a string. +pub fn match_type_to_str(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str) -> bool { + match expr_type.kind() { + TyKind::Adt(adt_def, _) => cx.tcx.def_path_str(adt_def.did()).contains(type_str), + TyKind::Ref(_, ty, _) => match_type_to_str(cx, *ty, type_str), + _ => false, + } +} + +/// Check the return type of a function. +pub fn fn_returns(decl: &FnDecl<'_>, type_symbol: Symbol) -> bool { + if let FnRetTy::Return(ty) = decl.output { + matches!(ty.kind, rustc_hir::TyKind::Path(QPath::Resolved(_, path)) if path + .segments + .last() + .map_or(false, |seg| seg.ident.name == type_symbol)) + } else { + false + } +}