Skip to content

Commit

Permalink
assert violation detector + test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sofiazcoaga committed Apr 3, 2024
1 parent 4e26132 commit 7f15e8f
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 0 deletions.
20 changes: 20 additions & 0 deletions detectors/assert-violation/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[package]
name = "assert-violation"
version = "0.1.0"
edition = "2021"

[lib]
crate-type = ["cdylib"]

[dependencies]
scout-audit-clippy-utils = { workspace = true }
dylint_linting = { workspace = true }
if_chain = { workspace = true }

scout-audit-internal = { workspace = true }

[dev-dependencies]
dylint_testing = { workspace = true }

[package.metadata.rust-analyzer]
rustc_private = true
Empty file.
161 changes: 161 additions & 0 deletions detectors/assert-violation/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#![feature(rustc_private)]

extern crate rustc_ast;
extern crate rustc_span;

use if_chain::if_chain;
use rustc_ast::{
ptr::P,
tokenstream::{TokenStream, TokenTree},
AttrArgs, AttrKind, Expr, ExprKind, Item, MacCall, Stmt, StmtKind,
};
use rustc_lint::{EarlyContext, EarlyLintPass};
use rustc_span::{sym, Span};
use scout_audit_clippy_utils::sym;
use scout_audit_internal::{DetectorImpl, InkDetector as Detector};

dylint_linting::impl_pre_expansion_lint! {
/// ### What it does
/// Checks for `assert!` usage.
/// ### Why is this bad?
/// `assert!` causes a panic, and panicking it's not a good practice. Instead, use proper error handling.
/// ### Example
/// ```rust
/// #[ink(message)]
/// pub fn assert_if_greater_than_10(&self, value: u128) -> bool {
/// assert!(value <= 10, "value should be less than 10");
/// true
/// }
/// ```
/// Use instead:
///```rust
/// #[derive(Debug, PartialEq, Eq, scale::Encode, scale::Decode)]
/// #[cfg_attr(feature = "std", derive(scale_info::TypeInfo))]
/// pub enum Error {
/// GreaterThan10,
/// }
///
/// #[ink(message)]
/// pub fn revert_if_greater_than_10(&self, value: u128) -> Result<bool, Error> {
/// if value <= 10 {
/// return Ok(true)
/// } else {
/// return Err(Error::GreaterThan10)
/// }
/// }
///```
pub ASSERT_VIOLATION,
Warn,
"",
AssertViolation::default()
}

#[derive(Default)]
pub struct AssertViolation {
in_test_span: Option<Span>,
}

impl AssertViolation {
fn in_test_item(&self) -> bool {
self.in_test_span.is_some()
}
}

impl EarlyLintPass for AssertViolation {
fn check_item(&mut self, _cx: &EarlyContext, item: &Item) {
match (is_test_item(item), self.in_test_span) {
(true, None) => self.in_test_span = Some(item.span),
(true, Some(test_span)) => {
if !test_span.contains(item.span) {
self.in_test_span = Some(item.span);
}
}
(false, None) => {}
(false, Some(test_span)) => {
if !test_span.contains(item.span) {
self.in_test_span = None;
}
}
};
}

fn check_stmt(&mut self, cx: &EarlyContext, stmt: &Stmt) {
if self.in_test_item() {
return;
}

if let StmtKind::MacCall(mac) = &stmt.kind {
check_macro_call(cx, stmt.span, &mac.mac)
}
}
fn check_expr(&mut self, cx: &EarlyContext, expr: &Expr) {
if self.in_test_item() {
return;
}

if let ExprKind::MacCall(mac) = &expr.kind {
check_macro_call(cx, expr.span, mac)
}
}
}

fn check_macro_call(cx: &EarlyContext, span: Span, mac: &P<MacCall>) {
if [
sym!(assert),
sym!(assert_eq),
sym!(assert_ne),
sym!(debug_assert),
sym!(debug_assert_eq),
sym!(debug_assert_ne),
]
.iter()
.any(|sym| &mac.path == sym)
{
Detector::AssertViolation.span_lint_and_help(
cx,
ASSERT_VIOLATION,
span,
"You could use instead an Error enum.",
);
}
}

fn is_test_item(item: &Item) -> bool {
item.attrs.iter().any(|attr| {
// Find #[cfg(all(test, feature = "e2e-tests"))]
if_chain!(
if let AttrKind::Normal(normal) = &attr.kind;
if let AttrArgs::Delimited(delim_args) = &normal.item.args;
if is_test_token_present(&delim_args.tokens);
then {
return true;
}
);

// Find unit or integration tests
if attr.has_name(sym::test) {
return true;
}

if_chain! {
if attr.has_name(sym::cfg);
if let Some(items) = attr.meta_item_list();
if let [item] = items.as_slice();
if let Some(feature_item) = item.meta_item();
if feature_item.has_name(sym::test);
then {
return true;
}
}

false
})
}

