Skip to content

Commit

Permalink
add Consumer::receive_all
Browse files Browse the repository at this point in the history
Support for batched reads is not uniform among the various backends we
have.
In some cases, partial batches will wait longer or shorter than the
specified deadline.

The odd one out right now is RabbitMQ which supports configuring a
channel with a pre-fetch limit, but you still have to just consume
messages one-by-one (the batching is all under the hood). Additionally,
there doesn't seem to be a way to specify a timeout or deadline
(apparently this is left up to individual clients to decide what makes
the most sense).

As such, RabbitMQ is under-served compared to the rest. The
implementation for batch reads is entirely done in the client code (i.e.
reading messages one-by-one, buffering them, then returning).

One possible solution is to add a max prefetch config option and use
this when initializing a consumer channel, but this would also mean the
`max_messages` argument would be meaningless for RabbitMQ, and using
`receive_all` at all would probably be irrelevant. Callers should
probably just use `receive` instead.

For now, I've left some comments in the rabbit impl and will leave this
as future work.
  • Loading branch information
svix-onelson committed Oct 10, 2023
1 parent 2479040 commit b7d6e51
Show file tree
Hide file tree
Showing 11 changed files with 961 additions and 59 deletions.
73 changes: 59 additions & 14 deletions omniqueue/src/backends/gcp_pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use google_cloud_pubsub::subscription::Subscription;
use serde::Serialize;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use std::{any::TypeId, collections::HashMap};

pub struct GcpPubSubBackend;
Expand Down Expand Up @@ -217,31 +218,75 @@ async fn subscription(client: &Client, subscription_id: &str) -> Result<Subscrip
Ok(subscription)
}

#[async_trait]
impl QueueConsumer for GcpPubSubConsumer {
type Payload = Payload;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let subscription = subscription(&self.client, &self.subscription_id).await?;
let mut stream = subscription
.subscribe(None)
.await
.map_err(QueueError::generic)?;

let mut recv_msg = stream.next().await.ok_or_else(|| QueueError::NoData)?;
impl GcpPubSubConsumer {
fn wrap_recv_msg(&self, mut recv_msg: ReceivedMessage) -> Delivery {
// FIXME: would be nice to avoid having to move the data out here.
// While it's possible to ack via a subscription and an ack_id, nack is only
// possible via a `ReceiveMessage`. This means we either need to hold 2 copies of
// the payload, or move the bytes out so they can be returned _outside of the Acker_.
let payload = recv_msg.message.data.drain(..).collect();
Ok(Delivery {

Delivery {
decoders: self.registry.clone(),
acker: Box::new(GcpPubSubAcker {
recv_msg,
subscription_id: self.subscription_id.clone(),
}),
payload: Some(payload),
})
}
}
}

#[async_trait]
impl QueueConsumer for GcpPubSubConsumer {
type Payload = Payload;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let subscription = subscription(&self.client, &self.subscription_id).await?;
let mut stream = subscription
.subscribe(None)
.await
.map_err(QueueError::generic)?;

let recv_msg = stream.next().await.ok_or_else(|| QueueError::NoData)?;

Ok(self.wrap_recv_msg(recv_msg))
}

async fn receive_all(
&mut self,
max_messages: usize,
deadline: Duration,
) -> Result<Vec<Delivery>, QueueError> {
let subscription = subscription(&self.client, &self.subscription_id).await?;

let mut out = Vec::with_capacity(max_messages);

if let Ok(messages) = subscription.pull(max_messages as _, None).await {
out.extend(messages.into_iter().map(|m| self.wrap_recv_msg(m)));
if out.len() >= max_messages {
return Ok(out);
}

let mut interval = tokio::time::interval(deadline);
interval.tick().await;

loop {
tokio::select! {
_ = interval.tick() => break,
messages = subscription.pull(max_messages.saturating_sub(out.len()) as _, None) => {
if let Ok(messages) = messages {
out.extend(messages.into_iter().map(|m| self.wrap_recv_msg(m)));
if out.len() >= max_messages {
break;
}
}
}
}
}
}

Ok(out)
}
}

