diff --git a/omniqueue/src/backends/gcp_pubsub.rs b/omniqueue/src/backends/gcp_pubsub.rs index 16227b4..7435ce5 100644 --- a/omniqueue/src/backends/gcp_pubsub.rs +++ b/omniqueue/src/backends/gcp_pubsub.rs @@ -15,6 +15,7 @@ use google_cloud_pubsub::subscription::Subscription; use serde::Serialize; use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::time::Duration; use std::{any::TypeId, collections::HashMap}; pub struct GcpPubSubBackend; @@ -217,31 +218,56 @@ async fn subscription(client: &Client, subscription_id: &str) -> Result Result { - let subscription = subscription(&self.client, &self.subscription_id).await?; - let mut stream = subscription - .subscribe(None) - .await - .map_err(QueueError::generic)?; - - let mut recv_msg = stream.next().await.ok_or_else(|| QueueError::NoData)?; +impl GcpPubSubConsumer { + fn wrap_recv_msg(&self, mut recv_msg: ReceivedMessage) -> Delivery { // FIXME: would be nice to avoid having to move the data out here. // While it's possible to ack via a subscription and an ack_id, nack is only // possible via a `ReceiveMessage`. This means we either need to hold 2 copies of // the payload, or move the bytes out so they can be returned _outside of the Acker_. let payload = recv_msg.message.data.drain(..).collect(); - Ok(Delivery { + + Delivery { decoders: self.registry.clone(), acker: Box::new(GcpPubSubAcker { recv_msg, subscription_id: self.subscription_id.clone(), }), payload: Some(payload), - }) + } + } +} + +#[async_trait] +impl QueueConsumer for GcpPubSubConsumer { + type Payload = Payload; + + async fn receive(&mut self) -> Result { + let subscription = subscription(&self.client, &self.subscription_id).await?; + let mut stream = subscription + .subscribe(None) + .await + .map_err(QueueError::generic)?; + + let recv_msg = stream.next().await.ok_or_else(|| QueueError::NoData)?; + + Ok(self.wrap_recv_msg(recv_msg)) + } + + async fn receive_all( + &mut self, + max_messages: usize, + deadline: Duration, + ) -> Result, QueueError> { + let subscription = subscription(&self.client, &self.subscription_id).await?; + match tokio::time::timeout(deadline, subscription.pull(max_messages as _, None)).await { + Ok(messages) => Ok(messages + .map_err(QueueError::generic)? + .into_iter() + .map(|m| self.wrap_recv_msg(m)) + .collect()), + // Timeout + Err(_) => Ok(vec![]), + } } } diff --git a/omniqueue/src/backends/memory_queue.rs b/omniqueue/src/backends/memory_queue.rs index 0733805..438940a 100644 --- a/omniqueue/src/backends/memory_queue.rs +++ b/omniqueue/src/backends/memory_queue.rs @@ -1,3 +1,4 @@ +use std::time::{Duration, Instant}; use std::{any::TypeId, collections::HashMap}; use async_trait::async_trait; @@ -90,14 +91,9 @@ pub struct MemoryQueueConsumer { tx: broadcast::Sender>, } -#[async_trait] -impl QueueConsumer for MemoryQueueConsumer { - type Payload = Vec; - - async fn receive(&mut self) -> Result { - let payload = self.rx.recv().await.map_err(QueueError::generic)?; - - Ok(Delivery { +impl MemoryQueueConsumer { + fn wrap_payload(&self, payload: Vec) -> Delivery { + Delivery { payload: Some(payload.clone()), decoders: self.registry.clone(), acker: Box::new(MemoryQueueAcker { @@ -105,7 +101,43 @@ impl QueueConsumer for MemoryQueueConsumer { payload_copy: Some(payload), alredy_acked_or_nacked: false, }), - }) + } + } +} + +#[async_trait] +impl QueueConsumer for MemoryQueueConsumer { + type Payload = Vec; + + async fn receive(&mut self) -> Result { + let payload = self.rx.recv().await.map_err(QueueError::generic)?; + Ok(self.wrap_payload(payload)) + } + + async fn receive_all( + &mut self, + max_messages: usize, + deadline: Duration, + ) -> Result, QueueError> { + let mut out = Vec::with_capacity(max_messages); + let start = Instant::now(); + match tokio::time::timeout(deadline, self.rx.recv()).await { + Ok(Ok(x)) => out.push(self.wrap_payload(x)), + // Timeouts and stream termination + Err(_) | Ok(Err(_)) => return Ok(out), + } + + if max_messages > 1 { + // `try_recv` will break the loop if no ready items are already buffered in the channel. + // This should allow us to opportunistically fill up the buffer in the remaining time. + while let Ok(x) = self.rx.try_recv() { + out.push(self.wrap_payload(x)); + if out.len() >= max_messages || start.elapsed() >= deadline { + break; + } + } + } + Ok(out) } } @@ -146,6 +178,7 @@ impl Acker for MemoryQueueAcker { #[cfg(test)] mod tests { use serde::{Deserialize, Serialize}; + use std::time::{Duration, Instant}; use crate::{ queue::{consumer::QueueConsumer, producer::QueueProducer, QueueBuilder}, @@ -233,4 +266,133 @@ mod tests { TypeA { a: 12 } ); } + + #[derive(Debug, Deserialize, Serialize, PartialEq)] + pub struct ExType { + a: u8, + } + + /// Consumer will return immediately if there are fewer than max messages to start with. + #[tokio::test] + async fn test_send_recv_all_partial() { + let payload = ExType { a: 2 }; + + let (p, mut c) = QueueBuilder::::new(16) + .build_pair() + .await + .unwrap(); + + p.send_serde_json(&payload).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 1); + let d = xs.remove(0); + assert_eq!(d.payload_serde_json::().unwrap().unwrap(), payload); + d.ack().await.unwrap(); + assert!(now.elapsed() <= deadline); + } + + /// Consumer should yield items immediately if there's a full batch ready on the first poll. + #[tokio::test] + async fn test_send_recv_all_full() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + + let (p, mut c) = QueueBuilder::::new(16) + .build_pair() + .await + .unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + // N.b. it's still possible this could turn up false if the test runs too slow. + assert!(now.elapsed() < deadline); + } + + /// Consumer will return the full batch immediately, but also return immediately if a partial batch is ready. + #[tokio::test] + async fn test_send_recv_all_full_then_partial() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + let payload3 = ExType { a: 3 }; + + let (p, mut c) = QueueBuilder::::new(16) + .build_pair() + .await + .unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + p.send_serde_json(&payload3).await.unwrap(); + + let deadline = Duration::from_secs(1); + let now1 = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + assert!(now1.elapsed() < deadline); + + // 2nd call + let now2 = Instant::now(); + let mut ys = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(ys.len(), 1); + let d3 = ys.remove(0); + assert_eq!( + d3.payload_serde_json::().unwrap().unwrap(), + payload3 + ); + d3.ack().await.unwrap(); + assert!(now2.elapsed() <= deadline); + } + + /// Consumer will NOT wait indefinitely for at least one item. + #[tokio::test] + async fn test_send_recv_all_late_arriving_items() { + let (_p, mut c) = QueueBuilder::::new(16) + .build_pair() + .await + .unwrap(); + + let deadline = Duration::from_secs(1); + let now = Instant::now(); + let xs = c.receive_all(2, deadline).await.unwrap(); + let elapsed = now.elapsed(); + + assert_eq!(xs.len(), 0); + // Elapsed should be around the deadline, ballpark + assert!(elapsed >= deadline); + assert!(elapsed <= deadline + Duration::from_millis(200)); + } } diff --git a/omniqueue/src/backends/rabbitmq.rs b/omniqueue/src/backends/rabbitmq.rs index 059d0d2..7b2788b 100644 --- a/omniqueue/src/backends/rabbitmq.rs +++ b/omniqueue/src/backends/rabbitmq.rs @@ -1,13 +1,17 @@ +use std::time::{Duration, Instant}; use std::{any::TypeId, collections::HashMap}; use async_trait::async_trait; use futures::StreamExt; -use lapin::{acker::Acker as LapinAcker, Channel, Connection, Consumer}; - +use futures_util::FutureExt; pub use lapin::{ - options::{BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions}, + acker::Acker as LapinAcker, + options::{ + BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions, + BasicQosOptions, + }, types::FieldTable, - BasicProperties, ConnectionProperties, + BasicProperties, Channel, Connection, ConnectionProperties, Consumer, }; use crate::{ @@ -17,6 +21,7 @@ use crate::{ QueueError, }; +#[derive(Clone)] pub struct RabbitMqConfig { pub uri: String, pub connection_properties: ConnectionProperties, @@ -24,18 +29,64 @@ pub struct RabbitMqConfig { pub publish_exchange: String, pub publish_routing_key: String, pub publish_options: BasicPublishOptions, - pub publish_properites: BasicProperties, + pub publish_properties: BasicProperties, pub consume_queue: String, pub consumer_tag: String, pub consume_options: BasicConsumeOptions, pub consume_arguments: FieldTable, + pub consume_prefetch_count: Option, pub requeue_on_nack: bool, } pub struct RabbitMqBackend; +async fn consumer( + conn: &Connection, + cfg: RabbitMqConfig, + custom_decoders: DecoderRegistry>, +) -> Result { + let channel_rx = conn.create_channel().await.map_err(QueueError::generic)?; + + if let Some(n) = cfg.consume_prefetch_count { + channel_rx + .basic_qos(n, BasicQosOptions::default()) + .await + .map_err(QueueError::generic)?; + } + + Ok(RabbitMqConsumer { + registry: custom_decoders, + consumer: channel_rx + .basic_consume( + &cfg.consume_queue, + &cfg.consumer_tag, + cfg.consume_options, + cfg.consume_arguments.clone(), + ) + .await + .map_err(QueueError::generic)?, + requeue_on_nack: cfg.requeue_on_nack, + }) +} + +async fn producer( + conn: &Connection, + cfg: RabbitMqConfig, + custom_encoders: EncoderRegistry>, +) -> Result { + let channel_tx = conn.create_channel().await.map_err(QueueError::generic)?; + Ok(RabbitMqProducer { + registry: custom_encoders, + channel: channel_tx, + exchange: cfg.publish_exchange.clone(), + routing_key: cfg.publish_routing_key.clone(), + options: cfg.publish_options, + properties: cfg.publish_properties.clone(), + }) +} + #[async_trait] impl QueueBackend for RabbitMqBackend { type Config = RabbitMqConfig; @@ -51,35 +102,13 @@ impl QueueBackend for RabbitMqBackend { custom_encoders: EncoderRegistry>, custom_decoders: DecoderRegistry>, ) -> Result<(RabbitMqProducer, RabbitMqConsumer), QueueError> { - let conn = Connection::connect(&cfg.uri, cfg.connection_properties) + let conn = Connection::connect(&cfg.uri, cfg.connection_properties.clone()) .await .map_err(QueueError::generic)?; - let channel_tx = conn.create_channel().await.map_err(QueueError::generic)?; - let channel_rx = conn.create_channel().await.map_err(QueueError::generic)?; - Ok(( - RabbitMqProducer { - registry: custom_encoders, - channel: channel_tx, - exchange: cfg.publish_exchange, - routing_key: cfg.publish_routing_key, - options: cfg.publish_options, - properties: cfg.publish_properites, - }, - RabbitMqConsumer { - registry: custom_decoders, - consumer: channel_rx - .basic_consume( - &cfg.consume_queue, - &cfg.consumer_tag, - cfg.consume_options, - cfg.consume_arguments, - ) - .await - .map_err(QueueError::generic)?, - requeue_on_nack: cfg.requeue_on_nack, - }, + producer(&conn, cfg.clone(), custom_encoders).await?, + consumer(&conn, cfg.clone(), custom_decoders).await?, )) } @@ -87,45 +116,22 @@ impl QueueBackend for RabbitMqBackend { cfg: RabbitMqConfig, custom_encoders: EncoderRegistry>, ) -> Result { - let conn = Connection::connect(&cfg.uri, cfg.connection_properties) + let conn = Connection::connect(&cfg.uri, cfg.connection_properties.clone()) .await .map_err(QueueError::generic)?; - let channel_tx = conn.create_channel().await.map_err(QueueError::generic)?; - - Ok(RabbitMqProducer { - registry: custom_encoders, - channel: channel_tx, - exchange: cfg.publish_exchange, - routing_key: cfg.publish_routing_key, - options: cfg.publish_options, - properties: cfg.publish_properites, - }) + Ok(producer(&conn, cfg.clone(), custom_encoders).await?) } async fn consuming_half( cfg: RabbitMqConfig, custom_decoders: DecoderRegistry>, ) -> Result { - let conn = Connection::connect(&cfg.uri, cfg.connection_properties) + let conn = Connection::connect(&cfg.uri, cfg.connection_properties.clone()) .await .map_err(QueueError::generic)?; - let channel_rx = conn.create_channel().await.map_err(QueueError::generic)?; - - Ok(RabbitMqConsumer { - registry: custom_decoders, - consumer: channel_rx - .basic_consume( - &cfg.consume_queue, - &cfg.consumer_tag, - cfg.consume_options, - cfg.consume_arguments, - ) - .await - .map_err(QueueError::generic)?, - requeue_on_nack: cfg.requeue_on_nack, - }) + Ok(consumer(&conn, cfg.clone(), custom_decoders).await?) } } @@ -168,6 +174,19 @@ pub struct RabbitMqConsumer { requeue_on_nack: bool, } +impl RabbitMqConsumer { + fn wrap_delivery(&self, delivery: lapin::message::Delivery) -> Delivery { + Delivery { + decoders: self.registry.clone(), + payload: Some(delivery.data), + acker: Box::new(RabbitMqAcker { + acker: Some(delivery.acker), + requeue_on_nack: self.requeue_on_nack, + }), + } + } +} + #[async_trait] impl QueueConsumer for RabbitMqConsumer { type Payload = Vec; @@ -178,19 +197,43 @@ impl QueueConsumer for RabbitMqConsumer { .clone() .map(|l: Result| { let l = l.map_err(QueueError::generic)?; - - Ok(Delivery { - decoders: self.registry.clone(), - payload: Some(l.data), - acker: Box::new(RabbitMqAcker { - acker: Some(l.acker), - requeue_on_nack: self.requeue_on_nack, - }), - }) + Ok(self.wrap_delivery(l)) }); stream.next().await.ok_or(QueueError::NoData)? } + + async fn receive_all( + &mut self, + max_messages: usize, + deadline: Duration, + ) -> Result, QueueError> { + let mut stream = self.consumer.clone().map( + |l: Result| -> Result { + let l = l.map_err(QueueError::generic)?; + Ok(self.wrap_delivery(l)) + }, + ); + let start = Instant::now(); + let mut out = Vec::with_capacity(max_messages); + match tokio::time::timeout(deadline, stream.next()).await { + Ok(Some(x)) => out.push(x?), + // Timeouts and stream termination + Err(_) | Ok(None) => return Ok(out), + } + + if max_messages > 1 { + // `now_or_never` will break the loop if no ready items are already buffered in the stream. + // This should allow us to opportunistically fill up the buffer in the remaining time. + while let Some(Some(x)) = stream.next().now_or_never() { + out.push(x?); + if out.len() >= max_messages || start.elapsed() >= deadline { + break; + } + } + } + Ok(out) + } } pub struct RabbitMqAcker { diff --git a/omniqueue/src/backends/redis/mod.rs b/omniqueue/src/backends/redis/mod.rs index 8b48aa5..7762060 100644 --- a/omniqueue/src/backends/redis/mod.rs +++ b/omniqueue/src/backends/redis/mod.rs @@ -1,9 +1,10 @@ +use std::time::Duration; use std::{any::TypeId, collections::HashMap, marker::PhantomData}; use async_trait::async_trait; use bb8::ManageConnection; pub use bb8_redis::RedisMultiplexedConnectionManager; -use redis::streams::{StreamReadOptions, StreamReadReply}; +use redis::streams::{StreamId, StreamReadOptions, StreamReadReply}; use crate::{ decoding::DecoderRegistry, @@ -219,6 +220,31 @@ pub struct RedisStreamConsumer { payload_key: String, } +impl RedisStreamConsumer +where + M: ManageConnection, + M::Connection: redis::aio::ConnectionLike + Send + Sync, + M::Error: 'static + std::error::Error + Send + Sync, +{ + fn wrap_entry(&self, entry: StreamId) -> Result { + let entry_id = entry.id.clone(); + let payload = entry.map.get(&self.payload_key).ok_or(QueueError::NoData)?; + let payload: Vec = redis::from_redis_value(payload).map_err(QueueError::generic)?; + + Ok(Delivery { + payload: Some(payload), + acker: Box::new(RedisStreamAcker { + redis: self.redis.clone(), + queue_key: self.queue_key.clone(), + consumer_group: self.consumer_group.clone(), + entry_id, + already_acked_or_nacked: false, + }), + decoders: self.registry.clone(), + }) + } +} + #[async_trait] impl QueueConsumer for RedisStreamConsumer where @@ -247,21 +273,40 @@ where let queue = read_out.keys.into_iter().next().ok_or(QueueError::NoData)?; let entry = queue.ids.into_iter().next().ok_or(QueueError::NoData)?; + self.wrap_entry(entry) + } - let entry_id = entry.id.clone(); - let payload = entry.map.get(&self.payload_key).ok_or(QueueError::NoData)?; - let payload: Vec = redis::from_redis_value(payload).map_err(QueueError::generic)?; + async fn receive_all( + &mut self, + max_messages: usize, + deadline: Duration, + ) -> Result, QueueError> { + let mut conn = self.redis.get().await.map_err(QueueError::generic)?; - Ok(Delivery { - payload: Some(payload), - acker: Box::new(RedisStreamAcker { - redis: self.redis.clone(), - queue_key: self.queue_key.clone(), - consumer_group: self.consumer_group.clone(), - entry_id, - already_acked_or_nacked: false, - }), - decoders: self.registry.clone(), - }) + let read_out: StreamReadReply = redis::Cmd::xread_options( + &[&self.queue_key], + &[">"], + &StreamReadOptions::default() + .group(&self.consumer_group, &self.consumer_name) + .block( + deadline + .as_millis() + .try_into() + .map_err(QueueError::generic)?, + ) + .count(max_messages), + ) + .query_async(&mut *conn) + .await + .map_err(QueueError::generic)?; + + let mut out = Vec::with_capacity(max_messages); + + if let Some(queue) = read_out.keys.into_iter().next() { + for entry in queue.ids { + out.push(self.wrap_entry(entry)?); + } + } + Ok(out) } } diff --git a/omniqueue/src/backends/sqs.rs b/omniqueue/src/backends/sqs.rs index 1a142e1..5f79ea0 100644 --- a/omniqueue/src/backends/sqs.rs +++ b/omniqueue/src/backends/sqs.rs @@ -1,6 +1,8 @@ +use std::time::Duration; use std::{any::TypeId, collections::HashMap, sync::Arc}; use async_trait::async_trait; +use aws_sdk_sqs::types::Message; use aws_sdk_sqs::{ operation::delete_message::DeleteMessageError, types::error::ReceiptHandleIsInvalid, Client, }; @@ -240,6 +242,21 @@ pub struct SqsQueueConsumer { queue_dsn: String, } +impl SqsQueueConsumer { + fn wrap_message(&self, message: &Message) -> Delivery { + Delivery { + decoders: self.bytes_registry.clone(), + acker: Box::new(SqsAcker { + ack_client: self.client.clone(), + queue_dsn: self.queue_dsn.clone(), + receipt_handle: message.receipt_handle().map(ToOwned::to_owned), + has_been_acked_or_nacked: false, + }), + payload: Some(message.body().unwrap_or_default().as_bytes().to_owned()), + } + } +} + #[async_trait] impl QueueConsumer for SqsQueueConsumer { type Payload = String; @@ -257,19 +274,33 @@ impl QueueConsumer for SqsQueueConsumer { out.messages() .unwrap_or_default() .iter() - .map(|message| -> Result { - Ok(Delivery { - decoders: self.bytes_registry.clone(), - acker: Box::new(SqsAcker { - ack_client: self.client.clone(), - queue_dsn: self.queue_dsn.clone(), - receipt_handle: message.receipt_handle().map(ToOwned::to_owned), - has_been_acked_or_nacked: false, - }), - payload: Some(message.body().unwrap_or_default().as_bytes().to_owned()), - }) - }) + .map(|message| -> Result { Ok(self.wrap_message(message)) }) .next() .ok_or(QueueError::NoData)? } + + async fn receive_all( + &mut self, + max_messages: usize, + deadline: Duration, + ) -> Result, QueueError> { + let out = self + .client + .receive_message() + .set_wait_time_seconds(Some( + deadline.as_secs().try_into().map_err(QueueError::generic)?, + )) + .set_max_number_of_messages(Some(max_messages.try_into().map_err(QueueError::generic)?)) + .queue_url(&self.queue_dsn) + .send() + .await + .map_err(QueueError::generic)?; + + Ok(out + .messages() + .unwrap_or_default() + .iter() + .map(|message| -> Result { Ok(self.wrap_message(message)) }) + .collect::, _>>()?) + } } diff --git a/omniqueue/src/lib.rs b/omniqueue/src/lib.rs index 309b6fc..ba88db5 100644 --- a/omniqueue/src/lib.rs +++ b/omniqueue/src/lib.rs @@ -129,6 +129,8 @@ pub enum QueueError { #[error("{0}")] Generic(Box), + #[error("{0}")] + Unsupported(&'static str), } impl QueueError { diff --git a/omniqueue/src/queue/consumer.rs b/omniqueue/src/queue/consumer.rs index 890b750..4f0bab3 100644 --- a/omniqueue/src/queue/consumer.rs +++ b/omniqueue/src/queue/consumer.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use std::time::Duration; use crate::{decoding::DecoderRegistry, QueueError, QueuePayload}; @@ -10,6 +11,12 @@ pub trait QueueConsumer: Send + Sync { async fn receive(&mut self) -> Result; + async fn receive_all( + &mut self, + max_messages: usize, + deadline: Duration, + ) -> Result, QueueError>; + fn into_dyn(self, custom_decoders: DecoderRegistry>) -> DynConsumer where Self: 'static + Sized, @@ -48,6 +55,28 @@ impl> QueueConsumer }) } + async fn receive_all( + &mut self, + max_messages: usize, + deadline: Duration, + ) -> Result, QueueError> { + let xs = self.inner.receive_all(max_messages, deadline).await?; + let mut out = Vec::with_capacity(xs.len()); + for mut t_payload in xs { + let bytes_payload: Option> = match t_payload.payload_custom() { + Ok(b) => b, + Err(QueueError::NoDecoderForThisType) => t_payload.take_payload(), + Err(e) => return Err(e), + }; + out.push(Delivery { + payload: bytes_payload, + decoders: self.custom_decoders.clone(), + acker: t_payload.acker, + }); + } + Ok(out) + } + fn into_dyn(mut self, custom_decoders: DecoderRegistry>) -> DynConsumer where Self: Sized, @@ -67,6 +96,14 @@ impl QueueConsumer for DynConsumer { self.0.receive().await } + async fn receive_all( + &mut self, + max_messages: usize, + deadline: Duration, + ) -> Result, QueueError> { + self.0.receive_all(max_messages, deadline).await + } + fn into_dyn(self, _custom_decoders: DecoderRegistry>) -> DynConsumer where Self: Sized, diff --git a/omniqueue/tests/gcp_pubsub.rs b/omniqueue/tests/gcp_pubsub.rs index 6a06316..51022e9 100644 --- a/omniqueue/tests/gcp_pubsub.rs +++ b/omniqueue/tests/gcp_pubsub.rs @@ -54,6 +54,7 @@ use google_cloud_googleapis::pubsub::v1::DeadLetterPolicy; use google_cloud_pubsub::client::{Client, ClientConfig}; use google_cloud_pubsub::subscription::SubscriptionConfig; +use std::time::{Duration, Instant}; use omniqueue::backends::gcp_pubsub::{GcpPubSubBackend, GcpPubSubConfig}; use omniqueue::queue::{ @@ -196,3 +197,112 @@ async fn test_custom_send_recv() { d.payload_serde_json::().unwrap_err(); d.ack().await.unwrap(); } + +/// Consumer will return immediately if there are fewer than max messages to start with. +#[tokio::test] +async fn test_send_recv_all_partial() { + let payload = ExType { a: 2 }; + let (p, mut c) = make_test_queue().await.build_pair().await.unwrap(); + + p.send_serde_json(&payload).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 1); + let d = xs.remove(0); + assert_eq!(d.payload_serde_json::().unwrap().unwrap(), payload); + d.ack().await.unwrap(); + assert!(now.elapsed() <= deadline); +} + +/// Consumer should yield items immediately if there's a full batch ready on the first poll. +#[tokio::test] +async fn test_send_recv_all_full() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + let (p, mut c) = make_test_queue().await.build_pair().await.unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + // N.b. it's still possible this could turn up false if the test runs too slow. + assert!(now.elapsed() < deadline); +} + +/// Consumer will return the full batch immediately, but also return immediately if a partial batch is ready. +#[tokio::test] +async fn test_send_recv_all_full_then_partial() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + let payload3 = ExType { a: 3 }; + let (p, mut c) = make_test_queue().await.build_pair().await.unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + p.send_serde_json(&payload3).await.unwrap(); + + let deadline = Duration::from_secs(1); + let now1 = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + assert!(now1.elapsed() < deadline); + + // 2nd call + let now2 = Instant::now(); + let mut ys = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(ys.len(), 1); + let d3 = ys.remove(0); + assert_eq!( + d3.payload_serde_json::().unwrap().unwrap(), + payload3 + ); + d3.ack().await.unwrap(); + assert!(now2.elapsed() <= deadline); +} + +/// Consumer will NOT wait indefinitely for at least one item. +#[tokio::test] +async fn test_send_recv_all_late_arriving_items() { + let (_p, mut c) = make_test_queue().await.build_pair().await.unwrap(); + + let deadline = Duration::from_secs(1); + let now = Instant::now(); + let xs = c.receive_all(2, deadline).await.unwrap(); + let elapsed = now.elapsed(); + + assert_eq!(xs.len(), 0); + // Elapsed should be around the deadline, ballpark + assert!(elapsed >= deadline); + assert!(elapsed <= deadline + Duration::from_millis(200)); +} diff --git a/omniqueue/tests/rabbitmq.rs b/omniqueue/tests/rabbitmq.rs index 1c81d09..8a7efbe 100644 --- a/omniqueue/tests/rabbitmq.rs +++ b/omniqueue/tests/rabbitmq.rs @@ -8,6 +8,7 @@ use omniqueue::{ queue::{consumer::QueueConsumer, producer::QueueProducer, QueueBackend, QueueBuilder, Static}, }; use serde::{Deserialize, Serialize}; +use std::time::{Duration, Instant}; const MQ_URI: &str = "amqp://guest:guest@localhost:5672/%2f"; @@ -16,7 +17,10 @@ const MQ_URI: &str = "amqp://guest:guest@localhost:5672/%2f"; /// /// Additionally this will make a temporary queue on that instance for the duration of the test such /// as to ensure there is no stealing.w -async fn make_test_queue(reinsert_on_nack: bool) -> QueueBuilder { +async fn make_test_queue( + prefetch_count: Option, + reinsert_on_nack: bool, +) -> QueueBuilder { let options = ConnectionProperties::default() .with_connection_name( std::iter::repeat_with(fastrand::alphanumeric) @@ -51,11 +55,12 @@ async fn make_test_queue(reinsert_on_nack: bool) -> QueueBuilder QueueBuilder().unwrap_err(); d.ack().await.unwrap(); } + +/// Consumer will return immediately if there are fewer than max messages to start with. +#[tokio::test] +async fn test_send_recv_all_partial() { + let payload = ExType { a: 2 }; + let (p, mut c) = make_test_queue(None, false) + .await + .build_pair() + .await + .unwrap(); + + p.send_serde_json(&payload).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 1); + let d = xs.remove(0); + assert_eq!(d.payload_serde_json::().unwrap().unwrap(), payload); + d.ack().await.unwrap(); + assert!(now.elapsed() <= deadline); +} + +/// Consumer should yield items immediately if there's a full batch ready on the first poll. +#[tokio::test] +async fn test_send_recv_all_full() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + let (p, mut c) = make_test_queue(None, false) + .await + .build_pair() + .await + .unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + + // XXX: rabbit's receive_all impl relies on stream items to be in a ready state in order for + // them to be batched together. Sleeping to help them settle before we poll. + tokio::time::sleep(Duration::from_millis(100)).await; + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + // N.b. it's still possible this could turn up false if the test runs too slow. + assert!(now.elapsed() < deadline); +} + +/// Consumer will return the full batch immediately, but also return immediately if a partial batch is ready. +#[tokio::test] +async fn test_send_recv_all_full_then_partial() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + let payload3 = ExType { a: 3 }; + let (p, mut c) = make_test_queue(None, false) + .await + .build_pair() + .await + .unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + p.send_serde_json(&payload3).await.unwrap(); + + // XXX: rabbit's receive_all impl relies on stream items to be in a ready state in order for + // them to be batched together. Sleeping to help them settle before we poll. + tokio::time::sleep(Duration::from_millis(100)).await; + + let deadline = Duration::from_secs(1); + let now1 = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + assert!(now1.elapsed() < deadline); + + // 2nd call + let now2 = Instant::now(); + let mut ys = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(ys.len(), 1); + let d3 = ys.remove(0); + assert_eq!( + d3.payload_serde_json::().unwrap().unwrap(), + payload3 + ); + d3.ack().await.unwrap(); + assert!(now2.elapsed() <= deadline); +} + +/// Consumer will NOT wait indefinitely for at least one item. +#[tokio::test] +async fn test_send_recv_all_late_arriving_items() { + let (_p, mut c) = make_test_queue(None, false) + .await + .build_pair() + .await + .unwrap(); + + let deadline = Duration::from_secs(1); + let now = Instant::now(); + let xs = c.receive_all(2, deadline).await.unwrap(); + let elapsed = now.elapsed(); + + assert_eq!(xs.len(), 0); + // Elapsed should be around the deadline, ballpark + assert!(elapsed >= deadline); + assert!(elapsed <= deadline + Duration::from_millis(200)); +} diff --git a/omniqueue/tests/redis.rs b/omniqueue/tests/redis.rs index 88369f2..a33edb5 100644 --- a/omniqueue/tests/redis.rs +++ b/omniqueue/tests/redis.rs @@ -4,6 +4,7 @@ use omniqueue::{ }; use redis::{AsyncCommands, Client, Commands}; use serde::{Deserialize, Serialize}; +use std::time::{Duration, Instant}; const ROOT_URL: &str = "redis://localhost"; @@ -123,3 +124,122 @@ async fn test_custom_send_recv() { d.payload_serde_json::().unwrap_err(); d.ack().await.unwrap(); } + +/// Consumer will return immediately if there are fewer than max messages to start with. +#[tokio::test] +async fn test_send_recv_all_partial() { + let (builder, _drop) = make_test_queue().await; + + let payload = ExType { a: 2 }; + let (p, mut c) = builder.build_pair().await.unwrap(); + + p.send_serde_json(&payload).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 1); + let d = xs.remove(0); + assert_eq!(d.payload_serde_json::().unwrap().unwrap(), payload); + d.ack().await.unwrap(); + assert!(now.elapsed() <= deadline); +} + +/// Consumer should yield items immediately if there's a full batch ready on the first poll. +#[tokio::test] +async fn test_send_recv_all_full() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + + let (builder, _drop) = make_test_queue().await; + + let (p, mut c) = builder.build_pair().await.unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + // N.b. it's still possible this could turn up false if the test runs too slow. + assert!(now.elapsed() < deadline); +} + +/// Consumer will return the full batch immediately, but also return immediately if a partial batch is ready. +#[tokio::test] +async fn test_send_recv_all_full_then_partial() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + let payload3 = ExType { a: 3 }; + + let (builder, _drop) = make_test_queue().await; + + let (p, mut c) = builder.build_pair().await.unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + p.send_serde_json(&payload3).await.unwrap(); + + let deadline = Duration::from_secs(1); + let now1 = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + assert!(now1.elapsed() < deadline); + + // 2nd call + let now2 = Instant::now(); + let mut ys = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(ys.len(), 1); + let d3 = ys.remove(0); + assert_eq!( + d3.payload_serde_json::().unwrap().unwrap(), + payload3 + ); + d3.ack().await.unwrap(); + assert!(now2.elapsed() < deadline); +} + +/// Consumer will NOT wait indefinitely for at least one item. +#[tokio::test] +async fn test_send_recv_all_late_arriving_items() { + let (builder, _drop) = make_test_queue().await; + + let (_p, mut c) = builder.build_pair().await.unwrap(); + + let deadline = Duration::from_secs(1); + let now = Instant::now(); + let xs = c.receive_all(2, deadline).await.unwrap(); + let elapsed = now.elapsed(); + + assert_eq!(xs.len(), 0); + // Elapsed should be around the deadline, ballpark + assert!(elapsed >= deadline); + assert!(elapsed <= deadline + Duration::from_millis(200)); +} diff --git a/omniqueue/tests/redis_cluster.rs b/omniqueue/tests/redis_cluster.rs index cadc55f..0f0f0b4 100644 --- a/omniqueue/tests/redis_cluster.rs +++ b/omniqueue/tests/redis_cluster.rs @@ -4,6 +4,7 @@ use omniqueue::{ }; use redis::{cluster::ClusterClient, AsyncCommands, Commands}; use serde::{Deserialize, Serialize}; +use std::time::{Duration, Instant}; const ROOT_URL: &str = "redis://localhost:6380"; @@ -126,3 +127,122 @@ async fn test_custom_send_recv() { d.payload_serde_json::().unwrap_err(); d.ack().await.unwrap(); } + +/// Consumer will return immediately if there are fewer than max messages to start with. +#[tokio::test] +async fn test_send_recv_all_partial() { + let (builder, _drop) = make_test_queue().await; + + let payload = ExType { a: 2 }; + let (p, mut c) = builder.build_pair().await.unwrap(); + + p.send_serde_json(&payload).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 1); + let d = xs.remove(0); + assert_eq!(d.payload_serde_json::().unwrap().unwrap(), payload); + d.ack().await.unwrap(); + assert!(now.elapsed() <= deadline); +} + +/// Consumer should yield items immediately if there's a full batch ready on the first poll. +#[tokio::test] +async fn test_send_recv_all_full() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + + let (builder, _drop) = make_test_queue().await; + + let (p, mut c) = builder.build_pair().await.unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + // N.b. it's still possible this could turn up false if the test runs too slow. + assert!(now.elapsed() < deadline); +} + +/// Consumer will return the full batch immediately, but also return immediately if a partial batch is ready. +#[tokio::test] +async fn test_send_recv_all_full_then_partial() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + let payload3 = ExType { a: 3 }; + + let (builder, _drop) = make_test_queue().await; + + let (p, mut c) = builder.build_pair().await.unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + p.send_serde_json(&payload3).await.unwrap(); + + let deadline = Duration::from_secs(1); + let now1 = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + assert!(now1.elapsed() < deadline); + + // 2nd call + let now2 = Instant::now(); + let mut ys = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(ys.len(), 1); + let d3 = ys.remove(0); + assert_eq!( + d3.payload_serde_json::().unwrap().unwrap(), + payload3 + ); + d3.ack().await.unwrap(); + assert!(now2.elapsed() < deadline); +} + +/// Consumer will NOT wait indefinitely for at least one item. +#[tokio::test] +async fn test_send_recv_all_late_arriving_items() { + let (builder, _drop) = make_test_queue().await; + + let (_p, mut c) = builder.build_pair().await.unwrap(); + + let deadline = Duration::from_secs(1); + let now = Instant::now(); + let xs = c.receive_all(2, deadline).await.unwrap(); + let elapsed = now.elapsed(); + + assert_eq!(xs.len(), 0); + // Elapsed should be around the deadline, ballpark + assert!(elapsed >= deadline); + assert!(elapsed <= deadline + Duration::from_millis(200)); +} diff --git a/omniqueue/tests/sqs.rs b/omniqueue/tests/sqs.rs index d6fef06..f696a4e 100644 --- a/omniqueue/tests/sqs.rs +++ b/omniqueue/tests/sqs.rs @@ -4,6 +4,7 @@ use omniqueue::{ queue::{consumer::QueueConsumer, producer::QueueProducer, QueueBackend, QueueBuilder, Static}, }; use serde::{Deserialize, Serialize}; +use std::time::{Duration, Instant}; const ROOT_URL: &str = "http://localhost:9324"; const DEFAULT_CFG: [(&str, &str); 3] = [ @@ -113,3 +114,112 @@ async fn test_custom_send_recv() { d.payload_serde_json::().unwrap_err(); d.ack().await.unwrap(); } + +/// Consumer will return immediately if there are fewer than max messages to start with. +#[tokio::test] +async fn test_send_recv_all_partial() { + let payload = ExType { a: 2 }; + let (p, mut c) = make_test_queue().await.build_pair().await.unwrap(); + + p.send_serde_json(&payload).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 1); + let d = xs.remove(0); + assert_eq!(d.payload_serde_json::().unwrap().unwrap(), payload); + d.ack().await.unwrap(); + assert!(now.elapsed() <= deadline); +} + +/// Consumer should yield items immediately if there's a full batch ready on the first poll. +#[tokio::test] +async fn test_send_recv_all_full() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + let (p, mut c) = make_test_queue().await.build_pair().await.unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + let deadline = Duration::from_secs(1); + + let now = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + // N.b. it's still possible this could turn up false if the test runs too slow. + assert!(now.elapsed() < deadline); +} + +/// Consumer will return the full batch immediately, but also return immediately if a partial batch is ready. +#[tokio::test] +async fn test_send_recv_all_full_then_partial() { + let payload1 = ExType { a: 1 }; + let payload2 = ExType { a: 2 }; + let payload3 = ExType { a: 3 }; + let (p, mut c) = make_test_queue().await.build_pair().await.unwrap(); + + p.send_serde_json(&payload1).await.unwrap(); + p.send_serde_json(&payload2).await.unwrap(); + p.send_serde_json(&payload3).await.unwrap(); + + let deadline = Duration::from_secs(1); + let now1 = Instant::now(); + let mut xs = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(xs.len(), 2); + let d1 = xs.remove(0); + assert_eq!( + d1.payload_serde_json::().unwrap().unwrap(), + payload1 + ); + d1.ack().await.unwrap(); + + let d2 = xs.remove(0); + assert_eq!( + d2.payload_serde_json::().unwrap().unwrap(), + payload2 + ); + d2.ack().await.unwrap(); + assert!(now1.elapsed() < deadline); + + // 2nd call + let now2 = Instant::now(); + let mut ys = c.receive_all(2, deadline).await.unwrap(); + assert_eq!(ys.len(), 1); + let d3 = ys.remove(0); + assert_eq!( + d3.payload_serde_json::().unwrap().unwrap(), + payload3 + ); + d3.ack().await.unwrap(); + assert!(now2.elapsed() < deadline); +} + +/// Consumer will NOT wait indefinitely for at least one item. +#[tokio::test] +async fn test_send_recv_all_late_arriving_items() { + let (_p, mut c) = make_test_queue().await.build_pair().await.unwrap(); + + let deadline = Duration::from_secs(1); + let now = Instant::now(); + let xs = c.receive_all(2, deadline).await.unwrap(); + let elapsed = now.elapsed(); + + assert_eq!(xs.len(), 0); + // Elapsed should be around the deadline, ballpark + assert!(elapsed >= deadline); + assert!(elapsed <= deadline + Duration::from_millis(200)); +}