Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve unsafe-expect detector #336

Merged
merged 8 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions detectors/unsafe-expect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ clippy_utils = { workspace = true }
clippy_wrappers = { workspace = true }
dylint_linting = { workspace = true }
if_chain = { workspace = true }
utils = { workspace = true }

[package.metadata.rust-analyzer]
rustc_private = true
143 changes: 71 additions & 72 deletions detectors/unsafe-expect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
extern crate rustc_hir;
extern crate rustc_span;

use std::{collections::HashSet, hash::Hash};

use clippy_utils::higher;
use clippy_utils::higher::IfOrIfLet;
use clippy_wrappers::span_lint_and_help;
use if_chain::if_chain;
use rustc_hir::{
Expand All @@ -17,6 +15,8 @@ use rustc_hir::{
};
use rustc_lint::{LateContext, LateLintPass};
use rustc_span::{sym, Span, Symbol};
use std::{collections::HashSet, hash::Hash};
use utils::fn_returns;

const LINT_MESSAGE: &str = "Unsafe usage of `expect`";
const PANIC_INDUCING_FUNCTIONS: [&str; 2] = ["panic", "bail"];
Expand Down Expand Up @@ -174,21 +174,23 @@ impl UnsafeExpectVisitor<'_, '_> {
None
}

fn set_conditional_checker(&mut self, conditional_checkers: &HashSet<ConditionalChecker>) {
for checker in conditional_checkers {
self.conditional_checker.insert(*checker);
if checker.check_type.is_safe_to_expect() {
self.checked_exprs.insert(checker.checked_expr_hir_id);
}
}
}

fn reset_conditional_checker(&mut self, conditional_checkers: HashSet<ConditionalChecker>) {
fn update_conditional_checker(
&mut self,
conditional_checkers: &HashSet<ConditionalChecker>,
set: bool,
) {
for checker in conditional_checkers {
if checker.check_type.is_safe_to_expect() {
self.checked_exprs.remove(&checker.checked_expr_hir_id);
if set {
self.conditional_checker.insert(*checker);
if checker.check_type.is_safe_to_expect() {
self.checked_exprs.insert(checker.checked_expr_hir_id);
}
} else {
if checker.check_type.is_safe_to_expect() {
self.checked_exprs.remove(&checker.checked_expr_hir_id);
}
self.conditional_checker.remove(checker);
}
self.conditional_checker.remove(&checker);
}
}

Expand Down Expand Up @@ -236,7 +238,6 @@ impl UnsafeExpectVisitor<'_, '_> {
.get_expect_info(receiver)
.map_or(true, |id| !self.checked_exprs.contains(&id));
}

args.iter().any(|arg| self.contains_unsafe_method_call(arg))
|| self.contains_unsafe_method_call(receiver)
}
Expand All @@ -249,40 +250,54 @@ impl UnsafeExpectVisitor<'_, '_> {
_ => false,
}
}
}