Expand Down
57 changes: 48 additions & 9 deletions omniqueue/src/backends/memory_queue.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::time::Duration;
use std::{any::TypeId, collections::HashMap};

use async_trait::async_trait;
Expand Down Expand Up @@ -90,22 +91,60 @@ pub struct MemoryQueueConsumer {
tx: broadcast::Sender<Vec<u8>>,
}

#[async_trait]
impl QueueConsumer for MemoryQueueConsumer {
type Payload = Vec<u8>;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let payload = self.rx.recv().await.map_err(QueueError::generic)?;

Ok(Delivery {
impl MemoryQueueConsumer {
fn wrap_payload(&self, payload: Vec<u8>) -> Delivery {
Delivery {
payload: Some(payload.clone()),
decoders: self.registry.clone(),
acker: Box::new(MemoryQueueAcker {
tx: self.tx.clone(),
payload_copy: Some(payload),
alredy_acked_or_nacked: false,
}),
})
}
}
}

#[async_trait]
impl QueueConsumer for MemoryQueueConsumer {
type Payload = Vec<u8>;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let payload = self.rx.recv().await.map_err(QueueError::generic)?;
Ok(self.wrap_payload(payload))
}

async fn receive_all(
&mut self,
max_messages: usize,
deadline: Duration,
) -> Result<Vec<Delivery>, QueueError> {
let mut out = Vec::with_capacity(max_messages);

// Await at least one delivery before starting the clock
let msg = self.rx.recv().await;
let delivery = msg
.map(|payload| self.wrap_payload(payload))
.map_err(QueueError::generic)?;
out.push(delivery);

let mut interval = tokio::time::interval(deadline);
// Skip the first tick which is instantaneous
interval.tick().await;
loop {
tokio::select! {
_ = interval.tick() => break,
msg = self.rx.recv() => {
let delivery = msg
.map(|payload| self.wrap_payload(payload)).map_err(QueueError::generic)?;
out.push(delivery);
if out.len() >= max_messages {
break;
}
}
}
}
Ok(out)
}
}

Expand Down
63 changes: 54 additions & 9 deletions omniqueue/src/backends/rabbitmq.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::time::Duration;
use std::{any::TypeId, collections::HashMap};

