Skip to content

Commit

Permalink
Merge pull request #333 from CoinFabrik/332-improve-unsafe-unwrap-det…
Browse files Browse the repository at this point in the history
…ector

Improve `unsafe-unwrap` detector
  • Loading branch information
tenuki authored Aug 26, 2024
2 parents 408112f + 3ea8ba3 commit 1b84450
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 51 deletions.
1 change: 1 addition & 0 deletions detectors/unsafe-unwrap/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ clippy_utils = { workspace = true }
clippy_wrappers = { workspace = true }
dylint_linting = { workspace = true }
if_chain = { workspace = true }
utils = { workspace = true }

[package.metadata.rust-analyzer]
rustc_private = true
94 changes: 66 additions & 28 deletions detectors/unsafe-unwrap/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
#![feature(rustc_private)]
#![allow(clippy::enum_variant_names)]

extern crate rustc_ast;
extern crate rustc_hir;
extern crate rustc_span;

use std::{collections::HashSet, hash::Hash};

use clippy_utils::higher;
use clippy_utils::higher::IfOrIfLet;
use clippy_wrappers::span_lint_and_help;
use if_chain::if_chain;
use rustc_hir::{
def::Res,
def_id::LocalDefId,
intravisit::{walk_expr, FnKind, Visitor},
BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, PathSegment, QPath, UnOp,
BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, LetStmt, PathSegment, QPath, UnOp,
};
use rustc_lint::{LateContext, LateLintPass};
use rustc_span::{sym, Span, Symbol};
use std::{collections::HashSet, hash::Hash};
use utils::{fn_returns, get_node_type_opt, match_type_to_str};

const LINT_MESSAGE: &str = "Unsafe usage of `unwrap`";
const PANIC_INDUCING_FUNCTIONS: [&str; 2] = ["panic", "bail"];
Expand Down Expand Up @@ -65,6 +64,13 @@ dylint_linting::declare_late_lint! {
}
}

// Enum to represent the type being unwrapped
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum UnwrapType {
Option,
Result,
}