impl<'a, 'tcx> Visitor<'tcx> for UnsafeExpectVisitor<'a, 'tcx> {
fn visit_local(&mut self, local: &'tcx rustc_hir::LetStmt<'tcx>) -> Self::Result {
if let Some(init) = local.init {
match init.kind {
ExprKind::MethodCall(path_segment, receiver, args, _) => {
if self.is_method_call_unsafe(path_segment, receiver, args) {
span_lint_and_help(
self.cx,
UNSAFE_EXPECT,
local.span,
LINT_MESSAGE,
None,
"Please, use a custom error instead of `expect`",
);
}
fn check_expr_for_unsafe_expect(&mut self, expr: &Expr) {
match &expr.kind {
ExprKind::MethodCall(path_segment, receiver, args, _) => {
if self.is_method_call_unsafe(path_segment, receiver, args) {
span_lint_and_help(
self.cx,
UNSAFE_EXPECT,
expr.span,
LINT_MESSAGE,
None,
"Please, use a custom error instead of `expect`",
);
}
ExprKind::Call(func, args) => {
if let ExprKind::Path(QPath::Resolved(_, path)) = func.kind {
let is_some_or_ok = path
.segments
.iter()
.any(|segment| matches!(segment.ident.name, sym::Some | sym::Ok));
let all_literals = args
.iter()
.all(|arg| self.is_literal_or_composed_of_literals(arg));
if is_some_or_ok && all_literals {
self.checked_exprs.insert(local.pat.hir_id);
}
}
ExprKind::Call(func, args) => {
if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind {
let is_some_or_ok = path
.segments
.iter()
.any(|segment| matches!(segment.ident.name, sym::Some | sym::Ok));
let all_literals = args
.iter()
.all(|arg| self.is_literal_or_composed_of_literals(arg));
if is_some_or_ok && all_literals {
self.checked_exprs.insert(expr.hir_id);
return;
}
}
_ => {}
// Check arguments for unsafe expect
for arg in args.iter() {
self.check_expr_for_unsafe_expect(arg);
}
}
ExprKind::Tup(exprs) | ExprKind::Array(exprs) => {
for expr in exprs.iter() {
self.check_expr_for_unsafe_expect(expr);
}
}
_ => {}
}
}
}

impl<'a, 'tcx> Visitor<'tcx> for UnsafeExpectVisitor<'a, 'tcx> {
fn visit_local(&mut self, local: &'tcx rustc_hir::LetStmt<'tcx>) -> Self::Result {
if let Some(init) = local.init {
self.check_expr_for_unsafe_expect(init);
}
}

Expand All @@ -299,40 +314,22 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeExpectVisitor<'a, 'tcx> {
}

// Find `if` or `if let` expressions
if let Some(higher::IfOrIfLet {
if let Some(IfOrIfLet {
cond,
then: if_expr,
r#else: _,
}) = higher::IfOrIfLet::hir(expr)
}) = IfOrIfLet::hir(expr)
{
// If we are interested in the condition (if it is a CheckType) we traverse the body.
let conditional_checker = ConditionalChecker::from_expression(cond);
self.set_conditional_checker(&conditional_checker);
self.update_conditional_checker(&conditional_checker, true);
walk_expr(self, if_expr);
self.reset_conditional_checker(conditional_checker);
self.update_conditional_checker(&conditional_checker, false);
return;
}

// If we find an unsafe `expect`, we raise a warning
if_chain! {
if let ExprKind::MethodCall(path_segment, receiver, _, _) = &expr.kind;
if path_segment.ident.name == sym::expect;
then {
let receiver_hir_id = self.get_expect_info(receiver);
// If the receiver is `None`, then we asume that the `expect` is unsafe
let is_checked_safe = receiver_hir_id.map_or(false, |id| self.checked_exprs.contains(&id));
if !is_checked_safe {
span_lint_and_help(
self.cx,
UNSAFE_EXPECT,
expr.span,
LINT_MESSAGE,
None,
"Please, use a custom error instead of `expect`",
);
}
}
}
self.check_expr_for_unsafe_expect(expr);

walk_expr(self, expr);
}
Expand All @@ -343,13 +340,15 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeExpect {
&mut self,
cx: &LateContext<'tcx>,
_: FnKind<'tcx>,
_: &'tcx FnDecl<'tcx>,
fn_decl: &'tcx FnDecl<'tcx>,
body: &'tcx Body<'tcx>,
span: Span,
_: LocalDefId,
) {
// If the function comes from a macro expansion, we don't want to analyze it.
if span.from_expansion() {
// If the function comes from a macro expansion or does not return a Result<(), ()> or Option<()>, we don't want to analyze it.
if span.from_expansion()
|| !fn_returns(fn_decl, sym::Result) && !fn_returns(fn_decl, sym::Option)
{
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ impl UnsafeExpect {
}

// Returns the balance of a given account.
pub fn balance_of(env: Env, owner: Address) -> i128 {
pub fn balance_of(env: Env, owner: Address) -> Option<i128> {
let state = Self::get_state(env);
state.balances.get(owner).unwrap_or(0)
let balance = state.balances.get(owner).unwrap_or(0);
Some(balance)
}

/// Return the current state.
Expand Down Expand Up @@ -68,7 +69,7 @@ mod tests {

// Then
let balance = client.balance_of(&contract_id);
assert_eq!(TOTAL_SUPPLY, balance);
assert_eq!(TOTAL_SUPPLY, balance.unwrap());
}

#[test]
Expand All @@ -82,6 +83,6 @@ mod tests {

// Then
let balance = client.balance_of(&contract_id);
assert_eq!(0, balance);
assert_eq!(0, balance.unwrap());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ impl UnsafeExpect {
}

// Returns the balance of a given account.
pub fn balance_of(env: Env, owner: Address) -> i128 {
pub fn balance_of(env: Env, owner: Address) -> Option<i128> {
let state = Self::get_state(env);
state.balances.get(owner).expect("could not get balance")
let balance = state.balances.get(owner).expect("could not get balance");
Some(balance)
}

/// Return the current state.
Expand All @@ -49,6 +50,8 @@ mod tests {

use soroban_sdk::Env;

extern crate std;

use crate::{UnsafeExpect, UnsafeExpectClient, TOTAL_SUPPLY};

#[test]
Expand All @@ -63,22 +66,6 @@ mod tests {

// Then
let balance = client.balance_of(&contract_id);
assert_eq!(TOTAL_SUPPLY, balance);
}

#[test]
#[should_panic(expected = "could not get balance")]
fn balance_of_expect_works() {
// Given
let env = Env::default();
let contract_id = env.register_contract(None, UnsafeExpect);
let client = UnsafeExpectClient::new(&env, &contract_id);

// When - Balance not set

// Then
let _balance = client.balance_of(&contract_id);

// Test should panic
assert_eq!(TOTAL_SUPPLY, balance.unwrap());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@ impl UnsafeExpect {
}

// Returns the balance of a given account.
pub fn balance_of(env: Env, owner: Address) -> i128 {
pub fn balance_of(env: Env, owner: Address) -> Option<i128> {
let state = Self::get_state(env);
let balance = state.balances.get(owner);
if balance.is_none() {
return 0;
}
balance.expect("could not get balance")
balance?;
Some(balance.expect("could not get balance"))
}

/// Return the current state.
Expand Down Expand Up @@ -72,7 +70,7 @@ mod tests {

// Then
let balance = client.balance_of(&contract_id);
assert_eq!(TOTAL_SUPPLY, balance);
assert_eq!(TOTAL_SUPPLY, balance.unwrap());
}

#[test]
Expand All @@ -86,6 +84,6 @@ mod tests {

// Then
let balance = client.balance_of(&contract_id);
assert_eq!(0, balance);
assert_eq!(None, balance);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ impl UnsafeExpect {
}

// Returns the balance of a given account.
pub fn balance_of(env: Env, owner: Address) -> i128 {
pub fn balance_of(env: Env, owner: Address) -> Option<i128> {
let state = Self::get_state(env);
state.balances.get(owner).expect("could not get balance")
let balance = state.balances.get(owner).expect("could not get balance");
Some(balance)
}

/// Return the current state.
Expand Down Expand Up @@ -63,22 +64,6 @@ mod tests {

// Then
let balance = client.balance_of(&contract_id);
assert_eq!(TOTAL_SUPPLY, balance);
}

#[test]
#[should_panic(expected = "could not get balance")]
fn balance_of_expect_works() {
// Given
let env = Env::default();
let contract_id = env.register_contract(None, UnsafeExpect);
let client = UnsafeExpectClient::new(&env, &contract_id);

// When - Balance not set

// Then
let _balance = client.balance_of(&contract_id);

// Test should panic
assert_eq!(TOTAL_SUPPLY, balance.unwrap());
}
}
Loading