diff --git a/src/eth/storage/hybrid/hybrid_state.rs b/src/eth/storage/hybrid/hybrid_state.rs index d3b7a93ec..0ca75f92c 100644 --- a/src/eth/storage/hybrid/hybrid_state.rs +++ b/src/eth/storage/hybrid/hybrid_state.rs @@ -1,6 +1,5 @@ use core::fmt; use std::sync::Arc; -use tokio::join; use anyhow::Context; use indexmap::IndexMap; @@ -9,7 +8,7 @@ use sqlx::types::BigDecimal; use sqlx::FromRow; use sqlx::Pool; use sqlx::Postgres; -use tokio::sync::Mutex; +use tokio::join; use super::rocks_db::RocksDb; use crate::eth::primitives::Account; diff --git a/src/eth/storage/hybrid/mod.rs b/src/eth/storage/hybrid/mod.rs index 70e1e7237..d0efbb3ef 100644 --- a/src/eth/storage/hybrid/mod.rs +++ b/src/eth/storage/hybrid/mod.rs @@ -6,7 +6,6 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; -use anyhow::anyhow; use anyhow::Context; use async_trait::async_trait; use itertools::Itertools; @@ -20,11 +19,11 @@ use sqlx::QueryBuilder; use sqlx::Row; use tokio::sync::mpsc; use tokio::sync::mpsc::channel; -use tokio::sync::Mutex; -use tokio::sync::MutexGuard; use tokio::sync::RwLock; use tokio::sync::RwLockReadGuard; use tokio::sync::RwLockWriteGuard; +use tokio::sync::Semaphore; +use tokio::sync::SemaphorePermit; use tokio::task::JoinSet; use self::hybrid_state::HybridStorageState; @@ -62,7 +61,7 @@ pub struct HybridPermanentStorage { pool: Arc>, block_number: AtomicU64, task_sender: mpsc::Sender, - tasks_pending: Arc>, + tasks_pending: Arc, // TODO change to Mutex<()> } #[derive(Debug)] @@ -86,7 +85,7 @@ impl HybridPermanentStorage { tracing::error!(reason = ?e, "failed to start postgres client"); anyhow::anyhow!("failed to start postgres client") })?; - let tasks_pending = Arc::new(Mutex::new(())); + let tasks_pending = Arc::new(Semaphore::new(1)); let worker_tasks_pending = Arc::clone(&tasks_pending); let (task_sender, task_receiver) = channel::(32); let worker_pool = Arc::new(connection_pool.clone()); @@ -124,30 +123,22 @@ impl HybridPermanentStorage { mut receiver: tokio::sync::mpsc::Receiver, pool: Arc>, connections: u32, - tasks_pending: Arc>, + tasks_pending: Arc, ) { // Define the maximum number of concurrent tasks. Adjust this number based on your requirements. let max_concurrent_tasks: usize = (connections).try_into().unwrap_or(10usize); tracing::info!("Starting worker with max_concurrent_tasks: {}", max_concurrent_tasks); let mut futures = JoinSet::new(); - let mut pending_tasks_guard = None; - while let Ok(block_task_opt) = recv_block_task(&mut receiver, &mut pending_tasks_guard, !futures.is_empty()).await { - if let Some(block_task) = block_task_opt { - let pool_clone = Arc::clone(&pool); - - if futures.len() < max_concurrent_tasks { - futures.spawn(query_executor::commit_eventually(pool_clone, block_task)); - - if pending_tasks_guard.is_none() { - pending_tasks_guard = Some(tasks_pending.lock().await); - } - } else if let Some(_res) = futures.join_next().await { - futures.spawn(query_executor::commit_eventually(pool_clone, block_task)); + let mut permit = None; + while let Some(block_task) = recv_block_task(&mut receiver, &mut permit).await { + let pool_clone = Arc::clone(&pool); + if futures.len() < max_concurrent_tasks { + futures.spawn(query_executor::commit_eventually(pool_clone, block_task)); + if permit.is_none() { + permit = Some(tasks_pending.acquire().await.expect("semaphore has closed")); } - } else { - let timeout = Duration::from_millis(100); - tokio::time::sleep(timeout).await; - while futures.try_join_next().is_some() {} + } else if let Some(_res) = futures.join_next().await { + futures.spawn(query_executor::commit_eventually(pool_clone, block_task)); } } } @@ -467,7 +458,7 @@ impl PermanentStorage for HybridPermanentStorage { accounts_changes.5.push(account.code_hash); } - let _ = self.tasks_pending.lock().await; + let _ = self.tasks_pending.acquire().await.expect("semaphore has closed"); sqlx::query!( "INSERT INTO public.neo_accounts (block_number, address, bytecode, balance, nonce, code_hash) SELECT * FROM UNNEST($1::bigint[], $2::bytea[], $3::bytea[], $4::numeric[], $5::numeric[], $6::bytea[]) @@ -505,7 +496,7 @@ impl PermanentStorage for HybridPermanentStorage { state.transactions.retain(|_, t| t.block_number <= block_number); state.logs.retain(|_, l| l.block_number <= block_number); - let _ = self.tasks_pending.lock().await; + let _ = self.tasks_pending.acquire().await.expect("semaphore has closed"); sqlx::query!(r#"DELETE FROM neo_blocks WHERE block_number > $1"#, block_number as _) .execute(&*self.pool) @@ -535,29 +526,16 @@ impl PermanentStorage for HybridPermanentStorage { } } -/// This function blocks if the mpsc is empty AND either: -/// 1. We have the pending_tasks_guard and there are no tasks pending -/// 2. We don't have the pending_tasks_guard -/// Otherwise this function is non-blocking until we can finish the pending tasks and release the lock. -async fn recv_block_task( - receiver: &mut tokio::sync::mpsc::Receiver, - pending_tasks_guard: &mut Option>, - pending_tasks: bool, -) -> anyhow::Result> { +async fn recv_block_task(receiver: &mut tokio::sync::mpsc::Receiver, permit: &mut Option>) -> Option { match receiver.try_recv() { - Ok(block_task) => Ok(Some(block_task)), - Err(mpsc::error::TryRecvError::Empty) => - if pending_tasks_guard.is_some() { - if !pending_tasks { - let guard = std::mem::take(pending_tasks_guard); - drop(guard); - Ok(receiver.recv().await) - } else { - Ok(None) - } - } else { - Ok(receiver.recv().await) - }, - Err(mpsc::error::TryRecvError::Disconnected) => Err(anyhow!(mpsc::error::TryRecvError::Disconnected)), + Ok(block_task) => Some(block_task), + Err(mpsc::error::TryRecvError::Empty) => { + if permit.is_some() { + let perm = std::mem::take(permit); + drop(perm); + } + receiver.recv().await + } + Err(_) => None, } }