diff --git a/Cargo.lock b/Cargo.lock index d02c5d46..215d0450 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4500,7 +4500,7 @@ dependencies = [ [[package]] name = "scuffle-batching" -version = "0.0.3" +version = "0.0.4" dependencies = [ "criterion", "futures", diff --git a/Cargo.toml b/Cargo.toml index ea224b26..dc01c323 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ scuffle-signal = { path = "crates/signal", version = "0.0.2" } scuffle-http = { path = "crates/http", version = "0.0.4" } scuffle-metrics = { path = "crates/metrics", version = "0.0.4" } scuffle-pprof = { path = "crates/pprof", version = "0.0.2" } -scuffle-batching = { path = "crates/batching", version = "0.0.3" } +scuffle-batching = { path = "crates/batching", version = "0.0.4" } scuffle-postcompile = { path = "crates/postcompile", version = "0.0.5" } scuffle-ffmpeg = { path = "crates/ffmpeg", version = "0.0.2" } scuffle-h3-webtransport = { path = "crates/h3-webtransport", version = "0.0.2" } diff --git a/crates/batching/Cargo.toml b/crates/batching/Cargo.toml index 98c12f59..91d7c63a 100644 --- a/crates/batching/Cargo.toml +++ b/crates/batching/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scuffle-batching" -version = "0.0.3" +version = "0.0.4" edition = "2021" repository = "https://github.com/scufflecloud/scuffle" authors = ["Scuffle "] diff --git a/crates/batching/src/dataloader.rs b/crates/batching/src/dataloader.rs index 352aaa59..d45e2cad 100644 --- a/crates/batching/src/dataloader.rs +++ b/crates/batching/src/dataloader.rs @@ -161,7 +161,7 @@ where let mut count = 0; { - let mut new_batch = false; + let mut new_batch = true; let mut batch = self.current_batch.lock().await; for item in items { @@ -555,4 +555,25 @@ mod tests { assert!(start.elapsed() >= std::time::Duration::from_millis(5)); assert!(start.elapsed() < std::time::Duration::from_millis(20)); } + + #[tokio::test] + async fn already_batch() { + let requests = Arc::new(AtomicUsize::new(0)); + + let fetcher = TestFetcher { + values: HashMap::from_iter(vec![("a", 1), ("b", 2), ("c", 3)]), + delay: std::time::Duration::from_millis(5), + requests: requests.clone(), + capacity: 2, + }; + + let loader = DataLoader::builder().batch_size(10).concurrency(1).build(fetcher); + + let start = std::time::Instant::now(); + let (a, b) = tokio::join!(loader.load("a"), loader.load("b")); + assert_eq!(a, Ok(Some(1))); + assert_eq!(b, Ok(Some(2))); + assert!(start.elapsed() < std::time::Duration::from_millis(15)); + assert_eq!(requests.load(std::sync::atomic::Ordering::Relaxed), 1); + } }