From b44abd8cb19914f776d2fd5f122314cd67ca383f Mon Sep 17 00:00:00 2001 From: Troy Benson Date: Sun, 20 Oct 2024 17:54:16 +0000 Subject: [PATCH] test batcher --- foundations/src/batcher/mod.rs | 46 ++++++++++++++++---------- image-processor/src/management/grpc.rs | 6 +++- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/foundations/src/batcher/mod.rs b/foundations/src/batcher/mod.rs index 75d342b64..5b271a451 100644 --- a/foundations/src/batcher/mod.rs +++ b/foundations/src/batcher/mod.rs @@ -5,7 +5,6 @@ use std::sync::atomic::{AtomicU64, AtomicUsize}; use std::sync::Arc; use tokio::sync::OnceCell; -use tracing::Instrument; pub mod dataloader; @@ -225,7 +224,7 @@ impl Drop for CancelOnDrop { } struct BatcherInner { - semaphore: tokio::sync::Semaphore, + semaphore: Arc, notify: tokio::sync::Notify, sleep_duration: AtomicU64, batch_id: AtomicU64, @@ -233,6 +232,7 @@ struct BatcherInner { operation: T, name: String, active_batch: tokio::sync::RwLock>>, + queued_batches: tokio::sync::mpsc::Sender>, } struct Batch { @@ -288,18 +288,14 @@ impl From for BatcherError { impl Batch { #[tracing::instrument(skip_all, fields(name = %inner.name))] - async fn run(self, inner: Arc>) { + async fn run(self, inner: Arc>, ticket: tokio::sync::OwnedSemaphorePermit) { self.results .get_or_init(|| async move { - let _ticket = inner - .semaphore - .acquire() - .instrument(tracing::debug_span!("Semaphore")) - .await - .map_err(|_| BatcherError::AcquireSemaphore)?; - Ok(inner.operation.process(self.ops).await.map_err(BatcherError::Batch)?) + inner.operation.process(self.ops).await.map_err(BatcherError::Batch) }) .await; + + drop(ticket); } } @@ -312,8 +308,8 @@ pub struct BatcherConfig { } impl BatcherInner { - fn spawn_batch(self: &Arc, batch: Batch) { - tokio::spawn(batch.run(self.clone())); + fn spawn_batch(self: &Arc, batch: Batch, ticket: tokio::sync::OwnedSemaphorePermit) { + tokio::spawn(batch.run(self.clone(), ticket)); } fn new_batch(&self) -> Batch { @@ -342,7 +338,7 @@ impl BatcherInner { .unwrap_or(true) { if let Some(b) = batch.take() { - self.spawn_batch(b); + self.queued_batches.send(b).await.ok(); } *batch = Some(self.new_batch()); @@ -374,8 +370,11 @@ impl Batcher { pub fn new(operation: T) -> Self { let config = operation.config(); + let (tx, mut rx) = tokio::sync::mpsc::channel(64); + let inner = Arc::new(BatcherInner { - semaphore: tokio::sync::Semaphore::new(config.concurrency), + semaphore: Arc::new(tokio::sync::Semaphore::new(config.concurrency)), + queued_batches: tx.clone(), notify: tokio::sync::Notify::new(), batch_id: AtomicU64::new(0), active_batch: tokio::sync::RwLock::new(None), @@ -390,6 +389,13 @@ impl Batcher { _auto_loader_abort: CancelOnDrop( tokio::task::spawn(async move { loop { + tokio::select! { + Some(batch) = rx.recv() => { + let ticket = inner.semaphore.clone().acquire_owned().await.unwrap(); + inner.spawn_batch(batch, ticket); + }, + _ = inner.notify.notified() => {}, + } inner.notify.notified().await; let Some((id, expires_at)) = inner.active_batch.read().await.as_ref().map(|b| (b.id, b.expires_at)) else { @@ -399,11 +405,15 @@ impl Batcher { if expires_at > tokio::time::Instant::now() { tokio::time::sleep_until(expires_at).await; } - + let mut batch = inner.active_batch.write().await; - if batch.as_ref().is_some_and(|b| b.id == id) { - inner.spawn_batch(batch.take().unwrap()); - } + let batch = if batch.as_ref().is_some_and(|b| b.id == id) { + batch.take().unwrap() + } else { + continue; + }; + + tx.send(batch).await.ok(); } }) .abort_handle(), diff --git a/image-processor/src/management/grpc.rs b/image-processor/src/management/grpc.rs index 46590d214..d3b5801e0 100644 --- a/image-processor/src/management/grpc.rs +++ b/image-processor/src/management/grpc.rs @@ -7,7 +7,11 @@ impl ManagementServer { #[tracing::instrument(skip_all)] pub async fn run_grpc(&self, addr: std::net::SocketAddr) -> Result<(), tonic::transport::Error> { let server = tonic::transport::Server::builder() - .add_service(scuffle_image_processor_proto::image_processor_server::ImageProcessorServer::new(self.clone())) + .add_service( + scuffle_image_processor_proto::image_processor_server::ImageProcessorServer::new(self.clone()) + .max_decoding_message_size(128 * 1024 * 1024) + .max_encoding_message_size(128 * 1024 * 1024) + ) .serve_with_shutdown(addr, scuffle_foundations::context::Context::global().into_done()); tracing::info!("gRPC management server listening on {}", addr);