Skip to content

Commit

Permalink
Merge pull request #18 from svix/onelson/recv-all
Browse files Browse the repository at this point in the history
add `Consumer::receive_all`
  • Loading branch information
svix-onelson authored Oct 11, 2023
2 parents 2479040 + 48f7ad8 commit 6ed3b4b
Show file tree
Hide file tree
Showing 12 changed files with 1,073 additions and 121 deletions.
54 changes: 40 additions & 14 deletions omniqueue/src/backends/gcp_pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use google_cloud_pubsub::subscription::Subscription;
use serde::Serialize;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use std::{any::TypeId, collections::HashMap};

pub struct GcpPubSubBackend;
Expand Down Expand Up @@ -217,31 +218,56 @@ async fn subscription(client: &Client, subscription_id: &str) -> Result<Subscrip
Ok(subscription)
}

#[async_trait]
impl QueueConsumer for GcpPubSubConsumer {
type Payload = Payload;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let subscription = subscription(&self.client, &self.subscription_id).await?;
let mut stream = subscription
.subscribe(None)
.await
.map_err(QueueError::generic)?;

let mut recv_msg = stream.next().await.ok_or_else(|| QueueError::NoData)?;
impl GcpPubSubConsumer {
fn wrap_recv_msg(&self, mut recv_msg: ReceivedMessage) -> Delivery {
// FIXME: would be nice to avoid having to move the data out here.
// While it's possible to ack via a subscription and an ack_id, nack is only
// possible via a `ReceiveMessage`. This means we either need to hold 2 copies of
// the payload, or move the bytes out so they can be returned _outside of the Acker_.
let payload = recv_msg.message.data.drain(..).collect();
Ok(Delivery {

Delivery {
decoders: self.registry.clone(),
acker: Box::new(GcpPubSubAcker {
recv_msg,
subscription_id: self.subscription_id.clone(),
}),
payload: Some(payload),
})
}
}
}

#[async_trait]
impl QueueConsumer for GcpPubSubConsumer {
type Payload = Payload;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let subscription = subscription(&self.client, &self.subscription_id).await?;
let mut stream = subscription
.subscribe(None)
.await
.map_err(QueueError::generic)?;

let recv_msg = stream.next().await.ok_or_else(|| QueueError::NoData)?;

Ok(self.wrap_recv_msg(recv_msg))
}

async fn receive_all(
&mut self,
max_messages: usize,
deadline: Duration,
) -> Result<Vec<Delivery>, QueueError> {
let subscription = subscription(&self.client, &self.subscription_id).await?;
match tokio::time::timeout(deadline, subscription.pull(max_messages as _, None)).await {
Ok(messages) => Ok(messages
.map_err(QueueError::generic)?
.into_iter()
.map(|m| self.wrap_recv_msg(m))
.collect()),
// Timeout
Err(_) => Ok(vec![]),
}
}
}

Expand Down
180 changes: 171 additions & 9 deletions omniqueue/src/backends/memory_queue.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::time::{Duration, Instant};
use std::{any::TypeId, collections::HashMap};

use async_trait::async_trait;
Expand Down Expand Up @@ -90,22 +91,53 @@ pub struct MemoryQueueConsumer {
tx: broadcast::Sender<Vec<u8>>,
}

#[async_trait]
impl QueueConsumer for MemoryQueueConsumer {
type Payload = Vec<u8>;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let payload = self.rx.recv().await.map_err(QueueError::generic)?;

Ok(Delivery {
impl MemoryQueueConsumer {
fn wrap_payload(&self, payload: Vec<u8>) -> Delivery {
Delivery {
payload: Some(payload.clone()),
decoders: self.registry.clone(),
acker: Box::new(MemoryQueueAcker {
tx: self.tx.clone(),
payload_copy: Some(payload),
alredy_acked_or_nacked: false,
}),
})
}
}
}

#[async_trait]
impl QueueConsumer for MemoryQueueConsumer {
type Payload = Vec<u8>;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let payload = self.rx.recv().await.map_err(QueueError::generic)?;
Ok(self.wrap_payload(payload))
}

async fn receive_all(
&mut self,
max_messages: usize,
deadline: Duration,
) -> Result<Vec<Delivery>, QueueError> {
let mut out = Vec::with_capacity(max_messages);
let start = Instant::now();
match tokio::time::timeout(deadline, self.rx.recv()).await {
Ok(Ok(x)) => out.push(self.wrap_payload(x)),
// Timeouts and stream termination
Err(_) | Ok(Err(_)) => return Ok(out),
}

if max_messages > 1 {
// `try_recv` will break the loop if no ready items are already buffered in the channel.
// This should allow us to opportunistically fill up the buffer in the remaining time.
while let Ok(x) = self.rx.try_recv() {
out.push(self.wrap_payload(x));
if out.len() >= max_messages || start.elapsed() >= deadline {
break;
}
}
}
Ok(out)
}
}