use async_trait::async_trait;
Expand All @@ -24,6 +25,7 @@ pub struct RabbitMqConfig {
pub publish_exchange: String,
pub publish_routing_key: String,
pub publish_options: BasicPublishOptions,
// FIXME: typos
pub publish_properites: BasicProperties,

pub consume_queue: String,
Expand Down Expand Up @@ -168,6 +170,19 @@ pub struct RabbitMqConsumer {
requeue_on_nack: bool,
}

impl RabbitMqConsumer {
fn wrap_delivery(&self, delivery: lapin::message::Delivery) -> Delivery {
Delivery {
decoders: self.registry.clone(),
payload: Some(delivery.data),
acker: Box::new(RabbitMqAcker {
acker: Some(delivery.acker),
requeue_on_nack: self.requeue_on_nack,
}),
}
}
}

#[async_trait]
impl QueueConsumer for RabbitMqConsumer {
type Payload = Vec<u8>;
Expand All @@ -178,19 +193,49 @@ impl QueueConsumer for RabbitMqConsumer {
.clone()
.map(|l: Result<lapin::message::Delivery, lapin::Error>| {
let l = l.map_err(QueueError::generic)?;

Ok(Delivery {
decoders: self.registry.clone(),
payload: Some(l.data),
acker: Box::new(RabbitMqAcker {
acker: Some(l.acker),
requeue_on_nack: self.requeue_on_nack,
}),
})
Ok(self.wrap_delivery(l))
});

stream.next().await.ok_or(QueueError::NoData)?
}

async fn receive_all(
&mut self,
max_messages: usize,
deadline: Duration,
) -> Result<Vec<Delivery>, QueueError> {
let mut stream = self.consumer.clone();
let mut out = Vec::with_capacity(max_messages);

// FIXME: the real way to do this is to set the pre-fetch count on the channel (which happens much earlier).
// e.g. `channel_rx.basic_qos(10, Default::default())?`
// There is no config for controlling the timeout - it is up to each client impl.
// As written, this rabbit impl breaks the standard behavior of "return as soon as items are available"
// The tests have been modified to reflect this gap.
if let Some(delivery) = stream.next().await {
out.push(self.wrap_delivery(delivery.map_err(QueueError::generic)?));

let mut interval = tokio::time::interval(deadline);
// Skip the instant first period
interval.tick().await;

loop {
tokio::select! {
_ = interval.tick() => break,
delivery = stream.next() => {
if let Some(delivery) = delivery {
out.push(self.wrap_delivery(delivery.map_err(QueueError::generic)?));
if out.len() >= max_messages {
break;
}
}
}
}
}
}

Ok(out)
}
}

pub struct RabbitMqAcker {
Expand Down
82 changes: 67 additions & 15 deletions omniqueue/src/backends/redis/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::time::Duration;
use std::{any::TypeId, collections::HashMap, marker::PhantomData};

use async_trait::async_trait;
use bb8::ManageConnection;
pub use bb8_redis::RedisMultiplexedConnectionManager;
use redis::streams::{StreamReadOptions, StreamReadReply};
use redis::streams::{StreamId, StreamReadOptions, StreamReadReply};

use crate::{
decoding::DecoderRegistry,
Expand Down Expand Up @@ -219,6 +220,31 @@ pub struct RedisStreamConsumer<M: ManageConnection> {
payload_key: String,
}

impl<M> RedisStreamConsumer<M>
where
M: ManageConnection,
M::Connection: redis::aio::ConnectionLike + Send + Sync,
M::Error: 'static + std::error::Error + Send + Sync,
{
fn wrap_entry(&self, entry: StreamId) -> Result<Delivery, QueueError> {
let entry_id = entry.id.clone();
let payload = entry.map.get(&self.payload_key).ok_or(QueueError::NoData)?;
let payload: Vec<u8> = redis::from_redis_value(payload).map_err(QueueError::generic)?;

Ok(Delivery {
payload: Some(payload),
acker: Box::new(RedisStreamAcker {
redis: self.redis.clone(),
queue_key: self.queue_key.clone(),
consumer_group: self.consumer_group.clone(),
entry_id,
already_acked_or_nacked: false,
}),
decoders: self.registry.clone(),
})
}
}

#[async_trait]
impl<M> QueueConsumer for RedisStreamConsumer<M>
where
Expand Down Expand Up @@ -247,21 +273,47 @@ where
let queue = read_out.keys.into_iter().next().ok_or(QueueError::NoData)?;

let entry = queue.ids.into_iter().next().ok_or(QueueError::NoData)?;
self.wrap_entry(entry)
}

let entry_id = entry.id.clone();
let payload = entry.map.get(&self.payload_key).ok_or(QueueError::NoData)?;
let payload: Vec<u8> = redis::from_redis_value(payload).map_err(QueueError::generic)?;
async fn receive_all(
&mut self,
max_messages: usize,
deadline: Duration,
) -> Result<Vec<Delivery>, QueueError> {
let mut conn = self.redis.get().await.map_err(QueueError::generic)?;

Ok(Delivery {
payload: Some(payload),
acker: Box::new(RedisStreamAcker {
redis: self.redis.clone(),
queue_key: self.queue_key.clone(),
consumer_group: self.consumer_group.clone(),
entry_id,
already_acked_or_nacked: false,
}),
decoders: self.registry.clone(),
})
// Ensure an empty vec is never returned
let queue = loop {
let read_out: StreamReadReply = redis::Cmd::xread_options(
&[&self.queue_key],
&[">"],
&StreamReadOptions::default()
.group(&self.consumer_group, &self.consumer_name)
.block(
deadline
.as_millis()
.try_into()
.map_err(QueueError::generic)?,
)
.count(max_messages),
)
.query_async(&mut *conn)
.await
.map_err(QueueError::generic)?;

if let Some(queue) = read_out.keys.into_iter().next() {
if !queue.ids.is_empty() {
break queue;
}
}
};

let mut out = Vec::with_capacity(max_messages);
for entry in queue.ids {
out.push(self.wrap_entry(entry)?);
}

Ok(out)
}
}
Loading

0 comments on commit b7d6e51

Please sign in to comment.