Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve unsafe-unwrap detector #333

Merged
merged 7 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions detectors/unsafe-unwrap/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ clippy_utils = { workspace = true }
clippy_wrappers = { workspace = true }
dylint_linting = { workspace = true }
if_chain = { workspace = true }
utils = { workspace = true }

[package.metadata.rust-analyzer]
rustc_private = true
94 changes: 66 additions & 28 deletions detectors/unsafe-unwrap/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
#![feature(rustc_private)]
#![allow(clippy::enum_variant_names)]

extern crate rustc_ast;
extern crate rustc_hir;
extern crate rustc_span;

use std::{collections::HashSet, hash::Hash};

use clippy_utils::higher;
use clippy_utils::higher::IfOrIfLet;
use clippy_wrappers::span_lint_and_help;
use if_chain::if_chain;
use rustc_hir::{
def::Res,
def_id::LocalDefId,
intravisit::{walk_expr, FnKind, Visitor},
BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, PathSegment, QPath, UnOp,
BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, LetStmt, PathSegment, QPath, UnOp,
};
use rustc_lint::{LateContext, LateLintPass};
use rustc_span::{sym, Span, Symbol};
use std::{collections::HashSet, hash::Hash};
use utils::{fn_returns, get_node_type_opt, match_type_to_str};

const LINT_MESSAGE: &str = "Unsafe usage of `unwrap`";
const PANIC_INDUCING_FUNCTIONS: [&str; 2] = ["panic", "bail"];
Expand Down Expand Up @@ -65,6 +64,13 @@ dylint_linting::declare_late_lint! {
}
}

// Enum to represent the type being unwrapped
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum UnwrapType {
Option,
Result,
}

