Skip to content

Commit

Permalink
Update detector
Browse files Browse the repository at this point in the history
  • Loading branch information
jgcrosta committed Aug 22, 2024
1 parent 7471be1 commit 5be964d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 28 deletions.
1 change: 1 addition & 0 deletions detectors/unsafe-unwrap/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ clippy_utils = { workspace = true }
clippy_wrappers = { workspace = true }
dylint_linting = { workspace = true }
if_chain = { workspace = true }
utils = { workspace = true }

[package.metadata.rust-analyzer]
rustc_private = true
93 changes: 65 additions & 28 deletions detectors/unsafe-unwrap/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
#![feature(rustc_private)]
#![allow(clippy::enum_variant_names)]

extern crate rustc_ast;
extern crate rustc_hir;
extern crate rustc_span;

use std::{collections::HashSet, hash::Hash};

use clippy_utils::higher;
use clippy_utils::higher::IfOrIfLet;
use clippy_wrappers::span_lint_and_help;
use if_chain::if_chain;
use rustc_hir::{
def::Res,
def_id::LocalDefId,
intravisit::{walk_expr, FnKind, Visitor},
BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, PathSegment, QPath, UnOp,
BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, LetStmt, PathSegment, QPath, UnOp,
};
use rustc_lint::{LateContext, LateLintPass};
use rustc_span::{sym, Span, Symbol};
use std::{collections::HashSet, hash::Hash};
use utils::{get_node_type_opt, match_type_to_str, returns_result};

const LINT_MESSAGE: &str = "Unsafe usage of `unwrap`";
const PANIC_INDUCING_FUNCTIONS: [&str; 2] = ["panic", "bail"];
Expand Down Expand Up @@ -65,6 +64,13 @@ dylint_linting::declare_late_lint! {
}
}

// Enum to represent the type being unwrapped
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum UnwrapType {
Option,
Result,
}

