diff --git a/foundations/src/batcher/mod.rs b/foundations/src/batcher/mod.rs index b53e2b828..75d342b64 100644 --- a/foundations/src/batcher/mod.rs +++ b/foundations/src/batcher/mod.rs @@ -5,6 +5,7 @@ use std::sync::atomic::{AtomicU64, AtomicUsize}; use std::sync::Arc; use tokio::sync::OnceCell; +use tracing::Instrument; pub mod dataloader; @@ -224,7 +225,7 @@ impl Drop for CancelOnDrop { } struct BatcherInner { - semaphore: Arc, + semaphore: tokio::sync::Semaphore, notify: tokio::sync::Notify, sleep_duration: AtomicU64, batch_id: AtomicU64, @@ -232,7 +233,6 @@ struct BatcherInner { operation: T, name: String, active_batch: tokio::sync::RwLock>>, - queued_batches: tokio::sync::mpsc::Sender>, } struct Batch { @@ -288,14 +288,18 @@ impl From for BatcherError { impl Batch { #[tracing::instrument(skip_all, fields(name = %inner.name))] - async fn run(self, inner: Arc>, ticket: tokio::sync::OwnedSemaphorePermit) { + async fn run(self, inner: Arc>) { self.results .get_or_init(|| async move { - inner.operation.process(self.ops).await.map_err(BatcherError::Batch) + let _ticket = inner + .semaphore + .acquire() + .instrument(tracing::debug_span!("Semaphore")) + .await + .map_err(|_| BatcherError::AcquireSemaphore)?; + Ok(inner.operation.process(self.ops).await.map_err(BatcherError::Batch)?) }) .await; - - drop(ticket); } } @@ -308,8 +312,8 @@ pub struct BatcherConfig { } impl BatcherInner { - fn spawn_batch(self: &Arc, batch: Batch, ticket: tokio::sync::OwnedSemaphorePermit) { - tokio::spawn(batch.run(self.clone(), ticket)); + fn spawn_batch(self: &Arc, batch: Batch) { + tokio::spawn(batch.run(self.clone())); } fn new_batch(&self) -> Batch { @@ -330,7 +334,6 @@ impl BatcherInner { let mut waiters = vec![]; let mut batch = self.active_batch.write().await; let max_documents = self.max_batch_size.load(std::sync::atomic::Ordering::Relaxed); - let mut batches = vec![]; for document in T::Mode::filter_item_iter(documents) { if batch @@ -339,7 +342,7 @@ impl BatcherInner { .unwrap_or(true) { if let Some(b) = batch.take() { - batches.push(b); + self.spawn_batch(b); } *batch = Some(self.new_batch()); @@ -363,10 +366,6 @@ impl BatcherInner { T::Mode::input_add(&mut b.ops, tracker, document); } - for batch in batches { - self.queued_batches.send(batch).await.ok(); - } - waiters } } @@ -375,11 +374,8 @@ impl Batcher { pub fn new(operation: T) -> Self { let config = operation.config(); - let (tx, mut rx) = tokio::sync::mpsc::channel(64); - let inner = Arc::new(BatcherInner { - semaphore: Arc::new(tokio::sync::Semaphore::new(config.concurrency)), - queued_batches: tx.clone(), + semaphore: tokio::sync::Semaphore::new(config.concurrency), notify: tokio::sync::Notify::new(), batch_id: AtomicU64::new(0), active_batch: tokio::sync::RwLock::new(None), @@ -393,43 +389,21 @@ impl Batcher { inner: inner.clone(), _auto_loader_abort: CancelOnDrop( tokio::task::spawn(async move { - let mut next_wakeup = None; loop { - tokio::select! { - Some(batch) = rx.recv() => { - let ticket = inner.semaphore.clone().acquire_owned().await.unwrap(); - inner.spawn_batch(batch, ticket); - }, - _ = async { - if let Some(expires_at) = next_wakeup { - tokio::time::sleep_until(expires_at).await; - } else { - inner.notify.notified().await; - } - } => {}, - _ = inner.notify.notified() => {}, - } - + inner.notify.notified().await; let Some((id, expires_at)) = inner.active_batch.read().await.as_ref().map(|b| (b.id, b.expires_at)) else { continue; }; if expires_at > tokio::time::Instant::now() { - next_wakeup = Some(expires_at); - continue; - } else { - next_wakeup = None; + tokio::time::sleep_until(expires_at).await; } - - let mut batch = inner.active_batch.write().await; - let batch = if batch.as_ref().is_some_and(|b| b.id == id) { - batch.take().unwrap() - } else { - continue; - }; - tx.send(batch).await.ok(); + let mut batch = inner.active_batch.write().await; + if batch.as_ref().is_some_and(|b| b.id == id) { + inner.spawn_batch(batch.take().unwrap()); + } } }) .abort_handle(),