/// Represents the type of check performed on method call expressions to determine their safety or behavior.
#[derive(Clone, Copy, Hash, Eq, PartialEq)]
enum CheckType {
Expand Down Expand Up @@ -132,7 +138,7 @@ impl ConditionalChecker {
fn from_expression(condition: &Expr<'_>) -> HashSet<Self> {
match condition.kind {
// Single `not` expressions are supported
ExprKind::Unary(op, condition) => Self::handle_condition(condition, op == UnOp::Not),
ExprKind::Unary(UnOp::Not, condition) => Self::handle_condition(condition, true),
// Multiple `or` expressions are supported
ExprKind::Binary(op, left_condition, right_condition) if op.node == BinOpKind::Or => {
let mut result = Self::from_expression(left_condition);
Expand All @@ -150,9 +156,33 @@ struct UnsafeUnwrapVisitor<'a, 'tcx> {
cx: &'a LateContext<'tcx>,
conditional_checker: HashSet<ConditionalChecker>,
checked_exprs: HashSet<HirId>,
returns_result_or_option: bool,
}

impl UnsafeUnwrapVisitor<'_, '_> {
fn get_help_message(&self, unwrap_type: UnwrapType) -> &'static str {
match (self.returns_result_or_option, unwrap_type) {
(true, UnwrapType::Option) => "Consider using `ok_or` to convert Option to Result",
(true, UnwrapType::Result) => "Consider using the `?` operator for error propagation",
(false, UnwrapType::Option) => {
"Consider pattern matching or using `if let` instead of `unwrap`"
}
(false, UnwrapType::Result) => {
"Consider handling the error case explicitly or using `if let` instead of `unwrap`"
}
}
}

fn determine_unwrap_type(&self, receiver: &Expr<'_>) -> UnwrapType {
let type_opt = get_node_type_opt(self.cx, &receiver.hir_id);
if let Some(type_) = type_opt {
if match_type_to_str(self.cx, type_, "Result") {
return UnwrapType::Result;
}
}
UnwrapType::Option
}

fn is_panic_inducing_call(&self, func: &Expr<'_>) -> bool {
if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind {
return PANIC_INDUCING_FUNCTIONS.iter().any(|&func| {
Expand All @@ -175,21 +205,23 @@ impl UnsafeUnwrapVisitor<'_, '_> {
None
}

fn set_conditional_checker(&mut self, conditional_checkers: &HashSet<ConditionalChecker>) {
for checker in conditional_checkers {
self.conditional_checker.insert(*checker);
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.insert(checker.checked_expr_hir_id);
}
}
}

fn reset_conditional_checker(&mut self, conditional_checkers: HashSet<ConditionalChecker>) {
fn update_conditional_checker(
&mut self,
conditional_checkers: &HashSet<ConditionalChecker>,
set: bool,
) {
for checker in conditional_checkers {
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.remove(&checker.checked_expr_hir_id);
if set {
self.conditional_checker.insert(*checker);
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.insert(checker.checked_expr_hir_id);
}
} else {
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.remove(&checker.checked_expr_hir_id);
}
self.conditional_checker.remove(checker);
}
self.conditional_checker.remove(&checker);
}
}

Expand Down Expand Up @@ -253,18 +285,20 @@ impl UnsafeUnwrapVisitor<'_, '_> {
}

impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> {
fn visit_local(&mut self, local: &'tcx rustc_hir::LetStmt<'tcx>) -> Self::Result {
fn visit_local(&mut self, local: &'tcx LetStmt<'tcx>) -> Self::Result {
if let Some(init) = local.init {
match init.kind {
ExprKind::MethodCall(path_segment, receiver, args, _) => {
if self.is_method_call_unsafe(path_segment, receiver, args) {
let unwrap_type = self.determine_unwrap_type(receiver);
let help_message = self.get_help_message(unwrap_type);
span_lint_and_help(
self.cx,
UNSAFE_UNWRAP,
local.span,
LINT_MESSAGE,
None,
"Please, use a custom error instead of `unwrap`",
help_message,
);
}
}
Expand Down Expand Up @@ -300,17 +334,17 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> {
}

// Find `if` or `if let` expressions
if let Some(higher::IfOrIfLet {
if let Some(IfOrIfLet {
cond,
then: if_expr,
r#else: _,
}) = higher::IfOrIfLet::hir(expr)
}) = IfOrIfLet::hir(expr)
{
// If we are interested in the condition (if it is a CheckType) we traverse the body.
let conditional_checker = ConditionalChecker::from_expression(cond);
self.set_conditional_checker(&conditional_checker);
self.update_conditional_checker(&conditional_checker, true);
walk_expr(self, if_expr);
self.reset_conditional_checker(conditional_checker);
self.update_conditional_checker(&conditional_checker, false);
return;
}

Expand All @@ -320,16 +354,18 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> {
if path_segment.ident.name == sym::unwrap;
then {
let receiver_hir_id = self.get_unwrap_info(receiver);
// If the receiver is `None`, then we asume that the `unwrap` is unsafe
// If the receiver is `None` or `Err`, then we assume that the `unwrap` is unsafe
let is_checked_safe = receiver_hir_id.map_or(false, |id| self.checked_exprs.contains(&id));
if !is_checked_safe {
let unwrap_type = self.determine_unwrap_type(receiver);
let help_message = self.get_help_message(unwrap_type);
span_lint_and_help(
self.cx,
UNSAFE_UNWRAP,
expr.span,
LINT_MESSAGE,
None,
"Please, use a custom error instead of `unwrap`",
help_message,
);
}
}
Expand All @@ -344,7 +380,7 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap {
&mut self,
cx: &LateContext<'tcx>,
_: FnKind<'tcx>,
_: &'tcx FnDecl<'tcx>,
fn_decl: &'tcx FnDecl<'tcx>,
body: &'tcx Body<'tcx>,
span: Span,
_: LocalDefId,
Expand All @@ -358,6 +394,8 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap {
cx,
checked_exprs: HashSet::new(),
conditional_checker: HashSet::new(),
returns_result_or_option: fn_returns(fn_decl, sym::Result)
|| fn_returns(fn_decl, sym::Option),
};

walk_expr(&mut visitor, body.value);
Expand Down
35 changes: 14 additions & 21 deletions utils/src/soroban_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ extern crate rustc_lint;
extern crate rustc_middle;
extern crate rustc_span;

use std::collections::HashSet;

use crate::type_utils::match_type_to_str;
use rustc_lint::LateContext;
use rustc_middle::ty::{Ty, TyKind};
use rustc_middle::ty::Ty;
use rustc_span::def_id::DefId;
use std::collections::HashSet;

/// Constants defining the fully qualified names of Soroban types.
const SOROBAN_ENV: &str = "soroban_sdk::Env";
Expand Down Expand Up @@ -58,28 +58,19 @@ pub fn is_soroban_function(
.all(|pattern| checked_functions.contains(pattern))
}

// Private helper function to match soroban types
fn is_soroban_type(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str) -> bool {
match expr_type.kind() {
TyKind::Adt(adt_def, _) => cx.tcx.def_path_str(adt_def.did()).contains(type_str),
TyKind::Ref(_, ty, _) => is_soroban_type(cx, *ty, type_str),
_ => false,
}
}

/// Checks if the provided type is a Soroban environment (`soroban_sdk::Env`).
pub fn is_soroban_env(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool {
is_soroban_type(cx, expr_type, SOROBAN_ENV)
match_type_to_str(cx, expr_type, SOROBAN_ENV)
}

/// Checks if the provided type is a Soroban Address (`soroban_sdk::Address`).
pub fn is_soroban_address(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool {
is_soroban_type(cx, expr_type, SOROBAN_ADDRESS)
match_type_to_str(cx, expr_type, SOROBAN_ADDRESS)
}

/// Checks if the provided type is a Soroban Map (`soroban_sdk::Map`).
pub fn is_soroban_map(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool {
is_soroban_type(cx, expr_type, SOROBAN_MAP)
match_type_to_str(cx, expr_type, SOROBAN_MAP)
}

pub enum SorobanStorageType {
Expand All @@ -97,14 +88,16 @@ pub fn is_soroban_storage(
) -> bool {
match storage_type {
SorobanStorageType::Any => {
is_soroban_type(cx, expr_type, SOROBAN_INSTANCE_STORAGE)
|| is_soroban_type(cx, expr_type, SOROBAN_TEMPORARY_STORAGE)
|| is_soroban_type(cx, expr_type, SOROBAN_PERSISTENT_STORAGE)
match_type_to_str(cx, expr_type, SOROBAN_INSTANCE_STORAGE)
|| match_type_to_str(cx, expr_type, SOROBAN_TEMPORARY_STORAGE)
|| match_type_to_str(cx, expr_type, SOROBAN_PERSISTENT_STORAGE)
}
SorobanStorageType::Instance => match_type_to_str(cx, expr_type, SOROBAN_INSTANCE_STORAGE),
SorobanStorageType::Temporary => {
match_type_to_str(cx, expr_type, SOROBAN_TEMPORARY_STORAGE)
}
SorobanStorageType::Instance => is_soroban_type(cx, expr_type, SOROBAN_INSTANCE_STORAGE),
SorobanStorageType::Temporary => is_soroban_type(cx, expr_type, SOROBAN_TEMPORARY_STORAGE),
SorobanStorageType::Persistent => {
is_soroban_type(cx, expr_type, SOROBAN_PERSISTENT_STORAGE)
match_type_to_str(cx, expr_type, SOROBAN_PERSISTENT_STORAGE)
}
}
}
27 changes: 25 additions & 2 deletions utils/src/type_utils.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
extern crate rustc_hir;
extern crate rustc_lint;
extern crate rustc_middle;
extern crate rustc_span;

use rustc_hir::HirId;
use rustc_hir::{FnDecl, FnRetTy, HirId, QPath};
use rustc_lint::LateContext;
use rustc_middle::ty::Ty;
use rustc_middle::ty::{Ty, TyKind};
use rustc_span::Symbol;

/// Get the type of a node, if it exists.
pub fn get_node_type_opt<'tcx>(cx: &LateContext<'tcx>, hir_id: &HirId) -> Option<Ty<'tcx>> {
cx.typeck_results().node_type_opt(*hir_id)
}

/// Match the type of an expression to a string.
pub fn match_type_to_str(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str) -> bool {
match expr_type.kind() {
TyKind::Adt(adt_def, _) => cx.tcx.def_path_str(adt_def.did()).contains(type_str),
TyKind::Ref(_, ty, _) => match_type_to_str(cx, *ty, type_str),
_ => false,
}
}

/// Check the return type of a function.
pub fn fn_returns(decl: &FnDecl<'_>, type_symbol: Symbol) -> bool {
if let FnRetTy::Return(ty) = decl.output {
matches!(ty.kind, rustc_hir::TyKind::Path(QPath::Resolved(_, path)) if path
.segments
.last()
.map_or(false, |seg| seg.ident.name == type_symbol))
} else {
false
}
}

0 comments on commit 1b84450

Please sign in to comment.