/// Represents the type of check performed on method call expressions to determine their safety or behavior.
#[derive(Clone, Copy, Hash, Eq, PartialEq)]
enum CheckType {
Expand Down Expand Up @@ -132,7 +138,7 @@ impl ConditionalChecker {
fn from_expression(condition: &Expr<'_>) -> HashSet<Self> {
match condition.kind {
// Single `not` expressions are supported
ExprKind::Unary(op, condition) => Self::handle_condition(condition, op == UnOp::Not),
ExprKind::Unary(UnOp::Not, condition) => Self::handle_condition(condition, true),
// Multiple `or` expressions are supported
ExprKind::Binary(op, left_condition, right_condition) if op.node == BinOpKind::Or => {
let mut result = Self::from_expression(left_condition);
Expand All @@ -150,9 +156,33 @@ struct UnsafeUnwrapVisitor<'a, 'tcx> {
cx: &'a LateContext<'tcx>,
conditional_checker: HashSet<ConditionalChecker>,
checked_exprs: HashSet<HirId>,
returns_result: bool,
}

impl UnsafeUnwrapVisitor<'_, '_> {
fn get_help_message(&self, unwrap_type: UnwrapType) -> &'static str {
match (self.returns_result, unwrap_type) {
(true, UnwrapType::Option) => "Consider using `ok_or` to convert Option to Result",
(true, UnwrapType::Result) => "Consider using the `?` operator for error propagation",
(false, UnwrapType::Option) => {
"Consider pattern matching or using `if let` instead of `unwrap`"
}
(false, UnwrapType::Result) => {
"Consider handling the error case explicitly or using `if let` instead of `unwrap`"
}
}
}

fn determine_unwrap_type(&self, receiver: &Expr<'_>) -> UnwrapType {
let type_opt = get_node_type_opt(self.cx, &receiver.hir_id);
if let Some(type_) = type_opt {
if match_type_to_str(self.cx, type_, "Result") {
return UnwrapType::Result;
}
}
UnwrapType::Option
}

fn is_panic_inducing_call(&self, func: &Expr<'_>) -> bool {
if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind {
return PANIC_INDUCING_FUNCTIONS.iter().any(|&func| {
Expand All @@ -175,21 +205,23 @@ impl UnsafeUnwrapVisitor<'_, '_> {
None
}

fn set_conditional_checker(&mut self, conditional_checkers: &HashSet<ConditionalChecker>) {
for checker in conditional_checkers {
self.conditional_checker.insert(*checker);
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.insert(checker.checked_expr_hir_id);
}
}
}

fn reset_conditional_checker(&mut self, conditional_checkers: HashSet<ConditionalChecker>) {
fn update_conditional_checker(
&mut self,
conditional_checkers: &HashSet<ConditionalChecker>,
set: bool,
) {
for checker in conditional_checkers {
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.remove(&checker.checked_expr_hir_id);
if set {
self.conditional_checker.insert(*checker);
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.insert(checker.checked_expr_hir_id);
}
} else {
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.remove(&checker.checked_expr_hir_id);
}
self.conditional_checker.remove(checker);
}
self.conditional_checker.remove(&checker);
}
}

Expand Down Expand Up @@ -253,18 +285,20 @@ impl UnsafeUnwrapVisitor<'_, '_> {
}

impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> {
fn visit_local(&mut self, local: &'tcx rustc_hir::LetStmt<'tcx>) -> Self::Result {
fn visit_local(&mut self, local: &'tcx LetStmt<'tcx>) -> Self::Result {
if let Some(init) = local.init {
match init.kind {
ExprKind::MethodCall(path_segment, receiver, args, _) => {
if self.is_method_call_unsafe(path_segment, receiver, args) {
let unwrap_type = self.determine_unwrap_type(receiver);
let help_message = self.get_help_message(unwrap_type);
span_lint_and_help(
self.cx,
UNSAFE_UNWRAP,
local.span,
LINT_MESSAGE,
None,
"Please, use a custom error instead of `unwrap`",
help_message,
);
}
}
Expand Down Expand Up @@ -300,17 +334,17 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> {
}

// Find `if` or `if let` expressions
if let Some(higher::IfOrIfLet {
if let Some(IfOrIfLet {
cond,
then: if_expr,
r#else: _,
}) = higher::IfOrIfLet::hir(expr)
}) = IfOrIfLet::hir(expr)
{
// If we are interested in the condition (if it is a CheckType) we traverse the body.
let conditional_checker = ConditionalChecker::from_expression(cond);
self.set_conditional_checker(&conditional_checker);
self.update_conditional_checker(&conditional_checker, true);
walk_expr(self, if_expr);
self.reset_conditional_checker(conditional_checker);
self.update_conditional_checker(&conditional_checker, false);
return;
}

Expand All @@ -320,16 +354,18 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> {
if path_segment.ident.name == sym::unwrap;
then {
let receiver_hir_id = self.get_unwrap_info(receiver);
// If the receiver is `None`, then we asume that the `unwrap` is unsafe
// If the receiver is `None` or `Err`, then we assume that the `unwrap` is unsafe
let is_checked_safe = receiver_hir_id.map_or(false, |id| self.checked_exprs.contains(&id));
if !is_checked_safe {
let unwrap_type = self.determine_unwrap_type(receiver);
let help_message = self.get_help_message(unwrap_type);
span_lint_and_help(
self.cx,
UNSAFE_UNWRAP,
expr.span,
LINT_MESSAGE,
None,
"Please, use a custom error instead of `unwrap`",
help_message,
);
}
}
Expand All @@ -344,7 +380,7 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap {
&mut self,
cx: &LateContext<'tcx>,
_: FnKind<'tcx>,
_: &'tcx FnDecl<'tcx>,
fn_decl: &'tcx FnDecl<'tcx>,
body: &'tcx Body<'tcx>,
span: Span,
_: LocalDefId,
Expand All @@ -358,6 +394,7 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap {
cx,
checked_exprs: HashSet::new(),
conditional_checker: HashSet::new(),
returns_result: returns_result(fn_decl),
};

walk_expr(&mut visitor, body.value);
Expand Down

0 comments on commit 5be964d

Please sign in to comment.