diff --git a/detectors/unsafe-expect/src/lib.rs b/detectors/unsafe-expect/src/lib.rs index db94f5a7..e7224e9a 100644 --- a/detectors/unsafe-expect/src/lib.rs +++ b/detectors/unsafe-expect/src/lib.rs @@ -16,7 +16,7 @@ use rustc_hir::{ use rustc_lint::{LateContext, LateLintPass}; use rustc_span::{sym, Span, Symbol}; use std::{collections::HashSet, hash::Hash}; -use utils::returns_result; +use utils::fn_returns; const LINT_MESSAGE: &str = "Unsafe usage of `expect`"; const PANIC_INDUCING_FUNCTIONS: [&str; 2] = ["panic", "bail"]; @@ -350,8 +350,10 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeExpect { span: Span, _: LocalDefId, ) { - // If the function comes from a macro expansion or does not return a Result<(), ()>, we don't want to analyze it. - if span.from_expansion() || !returns_result(fn_decl) { + // If the function comes from a macro expansion or does not return a Result<(), ()> or Option<()>, we don't want to analyze it. + if span.from_expansion() + || !fn_returns(fn_decl, sym::Result) && !fn_returns(fn_decl, sym::Option) + { return; } diff --git a/utils/src/type_utils.rs b/utils/src/type_utils.rs index 3a4a4b9a..ab32ff22 100644 --- a/utils/src/type_utils.rs +++ b/utils/src/type_utils.rs @@ -6,7 +6,7 @@ extern crate rustc_span; use rustc_hir::{FnDecl, FnRetTy, HirId, QPath}; use rustc_lint::LateContext; use rustc_middle::ty::{Ty, TyKind}; -use rustc_span::sym; +use rustc_span::Symbol; /// Get the type of a node, if it exists. pub fn get_node_type_opt<'tcx>(cx: &LateContext<'tcx>, hir_id: &HirId) -> Option> { @@ -22,13 +22,13 @@ pub fn match_type_to_str(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str } } -/// Check if a function returns a Result type. -pub fn returns_result(decl: &FnDecl<'_>) -> bool { +/// Check the return type of a function. +pub fn fn_returns(decl: &FnDecl<'_>, type_symbol: Symbol) -> bool { if let FnRetTy::Return(ty) = decl.output { matches!(ty.kind, rustc_hir::TyKind::Path(QPath::Resolved(_, path)) if path .segments .last() - .map_or(false, |seg| seg.ident.name == sym::Result)) + .map_or(false, |seg| seg.ident.name == type_symbol)) } else { false }