diff --git a/detectors/avoid-panic-error/src/lib.rs b/detectors/avoid-panic-error/src/lib.rs index 58f09097..72895519 100644 --- a/detectors/avoid-panic-error/src/lib.rs +++ b/detectors/avoid-panic-error/src/lib.rs @@ -3,22 +3,21 @@ extern crate rustc_ast; extern crate rustc_span; -use clippy_utils::sym; -use clippy_wrappers::span_lint_and_help; -use if_chain::if_chain; +use clippy_utils::diagnostics::span_lint_and_help; use rustc_ast::{ - ptr::P, - token::{LitKind, TokenKind}, - tokenstream::{TokenStream, TokenTree}, - AttrArgs, AttrKind, Expr, ExprKind, Item, MacCall, StmtKind, + tokenstream::TokenTree, + visit::{walk_expr, Visitor}, + AssocItemKind, AttrArgs, AttrKind, Block, Expr, ExprKind, FnRetTy, Item, ItemKind, MacCall, + ModKind, TyKind, }; use rustc_lint::{EarlyContext, EarlyLintPass}; use rustc_span::{sym, Span}; -const LINT_MESSAGE: &str = "The panic! macro is used to stop execution when a condition is not met. Even when this does not break the execution of the contract, it is recommended to use Result instead of panic! because it will stop the execution of the caller contract"; +const LINT_MESSAGE: &str = "The panic! macro is used in a function that returns Result. \ + Consider using the ? operator or return Err() instead."; dylint_linting::impl_pre_expansion_lint! { - /// ### What it does + /// ### What it does /// The panic! macro is used to stop execution when a condition is not met. /// This is useful for testing and prototyping, but should be avoided in production code /// @@ -62,7 +61,8 @@ dylint_linting::impl_pre_expansion_lint! { AvoidPanicError::default(), { name: "Avoid panic! macro", - long_message: "The use of the panic! macro to stop execution when a condition is not met is useful for testing and prototyping but should be avoided in production code. Using Result as the return type for functions that can fail is the idiomatic way to handle errors in Rust. ", + long_message: "Using panic! in functions that return Result defeats the purpose of error handling. \ + Consider propagating the error using ? or return Err() instead.", severity: "Enhancement", help: "https://coinfabrik.github.io/scout-soroban/docs/detectors/avoid-panic-error", vulnerability_class: "Validations and error handling", @@ -75,123 +75,120 @@ pub struct AvoidPanicError { } impl EarlyLintPass for AvoidPanicError { - fn check_item(&mut self, _cx: &EarlyContext, item: &Item) { - match (is_test_item(item), self.in_test_span) { - (true, None) => self.in_test_span = Some(item.span), - (true, Some(test_span)) => { - if !test_span.contains(item.span) { - self.in_test_span = Some(item.span); + fn check_item(&mut self, cx: &EarlyContext, item: &Item) { + if is_test_item(item) { + self.in_test_span = Some(item.span); + return; + } + + if let Some(test_span) = self.in_test_span { + if !test_span.contains(item.span) { + self.in_test_span = None; + } else { + return; + } + } + + match &item.kind { + ItemKind::Impl(impl_item) => { + for assoc_item in &impl_item.items { + if let AssocItemKind::Fn(fn_item) = &assoc_item.kind { + self.check_function( + cx, + &fn_item.sig.decl.output, + fn_item.body.as_ref().unwrap(), + ); + } } } - (false, None) => {} - (false, Some(test_span)) => { - if !test_span.contains(item.span) { - self.in_test_span = None; + ItemKind::Fn(fn_item) => { + self.check_function(cx, &fn_item.sig.decl.output, fn_item.body.as_ref().unwrap()); + } + ItemKind::Mod(_, ModKind::Loaded(items, _, _)) => { + for item in items { + self.check_item(cx, item); } } - }; - } - - fn check_stmt(&mut self, cx: &EarlyContext<'_>, stmt: &rustc_ast::Stmt) { - if_chain! { - if !self.in_test_item(); - if let StmtKind::MacCall(mac) = &stmt.kind; - then { - check_macro_call(cx, stmt.span, &mac.mac) + ItemKind::Trait(trait_item) => { + for item in &trait_item.items { + if let AssocItemKind::Fn(fn_item) = &item.kind { + self.check_function( + cx, + &fn_item.sig.decl.output, + fn_item.body.as_ref().unwrap(), + ); + } + } } + _ => {} } } +} - fn check_expr(&mut self, cx: &EarlyContext, expr: &Expr) { - if_chain! { - if !self.in_test_item(); - if let ExprKind::MacCall(mac) = &expr.kind; - then { - check_macro_call(cx, expr.span, mac) - } +impl AvoidPanicError { + fn check_function(&self, cx: &EarlyContext, output: &FnRetTy, body: &Block) { + if is_result_type(output) { + let mut visitor = PanicVisitor { cx }; + visitor.visit_block(body); } } } -fn check_macro_call(cx: &EarlyContext, span: Span, mac: &P) { - if_chain! { - if mac.path == sym!(panic); - if let [TokenTree::Token(token, _)] = mac - .args - .tokens - .clone() - .trees() - .collect::>() - .as_slice(); - if let TokenKind::Literal(lit) = token.kind; - if lit.kind == LitKind::Str; - then { - span_lint_and_help( - cx, - AVOID_PANIC_ERROR, - span, - LINT_MESSAGE, - None, - format!("You could use instead an Error enum and then 'return Err(Error::{})'", capitalize_err_msg(lit.symbol.as_str()).replace(' ', "")) - ) +struct PanicVisitor<'a, 'tcx> { + cx: &'a EarlyContext<'tcx>, +} + +impl<'a, 'tcx> Visitor<'tcx> for PanicVisitor<'a, 'tcx> { + fn visit_expr(&mut self, expr: &'tcx Expr) { + if let ExprKind::MacCall(mac) = &expr.kind { + check_macro_call(self.cx, expr.span, mac); } + walk_expr(self, expr); + } +} + +fn check_macro_call(cx: &EarlyContext, span: Span, mac: &MacCall) { + if mac.path == sym::panic { + let suggestion = "Consider using '?' to propagate errors or 'return Err()' to return early with an error"; + span_lint_and_help(cx, AVOID_PANIC_ERROR, span, LINT_MESSAGE, None, suggestion); } } fn is_test_item(item: &Item) -> bool { item.attrs.iter().any(|attr| { - // Find #[cfg(all(test, feature = "e2e-tests"))] - if_chain!( - if let AttrKind::Normal(normal) = &attr.kind; - if let AttrArgs::Delimited(delim_args) = &normal.item.args; - if is_test_token_present(&delim_args.tokens); - then { - return true; - } - ); - - // Find unit or integration tests - if attr.has_name(sym::test) { - return true; - } - - if_chain! { - if attr.has_name(sym::cfg); - if let Some(items) = attr.meta_item_list(); - if let [item] = items.as_slice(); - if let Some(feature_item) = item.meta_item(); - if feature_item.has_name(sym::test); - then { - return true; - } - } - - false + attr.has_name(sym::test) + || (attr.has_name(sym::cfg) + && attr.meta_item_list().map_or(false, |list| { + list.iter().any(|item| item.has_name(sym::test)) + })) + || matches!( + &attr.kind, + AttrKind::Normal(normal) if is_test_token_present(&normal.item.args) + ) }) } -impl AvoidPanicError { - fn in_test_item(&self) -> bool { - self.in_test_span.is_some() +fn is_test_token_present(args: &AttrArgs) -> bool { + if let AttrArgs::Delimited(delim_args) = args { + delim_args.tokens.trees().any( + |tree| matches!(tree, TokenTree::Token(token, _) if token.is_ident_named(sym::test)), + ) + } else { + false } } -fn capitalize_err_msg(s: &str) -> String { - s.split_whitespace() - .map(|word| { - let mut chars = word.chars(); - match chars.next() { - None => String::new(), - Some(f) => f.to_uppercase().collect::() + chars.as_str(), +fn is_result_type(output: &FnRetTy) -> bool { + match output { + FnRetTy::Default(_) => false, + FnRetTy::Ty(ty) => { + if let TyKind::Path(None, path) = &ty.kind { + path.segments + .last() + .map_or(false, |seg| seg.ident.name == sym::Result) + } else { + false } - }) - .collect::>() - .join(" ") -} - -fn is_test_token_present(token_stream: &TokenStream) -> bool { - token_stream.trees().any(|tree| match tree { - TokenTree::Token(token, _) => token.is_ident_named(sym::test), - TokenTree::Delimited(_, _, _, token_stream) => is_test_token_present(token_stream), - }) + } + } } diff --git a/detectors/dynamic-instance-storage/Cargo.toml b/detectors/dynamic-storage/Cargo.toml similarity index 89% rename from detectors/dynamic-instance-storage/Cargo.toml rename to detectors/dynamic-storage/Cargo.toml index 3060bf9f..c564b4ff 100644 --- a/detectors/dynamic-instance-storage/Cargo.toml +++ b/detectors/dynamic-storage/Cargo.toml @@ -1,6 +1,6 @@ [package] edition = "2021" -name = "dynamic-instance-storage" +name = "dynamic-storage" version = "0.1.0" [lib] diff --git a/detectors/dynamic-instance-storage/src/lib.rs b/detectors/dynamic-storage/src/lib.rs similarity index 74% rename from detectors/dynamic-instance-storage/src/lib.rs rename to detectors/dynamic-storage/src/lib.rs index 9319f082..151a0c52 100644 --- a/detectors/dynamic-instance-storage/src/lib.rs +++ b/detectors/dynamic-storage/src/lib.rs @@ -15,26 +15,26 @@ use rustc_middle::ty::{Ty, TyKind}; use rustc_span::{def_id::LocalDefId, Span, Symbol}; use utils::{get_node_type_opt, is_soroban_storage, SorobanStorageType}; -const LINT_MESSAGE: &str = "This function may lead to excessive instance storage growth, which could increase execution costs or potentially cause DoS"; +const LINT_MESSAGE: &str = "Using dynamic types in instance or persistent storage can lead to unnecessary growth or storage-related vulnerabilities."; dylint_linting::impl_late_lint! { - pub DYNAMIC_INSTANCE_STORAGE, + pub DYNAMIC_STORAGE, Warn, LINT_MESSAGE, - DynamicInstanceStorage, + DynamicStorage, { - name: "Dynamic Instance Storage Analyzer", - long_message: "Detects potential misuse of instance storage that could lead to unnecessary growth or storage-related vulnerabilities.", + name: "Dynamic Storage Analyzer", + long_message: "Using dynamic types in instance or persistent storage can lead to unnecessary growth or storage-related vulnerabilities.", severity: "Warning", - help: "https://coinfabrik.github.io/scout-soroban/docs/detectors/dynamic-instance-storage", + help: "https://coinfabrik.github.io/scout-soroban/docs/detectors/dynamic-storage", vulnerability_class: "Resource Management", } } #[derive(Default)] -struct DynamicInstanceStorage; +struct DynamicStorage; -impl<'tcx> LateLintPass<'tcx> for DynamicInstanceStorage { +impl<'tcx> LateLintPass<'tcx> for DynamicStorage { fn check_fn( &mut self, cx: &LateContext<'tcx>, @@ -48,30 +48,31 @@ impl<'tcx> LateLintPass<'tcx> for DynamicInstanceStorage { return; } - let mut storage_warn_visitor = DynamicInstanceStorageVisitor { cx }; + let mut storage_warn_visitor = DynamicStorageVisitor { cx }; storage_warn_visitor.visit_body(body); } } -struct DynamicInstanceStorageVisitor<'a, 'tcx> { +struct DynamicStorageVisitor<'a, 'tcx> { cx: &'a LateContext<'tcx>, } -impl<'a, 'tcx> Visitor<'tcx> for DynamicInstanceStorageVisitor<'a, 'tcx> { +impl<'a, 'tcx> Visitor<'tcx> for DynamicStorageVisitor<'a, 'tcx> { fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { if_chain! { // Detect calls to `set` method if let ExprKind::MethodCall(path, receiver, args, _) = &expr.kind; if path.ident.name == Symbol::intern("set"); - // Get the type of the receiver and check if it is an instance storage + // Get the type of the receiver and check if it is an instance or persistent storage if let Some(receiver_ty) = get_node_type_opt(self.cx, &receiver.hir_id); - if is_soroban_storage(self.cx, receiver_ty, SorobanStorageType::Instance); + if is_soroban_storage(self.cx, receiver_ty, SorobanStorageType::Instance) + || is_soroban_storage(self.cx, receiver_ty, SorobanStorageType::Persistent); // Check if the value being set is a dynamic type if args.len() >= 2; if let Some(value_type) = get_node_type_opt(self.cx, &args[1].hir_id); if is_dynamic_type(self.cx, &value_type); then { - span_lint(self.cx, DYNAMIC_INSTANCE_STORAGE, expr.span, LINT_MESSAGE) + span_lint(self.cx, DYNAMIC_STORAGE, expr.span, LINT_MESSAGE) } } diff --git a/detectors/storage-change-events/Cargo.toml b/detectors/storage-change-events/Cargo.toml new file mode 100644 index 00000000..97569970 --- /dev/null +++ b/detectors/storage-change-events/Cargo.toml @@ -0,0 +1,16 @@ +[package] +edition = "2021" +name = "storage-change-events" +version = "0.1.0" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +clippy_utils = { workspace = true } +dylint_linting = { workspace = true } +if_chain = { workspace = true } +utils = { workspace = true } + +[package.metadata.rust-analyzer] +rustc_private = true diff --git a/detectors/storage-change-events/src/lib.rs b/detectors/storage-change-events/src/lib.rs new file mode 100644 index 00000000..b119ff21 --- /dev/null +++ b/detectors/storage-change-events/src/lib.rs @@ -0,0 +1,204 @@ +#![feature(rustc_private)] + +extern crate rustc_hir; +extern crate rustc_middle; +extern crate rustc_span; + +use clippy_utils::diagnostics::span_lint_and_help; + +use rustc_hir::{ + intravisit::{walk_expr, Visitor}, + Expr, ExprKind, +}; +use rustc_lint::{LateContext, LateLintPass}; + +use rustc_span::Span; + +use std::collections::HashMap; +use std::collections::HashSet; +use std::vec; +use utils::{is_soroban_function, FunctionCallVisitor}; + +use rustc_span::def_id::DefId; + +const LINT_MESSAGE: &str = "Consider emiting an event when storage is modified"; + +dylint_linting::impl_late_lint! { + pub STORAGE_CHANGE_EVENTS, + Warn, + "", + StorageChangeEvents::default(), + { + name: "Storage Changed without Emiting an Event", + long_message: "", + severity: "", + help: "", + vulnerability_class: "", + } +} + +#[derive(Default)] +struct StorageChangeEvents { + function_call_graph: HashMap>, + checked_functions: HashSet, + eventless_storage_changers: HashSet, + defids_with_events: HashSet, +} + +/// Used to verify if, starting from a specific parent in the call graph, an event is emitted at any point of the flow. +/// # Params: +/// - fcg: function call graph +/// - parent: the item from which the analysis starts. +/// - check_against: a HashSet that is used to compare the defids. This HashSet is supposed to contain all the defids of the functions that emit events (collected by the `visit_expr` and `check_func` functions). +fn check_events_children( + fcg: &HashMap>, + parent: &DefId, + check_against: &HashSet, +) -> bool { + if check_against.contains(parent) { + return true; + } + let children = fcg.get(parent); + if children.is_some() { + for c in children.unwrap() { + if check_against.contains(c) || check_events_children(fcg, c, check_against) { + return true; + } + } + } + false +} + +/// Used to verify if, starting from a specific parent in the call graph, a function that sets storage in a considered "unsafe" way is called in any part of its flow. +/// # Params: +/// - fcg: function call graph +/// - func: the defid from which the analysis starts. +/// - unsafe_set_storage: a HashSet that is used to compare the defids. This HashSet is supposed to contain all the defids of the functions that are considered "unsafe storage setters". +fn check_storage_setters_calls( + fcg: &HashMap>, + func: &DefId, + unsafe_set_storage: &HashSet, +) -> bool { + if unsafe_set_storage.contains(func) { + return true; + } + let children = fcg.get(func); + if children.is_some() { + for c in children.unwrap() { + if unsafe_set_storage.contains(c) + || check_storage_setters_calls(fcg, c, unsafe_set_storage) + { + return true; + } + } + } + false +} + +impl<'tcx> LateLintPass<'tcx> for StorageChangeEvents { + fn check_crate_post(&mut self, cx: &LateContext<'tcx>) { + // Emit the alerts for the considered "unsafe" functions. + for func in self.function_call_graph.keys() { + // Only take into account those functions that are public and exposed in a soroban contract (entrypoints that can be called externally). We do not advise on functions that are used auxiliarily. + if is_soroban_function(cx, &self.checked_functions, func) { + // Verify if the function itself or the ones it calls (directly or indirectly) emit an event at any point of the flow. + let emits_event_in_flow = check_events_children( + &self.function_call_graph, + func, + &self.defids_with_events, + ); + + // Verify if the function itself or the ones it calls (directly or indirectly) call an unsafe storage setter at any point of the flow. + let calls_unsafe_storage_setter = check_storage_setters_calls( + &self.function_call_graph, + func, + &self.eventless_storage_changers, + ); + + // If both conditions are met, emit an warning. + if !emits_event_in_flow && calls_unsafe_storage_setter { + span_lint_and_help( + cx, + STORAGE_CHANGE_EVENTS, + cx.tcx.hir().span_if_local(*func).unwrap(), + LINT_MESSAGE, + /* cx.tcx.hir().span_if_local(r) */ None, + "", + ); + } + } + } + } + + fn check_fn( + &mut self, + cx: &LateContext<'tcx>, + _: rustc_hir::intravisit::FnKind<'tcx>, + _fn_decl: &'tcx rustc_hir::FnDecl<'tcx>, + body: &'tcx rustc_hir::Body<'tcx>, + span: Span, + local_def_id: rustc_span::def_id::LocalDefId, + ) { + let def_id = local_def_id.to_def_id(); + self.checked_functions.insert(cx.tcx.def_path_str(def_id)); + + if span.from_expansion() { + return; + } + + let mut function_call_visitor = + FunctionCallVisitor::new(cx, def_id, &mut self.function_call_graph); + function_call_visitor.visit_body(body); + + let mut storage_change_events_visitor = StorageChangeEventsVisitor { + cx, + is_storage_changer: false, + emits_event: false, + }; + + storage_change_events_visitor.visit_body(body); + + // If the function modifies the storage and does not emit event, we keep record of its defid as an eventless storage changer. + if storage_change_events_visitor.is_storage_changer + && !storage_change_events_visitor.emits_event + { + self.eventless_storage_changers.insert(def_id); + } + + // If the function emits an event, we storage its defid. + if storage_change_events_visitor.emits_event { + self.defids_with_events.insert(def_id); + } + } +} + +struct StorageChangeEventsVisitor<'a, 'tcx> { + cx: &'a LateContext<'tcx>, + is_storage_changer: bool, + emits_event: bool, +} + +impl<'a, 'tcx> Visitor<'tcx> for StorageChangeEventsVisitor<'a, 'tcx> { + fn visit_expr(&mut self, expr: &'tcx Expr<'_>) { + if let ExprKind::MethodCall(path, receiver, _, _) = expr.kind { + let name = path.ident.name.as_str(); + + let receiver_type = self.cx.typeck_results().node_type(receiver.hir_id); + + // verify if it is an event emission + if name == "events" { + self.emits_event = true; + } + + // verify if it is a storage change + if (name == "set" || name == "update" || name == "remove" || name == "try_update") + && (receiver_type.to_string() == "soroban_sdk::storage::Instance" + || receiver_type.to_string() == "soroban_sdk::storage::Persistent" + || receiver_type.to_string() == "soroban_sdk::storage::Temporary") + { + self.is_storage_changer = true; + } + } + walk_expr(self, expr); + } +} diff --git a/detectors/token-interface-events/Cargo.toml b/detectors/token-interface-events/Cargo.toml new file mode 100644 index 00000000..644b6b4d --- /dev/null +++ b/detectors/token-interface-events/Cargo.toml @@ -0,0 +1,16 @@ +[package] +edition = "2021" +name = "token-interface-events" +version = "0.1.0" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +clippy_utils = { workspace = true } +dylint_linting = { workspace = true } +if_chain = { workspace = true } +utils = { workspace = true } + +[package.metadata.rust-analyzer] +rustc_private = true diff --git a/detectors/token-interface-events/src/lib.rs b/detectors/token-interface-events/src/lib.rs new file mode 100644 index 00000000..69e3d264 --- /dev/null +++ b/detectors/token-interface-events/src/lib.rs @@ -0,0 +1,194 @@ +#![feature(rustc_private)] + +extern crate rustc_hir; +extern crate rustc_middle; +extern crate rustc_span; + +use clippy_utils::diagnostics::span_lint_and_help; + +use rustc_hir::{ + intravisit::{walk_expr, Visitor}, + Expr, ExprKind, +}; +use rustc_lint::{LateContext, LateLintPass}; + +use rustc_span::Span; + +use std::collections::HashMap; +use std::collections::HashSet; +use std::vec; +use utils::{verify_token_interface_function, FunctionCallVisitor}; + +use rustc_span::def_id::DefId; + +const LINT_MESSAGE: &str = "This function belongs to the Token Interface and should emit an event"; + +dylint_linting::impl_late_lint! { + pub TOKEN_INTERFACE_EVENTS, + Warn, + "", + TokenInterfaceEvents::default(), + { + name: "Storage Changed without Emiting an Event in Token Interface implementations", + long_message: " It can originate a problem when a canonical function does not emit an event expected by the contract's clients.", + severity: "", + help: "", + vulnerability_class: "", + } +} + +#[derive(Default)] +struct TokenInterfaceEvents { + function_call_graph: HashMap>, + checked_functions: HashSet, + //eventless_storage_changers: HashSet, + defids_with_events: HashSet, + canonical_funcs_def_id: HashSet, + impl_token_interface_trait: bool, +} + +/// Used to verify if, starting from a specific parent in the call graph, an event is emitted at any point of the flow. +/// # Params: +/// - fcg: function call graph +/// - parent: the item from which the analysis starts. +/// - check_against: a HashSet that is used to compare the defids. This HashSet is supposed to contain all the defids of the functions that emit events (collected by the `visit_expr` and `check_func` functions). +fn check_events_children( + fcg: &HashMap>, + parent: &DefId, + check_against: &HashSet, +) -> bool { + if check_against.contains(parent) { + return true; + } + let children = fcg.get(parent); + if children.is_some() { + for c in children.unwrap() { + if check_against.contains(c) || check_events_children(fcg, c, check_against) { + return true; + } + } + } + false +} + +impl<'tcx> LateLintPass<'tcx> for TokenInterfaceEvents { + fn check_item(&mut self, cx: &LateContext<'tcx>, item: &'tcx rustc_hir::Item<'tcx>) { + if let rustc_hir::ItemKind::Impl(impl_block) = item.kind { + if let Some(trait_ref) = impl_block.of_trait { + let trait_def_id = trait_ref.path.res.def_id(); + let trait_name = cx.tcx.def_path_str(trait_def_id); + + if trait_name == "soroban_sdk::token::TokenInterface" { + self.impl_token_interface_trait = true; + } + } + } + } + + fn check_crate_post(&mut self, cx: &LateContext<'tcx>) { + let functions_that_emit_events: [String; 6] = [ + "burn".to_string(), + "approve".to_string(), + "transfer_from".to_string(), + "burn_from".to_string(), + "transfer".to_string(), + "mint".to_string(), + ]; + // Verify if the contract implements the token interface trait. + if !self.impl_token_interface_trait { + return; + } + + // Emit the alerts for the considered "unsafe" functions. + for func in self.function_call_graph.keys() { + // Only take into account those functions that are public and exposed in a soroban contract (entrypoints that can be called externally). We do not advise on functions that are used auxiliarily. + if self.canonical_funcs_def_id.contains(func) + && functions_that_emit_events.contains( + &cx.tcx + .def_path_str(func) + .split("::") + .last() + .unwrap() + .to_string(), + ) + { + // Verify if the function itself or the ones it calls (directly or indirectly) emit an event at any point of the flow. + let emits_event_in_flow = check_events_children( + &self.function_call_graph, + func, + &self.defids_with_events, + ); + + // If both conditions are met, emit an warning. + if !emits_event_in_flow { + span_lint_and_help( + cx, + TOKEN_INTERFACE_EVENTS, + cx.tcx.hir().span_if_local(*func).unwrap(), + LINT_MESSAGE, + /* cx.tcx.hir().span_if_local(r) */ None, + "", + ); + } + } + } + } + + fn check_fn( + &mut self, + cx: &LateContext<'tcx>, + _: rustc_hir::intravisit::FnKind<'tcx>, + fn_decl: &'tcx rustc_hir::FnDecl<'tcx>, + body: &'tcx rustc_hir::Body<'tcx>, + span: Span, + local_def_id: rustc_span::def_id::LocalDefId, + ) { + let def_id = local_def_id.to_def_id(); + self.checked_functions.insert(cx.tcx.def_path_str(def_id)); + + if span.from_expansion() { + return; + } + + let fn_name = cx.tcx.def_path_str(def_id); + + let mut function_call_visitor = + FunctionCallVisitor::new(cx, def_id, &mut self.function_call_graph); + function_call_visitor.visit_body(body); + + // If the function is part of the token interface, I store its defid. + if verify_token_interface_function(fn_name.clone(), fn_decl.inputs, fn_decl.output) { + self.canonical_funcs_def_id.insert(def_id); + } + let mut token_interface_events_visitor = TokenInterfaceEventsVisitor { + _cx: cx, + emits_event: false, + }; + + token_interface_events_visitor.visit_body(body); + + // If the function emits an event, we storage its defid. + if token_interface_events_visitor.emits_event { + self.defids_with_events.insert(def_id); + } + } +} + +struct TokenInterfaceEventsVisitor<'a, 'tcx> { + _cx: &'a LateContext<'tcx>, + emits_event: bool, +} + +impl<'a, 'tcx> Visitor<'tcx> for TokenInterfaceEventsVisitor<'a, 'tcx> { + fn visit_expr(&mut self, expr: &'tcx Expr<'_>) { + if let ExprKind::MethodCall(path, _receiver, _, _) = expr.kind { + let name = path.ident.name.as_str(); + + // verify if it is an event emission + if name == "events" { + self.emits_event = true; + } + } + walk_expr(self, expr); + } +} diff --git a/detectors/unsafe-expect/src/lib.rs b/detectors/unsafe-expect/src/lib.rs index f575eff5..42cd146f 100644 --- a/detectors/unsafe-expect/src/lib.rs +++ b/detectors/unsafe-expect/src/lib.rs @@ -236,7 +236,6 @@ impl UnsafeExpectVisitor<'_, '_> { .get_expect_info(receiver) .map_or(true, |id| !self.checked_exprs.contains(&id)); } - args.iter().any(|arg| self.contains_unsafe_method_call(arg)) || self.contains_unsafe_method_call(receiver) } @@ -249,40 +248,54 @@ impl UnsafeExpectVisitor<'_, '_> { _ => false, } } -} -impl<'a, 'tcx> Visitor<'tcx> for UnsafeExpectVisitor<'a, 'tcx> { - fn visit_local(&mut self, local: &'tcx rustc_hir::LetStmt<'tcx>) -> Self::Result { - if let Some(init) = local.init { - match init.kind { - ExprKind::MethodCall(path_segment, receiver, args, _) => { - if self.is_method_call_unsafe(path_segment, receiver, args) { - span_lint_and_help( - self.cx, - UNSAFE_EXPECT, - local.span, - LINT_MESSAGE, - None, - "Please, use a custom error instead of `expect`", - ); - } + fn check_expr_for_unsafe_expect(&mut self, expr: &Expr) { + match &expr.kind { + ExprKind::MethodCall(path_segment, receiver, args, _) => { + if self.is_method_call_unsafe(path_segment, receiver, args) { + span_lint_and_help( + self.cx, + UNSAFE_EXPECT, + expr.span, + LINT_MESSAGE, + None, + "Please, use a custom error instead of `expect`", + ); } - ExprKind::Call(func, args) => { - if let ExprKind::Path(QPath::Resolved(_, path)) = func.kind { - let is_some_or_ok = path - .segments - .iter() - .any(|segment| matches!(segment.ident.name, sym::Some | sym::Ok)); - let all_literals = args - .iter() - .all(|arg| self.is_literal_or_composed_of_literals(arg)); - if is_some_or_ok && all_literals { - self.checked_exprs.insert(local.pat.hir_id); - } + } + ExprKind::Call(func, args) => { + if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind { + let is_some_or_ok = path + .segments + .iter() + .any(|segment| matches!(segment.ident.name, sym::Some | sym::Ok)); + let all_literals = args + .iter() + .all(|arg| self.is_literal_or_composed_of_literals(arg)); + if is_some_or_ok && all_literals { + self.checked_exprs.insert(expr.hir_id); + return; } } - _ => {} + // Check arguments for unsafe expect + for arg in args.iter() { + self.check_expr_for_unsafe_expect(arg); + } + } + ExprKind::Tup(exprs) | ExprKind::Array(exprs) => { + for expr in exprs.iter() { + self.check_expr_for_unsafe_expect(expr); + } } + _ => {} + } + } +} + +impl<'a, 'tcx> Visitor<'tcx> for UnsafeExpectVisitor<'a, 'tcx> { + fn visit_local(&mut self, local: &'tcx rustc_hir::LetStmt<'tcx>) -> Self::Result { + if let Some(init) = local.init { + self.check_expr_for_unsafe_expect(init); } } @@ -314,25 +327,7 @@ impl<'a, 'tcx> Visitor<'tcx> for UnsafeExpectVisitor<'a, 'tcx> { } // If we find an unsafe `expect`, we raise a warning - if_chain! { - if let ExprKind::MethodCall(path_segment, receiver, _, _) = &expr.kind; - if path_segment.ident.name == sym::expect; - then { - let receiver_hir_id = self.get_expect_info(receiver); - // If the receiver is `None`, then we asume that the `expect` is unsafe - let is_checked_safe = receiver_hir_id.map_or(false, |id| self.checked_exprs.contains(&id)); - if !is_checked_safe { - span_lint_and_help( - self.cx, - UNSAFE_EXPECT, - expr.span, - LINT_MESSAGE, - None, - "Please, use a custom error instead of `expect`", - ); - } - } - } + self.check_expr_for_unsafe_expect(expr); walk_expr(self, expr); } diff --git a/docs/docs/detectors/11-dos-unbounded-operation.md b/docs/docs/detectors/11-dos-unbounded-operation.md index 7a4a8299..49bf2b5b 100644 --- a/docs/docs/detectors/11-dos-unbounded-operation.md +++ b/docs/docs/detectors/11-dos-unbounded-operation.md @@ -1,43 +1,56 @@ # DoS unbounded operation -### What it does +## Description -This detector checks that when using for or while loops, their conditions limit the execution to a constant number of iterations. +- Category: `Denial of Service` +- Severity: `Medium` +- Detector: [`dos-unbounded-operation`](https://github.com/CoinFabrik/scout-soroban/tree/main/detectors/dos-unbounded-operation) +- Test Cases: [`dos-unbounded-operation-1`](https://github.com/CoinFabrik/scout-soroban/tree/main/test-cases/dos-unbounded-operation/dos-unbounded-operation-1) [`dos-unbounded-operation-2`](https://github.com/CoinFabrik/scout-soroban/tree/main/test-cases/dos-unbounded-operation/dos-unbounded-operation-2) [`dos-unbounded-operation-3`](https://github.com/CoinFabrik/scout-soroban/tree/main/test-cases/dos-unbounded-operation/dos-unbounded-operation-3) + +Each block in a Stellar Blockchain has an upper bound on the amount of gas that can be spent, and thus the amount computation that can be done. This is the Block Gas Limit. -### Why is this bad? +## Why is this bad? -If the number of iterations is not limited to a specific range, it could potentially cause out of gas exceptions. +If the number of iterations is not limited to a specific range, it could potentially cause out of gas exceptions. If this happens, gas will leak, the transaction will fail, and there will be a risk of a potential attack on the contract. -### Known problems +## Issue example -False positives are to be expected when using variables that can only be set using controlled flows that limit the values within acceptable ranges. +In the following example, a contract has a function ´unsafe_loop_with_array´, which contains a for loop that iterates over a range of numbers from 0 to the lenght of the array ´unknown_array´. The issue is that if the length of the array is extremely large, it would cause the loop to execute many times, potentially leading to an unusable state of the contract. -### Example +Consider the following `Soroban` contract: ```rust -pub fn unrestricted_loop(for_loop_count: u64) -> u64 { - let mut count = 0; - for i in 0..for_loop_count { - count += i; + pub fn unsafe_loop_with_array(unknown_array: BytesN<8>) -> u32 { + let mut sum = 0; + for i in 0..unknown_array.len() { + sum += i; + } + sum } - count -} ``` +The code example can be found [here](https://github.com/CoinFabrik/scout-soroban/tree/main/test-cases/dos-unbounded-operation/dos-unbounded-operation-3/vulnerable-example). -Use instead: -```rust -const FIXED_COUNT: u64 = 1000; +## Remediated example -pub fn restricted_loop_with_const() -> u64 { - let mut sum = 0; - for i in 0..FIXED_COUNT { - sum += i; +To solve this, instead of relying on an external parameter, we should introduce a known value directly into the loop. +```rust + pub fn safe_loop_with_array() -> u64 { + let mut sum = 0; + let known_array = [0; 8]; + for i in 0..known_array.len() { + sum += i; + } + sum as u64 } - sum -} ``` -### Implementation +The remediated code example can be found [here](https://github.com/CoinFabrik/scout-soroban/tree/main/test-cases/dos-unbounded-operation/dos-unbounded-operation-3/remediated-example). + +## How is it detected? + +This detector checks that when using for or while loops, their conditions limit the execution to a constant number of iterations. + + -The detector's implementation can be found at [this link](https://github.com/CoinFabrik/scout-soroban/tree/main/detectors/dos-unbounded-operation). \ No newline at end of file + diff --git a/test-cases/avoid-panic-error/avoid-panic-error-1/remediated-example/src/lib.rs b/test-cases/avoid-panic-error/avoid-panic-error-1/remediated-example/src/lib.rs index 4a50c7d8..a0a0cda9 100644 --- a/test-cases/avoid-panic-error/avoid-panic-error-1/remediated-example/src/lib.rs +++ b/test-cases/avoid-panic-error/avoid-panic-error-1/remediated-example/src/lib.rs @@ -9,18 +9,18 @@ pub struct AvoidPanicError; #[contracterror] #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] #[repr(u32)] -pub enum Error { +pub enum PanicError { OverflowError = 1, } #[contractimpl] impl AvoidPanicError { - pub fn add(env: Env, value: u32) -> Result { + pub fn add(env: Env, value: u32) -> Result { let storage = env.storage().instance(); let mut count: u32 = storage.get(&COUNTER).unwrap_or(0); match count.checked_add(value) { Some(value) => count = value, - None => return Err(Error::OverflowError), + None => return Err(PanicError::OverflowError), } storage.set(&COUNTER, &count); storage.extend_ttl(100, 100); @@ -32,7 +32,7 @@ impl AvoidPanicError { mod tests { use soroban_sdk::Env; - use crate::{AvoidPanicError, AvoidPanicErrorClient, Error}; + use crate::{AvoidPanicError, AvoidPanicErrorClient, PanicError}; #[test] fn add() { @@ -64,6 +64,6 @@ mod tests { let overflow = client.try_add(&1); // Then - assert_eq!(overflow, Err(Ok(Error::OverflowError))); + assert_eq!(overflow, Err(Ok(PanicError::OverflowError))); } } diff --git a/test-cases/avoid-panic-error/avoid-panic-error-1/vulnerable-example/src/lib.rs b/test-cases/avoid-panic-error/avoid-panic-error-1/vulnerable-example/src/lib.rs index 7119ed1d..0aa1758f 100644 --- a/test-cases/avoid-panic-error/avoid-panic-error-1/vulnerable-example/src/lib.rs +++ b/test-cases/avoid-panic-error/avoid-panic-error-1/vulnerable-example/src/lib.rs @@ -1,14 +1,21 @@ #![no_std] -use soroban_sdk::{contract, contractimpl, symbol_short, Env, Symbol}; + +use soroban_sdk::{contract, contracterror, contractimpl, symbol_short, Env, Symbol}; const COUNTER: Symbol = symbol_short!("COUNTER"); #[contract] pub struct AvoidPanicError; +#[contracterror] +#[derive(Copy, Clone)] +pub enum PanicError { + Overflow = 1, +} + #[contractimpl] impl AvoidPanicError { - pub fn add(env: Env, value: u32) -> u32 { + pub fn add(env: Env, value: u32) -> Result { let storage = env.storage().instance(); let mut count: u32 = storage.get(&COUNTER).unwrap_or(0); match count.checked_add(value) { @@ -17,7 +24,7 @@ impl AvoidPanicError { } storage.set(&COUNTER, &count); storage.extend_ttl(100, 100); - count + Ok(count) } } diff --git a/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/Cargo.toml b/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/Cargo.toml new file mode 100644 index 00000000..e64d4f1f --- /dev/null +++ b/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/Cargo.toml @@ -0,0 +1,13 @@ +[package] +edition = "2021" +name = "avoid-panic-error-remediated-2" +version = "0.1.0" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +soroban-sdk = { workspace = true } + +[features] +testutils = ["soroban-sdk/testutils"] diff --git a/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/src/lib.rs b/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/src/lib.rs new file mode 100644 index 00000000..6093d1f5 --- /dev/null +++ b/test-cases/avoid-panic-error/avoid-panic-error-2/remediated-example/src/lib.rs @@ -0,0 +1,63 @@ +#![no_std] +use soroban_sdk::{contract, contractimpl, symbol_short, Env, Symbol}; + +const COUNTER: Symbol = symbol_short!("COUNTER"); + +#[contract] +pub struct AvoidPanicError; + +#[contractimpl] +impl AvoidPanicError { + pub fn add(env: Env, value: u32) -> u32 { + let storage = env.storage().instance(); + let mut count: u32 = storage.get(&COUNTER).unwrap_or(0); + match count.checked_add(value) { + Some(value) => count = value, + None => panic!("Overflow error"), + } + storage.set(&COUNTER, &count); + storage.extend_ttl(100, 100); + count + } +} + +#[cfg(test)] +mod tests { + use soroban_sdk::Env; + + use crate::{AvoidPanicError, AvoidPanicErrorClient}; + + #[test] + fn add() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, AvoidPanicError); + let client = AvoidPanicErrorClient::new(&env, &contract_id); + + // When + let first_increment = client.try_add(&1); + let second_increment = client.try_add(&2); + let third_increment = client.try_add(&3); + + // Then + assert_eq!(first_increment, Ok(Ok(1))); + assert_eq!(second_increment, Ok(Ok(3))); + assert_eq!(third_increment, Ok(Ok(6))); + } + + #[test] + #[should_panic(expected = "Overflow error")] + fn overflow() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, AvoidPanicError); + let client = AvoidPanicErrorClient::new(&env, &contract_id); + + // When + client.add(&u32::MAX); + client.add(&1); + + // Then + // panic + } +} diff --git a/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/Cargo.toml b/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/Cargo.toml new file mode 100644 index 00000000..c690ea8e --- /dev/null +++ b/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/Cargo.toml @@ -0,0 +1,13 @@ +[package] +edition = "2021" +name = "avoid-panic-error-vulnerable-2" +version = "0.1.0" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +soroban-sdk = { workspace = true } + +[features] +testutils = ["soroban-sdk/testutils"] diff --git a/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/src/lib.rs b/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/src/lib.rs new file mode 100644 index 00000000..0aa1758f --- /dev/null +++ b/test-cases/avoid-panic-error/avoid-panic-error-2/vulnerable-example/src/lib.rs @@ -0,0 +1,70 @@ +#![no_std] + +use soroban_sdk::{contract, contracterror, contractimpl, symbol_short, Env, Symbol}; + +const COUNTER: Symbol = symbol_short!("COUNTER"); + +#[contract] +pub struct AvoidPanicError; + +#[contracterror] +#[derive(Copy, Clone)] +pub enum PanicError { + Overflow = 1, +} + +#[contractimpl] +impl AvoidPanicError { + pub fn add(env: Env, value: u32) -> Result { + let storage = env.storage().instance(); + let mut count: u32 = storage.get(&COUNTER).unwrap_or(0); + match count.checked_add(value) { + Some(value) => count = value, + None => panic!("Overflow error"), + } + storage.set(&COUNTER, &count); + storage.extend_ttl(100, 100); + Ok(count) + } +} + +#[cfg(test)] +mod tests { + use soroban_sdk::Env; + + use crate::{AvoidPanicError, AvoidPanicErrorClient}; + + #[test] + fn add() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, AvoidPanicError); + let client = AvoidPanicErrorClient::new(&env, &contract_id); + + // When + let first_increment = client.add(&1); + let second_increment = client.add(&2); + let third_increment = client.add(&3); + + // Then + assert_eq!(first_increment, 1); + assert_eq!(second_increment, 3); + assert_eq!(third_increment, 6); + } + + #[test] + #[should_panic(expected = "Overflow error")] + fn overflow() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, AvoidPanicError); + let client = AvoidPanicErrorClient::new(&env, &contract_id); + + // When + client.add(&u32::MAX); + client.add(&1); + + // Then + // panic + } +} diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-1/vulnerable-example/src/lib.rs b/test-cases/dynamic-instance-storage/dynamic-instance-storage-1/vulnerable-example/src/lib.rs deleted file mode 100644 index 10f79de7..00000000 --- a/test-cases/dynamic-instance-storage/dynamic-instance-storage-1/vulnerable-example/src/lib.rs +++ /dev/null @@ -1,43 +0,0 @@ -#![no_std] -use soroban_sdk::{contract, contractimpl, Env, Symbol, Vec}; - -#[contract] -pub struct VectorStorage; - -#[contractimpl] -impl VectorStorage { - pub fn store_vector(e: Env, data: Vec) { - e.storage() - .instance() - .set(&Symbol::new(&e, "vector_data"), &data); - } - - pub fn get_vector(e: Env) -> Vec { - e.storage() - .instance() - .get(&Symbol::new(&e, "vector_data")) - .unwrap() - } -} - -#[cfg(test)] -mod test { - use super::*; - use soroban_sdk::{vec, Env}; - - #[test] - fn test_vector_storage() { - // Given - let env = Env::default(); - let contract_id = env.register_contract(None, VectorStorage); - let client = VectorStorageClient::new(&env, &contract_id); - - // When - let test_vec = vec![&env, 1, 2, 3, 4, 5]; - client.store_vector(&test_vec); - - // Then - let retrieved_vec = client.get_vector(); - assert_eq!(test_vec, retrieved_vec); - } -} diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-2/remediated-example/src/lib.rs b/test-cases/dynamic-instance-storage/dynamic-instance-storage-2/remediated-example/src/lib.rs deleted file mode 100644 index 354b0abc..00000000 --- a/test-cases/dynamic-instance-storage/dynamic-instance-storage-2/remediated-example/src/lib.rs +++ /dev/null @@ -1,43 +0,0 @@ -#![no_std] -use soroban_sdk::{contract, contractimpl, Env, String, Symbol}; - -#[contract] -pub struct StringStorage; - -#[contractimpl] -impl StringStorage { - pub fn store_string(e: Env, data: String) { - e.storage() - .persistent() - .set(&Symbol::new(&e, "string_data"), &data); - } - - pub fn get_string(e: Env) -> String { - e.storage() - .persistent() - .get(&Symbol::new(&e, "string_data")) - .unwrap() - } -} - -#[cfg(test)] -mod test { - use super::*; - use soroban_sdk::Env; - - #[test] - fn test_string_storage() { - // Given - let env = Env::default(); - let contract_id = env.register_contract(None, StringStorage); - let client = StringStorageClient::new(&env, &contract_id); - - // When - let test_string = String::from_str(&env, "Hello, Soroban!"); - client.store_string(&test_string); - - // Then - let retrieved_string = client.get_string(); - assert_eq!(test_string, retrieved_string); - } -} diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-2/vulnerable-example/src/lib.rs b/test-cases/dynamic-instance-storage/dynamic-instance-storage-2/vulnerable-example/src/lib.rs deleted file mode 100644 index 79a36aa0..00000000 --- a/test-cases/dynamic-instance-storage/dynamic-instance-storage-2/vulnerable-example/src/lib.rs +++ /dev/null @@ -1,43 +0,0 @@ -#![no_std] -use soroban_sdk::{contract, contractimpl, Env, String, Symbol}; - -#[contract] -pub struct StringStorage; - -#[contractimpl] -impl StringStorage { - pub fn store_string(e: Env, data: String) { - e.storage() - .instance() - .set(&Symbol::new(&e, "string_data"), &data); - } - - pub fn get_string(e: Env) -> String { - e.storage() - .instance() - .get(&Symbol::new(&e, "string_data")) - .unwrap() - } -} - -#[cfg(test)] -mod test { - use super::*; - use soroban_sdk::Env; - - #[test] - fn test_string_storage() { - // Given - let env = Env::default(); - let contract_id = env.register_contract(None, StringStorage); - let client = StringStorageClient::new(&env, &contract_id); - - // When - let test_string = String::from_str(&env, "Hello, Soroban!"); - client.store_string(&test_string); - - // Then - let retrieved_string = client.get_string(); - assert_eq!(test_string, retrieved_string); - } -} diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-3/remediated-example/Cargo.toml b/test-cases/dynamic-instance-storage/dynamic-instance-storage-3/remediated-example/Cargo.toml deleted file mode 100644 index fa1d3059..00000000 --- a/test-cases/dynamic-instance-storage/dynamic-instance-storage-3/remediated-example/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -edition = "2021" -name = "dynamic-instance-storage-remediated-3" -version = "0.1.0" - -[lib] -crate-type = ["cdylib"] - -[dependencies] -soroban-sdk = { workspace = true } - -[dev-dependencies] -soroban-sdk = { workspace = true, features = ["testutils"] } - -[features] -testutils = ["soroban-sdk/testutils"] diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-3/vulnerable-example/Cargo.toml b/test-cases/dynamic-instance-storage/dynamic-instance-storage-3/vulnerable-example/Cargo.toml deleted file mode 100644 index 5f0d3c3d..00000000 --- a/test-cases/dynamic-instance-storage/dynamic-instance-storage-3/vulnerable-example/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -edition = "2021" -name = "dynamic-instance-storage-vulnerable-3" -version = "0.1.0" - -[lib] -crate-type = ["cdylib"] - -[dependencies] -soroban-sdk = { workspace = true } - -[dev-dependencies] -soroban-sdk = { workspace = true, features = ["testutils"] } - -[features] -testutils = ["soroban-sdk/testutils"] diff --git a/test-cases/dynamic-storage/Cargo.toml b/test-cases/dynamic-storage/Cargo.toml new file mode 100644 index 00000000..6ac24264 --- /dev/null +++ b/test-cases/dynamic-storage/Cargo.toml @@ -0,0 +1,21 @@ +[workspace] +exclude = [".cargo", "target"] +members = ["dynamic-storage-*/*"] +resolver = "2" + +[workspace.dependencies] +soroban-sdk = { version = "=21.5.1" } + +[profile.release] +codegen-units = 1 +debug = 0 +debug-assertions = false +lto = true +opt-level = "z" +overflow-checks = true +panic = "abort" +strip = "symbols" + +[profile.release-with-logs] +debug-assertions = true +inherits = "release" diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-2/vulnerable-example/Cargo.toml b/test-cases/dynamic-storage/dynamic-storage-1/remediated-example/Cargo.toml similarity index 84% rename from test-cases/dynamic-instance-storage/dynamic-instance-storage-2/vulnerable-example/Cargo.toml rename to test-cases/dynamic-storage/dynamic-storage-1/remediated-example/Cargo.toml index 9241c0f4..f096e323 100644 --- a/test-cases/dynamic-instance-storage/dynamic-instance-storage-2/vulnerable-example/Cargo.toml +++ b/test-cases/dynamic-storage/dynamic-storage-1/remediated-example/Cargo.toml @@ -1,6 +1,6 @@ [package] edition = "2021" -name = "dynamic-instance-storage-vulnerable-2" +name = "dynamic-storage-remediated-1" version = "0.1.0" [lib] diff --git a/test-cases/dynamic-storage/dynamic-storage-1/remediated-example/src/lib.rs b/test-cases/dynamic-storage/dynamic-storage-1/remediated-example/src/lib.rs new file mode 100644 index 00000000..db7a3faa --- /dev/null +++ b/test-cases/dynamic-storage/dynamic-storage-1/remediated-example/src/lib.rs @@ -0,0 +1,63 @@ +#![no_std] +use soroban_sdk::{contract, contractimpl, contracttype, Env, Vec}; + +#[contract] +pub struct VectorStorage; + +#[derive(Clone)] +#[contracttype] +pub enum DataKey { + VecElement(u32), +} + +#[contractimpl] +impl VectorStorage { + pub fn store_vector(e: Env, data: Vec) { + for (i, value) in data.iter().enumerate() { + let key = DataKey::VecElement(i as u32); + e.storage().persistent().set(&key, &value); + } + } + + pub fn get_vector(e: Env) -> Vec { + let mut result = Vec::new(&e); + let mut i = 0; + + while let Some(value) = VectorStorage::get_element(e.clone(), i) { + result.push_back(value); + i += 1; + } + + result + } + + pub fn get_element(e: Env, index: u32) -> Option { + let key = DataKey::VecElement(index); + e.storage().persistent().get(&key) + } +} + +#[cfg(test)] +mod test { + use super::*; + use soroban_sdk::{vec, Env}; + + #[test] + fn test_vector_storage() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, VectorStorage); + let client = VectorStorageClient::new(&env, &contract_id); + + // When + let test_vec = vec![&env, 1, 2, 3, 4, 5]; + client.store_vector(&test_vec); + + // Then + let retrieved_vec = client.get_vector(); + assert_eq!(test_vec, retrieved_vec); + + assert_eq!(client.get_element(&2), Some(3)); + assert_eq!(client.get_element(&5), None); + } +} diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-1/vulnerable-example/Cargo.toml b/test-cases/dynamic-storage/dynamic-storage-1/vulnerable-example/Cargo.toml similarity index 84% rename from test-cases/dynamic-instance-storage/dynamic-instance-storage-1/vulnerable-example/Cargo.toml rename to test-cases/dynamic-storage/dynamic-storage-1/vulnerable-example/Cargo.toml index 7122885e..f1ef1214 100644 --- a/test-cases/dynamic-instance-storage/dynamic-instance-storage-1/vulnerable-example/Cargo.toml +++ b/test-cases/dynamic-storage/dynamic-storage-1/vulnerable-example/Cargo.toml @@ -1,6 +1,6 @@ [package] edition = "2021" -name = "dynamic-instance-storage-vulnerable-1" +name = "dynamic-storage-vulnerable-1" version = "0.1.0" [lib] diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-1/remediated-example/src/lib.rs b/test-cases/dynamic-storage/dynamic-storage-1/vulnerable-example/src/lib.rs similarity index 100% rename from test-cases/dynamic-instance-storage/dynamic-instance-storage-1/remediated-example/src/lib.rs rename to test-cases/dynamic-storage/dynamic-storage-1/vulnerable-example/src/lib.rs diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-2/remediated-example/Cargo.toml b/test-cases/dynamic-storage/dynamic-storage-2/remediated-example/Cargo.toml similarity index 84% rename from test-cases/dynamic-instance-storage/dynamic-instance-storage-2/remediated-example/Cargo.toml rename to test-cases/dynamic-storage/dynamic-storage-2/remediated-example/Cargo.toml index 7d90d213..e0bfcd9c 100644 --- a/test-cases/dynamic-instance-storage/dynamic-instance-storage-2/remediated-example/Cargo.toml +++ b/test-cases/dynamic-storage/dynamic-storage-2/remediated-example/Cargo.toml @@ -1,6 +1,6 @@ [package] edition = "2021" -name = "dynamic-instance-storage-remediated-2" +name = "dynamic-storage-remediated-2" version = "0.1.0" [lib] diff --git a/test-cases/dynamic-storage/dynamic-storage-2/remediated-example/src/lib.rs b/test-cases/dynamic-storage/dynamic-storage-2/remediated-example/src/lib.rs new file mode 100644 index 00000000..de7647e6 --- /dev/null +++ b/test-cases/dynamic-storage/dynamic-storage-2/remediated-example/src/lib.rs @@ -0,0 +1,44 @@ +#![no_std] +use soroban_sdk::{contract, contractimpl, Env, Map, Symbol}; + +#[contract] +pub struct MapStorage; + +#[contractimpl] +impl MapStorage { + pub fn store_map(e: Env, data: Map) { + data.iter().for_each(|(key, value)| { + e.storage().persistent().set(&key, &value); + }); + } + + pub fn get_key(e: Env, key: Symbol) -> i32 { + e.storage().persistent().get(&key).unwrap() + } +} + +#[cfg(test)] +mod test { + use super::*; + use soroban_sdk::{symbol_short, Env, Map}; + + #[test] + fn test_map_storage() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, MapStorage); + let client = MapStorageClient::new(&env, &contract_id); + + // When + let mut test_map = Map::new(&env); + test_map.set(symbol_short!("key1"), 1); + test_map.set(symbol_short!("key2"), 2); + client.store_map(&test_map); + + // Then + let key1 = client.get_key(&symbol_short!("key1")); + let key2 = client.get_key(&symbol_short!("key2")); + assert_eq!(test_map.get(symbol_short!("key1")).unwrap(), key1); + assert_eq!(test_map.get(symbol_short!("key2")).unwrap(), key2); + } +} diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-1/remediated-example/Cargo.toml b/test-cases/dynamic-storage/dynamic-storage-2/vulnerable-example/Cargo.toml similarity index 84% rename from test-cases/dynamic-instance-storage/dynamic-instance-storage-1/remediated-example/Cargo.toml rename to test-cases/dynamic-storage/dynamic-storage-2/vulnerable-example/Cargo.toml index 6da39cf1..93954019 100644 --- a/test-cases/dynamic-instance-storage/dynamic-instance-storage-1/remediated-example/Cargo.toml +++ b/test-cases/dynamic-storage/dynamic-storage-2/vulnerable-example/Cargo.toml @@ -1,6 +1,6 @@ [package] edition = "2021" -name = "dynamic-instance-storage-remediated-1" +name = "dynamic-storage-vulnerable-2" version = "0.1.0" [lib] diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-3/remediated-example/src/lib.rs b/test-cases/dynamic-storage/dynamic-storage-2/vulnerable-example/src/lib.rs similarity index 100% rename from test-cases/dynamic-instance-storage/dynamic-instance-storage-3/remediated-example/src/lib.rs rename to test-cases/dynamic-storage/dynamic-storage-2/vulnerable-example/src/lib.rs diff --git a/test-cases/dynamic-storage/dynamic-storage-3/remediated-example/Cargo.toml b/test-cases/dynamic-storage/dynamic-storage-3/remediated-example/Cargo.toml new file mode 100644 index 00000000..32465e65 --- /dev/null +++ b/test-cases/dynamic-storage/dynamic-storage-3/remediated-example/Cargo.toml @@ -0,0 +1,16 @@ +[package] +edition = "2021" +name = "dynamic-storage-remediated-3" +version = "0.1.0" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +soroban-sdk = { workspace = true } + +[dev-dependencies] +soroban-sdk = { workspace = true, features = ["testutils"] } + +[features] +testutils = ["soroban-sdk/testutils"] diff --git a/test-cases/dynamic-storage/dynamic-storage-3/remediated-example/src/lib.rs b/test-cases/dynamic-storage/dynamic-storage-3/remediated-example/src/lib.rs new file mode 100644 index 00000000..60f47e41 --- /dev/null +++ b/test-cases/dynamic-storage/dynamic-storage-3/remediated-example/src/lib.rs @@ -0,0 +1,44 @@ +#![no_std] +use soroban_sdk::{contract, contractimpl, Env, Map, Symbol}; + +#[contract] +pub struct MapStorage; + +#[contractimpl] +impl MapStorage { + pub fn store_map(e: Env, data: Map) { + data.iter().for_each(|(key, value)| { + e.storage().instance().set(&key, &value); + }); + } + + pub fn get_key(e: Env, key: Symbol) -> i32 { + e.storage().instance().get(&key).unwrap() + } +} + +#[cfg(test)] +mod test { + use super::*; + use soroban_sdk::{symbol_short, Env, Map}; + + #[test] + fn test_map_storage() { + // Given + let env = Env::default(); + let contract_id = env.register_contract(None, MapStorage); + let client = MapStorageClient::new(&env, &contract_id); + + // When + let mut test_map = Map::new(&env); + test_map.set(symbol_short!("key1"), 1); + test_map.set(symbol_short!("key2"), 2); + client.store_map(&test_map); + + // Then + let key1 = client.get_key(&symbol_short!("key1")); + let key2 = client.get_key(&symbol_short!("key2")); + assert_eq!(test_map.get(symbol_short!("key1")).unwrap(), key1); + assert_eq!(test_map.get(symbol_short!("key2")).unwrap(), key2); + } +} diff --git a/test-cases/dynamic-storage/dynamic-storage-3/vulnerable-example/Cargo.toml b/test-cases/dynamic-storage/dynamic-storage-3/vulnerable-example/Cargo.toml new file mode 100644 index 00000000..37605628 --- /dev/null +++ b/test-cases/dynamic-storage/dynamic-storage-3/vulnerable-example/Cargo.toml @@ -0,0 +1,16 @@ +[package] +edition = "2021" +name = "dynamic-storage-vulnerable-3" +version = "0.1.0" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +soroban-sdk = { workspace = true } + +[dev-dependencies] +soroban-sdk = { workspace = true, features = ["testutils"] } + +[features] +testutils = ["soroban-sdk/testutils"] diff --git a/test-cases/dynamic-instance-storage/dynamic-instance-storage-3/vulnerable-example/src/lib.rs b/test-cases/dynamic-storage/dynamic-storage-3/vulnerable-example/src/lib.rs similarity index 100% rename from test-cases/dynamic-instance-storage/dynamic-instance-storage-3/vulnerable-example/src/lib.rs rename to test-cases/dynamic-storage/dynamic-storage-3/vulnerable-example/src/lib.rs diff --git a/test-cases/dynamic-instance-storage/Cargo.toml b/test-cases/storage-change-events/Cargo.toml similarity index 80% rename from test-cases/dynamic-instance-storage/Cargo.toml rename to test-cases/storage-change-events/Cargo.toml index 2af55099..c2fabcfe 100644 --- a/test-cases/dynamic-instance-storage/Cargo.toml +++ b/test-cases/storage-change-events/Cargo.toml @@ -1,10 +1,11 @@ [workspace] exclude = [".cargo", "target"] -members = ["dynamic-instance-storage-*/*"] +members = ["storage-change-events-*/*"] resolver = "2" [workspace.dependencies] soroban-sdk = { version = "=21.4.0" } +soroban-token-sdk = { version = "21.4.0" } [profile.release] codegen-units = 1 @@ -19,3 +20,4 @@ strip = "symbols" [profile.release-with-logs] debug-assertions = true inherits = "release" + diff --git a/test-cases/storage-change-events/storage-change-events-1/remediated-example/Cargo.toml b/test-cases/storage-change-events/storage-change-events-1/remediated-example/Cargo.toml new file mode 100644 index 00000000..dd120b7f --- /dev/null +++ b/test-cases/storage-change-events/storage-change-events-1/remediated-example/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "storage-change-events-remediated-1" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +soroban-sdk = "=21.4.0" +soroban-token-sdk = { version = "21.4.0" } + + +[dev_dependencies] +soroban-sdk = { version = "=21.4.0", features = ["testutils"] } +soroban-token-sdk = { version = "21.4.0" } + + +[features] +testutils = ["soroban-sdk/testutils"] + +[profile.release] +opt-level = "z" +overflow-checks = true +debug = 0 +strip = "symbols" +debug-assertions = false +panic = "abort" +codegen-units = 1 +lto = true + +[profile.release-with-logs] +inherits = "release" +debug-assertions = true + diff --git a/test-cases/storage-change-events/storage-change-events-1/remediated-example/src/lib.rs b/test-cases/storage-change-events/storage-change-events-1/remediated-example/src/lib.rs new file mode 100644 index 00000000..dd92fafe --- /dev/null +++ b/test-cases/storage-change-events/storage-change-events-1/remediated-example/src/lib.rs @@ -0,0 +1,82 @@ +#![no_std] + +use soroban_sdk::{ + contract, contracterror, contractimpl, contracttype, symbol_short, Address, Env, Symbol, +}; + +#[derive(Clone, Debug)] +#[contracttype] +pub struct CounterState { + admin: Address, + count: u32, +} + +const STATE: Symbol = symbol_short!("STATE"); +const COUNTER: Symbol = symbol_short!("COUNTER"); + +#[contracterror] +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +#[repr(u32)] +pub enum SCError { + AlreadyInitialized = 1, + NotInitialized = 2, + FailedToRetrieveState = 3, +} + +#[contract] +pub struct StorageChangeEvents; + +#[contractimpl] +impl StorageChangeEvents { + pub fn initialize(env: Env, admin: Address) -> Result<(), SCError> { + let current_state = Self::get_state(env.clone()); + if current_state.is_ok() { + return Err(SCError::AlreadyInitialized); + } + + env.storage().instance().set( + &STATE, + &CounterState { + admin: admin.clone(), + count: 0, + }, + ); + + env.events() + .publish((COUNTER, symbol_short!("init")), admin); + Ok(()) + } + + pub fn increase_counter(env: Env) -> Result<(), SCError> { + let mut counter = Self::get_state(env.clone())?; + counter.count += 1; + env.storage().instance().set(&STATE, &counter); + env.events() + .publish((COUNTER, symbol_short!("increase")), counter.count); + Ok(()) + } + + pub fn set_counter_indirectly(env: Env, number: u32) -> Result<(), SCError> { + let mut counter = Self::get_state(env.clone())?; + counter.admin.require_auth(); + counter.count = number; + Self::set_counter(env, counter); + + Ok(()) + } + + fn set_counter(env: Env, counter: CounterState) { + env.storage().instance().set(&STATE, &counter); + env.events() + .publish((COUNTER, symbol_short!("set")), counter.count); + } + + pub fn get_state(env: Env) -> Result { + let state_op: Option = env.storage().instance().get(&STATE); + if let Some(state) = state_op { + Ok(state) + } else { + Err(SCError::FailedToRetrieveState) + } + } +} diff --git a/test-cases/storage-change-events/storage-change-events-1/vulnerable-example/Cargo.toml b/test-cases/storage-change-events/storage-change-events-1/vulnerable-example/Cargo.toml new file mode 100644 index 00000000..e459c2e3 --- /dev/null +++ b/test-cases/storage-change-events/storage-change-events-1/vulnerable-example/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "storage-change-events-vulnerable-1" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +soroban-sdk = "=21.4.0" +soroban-token-sdk = { version = "21.4.0" } + + +[dev_dependencies] +soroban-sdk = { version = "=21.4.0", features = ["testutils"] } +soroban-token-sdk = { version = "21.4.0" } + + +[features] +testutils = ["soroban-sdk/testutils"] + +[profile.release] +opt-level = "z" +overflow-checks = true +debug = 0 +strip = "symbols" +debug-assertions = false +panic = "abort" +codegen-units = 1 +lto = true + +[profile.release-with-logs] +inherits = "release" +debug-assertions = true + diff --git a/test-cases/storage-change-events/storage-change-events-1/vulnerable-example/src/lib.rs b/test-cases/storage-change-events/storage-change-events-1/vulnerable-example/src/lib.rs new file mode 100644 index 00000000..97079247 --- /dev/null +++ b/test-cases/storage-change-events/storage-change-events-1/vulnerable-example/src/lib.rs @@ -0,0 +1,72 @@ +#![no_std] + +use soroban_sdk::{ + contract, contracterror, contractimpl, contracttype, symbol_short, Address, Env, Symbol, +}; + +#[derive(Clone, Debug)] +#[contracttype] +pub struct CounterState { + admin: Address, + count: u32, +} + +const STATE: Symbol = symbol_short!("STATE"); + +#[contracterror] +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +#[repr(u32)] +pub enum SCError { + AlreadyInitialized = 1, + NotInitialized = 2, + FailedToRetrieveState = 3, +} + +#[contract] +pub struct StorageChangeEvents; + +#[contractimpl] +impl StorageChangeEvents { + pub fn initialize(env: Env, admin: Address) -> Result<(), SCError> { + let current_state = Self::get_state(env.clone()); + if current_state.is_ok() { + return Err(SCError::AlreadyInitialized); + } + + env.storage() + .instance() + .set(&STATE, &CounterState { admin, count: 0 }); + + Ok(()) + } + + pub fn increase_counter(env: Env) -> Result<(), SCError> { + let mut counter = Self::get_state(env.clone())?; + counter.count += 1; + env.storage().instance().set(&STATE, &counter); + + Ok(()) + } + + pub fn set_counter_indirectly(env: Env, number: u32) -> Result<(), SCError> { + let mut counter = Self::get_state(env.clone())?; + counter.admin.require_auth(); + counter.count = number; + Self::set_counter(env, counter); + + Ok(()) + } + + fn set_counter(env: Env, counter: CounterState) { + env.storage().instance().set(&STATE, &counter); + } + + pub fn get_state(env: Env) -> Result { + let state_op: Option = env.storage().instance().get(&STATE); + if let Some(state) = state_op { + Ok(state) + } else { + Err(SCError::FailedToRetrieveState) + } + } +} diff --git a/test-cases/token-interface-events/Cargo.toml b/test-cases/token-interface-events/Cargo.toml new file mode 100644 index 00000000..257b9d87 --- /dev/null +++ b/test-cases/token-interface-events/Cargo.toml @@ -0,0 +1,23 @@ +[workspace] +exclude = [".cargo", "target"] +members = ["token-interface-events-*/*"] +resolver = "2" + +[workspace.dependencies] +soroban-sdk = { version = "=21.4.0" } +soroban-token-sdk = { version = "21.4.0" } + +[profile.release] +codegen-units = 1 +debug = 0 +debug-assertions = false +lto = true +opt-level = "z" +overflow-checks = true +panic = "abort" +strip = "symbols" + +[profile.release-with-logs] +debug-assertions = true +inherits = "release" + diff --git a/test-cases/token-interface-events/token-interface-events-1/remediated-example/Cargo.toml b/test-cases/token-interface-events/token-interface-events-1/remediated-example/Cargo.toml new file mode 100644 index 00000000..cd132465 --- /dev/null +++ b/test-cases/token-interface-events/token-interface-events-1/remediated-example/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "token-interface-events-remediated-1" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +soroban-sdk = "=21.4.0" +soroban-token-sdk = { version = "21.4.0" } + + +[dev_dependencies] +soroban-sdk = { version = "=21.4.0", features = ["testutils"] } +soroban-token-sdk = { version = "21.4.0" } + + +[features] +testutils = ["soroban-sdk/testutils"] + +[profile.release] +opt-level = "z" +overflow-checks = true +debug = 0 +strip = "symbols" +debug-assertions = false +panic = "abort" +codegen-units = 1 +lto = true + +[profile.release-with-logs] +inherits = "release" +debug-assertions = true + diff --git a/test-cases/token-interface-events/token-interface-events-1/remediated-example/src/lib.rs b/test-cases/token-interface-events/token-interface-events-1/remediated-example/src/lib.rs new file mode 100644 index 00000000..5417416e --- /dev/null +++ b/test-cases/token-interface-events/token-interface-events-1/remediated-example/src/lib.rs @@ -0,0 +1,225 @@ +#![no_std] + +use soroban_sdk::{ + contract, contracterror, contractimpl, contracttype, token, Address, Env, String, +}; + +use soroban_sdk::token::TokenInterface; +use soroban_token_sdk::TokenUtils; + +#[derive(Clone)] +#[contracttype] +pub struct TokenMetadata { + pub decimals: u32, + pub name: String, + pub symbol: String, + pub admin: Address, +} + +#[derive(Clone, Default)] +#[contracttype] +pub struct AllowanceFromSpender { + pub amount: i128, + pub expiration_ledger: u32, +} + +#[derive(Clone)] +#[contracttype] +pub enum DataKey { + Balance(Address), + TokenMetadata, + AllowanceFromSpender(Address, Address), +} + +#[contracterror] +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +#[repr(u32)] +pub enum VTError { + AlreadyInitialized = 1, + NotInitialized = 2, +} + +#[contract] +pub struct TokenInterfaceEvents; + +#[contractimpl] +impl TokenInterfaceEvents { + pub fn initialize( + env: Env, + admin: Address, + decimals: u32, + name: String, + symbol: String, + ) -> Result<(), VTError> { + let current_token_metadata: Option = + env.storage().instance().get(&DataKey::TokenMetadata); + if current_token_metadata.is_some() { + return Err(VTError::AlreadyInitialized); + } else { + env.storage().instance().set( + &DataKey::TokenMetadata, + &TokenMetadata { + decimals, + name, + symbol, + admin, + }, + ); + } + + Ok(()) + } + + pub fn get_metadata(env: Env) -> TokenMetadata { + env.storage() + .instance() + .get(&DataKey::TokenMetadata) + .unwrap() + } + + pub fn mint(env: Env, to: Address, amount: i128) { + Self::get_metadata(env.clone()).admin.require_auth(); + let previous_balance: i128 = env + .clone() + .storage() + .instance() + .get(&DataKey::Balance(to.clone())) + .unwrap_or(0); + env.storage() + .instance() + .set(&DataKey::Balance(to), &(previous_balance + amount)); + } + + fn get_allowance(env: Env, from: Address, spender: Address) -> AllowanceFromSpender { + env.storage() + .instance() + .get(&DataKey::AllowanceFromSpender(from, spender)) + .unwrap_or_default() + } + + fn set_allowance( + env: Env, + from: Address, + spender: Address, + amount: i128, + expiration_ledger: u32, + ) { + env.storage().instance().set( + &DataKey::AllowanceFromSpender(from.clone(), spender.clone()), + &AllowanceFromSpender { + amount, + expiration_ledger, + }, + ); + TokenUtils::new(&env) + .events() + .approve(from, spender, amount, expiration_ledger); + } +} + +#[contractimpl] +impl token::TokenInterface for TokenInterfaceEvents { + fn allowance(env: Env, from: Address, spender: Address) -> i128 { + let allowance = Self::get_allowance(env.clone(), from, spender); + if allowance.expiration_ledger < env.ledger().sequence() { + 0 + } else { + allowance.amount + } + } + + fn approve(env: Env, from: Address, spender: Address, amount: i128, expiration_ledger: u32) { + from.require_auth(); + assert!(env.ledger().sequence() < expiration_ledger || amount == 0); + + // This function emits the event, so the warning will not come up + Self::set_allowance(env, from, spender, amount, expiration_ledger); + } + + fn balance(env: Env, id: Address) -> i128 { + env.storage() + .instance() + .get(&DataKey::Balance(id)) + .unwrap_or(0) + } + + fn transfer(env: Env, from: Address, to: Address, amount: i128) { + from.require_auth(); + let from_balance = Self::balance(env.clone(), from.clone()); + let to_balance = Self::balance(env.clone(), to.clone()); + assert!(from_balance >= amount); + env.storage() + .instance() + .set(&DataKey::Balance(from.clone()), &(from_balance - amount)); + env.storage() + .instance() + .set(&DataKey::Balance(to.clone()), &(to_balance + amount)); + + TokenUtils::new(&env).events().transfer(from, to, amount); + } + + fn transfer_from(env: Env, spender: Address, from: Address, to: Address, amount: i128) { + let spender_allowance = Self::allowance(env.clone(), from.clone(), spender.clone()); + assert!(spender_allowance >= amount); + + let from_balance = Self::balance(env.clone(), from.clone()); + let to_balance = Self::balance(env.clone(), to.clone()); + assert!(from_balance >= amount); + env.storage() + .instance() + .set(&DataKey::Balance(from.clone()), &(from_balance - amount)); + env.storage() + .instance() + .set(&DataKey::Balance(to.clone()), &(to_balance + amount)); + + let mut allowance = Self::get_allowance(env.clone(), from.clone(), spender.clone()); + allowance.amount -= amount; + + env.storage().instance().set( + &DataKey::AllowanceFromSpender(from.clone(), spender), + &allowance, + ); + + TokenUtils::new(&env).events().transfer(from, to, amount); + } + + fn burn(env: Env, from: Address, amount: i128) { + from.require_auth(); + let from_balance = Self::balance(env.clone(), from.clone()); + assert!(from_balance >= amount); + env.storage() + .instance() + .set(&DataKey::Balance(from.clone()), &(from_balance - amount)); + TokenUtils::new(&env).events().burn(from, amount); + } + + fn burn_from(env: Env, spender: Address, from: Address, amount: i128) { + let spender_allowance = Self::allowance(env.clone(), from.clone(), spender.clone()); + assert!(spender_allowance >= amount); + let from_balance = Self::balance(env.clone(), from.clone()); + assert!(from_balance >= amount); + env.storage() + .instance() + .set(&DataKey::Balance(from.clone()), &(from_balance - amount)); + + let mut allowance = Self::get_allowance(env.clone(), from.clone(), spender.clone()); + allowance.amount -= amount; + env.storage().instance().set( + &DataKey::AllowanceFromSpender(from.clone(), spender), + &allowance, + ); + TokenUtils::new(&env).events().burn(from, amount); + } + + fn decimals(env: Env) -> u32 { + Self::get_metadata(env).decimals + } + + fn name(env: Env) -> String { + Self::get_metadata(env).name + } + + fn symbol(env: Env) -> String { + Self::get_metadata(env).symbol + } +} diff --git a/test-cases/token-interface-events/token-interface-events-1/vulnerable-example/Cargo.toml b/test-cases/token-interface-events/token-interface-events-1/vulnerable-example/Cargo.toml new file mode 100644 index 00000000..8fd75f45 --- /dev/null +++ b/test-cases/token-interface-events/token-interface-events-1/vulnerable-example/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "token-interface-events-vulnerable-1" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +soroban-sdk = "=21.4.0" +soroban-token-sdk = { version = "21.4.0" } + + +[dev_dependencies] +soroban-sdk = { version = "=21.4.0", features = ["testutils"] } +soroban-token-sdk = { version = "21.4.0" } + + +[features] +testutils = ["soroban-sdk/testutils"] + +[profile.release] +opt-level = "z" +overflow-checks = true +debug = 0 +strip = "symbols" +debug-assertions = false +panic = "abort" +codegen-units = 1 +lto = true + +[profile.release-with-logs] +inherits = "release" +debug-assertions = true + diff --git a/test-cases/token-interface-events/token-interface-events-1/vulnerable-example/src/lib.rs b/test-cases/token-interface-events/token-interface-events-1/vulnerable-example/src/lib.rs new file mode 100644 index 00000000..44f752e0 --- /dev/null +++ b/test-cases/token-interface-events/token-interface-events-1/vulnerable-example/src/lib.rs @@ -0,0 +1,201 @@ +#![no_std] + +use soroban_sdk::{ + contract, contracterror, contractimpl, contracttype, token, Address, Env, String, +}; + +use soroban_sdk::token::TokenInterface; + +#[derive(Clone, Debug)] +#[contracttype] +pub struct TokenMetadata { + pub decimals: u32, + pub name: String, + pub symbol: String, + pub admin: Address, +} + +#[derive(Clone, Default)] +#[contracttype] +pub struct AllowanceFromSpender { + pub amount: i128, + pub expiration_ledger: u32, +} + +#[derive(Clone)] +#[contracttype] +pub enum DataKey { + Balance(Address), + TokenMetadata, + AllowanceFromSpender(Address, Address), +} + +#[contracterror] +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +#[repr(u32)] +pub enum VTError { + AlreadyInitialized = 1, + NotInitialized = 2, +} + +#[contract] +pub struct TokenInterfaceEvents; + +#[contractimpl] +impl TokenInterfaceEvents { + pub fn initialize( + env: Env, + admin: Address, + decimals: u32, + name: String, + symbol: String, + ) -> Result<(), VTError> { + let current_token_metadata: Option = + env.storage().instance().get(&DataKey::TokenMetadata); + if current_token_metadata.is_some() { + return Err(VTError::AlreadyInitialized); + } else { + env.storage().instance().set( + &DataKey::TokenMetadata, + &TokenMetadata { + decimals, + name, + symbol, + admin, + }, + ); + } + + Ok(()) + } + + pub fn get_metadata(env: Env) -> TokenMetadata { + env.storage() + .instance() + .get(&DataKey::TokenMetadata) + .unwrap() + } + + pub fn mint(env: Env, to: Address, amount: i128) { + Self::get_metadata(env.clone()).admin.require_auth(); + let previous_balance: i128 = env + .clone() + .storage() + .instance() + .get(&DataKey::Balance(to.clone())) + .unwrap_or(0); + env.storage() + .instance() + .set(&DataKey::Balance(to), &(previous_balance + amount)); + } + + fn get_allowance(env: Env, from: Address, spender: Address) -> AllowanceFromSpender { + env.storage() + .instance() + .get(&DataKey::AllowanceFromSpender(from, spender)) + .unwrap_or_default() + } +} + +#[contractimpl] +impl token::TokenInterface for TokenInterfaceEvents { + fn allowance(env: Env, from: Address, spender: Address) -> i128 { + let allowance = Self::get_allowance(env.clone(), from, spender); + if allowance.expiration_ledger < env.ledger().sequence() { + 0 + } else { + allowance.amount + } + } + + fn approve(env: Env, from: Address, spender: Address, amount: i128, expiration_ledger: u32) { + from.require_auth(); + assert!(env.ledger().sequence() < expiration_ledger || amount == 0); + env.storage().instance().set( + &DataKey::AllowanceFromSpender(from.clone(), spender.clone()), + &AllowanceFromSpender { + amount, + expiration_ledger, + }, + ); + } + + fn balance(env: Env, id: Address) -> i128 { + env.storage() + .instance() + .get(&DataKey::Balance(id)) + .unwrap_or(0) + } + + fn transfer(env: Env, from: Address, to: Address, amount: i128) { + from.require_auth(); + let from_balance = Self::balance(env.clone(), from.clone()); + let to_balance = Self::balance(env.clone(), to.clone()); + assert!(from_balance >= amount); + env.storage() + .instance() + .set(&DataKey::Balance(from), &(from_balance - amount)); + env.storage() + .instance() + .set(&DataKey::Balance(to), &(to_balance + amount)); + } + + fn transfer_from(env: Env, spender: Address, from: Address, to: Address, amount: i128) { + let spender_allowance = Self::allowance(env.clone(), from.clone(), spender.clone()); + assert!(spender_allowance >= amount); + + let from_balance = Self::balance(env.clone(), from.clone()); + let to_balance = Self::balance(env.clone(), to.clone()); + assert!(from_balance >= amount); + env.storage() + .instance() + .set(&DataKey::Balance(from.clone()), &(from_balance - amount)); + env.storage() + .instance() + .set(&DataKey::Balance(to.clone()), &(to_balance + amount)); + + let mut allowance = Self::get_allowance(env.clone(), from.clone(), spender.clone()); + allowance.amount -= amount; + + env.storage() + .instance() + .set(&DataKey::AllowanceFromSpender(from, spender), &allowance); + } + + fn burn(env: Env, from: Address, amount: i128) { + from.require_auth(); + let from_balance = Self::balance(env.clone(), from.clone()); + assert!(from_balance >= amount); + env.storage() + .instance() + .set(&DataKey::Balance(from), &(from_balance - amount)); + } + + fn burn_from(env: Env, spender: Address, from: Address, amount: i128) { + let spender_allowance = Self::allowance(env.clone(), from.clone(), spender.clone()); + assert!(spender_allowance >= amount); + let from_balance = Self::balance(env.clone(), from.clone()); + assert!(from_balance >= amount); + env.storage() + .instance() + .set(&DataKey::Balance(from.clone()), &(from_balance - amount)); + + let mut allowance = Self::get_allowance(env.clone(), from.clone(), spender.clone()); + allowance.amount -= amount; + env.storage() + .instance() + .set(&DataKey::AllowanceFromSpender(from, spender), &allowance); + } + fn decimals(env: Env) -> u32 { + Self::get_metadata(env).decimals + } + fn name(env: Env) -> String { + Self::get_metadata(env).name + } + fn symbol(env: Env) -> String { + Self::get_metadata(env).symbol + } +} + +#[cfg(test)] +mod test; diff --git a/test-cases/token-interface-events/token-interface-events-1/vulnerable-example/src/test.rs b/test-cases/token-interface-events/token-interface-events-1/vulnerable-example/src/test.rs new file mode 100644 index 00000000..a55f91c4 --- /dev/null +++ b/test-cases/token-interface-events/token-interface-events-1/vulnerable-example/src/test.rs @@ -0,0 +1,148 @@ +use super::*; +use soroban_sdk::testutils::{Address as _, Ledger}; +use soroban_sdk::{Address, Env}; + +fn initialize_env<'a>() -> (Env, TokenInterfaceEventsClient<'a>, Address, [Address; 5]) { + let env = Env::default(); + let token_contract = TokenInterfaceEventsClient::new( + &env, + &env.register_contract(None, TokenInterfaceEvents {}), + ); + let admin = Address::generate(&env); + let decimals: u32 = 3; + let name: String = String::from_str(&env, "TestToken"); + let symbol: String = String::from_str(&env, "TTK"); + let users = [ + Address::generate(&env), + Address::generate(&env), + Address::generate(&env), + Address::generate(&env), + Address::generate(&env), + ]; + + token_contract.initialize(&admin, &decimals, &name, &symbol); + (env, token_contract, admin, users) +} +#[test] +fn test_init_token() { + let env = Env::default(); + let token_contract = TokenInterfaceEventsClient::new( + &env, + &env.register_contract(None, TokenInterfaceEvents {}), + ); + let admin = Address::generate(&env); + let decimals: u32 = 9; + let name: String = String::from_str(&env, "TestToken"); + let symbol: String = String::from_str(&env, "TTK"); + + token_contract.initialize(&admin, &decimals, &name, &symbol); + let token_metadata = token_contract.get_metadata(); + assert_eq!(token_metadata.admin, admin); + assert_eq!(token_metadata.decimals, 9); + assert_eq!(token_metadata.name, name); + assert_eq!(token_metadata.symbol, symbol); +} + +#[test] + +fn test_mint_token() { + let (env, token_contract, _admin, users) = initialize_env(); + env.mock_all_auths(); + + let mut balance = token_contract.balance(&users[0]); + assert_eq!(balance, 0); + token_contract.mint(&users[0], &100000); + + balance = token_contract.balance(&users[0]); + assert_eq!(balance, 100000); +} + +#[test] + +fn test_transfer_burn_token() { + let (env, token_contract, _admin, users) = initialize_env(); + env.mock_all_auths(); + token_contract.mint(&users[0], &100_000); + let previous_balance_user_0 = token_contract.balance(&users[0]); + let previous_balance_user_1 = token_contract.balance(&users[1]); + assert_eq!(previous_balance_user_0, 100_000); + assert_eq!(previous_balance_user_1, 0); + + let transfer_amount = 50_000; + token_contract.transfer(&users[0], &users[1], &transfer_amount); + let mut balance_user_0 = token_contract.balance(&users[0]); + let balance_user_1 = token_contract.balance(&users[1]); + assert_eq!(balance_user_0, previous_balance_user_0 - transfer_amount); + assert_eq!(balance_user_1, previous_balance_user_1 + transfer_amount); + + token_contract.burn(&users[0], &10_000); + balance_user_0 = token_contract.balance(&users[0]); + assert_eq!( + balance_user_0, + previous_balance_user_0 - transfer_amount - 10_000 + ); +} + +#[test] + +fn test_allowance() { + let (env, token_contract, _admin, users) = initialize_env(); + env.mock_all_auths(); + token_contract.mint(&users[0], &500_000); // 500 tokens (3 decimals) + let from = users[0].clone(); + let spender = users[1].clone(); + let to = users[2].clone(); + + let mut current_allowance = token_contract.allowance(&from, &spender); + assert_eq!(current_allowance, 0); + let allowance_amount = 100_000; + let expiration_ledger = 300; + token_contract.approve(&from, &spender, &allowance_amount, &expiration_ledger); + current_allowance = token_contract.allowance(&from, &spender); + assert_eq!(current_allowance, 100_000); + + let transfer_amount = 20_000; + // transfer from 20 tokens - sequence is still 0, allowance should be valid + token_contract.transfer_from(&spender, &from, &to, &transfer_amount); + + let mut from_balance = token_contract.balance(&from); + let to_balance = token_contract.balance(&to); + assert_eq!(from_balance, 500_000 - 20_000); + assert_eq!(to_balance, transfer_amount); + + current_allowance = token_contract.allowance(&from, &spender); + + assert_eq!(current_allowance, allowance_amount - transfer_amount); + + token_contract.burn_from(&spender, &from, &10_000); + from_balance = token_contract.balance(&from); + assert_eq!(from_balance, 470_000); + + current_allowance = token_contract.allowance(&from, &spender); + assert_eq!(current_allowance, 70_000); + + // advance time to verify allowance is now invalid + env.ledger().with_mut(|info| { + info.sequence_number += 500; + }); + + current_allowance = token_contract.allowance(&users[0], &users[1]); + assert_eq!(current_allowance, 0); +} + +#[should_panic] +#[test] + +fn test_no_allowance() { + let (env, token_contract, _admin, users) = initialize_env(); + env.mock_all_auths(); + token_contract.mint(&users[0], &500_000); // 500 tokens (3 decimals) + let from = users[0].clone(); + let spender = users[1].clone(); + + token_contract.burn_from(&spender, &from, &100_000); + + let from_balance = token_contract.balance(&from); + + assert_eq!(from_balance, 500_000); +} diff --git a/test-cases/unsafe-expect/unsafe-expect-1/remediated-example/src/lib.rs b/test-cases/unsafe-expect/unsafe-expect-1/remediated-example/src/lib.rs index 37dcc366..080f7932 100644 --- a/test-cases/unsafe-expect/unsafe-expect-1/remediated-example/src/lib.rs +++ b/test-cases/unsafe-expect/unsafe-expect-1/remediated-example/src/lib.rs @@ -31,9 +31,10 @@ impl UnsafeExpect { } // Returns the balance of a given account. - pub fn balance_of(env: Env, owner: Address) -> i128 { + pub fn balance_of(env: Env, owner: Address) -> Option { let state = Self::get_state(env); - state.balances.get(owner).unwrap_or(0) + let balance = state.balances.get(owner).unwrap_or(0); + Some(balance) } /// Return the current state. @@ -68,7 +69,7 @@ mod tests { // Then let balance = client.balance_of(&contract_id); - assert_eq!(TOTAL_SUPPLY, balance); + assert_eq!(TOTAL_SUPPLY, balance.unwrap()); } #[test] @@ -82,6 +83,6 @@ mod tests { // Then let balance = client.balance_of(&contract_id); - assert_eq!(0, balance); + assert_eq!(0, balance.unwrap()); } } diff --git a/test-cases/unsafe-expect/unsafe-expect-1/vulnerable-example/src/lib.rs b/test-cases/unsafe-expect/unsafe-expect-1/vulnerable-example/src/lib.rs index f3f14ad8..064ceb62 100644 --- a/test-cases/unsafe-expect/unsafe-expect-1/vulnerable-example/src/lib.rs +++ b/test-cases/unsafe-expect/unsafe-expect-1/vulnerable-example/src/lib.rs @@ -28,9 +28,10 @@ impl UnsafeExpect { } // Returns the balance of a given account. - pub fn balance_of(env: Env, owner: Address) -> i128 { + pub fn balance_of(env: Env, owner: Address) -> Option { let state = Self::get_state(env); - state.balances.get(owner).expect("could not get balance") + let balance = state.balances.get(owner).expect("could not get balance"); + Some(balance) } /// Return the current state. @@ -49,6 +50,8 @@ mod tests { use soroban_sdk::Env; + extern crate std; + use crate::{UnsafeExpect, UnsafeExpectClient, TOTAL_SUPPLY}; #[test] @@ -63,22 +66,6 @@ mod tests { // Then let balance = client.balance_of(&contract_id); - assert_eq!(TOTAL_SUPPLY, balance); - } - - #[test] - #[should_panic(expected = "could not get balance")] - fn balance_of_expect_works() { - // Given - let env = Env::default(); - let contract_id = env.register_contract(None, UnsafeExpect); - let client = UnsafeExpectClient::new(&env, &contract_id); - - // When - Balance not set - - // Then - let _balance = client.balance_of(&contract_id); - - // Test should panic + assert_eq!(TOTAL_SUPPLY, balance.unwrap()); } } diff --git a/test-cases/unsafe-expect/unsafe-expect-2/remediated-example/src/lib.rs b/test-cases/unsafe-expect/unsafe-expect-2/remediated-example/src/lib.rs index d487edd3..2343c64e 100644 --- a/test-cases/unsafe-expect/unsafe-expect-2/remediated-example/src/lib.rs +++ b/test-cases/unsafe-expect/unsafe-expect-2/remediated-example/src/lib.rs @@ -31,13 +31,11 @@ impl UnsafeExpect { } // Returns the balance of a given account. - pub fn balance_of(env: Env, owner: Address) -> i128 { + pub fn balance_of(env: Env, owner: Address) -> Option { let state = Self::get_state(env); let balance = state.balances.get(owner); - if balance.is_none() { - return 0; - } - balance.expect("could not get balance") + balance?; + Some(balance.expect("could not get balance")) } /// Return the current state. @@ -72,7 +70,7 @@ mod tests { // Then let balance = client.balance_of(&contract_id); - assert_eq!(TOTAL_SUPPLY, balance); + assert_eq!(TOTAL_SUPPLY, balance.unwrap()); } #[test] @@ -86,6 +84,6 @@ mod tests { // Then let balance = client.balance_of(&contract_id); - assert_eq!(0, balance); + assert_eq!(None, balance); } } diff --git a/test-cases/unsafe-expect/unsafe-expect-2/vulnerable-example/src/lib.rs b/test-cases/unsafe-expect/unsafe-expect-2/vulnerable-example/src/lib.rs index f3f14ad8..975d25d8 100644 --- a/test-cases/unsafe-expect/unsafe-expect-2/vulnerable-example/src/lib.rs +++ b/test-cases/unsafe-expect/unsafe-expect-2/vulnerable-example/src/lib.rs @@ -28,9 +28,10 @@ impl UnsafeExpect { } // Returns the balance of a given account. - pub fn balance_of(env: Env, owner: Address) -> i128 { + pub fn balance_of(env: Env, owner: Address) -> Option { let state = Self::get_state(env); - state.balances.get(owner).expect("could not get balance") + let balance = state.balances.get(owner).expect("could not get balance"); + Some(balance) } /// Return the current state. @@ -63,22 +64,6 @@ mod tests { // Then let balance = client.balance_of(&contract_id); - assert_eq!(TOTAL_SUPPLY, balance); - } - - #[test] - #[should_panic(expected = "could not get balance")] - fn balance_of_expect_works() { - // Given - let env = Env::default(); - let contract_id = env.register_contract(None, UnsafeExpect); - let client = UnsafeExpectClient::new(&env, &contract_id); - - // When - Balance not set - - // Then - let _balance = client.balance_of(&contract_id); - - // Test should panic + assert_eq!(TOTAL_SUPPLY, balance.unwrap()); } } diff --git a/test-cases/unsafe-expect/unsafe-expect-3/remediated-example/src/lib.rs b/test-cases/unsafe-expect/unsafe-expect-3/remediated-example/src/lib.rs index 10b9fb40..0468a16c 100644 --- a/test-cases/unsafe-expect/unsafe-expect-3/remediated-example/src/lib.rs +++ b/test-cases/unsafe-expect/unsafe-expect-3/remediated-example/src/lib.rs @@ -31,14 +31,14 @@ impl UnsafeExpect { } // Returns the balance of a given account. - pub fn balance_of(env: Env, owner: Address) -> i128 { + pub fn balance_of(env: Env, owner: Address) -> Option { let state = Self::get_state(env); let balance = state.balances.get(owner); let mut return_value = 0; if balance.is_some() { return_value = balance.expect("could not get balance"); } - return_value + Some(return_value) } /// Return the current state. @@ -73,7 +73,7 @@ mod tests { // Then let balance = client.balance_of(&contract_id); - assert_eq!(TOTAL_SUPPLY, balance); + assert_eq!(TOTAL_SUPPLY, balance.unwrap()); } #[test] @@ -87,6 +87,6 @@ mod tests { // Then let balance = client.balance_of(&contract_id); - assert_eq!(0, balance); + assert_eq!(0, balance.unwrap()); } } diff --git a/test-cases/unsafe-expect/unsafe-expect-3/vulnerable-example/src/lib.rs b/test-cases/unsafe-expect/unsafe-expect-3/vulnerable-example/src/lib.rs index f3f14ad8..975d25d8 100644 --- a/test-cases/unsafe-expect/unsafe-expect-3/vulnerable-example/src/lib.rs +++ b/test-cases/unsafe-expect/unsafe-expect-3/vulnerable-example/src/lib.rs @@ -28,9 +28,10 @@ impl UnsafeExpect { } // Returns the balance of a given account. - pub fn balance_of(env: Env, owner: Address) -> i128 { + pub fn balance_of(env: Env, owner: Address) -> Option { let state = Self::get_state(env); - state.balances.get(owner).expect("could not get balance") + let balance = state.balances.get(owner).expect("could not get balance"); + Some(balance) } /// Return the current state. @@ -63,22 +64,6 @@ mod tests { // Then let balance = client.balance_of(&contract_id); - assert_eq!(TOTAL_SUPPLY, balance); - } - - #[test] - #[should_panic(expected = "could not get balance")] - fn balance_of_expect_works() { - // Given - let env = Env::default(); - let contract_id = env.register_contract(None, UnsafeExpect); - let client = UnsafeExpectClient::new(&env, &contract_id); - - // When - Balance not set - - // Then - let _balance = client.balance_of(&contract_id); - - // Test should panic + assert_eq!(TOTAL_SUPPLY, balance.unwrap()); } } diff --git a/test-cases/unsafe-expect/unsafe-expect-4/remediated-example/src/lib.rs b/test-cases/unsafe-expect/unsafe-expect-4/remediated-example/src/lib.rs index 41ccefdd..7fb0f307 100644 --- a/test-cases/unsafe-expect/unsafe-expect-4/remediated-example/src/lib.rs +++ b/test-cases/unsafe-expect/unsafe-expect-4/remediated-example/src/lib.rs @@ -75,9 +75,9 @@ mod tests { client .mock_all_auths() .set_balance(&contract_id, &TOTAL_SUPPLY); - let balance = client.balance_of(&contract_id); // Then + let balance = client.balance_of(&contract_id); assert_eq!(TOTAL_SUPPLY, balance.0); assert_eq!(TOTAL_SUPPLY, balance.1); } @@ -90,9 +90,9 @@ mod tests { let client = UnsafeExpectClient::new(&env, &contract_id); // When - Balance not set - let balance = client.balance_of(&contract_id); // Then + let balance = client.balance_of(&contract_id); assert_eq!(0, balance.0); assert_eq!(0, balance.1); } diff --git a/test-cases/unsafe-expect/unsafe-expect-4/vulnerable-example/src/lib.rs b/test-cases/unsafe-expect/unsafe-expect-4/vulnerable-example/src/lib.rs index 7a98f717..542ca1d4 100644 --- a/test-cases/unsafe-expect/unsafe-expect-4/vulnerable-example/src/lib.rs +++ b/test-cases/unsafe-expect/unsafe-expect-4/vulnerable-example/src/lib.rs @@ -28,16 +28,17 @@ impl UnsafeExpect { } // Returns the balance of a given account. - pub fn balance_of(env: Env, owner: Address) -> (i128, i128) { + pub fn balance_of(env: Env, owner: Address) -> Option<(i128, i128)> { let state = Self::get_state(env); // For similarity with the remediated-example, we will return the same value twice. - ( + let balances = ( state .balances .get(owner.clone()) .expect("could not get balance"), state.balances.get(owner).expect("could not get balance"), - ) + ); + Some(balances) } /// Return the current state. @@ -67,26 +68,10 @@ mod tests { // When client.set_balance(&contract_id, &TOTAL_SUPPLY); - let balances = client.balance_of(&contract_id); // Then + let balances = client.balance_of(&contract_id).unwrap(); assert_eq!(TOTAL_SUPPLY, balances.0); assert_eq!(TOTAL_SUPPLY, balances.1); } - - #[test] - #[should_panic(expected = "could not get balance")] - fn balance_of_expect_works() { - // Given - let env = Env::default(); - let contract_id = env.register_contract(None, UnsafeExpect); - let client = UnsafeExpectClient::new(&env, &contract_id); - - // When - Balance not set - let _balance_1 = client.balance_of(&contract_id); - - // Then - - // Test should panic - } } diff --git a/test-cases/unsafe-expect/unsafe-expect-5/remediated-example/src/lib.rs b/test-cases/unsafe-expect/unsafe-expect-5/remediated-example/src/lib.rs index d3acce06..37159e23 100644 --- a/test-cases/unsafe-expect/unsafe-expect-5/remediated-example/src/lib.rs +++ b/test-cases/unsafe-expect/unsafe-expect-5/remediated-example/src/lib.rs @@ -28,17 +28,14 @@ impl UnsafeExpect { } // Returns the balance of a given account. - pub fn balance_add(env: Env, owner: Address) -> i128 { + pub fn balance_add(env: Env, owner: Address) -> Option { let state = Self::get_state(env); let balance = state.balances.get(owner.clone()); - if balance.is_none() { - return 0; - } + balance?; let balance_plus_2 = 2i128.checked_add(balance.expect("could not get balance")); - if balance_plus_2.is_none() { - return 0; - } - balance_plus_2.unwrap() + balance_plus_2?; + let balances = balance_plus_2.unwrap(); + Some(balances) } /// Return the current state. @@ -71,7 +68,7 @@ mod tests { let balances = client.balance_add(&contract_id); // Then - assert_eq!(TOTAL_SUPPLY + 2, balances); + assert_eq!(TOTAL_SUPPLY + 2, balances.unwrap()); } #[test] @@ -85,6 +82,6 @@ mod tests { let balance = client.balance_add(&contract_id); // Then - assert_eq!(0, balance); + assert_eq!(None, balance); } } diff --git a/test-cases/unsafe-expect/unsafe-expect-5/vulnerable-example/src/lib.rs b/test-cases/unsafe-expect/unsafe-expect-5/vulnerable-example/src/lib.rs index a15ec3c7..ae70acc0 100644 --- a/test-cases/unsafe-expect/unsafe-expect-5/vulnerable-example/src/lib.rs +++ b/test-cases/unsafe-expect/unsafe-expect-5/vulnerable-example/src/lib.rs @@ -28,14 +28,13 @@ impl UnsafeExpect { } // Returns the balance of a given account. - pub fn balance_add(env: Env, owner: Address) -> i128 { + pub fn balance_add(env: Env, owner: Address) -> Option { let state = Self::get_state(env); let balance = state.balances.get(owner.clone()); let balance_plus_2 = 2i128.checked_add(balance.expect("could not get balance")); - if balance_plus_2.is_none() { - return 0; - } - balance_plus_2.unwrap() + balance_plus_2?; + let balance = balance_plus_2.unwrap(); + Some(balance) } /// Return the current state. @@ -68,22 +67,6 @@ mod tests { let balances = client.balance_add(&contract_id); // Then - assert_eq!(TOTAL_SUPPLY + 2, balances); - } - - #[test] - #[should_panic(expected = "could not get balance")] - fn balance_of_expect_works() { - // Given - let env = Env::default(); - let contract_id = env.register_contract(None, UnsafeExpect); - let client = UnsafeExpectClient::new(&env, &contract_id); - - // When - Balance not set - let _balance_1 = client.balance_add(&contract_id); - - // Then - - // Test should panic + assert_eq!(TOTAL_SUPPLY + 2, balances.unwrap()); } } diff --git a/utils/src/lib.rs b/utils/src/lib.rs index d119cbad..7d1e5cf0 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -14,3 +14,6 @@ pub use constant_analyzer::*; mod type_utils; pub use type_utils::*; + +mod token_interface_utils; +pub use token_interface_utils::*; diff --git a/utils/src/soroban_utils/mod.rs b/utils/src/soroban_utils/mod.rs index 6b97c8e6..189ebaee 100644 --- a/utils/src/soroban_utils/mod.rs +++ b/utils/src/soroban_utils/mod.rs @@ -2,11 +2,11 @@ extern crate rustc_lint; extern crate rustc_middle; extern crate rustc_span; -use std::collections::HashSet; - +use crate::type_utils::match_type_to_str; use rustc_lint::LateContext; -use rustc_middle::ty::{Ty, TyKind}; +use rustc_middle::ty::Ty; use rustc_span::def_id::DefId; +use std::collections::HashSet; /// Constants defining the fully qualified names of Soroban types. const SOROBAN_ENV: &str = "soroban_sdk::Env"; @@ -58,28 +58,19 @@ pub fn is_soroban_function( .all(|pattern| checked_functions.contains(pattern)) } -// Private helper function to match soroban types -fn is_soroban_type(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str) -> bool { - match expr_type.kind() { - TyKind::Adt(adt_def, _) => cx.tcx.def_path_str(adt_def.did()).contains(type_str), - TyKind::Ref(_, ty, _) => is_soroban_type(cx, *ty, type_str), - _ => false, - } -} - /// Checks if the provided type is a Soroban environment (`soroban_sdk::Env`). pub fn is_soroban_env(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool { - is_soroban_type(cx, expr_type, SOROBAN_ENV) + match_type_to_str(cx, expr_type, SOROBAN_ENV) } /// Checks if the provided type is a Soroban Address (`soroban_sdk::Address`). pub fn is_soroban_address(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool { - is_soroban_type(cx, expr_type, SOROBAN_ADDRESS) + match_type_to_str(cx, expr_type, SOROBAN_ADDRESS) } /// Checks if the provided type is a Soroban Map (`soroban_sdk::Map`). pub fn is_soroban_map(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool { - is_soroban_type(cx, expr_type, SOROBAN_MAP) + match_type_to_str(cx, expr_type, SOROBAN_MAP) } pub enum SorobanStorageType { @@ -97,14 +88,16 @@ pub fn is_soroban_storage( ) -> bool { match storage_type { SorobanStorageType::Any => { - is_soroban_type(cx, expr_type, SOROBAN_INSTANCE_STORAGE) - || is_soroban_type(cx, expr_type, SOROBAN_TEMPORARY_STORAGE) - || is_soroban_type(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) + match_type_to_str(cx, expr_type, SOROBAN_INSTANCE_STORAGE) + || match_type_to_str(cx, expr_type, SOROBAN_TEMPORARY_STORAGE) + || match_type_to_str(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) + } + SorobanStorageType::Instance => match_type_to_str(cx, expr_type, SOROBAN_INSTANCE_STORAGE), + SorobanStorageType::Temporary => { + match_type_to_str(cx, expr_type, SOROBAN_TEMPORARY_STORAGE) } - SorobanStorageType::Instance => is_soroban_type(cx, expr_type, SOROBAN_INSTANCE_STORAGE), - SorobanStorageType::Temporary => is_soroban_type(cx, expr_type, SOROBAN_TEMPORARY_STORAGE), SorobanStorageType::Persistent => { - is_soroban_type(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) + match_type_to_str(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) } } } diff --git a/utils/src/token_interface_utils/mod.rs b/utils/src/token_interface_utils/mod.rs new file mode 100644 index 00000000..1cfb03f5 --- /dev/null +++ b/utils/src/token_interface_utils/mod.rs @@ -0,0 +1,115 @@ +extern crate rustc_hir; +extern crate rustc_middle; +extern crate rustc_span; + +use rustc_hir::FnRetTy; +use rustc_hir::QPath; +use rustc_hir::Ty; +use rustc_hir::TyKind; + +// Used to check if the parameters for a function match the data types for a specific token interface function. +pub fn check_params( + fn_params: &[Ty], + expected_types: Vec, + fn_return: FnRetTy, + expected_return: Option, +) -> bool { + let mut param_types: Vec = Vec::new(); + for i in fn_params { + if let TyKind::Path(QPath::Resolved(_, path)) = i.kind { + let param_type = path.segments[0].ident.to_string(); + param_types.push(param_type.clone()); + } + } + if expected_return.is_none() { + if let FnRetTy::DefaultReturn(_) = fn_return { + return param_types == expected_types; + } + } else { + if let FnRetTy::Return(ty) = fn_return { + if let TyKind::Path(qpath) = &ty.kind { + if let QPath::Resolved(_, path) = qpath { + if let Some(first_segment) = path.segments.first() { + return first_segment.ident.to_string() == expected_return.unwrap() + && param_types == expected_types; + } + } + } + } + } + false +} + +// Used to verify if a function matches a token interface standard function. +pub fn verify_token_interface_function( + fn_name: String, + fn_params: &[Ty], + fn_return: FnRetTy, +) -> bool { + let function = fn_name.split("::").last().unwrap(); + if function.eq("mint") { + return true; + } + let (types, expected_return): (Vec, Option) = match function { + "allowance" => ( + ["Env", "Address", "Address"] + .iter() + .map(|&s| s.to_string()) + .collect(), + Some("i128".to_string()), + ), + "approve" => ( + ["Env", "Address", "Address", "i128", "u32"] + .iter() + .map(|&s| s.to_string()) + .collect(), + None, + ), + "balance" => ( + ["Env", "Address"].iter().map(|&s| s.to_string()).collect(), + Some("i128".to_string()), + ), + "transfer" => ( + ["Env", "Address", "Address", "i128"] + .iter() + .map(|&s| s.to_string()) + .collect(), + None, + ), + "transfer_from" => ( + ["Env", "Address", "Address", "Address", "i128"] + .iter() + .map(|&s| s.to_string()) + .collect(), + None, + ), + "burn" => ( + ["Env", "Address", "i128"] + .iter() + .map(|&s| s.to_string()) + .collect(), + None, + ), + "burn_from" => ( + ["Env", "Address", "Address", "i128"] + .iter() + .map(|&s| s.to_string()) + .collect(), + None, + ), + "decimals" => ( + ["Env"].iter().map(|&s| s.to_string()).collect(), + Some("u32".to_string()), + ), + "name" => ( + ["Env"].iter().map(|&s| s.to_string()).collect(), + Some("String".to_string()), + ), + "symbol" => ( + ["Env"].iter().map(|&s| s.to_string()).collect(), + Some("String".to_string()), + ), + _ => return false, + }; + check_params(fn_params, types, fn_return, expected_return) +} diff --git a/utils/src/type_utils.rs b/utils/src/type_utils.rs index 942003b6..ab32ff22 100644 --- a/utils/src/type_utils.rs +++ b/utils/src/type_utils.rs @@ -1,12 +1,35 @@ extern crate rustc_hir; extern crate rustc_lint; extern crate rustc_middle; +extern crate rustc_span; -use rustc_hir::HirId; +use rustc_hir::{FnDecl, FnRetTy, HirId, QPath}; use rustc_lint::LateContext; -use rustc_middle::ty::Ty; +use rustc_middle::ty::{Ty, TyKind}; +use rustc_span::Symbol; /// Get the type of a node, if it exists. pub fn get_node_type_opt<'tcx>(cx: &LateContext<'tcx>, hir_id: &HirId) -> Option> { cx.typeck_results().node_type_opt(*hir_id) } + +/// Match the type of an expression to a string. +pub fn match_type_to_str(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str) -> bool { + match expr_type.kind() { + TyKind::Adt(adt_def, _) => cx.tcx.def_path_str(adt_def.did()).contains(type_str), + TyKind::Ref(_, ty, _) => match_type_to_str(cx, *ty, type_str), + _ => false, + } +} + +/// Check the return type of a function. +pub fn fn_returns(decl: &FnDecl<'_>, type_symbol: Symbol) -> bool { + if let FnRetTy::Return(ty) = decl.output { + matches!(ty.kind, rustc_hir::TyKind::Path(QPath::Resolved(_, path)) if path + .segments + .last() + .map_or(false, |seg| seg.ident.name == type_symbol)) + } else { + false + } +}