/// Represents the type of check performed on method call expressions to determine their safety or behavior.
#[derive(Clone, Copy, Hash, Eq, PartialEq)]
enum CheckType {
Expand Down Expand Up @@ -132,7 +138,7 @@ impl ConditionalChecker {
fn from_expression(condition: &Expr<'_>) -> HashSet<Self> {
match condition.kind {
// Single `not` expressions are supported
ExprKind::Unary(op, condition) => Self::handle_condition(condition, op == UnOp::Not),
ExprKind::Unary(UnOp::Not, condition) => Self::handle_condition(condition, true),
// Multiple `or` expressions are supported
ExprKind::Binary(op, left_condition, right_condition) if op.node == BinOpKind::Or => {
let mut result = Self::from_expression(left_condition);
Expand All @@ -150,9 +156,33 @@ struct UnsafeUnwrapVisitor<'a, 'tcx> {
cx: &'a LateContext<'tcx>,
conditional_checker: HashSet<ConditionalChecker>,
checked_exprs: HashSet<HirId>,
returns_result_or_option: bool,
}

impl UnsafeUnwrapVisitor<'_, '_> {
fn get_help_message(&self, unwrap_type: UnwrapType) -> &'static str {
match (self.returns_result_or_option, unwrap_type) {
(true, UnwrapType::Option) => "Consider using `ok_or` to convert Option to Result",
(true, UnwrapType::Result) => "Consider using the `?` operator for error propagation",
(false, UnwrapType::Option) => {
"Consider pattern matching or using `if let` instead of `unwrap`"
}
(false, UnwrapType::Result) => {
"Consider handling the error case explicitly or using `if let` instead of `unwrap`"
}
}
}

fn determine_unwrap_type(&self, receiver: &Expr<'_>) -> UnwrapType {
let type_opt = get_node_type_opt(self.cx, &receiver.hir_id);
if let Some(type_) = type_opt {
if match_type_to_str(self.cx, type_, "Result") {
return UnwrapType::Result;
}
}
UnwrapType::Option
}

fn is_panic_inducing_call(&self, func: &Expr<'_>) -> bool {
if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind {
return PANIC_INDUCING_FUNCTIONS.iter().any(|&func| {
Expand All @@ -175,21 +205,23 @@ impl UnsafeUnwrapVisitor<'_, '_> {
None
}

fn set_conditional_checker(&mut self, conditional_checkers: &HashSet<ConditionalChecker>) {
for checker in conditional_checkers {
self.conditional_checker.insert(*checker);
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.insert(checker.checked_expr_hir_id);
}
}
}

fn reset_conditional_checker(&mut self, conditional_checkers: HashSet<ConditionalChecker>) {
fn update_conditional_checker(
&mut self,
conditional_checkers: &HashSet<ConditionalChecker>,
set: bool,
) {
for checker in conditional_checkers {
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.remove(&checker.checked_expr_hir_id);
if set {
self.conditional_checker.insert(*checker);
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.insert(checker.checked_expr_hir_id);
}
} else {
if checker.check_type.is_safe_to_unwrap() {
self.checked_exprs.remove(&checker.checked_expr_hir_id);
}
self.conditional_checker.remove(checker);
}
self.conditional_checker.remove(&checker);
}
}

Expand Down Expand Up @@ -253,18 +285,20 @@ impl UnsafeUnwrapVisitor<'_, '_> {
}

impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> {
fn visit_local(&mut self, local: &'tcx rustc_hir::LetStmt<'tcx>) -> Self::Result {
fn visit_local(&mut self, local: &'tcx LetStmt<'tcx>) -> Self::Result {
if let Some(init) = local.init {
match init.kind {
ExprKind::MethodCall(path_segment, receiver, args, _) => {
if self.is_method_call_unsafe(path_segment, receiver, args) {
let unwrap_type = self.determine_unwrap_type(receiver);
let help_message = self.get_help_message(unwrap_type);
span_lint_and_help(
self.cx,
UNSAFE_UNWRAP,
local.span,
LINT_MESSAGE,
None,
"Please, use a custom error instead of `unwrap`",
help_message,
);
}
}
Expand Down Expand Up @@ -300,17 +334,17 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> {
}

// Find `if` or `if let` expressions
if let Some(higher::IfOrIfLet {
if let Some(IfOrIfLet {
cond,
then: if_expr,
r#else: _,
}) = higher::IfOrIfLet::hir(expr)
}) = IfOrIfLet::hir(expr)
{
// If we are interested in the condition (if it is a CheckType) we traverse the body.
let conditional_checker = ConditionalChecker::from_expression(cond);
self.set_conditional_checker(&conditional_checker);
self.update_conditional_checker(&conditional_checker, true);
walk_expr(self, if_expr);
self.reset_conditional_checker(conditional_checker);
self.update_conditional_checker(&conditional_checker, false);
return;
}

Expand All @@ -320,16 +354,18 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> {
if path_segment.ident.name == sym::unwrap;
then {
let receiver_hir_id = self.get_unwrap_info(receiver);
// If the receiver is `None`, then we asume that the `unwrap` is unsafe
// If the receiver is `None` or `Err`, then we assume that the `unwrap` is unsafe
let is_checked_safe = receiver_hir_id.map_or(false, |id| self.checked_exprs.contains(&id));
if !is_checked_safe {
let unwrap_type = self.determine_unwrap_type(receiver);
let help_message = self.get_help_message(unwrap_type);
span_lint_and_help(
self.cx,
UNSAFE_UNWRAP,
expr.span,
LINT_MESSAGE,
None,
"Please, use a custom error instead of `unwrap`",
help_message,
);
}
}
Expand All @@ -344,7 +380,7 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap {
&mut self,
cx: &LateContext<'tcx>,
_: FnKind<'tcx>,
_: &'tcx FnDecl<'tcx>,
fn_decl: &'tcx FnDecl<'tcx>,
body: &'tcx Body<'tcx>,
span: Span,
_: LocalDefId,
Expand All @@ -358,6 +394,8 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap {
cx,
checked_exprs: HashSet::new(),
conditional_checker: HashSet::new(),
returns_result_or_option: fn_returns(fn_decl, sym::Result)
|| fn_returns(fn_decl, sym::Option),
};

walk_expr(&mut visitor, body.value);
Expand Down
35 changes: 14 additions & 21 deletions utils/src/soroban_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ extern crate rustc_lint;
extern crate rustc_middle;
extern crate rustc_span;

use std::collections::HashSet;

use crate::type_utils::match_type_to_str;
use rustc_lint::LateContext;
use rustc_middle::ty::{Ty, TyKind};
use rustc_middle::ty::Ty;
use rustc_span::def_id::DefId;
use std::collections::HashSet;

/// Constants defining the fully qualified names of Soroban types.
const SOROBAN_ENV: &str = "soroban_sdk::Env";
Expand Down Expand Up @@ -58,28 +58,19 @@ pub fn is_soroban_function(
.all(|pattern| checked_functions.contains(pattern))
}

// Private helper function to match soroban types
fn is_soroban_type(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str) -> bool {
match expr_type.kind() {
TyKind::Adt(adt_def, _) => cx.tcx.def_path_str(adt_def.did()).contains(type_str),
TyKind::Ref(_, ty, _) => is_soroban_type(cx, *ty, type_str),
_ => false,
}
}

/// Checks if the provided type is a Soroban environment (`soroban_sdk::Env`).
pub fn is_soroban_env(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool {
is_soroban_type(cx, expr_type, SOROBAN_ENV)
match_type_to_str(cx, expr_type, SOROBAN_ENV)
}

/// Checks if the provided type is a Soroban Address (`soroban_sdk::Address`).
pub fn is_soroban_address(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool {
is_soroban_type(cx, expr_type, SOROBAN_ADDRESS)
match_type_to_str(cx, expr_type, SOROBAN_ADDRESS)
}

/// Checks if the provided type is a Soroban Map (`soroban_sdk::Map`).
pub fn is_soroban_map(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool {
is_soroban_type(cx, expr_type, SOROBAN_MAP)
match_type_to_str(cx, expr_type, SOROBAN_MAP)
}

pub enum SorobanStorageType {
Expand All @@ -97,14 +88,16 @@ pub fn is_soroban_storage(
) -> bool {
match storage_type {
SorobanStorageType::Any => {
is_soroban_type(cx, expr_type, SOROBAN_INSTANCE_STORAGE)
|| is_soroban_type(cx, expr_type, SOROBAN_TEMPORARY_STORAGE)
|| is_soroban_type(cx, expr_type, SOROBAN_PERSISTENT_STORAGE)
match_type_to_str(cx, expr_type, SOROBAN_INSTANCE_STORAGE)
|| match_type_to_str(cx, expr_type, SOROBAN_TEMPORARY_STORAGE)
|| match_type_to_str(cx, expr_type, SOROBAN_PERSISTENT_STORAGE)
}
SorobanStorageType::Instance => match_type_to_str(cx, expr_type, SOROBAN_INSTANCE_STORAGE),
SorobanStorageType::Temporary => {
match_type_to_str(cx, expr_type, SOROBAN_TEMPORARY_STORAGE)
}
SorobanStorageType::Instance => is_soroban_type(cx, expr_type, SOROBAN_INSTANCE_STORAGE),
SorobanStorageType::Temporary => is_soroban_type(cx, expr_type, SOROBAN_TEMPORARY_STORAGE),
SorobanStorageType::Persistent => {
is_soroban_type(cx, expr_type, SOROBAN_PERSISTENT_STORAGE)
match_type_to_str(cx, expr_type, SOROBAN_PERSISTENT_STORAGE)
}
}
}
27 changes: 25 additions & 2 deletions utils/src/type_utils.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
extern crate rustc_hir;
extern crate rustc_lint;
extern crate rustc_middle;
extern crate rustc_span;

use rustc_hir::HirId;
use rustc_hir::{FnDecl, FnRetTy, HirId, QPath};
use rustc_lint::LateContext;
use rustc_middle::ty::Ty;
use rustc_middle::ty::{Ty, TyKind};
use rustc_span::Symbol;

/// Get the type of a node, if it exists.
pub fn get_node_type_opt<'tcx>(cx: &LateContext<'tcx>, hir_id: &HirId) -> Option<Ty<'tcx>> {
cx.typeck_results().node_type_opt(*hir_id)
}

/// Match the type of an expression to a string.
pub fn match_type_to_str(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str) -> bool {
match expr_type.kind() {
TyKind::Adt(adt_def, _) => cx.tcx.def_path_str(adt_def.did()).contains(type_str),
TyKind::Ref(_, ty, _) => match_type_to_str(cx, *ty, type_str),
_ => false,
}
}

/// Check the return type of a function.
pub fn fn_returns(decl: &FnDecl<'_>, type_symbol: Symbol) -> bool {
if let FnRetTy::Return(ty) = decl.output {
matches!(ty.kind, rustc_hir::TyKind::Path(QPath::Resolved(_, path)) if path
.segments
.last()
.map_or(false, |seg| seg.ident.name == type_symbol))
} else {
false
}
}