diff --git a/utils/src/soroban_utils/mod.rs b/utils/src/soroban_utils/mod.rs index 6b97c8e6..189ebaee 100644 --- a/utils/src/soroban_utils/mod.rs +++ b/utils/src/soroban_utils/mod.rs @@ -2,11 +2,11 @@ extern crate rustc_lint; extern crate rustc_middle; extern crate rustc_span; -use std::collections::HashSet; - +use crate::type_utils::match_type_to_str; use rustc_lint::LateContext; -use rustc_middle::ty::{Ty, TyKind}; +use rustc_middle::ty::Ty; use rustc_span::def_id::DefId; +use std::collections::HashSet; /// Constants defining the fully qualified names of Soroban types. const SOROBAN_ENV: &str = "soroban_sdk::Env"; @@ -58,28 +58,19 @@ pub fn is_soroban_function( .all(|pattern| checked_functions.contains(pattern)) } -// Private helper function to match soroban types -fn is_soroban_type(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str) -> bool { - match expr_type.kind() { - TyKind::Adt(adt_def, _) => cx.tcx.def_path_str(adt_def.did()).contains(type_str), - TyKind::Ref(_, ty, _) => is_soroban_type(cx, *ty, type_str), - _ => false, - } -} - /// Checks if the provided type is a Soroban environment (`soroban_sdk::Env`). pub fn is_soroban_env(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool { - is_soroban_type(cx, expr_type, SOROBAN_ENV) + match_type_to_str(cx, expr_type, SOROBAN_ENV) } /// Checks if the provided type is a Soroban Address (`soroban_sdk::Address`). pub fn is_soroban_address(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool { - is_soroban_type(cx, expr_type, SOROBAN_ADDRESS) + match_type_to_str(cx, expr_type, SOROBAN_ADDRESS) } /// Checks if the provided type is a Soroban Map (`soroban_sdk::Map`). pub fn is_soroban_map(cx: &LateContext<'_>, expr_type: Ty<'_>) -> bool { - is_soroban_type(cx, expr_type, SOROBAN_MAP) + match_type_to_str(cx, expr_type, SOROBAN_MAP) } pub enum SorobanStorageType { @@ -97,14 +88,16 @@ pub fn is_soroban_storage( ) -> bool { match storage_type { SorobanStorageType::Any => { - is_soroban_type(cx, expr_type, SOROBAN_INSTANCE_STORAGE) - || is_soroban_type(cx, expr_type, SOROBAN_TEMPORARY_STORAGE) - || is_soroban_type(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) + match_type_to_str(cx, expr_type, SOROBAN_INSTANCE_STORAGE) + || match_type_to_str(cx, expr_type, SOROBAN_TEMPORARY_STORAGE) + || match_type_to_str(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) + } + SorobanStorageType::Instance => match_type_to_str(cx, expr_type, SOROBAN_INSTANCE_STORAGE), + SorobanStorageType::Temporary => { + match_type_to_str(cx, expr_type, SOROBAN_TEMPORARY_STORAGE) } - SorobanStorageType::Instance => is_soroban_type(cx, expr_type, SOROBAN_INSTANCE_STORAGE), - SorobanStorageType::Temporary => is_soroban_type(cx, expr_type, SOROBAN_TEMPORARY_STORAGE), SorobanStorageType::Persistent => { - is_soroban_type(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) + match_type_to_str(cx, expr_type, SOROBAN_PERSISTENT_STORAGE) } } } diff --git a/utils/src/type_utils.rs b/utils/src/type_utils.rs index 942003b6..3a4a4b9a 100644 --- a/utils/src/type_utils.rs +++ b/utils/src/type_utils.rs @@ -1,12 +1,35 @@ extern crate rustc_hir; extern crate rustc_lint; extern crate rustc_middle; +extern crate rustc_span; -use rustc_hir::HirId; +use rustc_hir::{FnDecl, FnRetTy, HirId, QPath}; use rustc_lint::LateContext; -use rustc_middle::ty::Ty; +use rustc_middle::ty::{Ty, TyKind}; +use rustc_span::sym; /// Get the type of a node, if it exists. pub fn get_node_type_opt<'tcx>(cx: &LateContext<'tcx>, hir_id: &HirId) -> Option> { cx.typeck_results().node_type_opt(*hir_id) } + +/// Match the type of an expression to a string. +pub fn match_type_to_str(cx: &LateContext<'_>, expr_type: Ty<'_>, type_str: &str) -> bool { + match expr_type.kind() { + TyKind::Adt(adt_def, _) => cx.tcx.def_path_str(adt_def.did()).contains(type_str), + TyKind::Ref(_, ty, _) => match_type_to_str(cx, *ty, type_str), + _ => false, + } +} + +/// Check if a function returns a Result type. +pub fn returns_result(decl: &FnDecl<'_>) -> bool { + if let FnRetTy::Return(ty) = decl.output { + matches!(ty.kind, rustc_hir::TyKind::Path(QPath::Resolved(_, path)) if path + .segments + .last() + .map_or(false, |seg| seg.ident.name == sym::Result)) + } else { + false + } +}