fn is_test_token_present(token_stream: &TokenStream) -> bool {
token_stream.trees().any(|tree| match tree {
TokenTree::Token(token, _) => token.is_ident_named(sym::test),
TokenTree::Delimited(_, _, _, token_stream) => is_test_token_present(token_stream),
})
}
33 changes: 33 additions & 0 deletions test-cases/assert-violation/remediated-example/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[package]
name = "assert_violation"
version = "0.1.0"
edition = "2021"


[lib]
crate-type = ["cdylib"]
path = "lib.rs"

[dependencies]
soroban-sdk = "=20.0.0"

[dev_dependencies]
soroban-sdk = { version = "=20.0.0", features = ["testutils"] }

[features]
testutils = ["soroban-sdk/testutils"]

[profile.release]
opt-level = "z"
overflow-checks = true
debug = 0
strip = "symbols"
debug-assertions = false
panic = "abort"
codegen-units = 1
lto = true

[profile.release-with-logs]
inherits = "release"
debug-assertions = true

62 changes: 62 additions & 0 deletions test-cases/assert-violation/remediated-example/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#![no_std]

use soroban_sdk::{contracterror, contract, contractimpl, contracttype, Env, Symbol, symbol_short};

#[contracterror]
#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
#[repr(u32)]
pub enum AVError {
GreaterThan10 = 1,
}

#[derive(Debug, Clone)]
#[contracttype]
pub struct State {
value: u128,
}

const STATE: Symbol = symbol_short!("STATE");
#[contract]
pub struct AssertViolation;

#[contractimpl]
impl AssertViolation {
pub fn init(env: Env, init_value: u128) -> State {
let state = State {
value: init_value
};

env.storage().instance().set(&STATE, &state);
state

}

pub fn assert_if_greater_than_10(_env: Env, value: u128) -> Result<bool, AVError> {
if value <= 10 {
Ok(true)
} else {
Err(AVError::GreaterThan10)
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn does_not_revert_if_greater() {
let env = Env::default();
let contract = AssertViolationClient::new(&env, &env.register_contract(None, AssertViolation{}));
assert_eq!(contract.assert_if_greater_than_10(&5), true);
}

#[test]
#[should_panic(expected = 1)] // The custom error number is 1
fn reverts_if_greater() {
let env = Env::default();
let contract = AssertViolationClient::new(&env, &env.register_contract(None, AssertViolation{}));
contract.assert_if_greater_than_10(&11);
}

}
33 changes: 33 additions & 0 deletions test-cases/assert-violation/vulnerable-example/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[package]
name = "assert_violation"
version = "0.1.0"
edition = "2021"


[lib]
crate-type = ["cdylib"]
path = "lib.rs"

[dependencies]
soroban-sdk = "=20.0.0"

[dev_dependencies]
soroban-sdk = { version = "=20.0.0", features = ["testutils"] }

[features]
testutils = ["soroban-sdk/testutils"]

[profile.release]
opt-level = "z"
overflow-checks = true
debug = 0
strip = "symbols"
debug-assertions = false
panic = "abort"
codegen-units = 1
lto = true

[profile.release-with-logs]
inherits = "release"
debug-assertions = true

52 changes: 52 additions & 0 deletions test-cases/assert-violation/vulnerable-example/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#![no_std]

use soroban_sdk::{contract, contractimpl, contracttype, Env, Symbol, symbol_short};

#[derive(Debug, Clone)]
#[contracttype]
pub struct State {
value: u128,
}

const STATE: Symbol = symbol_short!("STATE");
#[contract]
pub struct AssertViolation;

#[contractimpl]
impl AssertViolation {
pub fn init(env: Env, init_value: u128) -> State {
let state = State {
value: init_value
};

env.storage().instance().set(&STATE, &state);
state

}

pub fn assert_if_greater_than_10(_env: Env, value: u128) -> bool {
assert!(value <= 10, "value should be less than 10");
true
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn does_not_revert_if_greater() {
let env = Env::default();
let contract = AssertViolationClient::new(&env, &env.register_contract(None, AssertViolation{}));
assert_eq!(contract.assert_if_greater_than_10(&5), true);
}

#[test]
#[should_panic(expected = "value should be less than 10")]
fn reverts_if_greater() {
let env = Env::default();
let contract = AssertViolationClient::new(&env, &env.register_contract(None, AssertViolation{}));
contract.assert_if_greater_than_10(&11);
}

}

0 comments on commit 7f15e8f

Please sign in to comment.