Expand Down Expand Up @@ -146,6 +178,7 @@ impl Acker for MemoryQueueAcker {
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};

use crate::{
queue::{consumer::QueueConsumer, producer::QueueProducer, QueueBuilder},
Expand Down Expand Up @@ -233,4 +266,133 @@ mod tests {
TypeA { a: 12 }
);
}

#[derive(Debug, Deserialize, Serialize, PartialEq)]
pub struct ExType {
a: u8,
}

/// Consumer will return immediately if there are fewer than max messages to start with.
#[tokio::test]
async fn test_send_recv_all_partial() {
let payload = ExType { a: 2 };

let (p, mut c) = QueueBuilder::<MemoryQueueBackend, _>::new(16)
.build_pair()
.await
.unwrap();

p.send_serde_json(&payload).await.unwrap();
let deadline = Duration::from_secs(1);

let now = Instant::now();
let mut xs = c.receive_all(2, deadline).await.unwrap();
assert_eq!(xs.len(), 1);
let d = xs.remove(0);
assert_eq!(d.payload_serde_json::<ExType>().unwrap().unwrap(), payload);
d.ack().await.unwrap();
assert!(now.elapsed() <= deadline);
}

/// Consumer should yield items immediately if there's a full batch ready on the first poll.
#[tokio::test]
async fn test_send_recv_all_full() {
let payload1 = ExType { a: 1 };
let payload2 = ExType { a: 2 };

let (p, mut c) = QueueBuilder::<MemoryQueueBackend, _>::new(16)
.build_pair()
.await
.unwrap();

p.send_serde_json(&payload1).await.unwrap();
p.send_serde_json(&payload2).await.unwrap();
let deadline = Duration::from_secs(1);

let now = Instant::now();
let mut xs = c.receive_all(2, deadline).await.unwrap();
assert_eq!(xs.len(), 2);
let d1 = xs.remove(0);
assert_eq!(
d1.payload_serde_json::<ExType>().unwrap().unwrap(),
payload1
);
d1.ack().await.unwrap();

let d2 = xs.remove(0);
assert_eq!(
d2.payload_serde_json::<ExType>().unwrap().unwrap(),
payload2
);
d2.ack().await.unwrap();
// N.b. it's still possible this could turn up false if the test runs too slow.
assert!(now.elapsed() < deadline);
}

/// Consumer will return the full batch immediately, but also return immediately if a partial batch is ready.
#[tokio::test]
async fn test_send_recv_all_full_then_partial() {
let payload1 = ExType { a: 1 };
let payload2 = ExType { a: 2 };
let payload3 = ExType { a: 3 };

let (p, mut c) = QueueBuilder::<MemoryQueueBackend, _>::new(16)
.build_pair()
.await
.unwrap();

p.send_serde_json(&payload1).await.unwrap();
p.send_serde_json(&payload2).await.unwrap();
p.send_serde_json(&payload3).await.unwrap();

let deadline = Duration::from_secs(1);
let now1 = Instant::now();
let mut xs = c.receive_all(2, deadline).await.unwrap();
assert_eq!(xs.len(), 2);
let d1 = xs.remove(0);
assert_eq!(
d1.payload_serde_json::<ExType>().unwrap().unwrap(),
payload1
);
d1.ack().await.unwrap();

let d2 = xs.remove(0);
assert_eq!(
d2.payload_serde_json::<ExType>().unwrap().unwrap(),
payload2
);
d2.ack().await.unwrap();
assert!(now1.elapsed() < deadline);

// 2nd call
let now2 = Instant::now();
let mut ys = c.receive_all(2, deadline).await.unwrap();
assert_eq!(ys.len(), 1);
let d3 = ys.remove(0);
assert_eq!(
d3.payload_serde_json::<ExType>().unwrap().unwrap(),
payload3
);
d3.ack().await.unwrap();
assert!(now2.elapsed() <= deadline);
}

/// Consumer will NOT wait indefinitely for at least one item.
#[tokio::test]
async fn test_send_recv_all_late_arriving_items() {
let (_p, mut c) = QueueBuilder::<MemoryQueueBackend, _>::new(16)
.build_pair()
.await
.unwrap();

let deadline = Duration::from_secs(1);
let now = Instant::now();
let xs = c.receive_all(2, deadline).await.unwrap();
let elapsed = now.elapsed();

assert_eq!(xs.len(), 0);
// Elapsed should be around the deadline, ballpark
assert!(elapsed >= deadline);
assert!(elapsed <= deadline + Duration::from_millis(200));
}
}
Loading

0 comments on commit 6ed3b4b

Please sign in to comment.