Skip to content

Commit

Permalink
async: adding efficient, poll-based stream support
Browse files Browse the repository at this point in the history
This is a much more efficient, polling-based stream implementation that
doesn't rely on horrible things like spawning threads just to do async.

Changelog-Added: Add async stream support
Signed-off-by: William Casarin <[email protected]>
  • Loading branch information
jb55 committed Dec 14, 2024
1 parent 46ca13d commit 91cff7b
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 53 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ bindgen = []
[dependencies]
flatbuffers = "23.5.26"
libc = "0.2.151"
futures = "0.3.31"
thiserror = "2.0.3"
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
tracing = "0.1.40"
Expand Down
2 changes: 1 addition & 1 deletion nostrdb
Submodule nostrdb updated 2 files
+109 −485 src/nostrdb.c
+1 −2 src/nostrdb.h
12 changes: 3 additions & 9 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use crate::bindings;
#[derive(Copy, Clone)]
pub struct Config {
pub config: bindings::ndb_config,
// We add a flag to know if we've installed a Rust closure so we can clean it up in Drop.
is_rust_closure: bool,
}

impl Default for Config {
Expand All @@ -29,11 +27,7 @@ impl Config {
bindings::ndb_default_config(&mut config);
}

let is_rust_closure = false;
Config {
config,
is_rust_closure,
}
Config { config }
}

//
Expand All @@ -54,7 +48,8 @@ impl Config {
self
}

/// Set a callback for when we have
/// Set a callback to be notified on updated subscriptions. The function
/// will be called with the corresponsing subscription id.
pub fn set_sub_callback<F>(mut self, closure: F) -> Self
where
F: FnMut(u64) + 'static,
Expand All @@ -67,7 +62,6 @@ impl Config {

self.config.sub_cb = Some(sub_callback_trampoline);
self.config.sub_cb_ctx = ctx_ptr;
self.is_rust_closure = true;
self
}

Expand Down
87 changes: 87 additions & 0 deletions src/future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use crate::{Ndb, NoteKey, Subscription};

use std::{
pin::Pin,
task::{Context, Poll},
};

use futures::Stream;

/// Used to track query futures
#[derive(Debug, Clone)]
pub(crate) struct SubscriptionState {
pub ready: bool,
pub done: bool,
pub waker: Option<std::task::Waker>,
}

/// A subscription that you can .await on. This can enables very clean
/// integration into Rust's async state machinery.
pub struct SubscriptionStream {
// some handle or state
// e.g., a reference to a non-blocking API or a shared atomic state
ndb: Ndb,
sub_id: Subscription,
max_notes: u32,
}

impl SubscriptionStream {
pub fn new(ndb: Ndb, sub_id: Subscription) -> Self {
// Most of the time we only want to fetch a few things. If expecting
// lots of data, use `set_max_notes_per_await`
let max_notes = 32;
SubscriptionStream {
ndb,
sub_id,
max_notes,
}
}

pub fn notes_per_await(mut self, max_notes: u32) -> Self {
self.max_notes = max_notes;
self
}

pub fn sub_id(&self) -> Subscription {
self.sub_id
}
}

impl Drop for SubscriptionStream {
fn drop(&mut self) {
// Perform cleanup here, like removing the subscription from the global map
let mut map = self.ndb.subs.lock().unwrap();
map.remove(&self.sub_id);
}
}

impl Stream for SubscriptionStream {
type Item = Vec<NoteKey>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let pinned = std::pin::pin!(self);
let me = pinned.as_ref().get_ref();
let mut map = me.ndb.subs.lock().unwrap();
let sub_state = map.entry(me.sub_id).or_insert(SubscriptionState {
ready: false,
done: false,
waker: None,
});

// we've unsubscribed
if sub_state.done {
return Poll::Ready(None);
}

if sub_state.ready {
// Reset ready, fetch notes
sub_state.ready = false;
let notes = me.ndb.poll_for_notes(me.sub_id, me.max_notes);
return Poll::Ready(Some(notes));
}

// Not ready yet, store waker
sub_state.waker = Some(cx.waker().clone());
std::task::Poll::Pending
}
}
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ mod bindings;
mod ndb_profile;

mod block;

mod future;

mod config;
mod error;
mod filter;
Expand All @@ -30,6 +33,8 @@ pub use block::{Block, BlockType, Blocks, Mention};
pub use config::Config;
pub use error::{Error, FilterError};
pub use filter::{Filter, FilterBuilder};
pub(crate) use future::SubscriptionState;
pub use future::SubscriptionStream;
pub use ndb::Ndb;
pub use ndb_profile::{NdbProfile, NdbProfileRecord};
pub use ndb_str::{NdbStr, NdbStrVariant};
Expand Down
124 changes: 82 additions & 42 deletions src/ndb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@ use std::ptr;

use crate::{
bindings, Blocks, Config, Error, Filter, Note, NoteKey, ProfileKey, ProfileRecord, QueryResult,
Result, Subscription, Transaction,
Result, Subscription, SubscriptionState, SubscriptionStream, Transaction,
};
use futures::StreamExt;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fs;
use std::os::raw::c_int;
use std::path::Path;
use std::sync::Arc;
use tokio::task; // Make sure to import the task module
use std::sync::{Arc, Mutex};
use tracing::debug;

#[derive(Debug)]
struct NdbRef {
ndb: *mut bindings::ndb,

/// Have we configured a rust closure for our callback? If so we need
/// to clean that up when this is dropped
has_rust_closure: bool,
rust_cb_ctx: *mut ::std::os::raw::c_void,
}

Expand All @@ -34,18 +32,23 @@ impl Drop for NdbRef {
unsafe {
bindings::ndb_destroy(self.ndb);

if self.has_rust_closure && !self.rust_cb_ctx.is_null() {
if !self.rust_cb_ctx.is_null() {
// Rebuild the Box from the raw pointer and drop it.
let _ = Box::from_raw(self.rust_cb_ctx as *mut Box<dyn FnMut()>);
}
}
}
}

type SubMap = HashMap<Subscription, SubscriptionState>;

/// A nostrdb context. Construct one of these with [Ndb::new].
#[derive(Debug, Clone)]
pub struct Ndb {
refs: Arc<NdbRef>,

/// Track query future states
pub(crate) subs: Arc<Mutex<SubMap>>,
}

impl Ndb {
Expand All @@ -65,7 +68,30 @@ impl Ndb {

let min_mapsize = 1024 * 1024 * 512;
let mut mapsize = config.config.mapsize;
let mut config = *config;
let config = *config;

let prev_callback = config.config.sub_cb;
let prev_callback_ctx = config.config.sub_cb_ctx;
let subs = Arc::new(Mutex::new(SubMap::default()));
let subs_clone = subs.clone();

// We need to register our own callback so that we can wake
// query futures
let mut config = config.set_sub_callback(move |sub_id: u64| {
let mut map = subs_clone.lock().unwrap();
if let Some(s) = map.get_mut(&Subscription::new(sub_id)) {
s.ready = true;
if let Some(w) = s.waker.take() {
w.wake();
}
}

if let Some(pcb) = prev_callback {
unsafe {
pcb(prev_callback_ctx, sub_id);
};
}
});

let result = loop {
let result =
Expand All @@ -90,15 +116,10 @@ impl Ndb {
return Err(Error::DbOpenFailed);
}

let has_rust_closure = !config.config.sub_cb_ctx.is_null();
let rust_cb_ctx = config.config.sub_cb_ctx;
let refs = Arc::new(NdbRef {
ndb,
has_rust_closure,
rust_cb_ctx,
});
let refs = Arc::new(NdbRef { ndb, rust_cb_ctx });

Ok(Ndb { refs })
Ok(Ndb { refs, subs })
}

/// Ingest a relay-sent event in the form `["EVENT","subid", {"id:"...}]`
Expand Down Expand Up @@ -155,9 +176,17 @@ impl Ndb {
unsafe { bindings::ndb_num_subscriptions(self.as_ptr()) as u32 }
}

pub fn unsubscribe(&self, sub: Subscription) -> Result<()> {
pub fn unsubscribe(&mut self, sub: Subscription) -> Result<()> {
let r = unsafe { bindings::ndb_unsubscribe(self.as_ptr(), sub.id()) };

// mark the subscription as done if it exists in our stream map
{
let mut map = self.subs.lock().unwrap();
if let Entry::Occupied(mut entry) = map.entry(sub) {
entry.get_mut().done = true;
}
}

if r == 0 {
Err(Error::SubscriptionError)
} else {
Expand Down Expand Up @@ -204,32 +233,11 @@ impl Ndb {
sub_id: Subscription,
max_notes: u32,
) -> Result<Vec<NoteKey>> {
let ndb = self.clone();
let handle = task::spawn_blocking(move || {
let mut vec: Vec<u64> = vec![];
vec.reserve_exact(max_notes as usize);
let res = unsafe {
bindings::ndb_wait_for_notes(
ndb.as_ptr(),
sub_id.id(),
vec.as_mut_ptr(),
max_notes as c_int,
)
};
if res == 0 {
Err(Error::SubscriptionError)
} else {
unsafe {
vec.set_len(res as usize);
};
Ok(vec)
}
});
let mut stream = SubscriptionStream::new(self.clone(), sub_id).notes_per_await(max_notes);

match handle.await {
Ok(Ok(res)) => Ok(res.into_iter().map(NoteKey::new).collect()),
Ok(Err(err)) => Err(err),
Err(_) => Err(Error::SubscriptionError),
match stream.next().await {
Some(res) => Ok(res),
None => Err(Error::SubscriptionError),
}
}

Expand Down Expand Up @@ -527,4 +535,36 @@ mod tests {
// we should definitely clean this up... especially on windows
test_util::cleanup_db(&db);
}

#[tokio::test]
async fn test_stream() {
let db = "target/testdbs/test_callback";
test_util::cleanup_db(&db);

{
let ndb = Ndb::new(db, &Config::new()).expect("ndb");

let filter = Filter::new().kinds(vec![1]).build();
let filters = vec![filter];

let mut sub = ndb.subscribe(&filters).expect("sub_id").stream(&ndb);

let res = sub.next();

ndb.process_event(r#"["EVENT","b",{"id": "702555e52e82cc24ad517ba78c21879f6e47a7c0692b9b20df147916ae8731a3","pubkey": "32bf915904bfde2d136ba45dde32c88f4aca863783999faea2e847a8fafd2f15","created_at": 1702675561,"kind": 1,"tags": [],"content": "hello, world","sig": "2275c5f5417abfd644b7bc74f0388d70feb5d08b6f90fa18655dda5c95d013bfbc5258ea77c05b7e40e0ee51d8a2efa931dc7a0ec1db4c0a94519762c6625675"}]"#).expect("process ok");

let res = res.await.expect("await ok");
assert_eq!(res, vec![NoteKey::new(1)]);

let txn = Transaction::new(&ndb).expect("txn");
let res = ndb.query(&txn, &filters, 1).expect("query ok");
assert_eq!(res.len(), 1);
assert_eq!(
hex::encode(res[0].note.id()),
"702555e52e82cc24ad517ba78c21879f6e47a7c0692b9b20df147916ae8731a3"
);
}

test_util::cleanup_db(&db);
}
}
8 changes: 7 additions & 1 deletion src/subscription.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
use crate::{Ndb, SubscriptionStream};

#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct Subscription(u64);

impl Subscription {
Expand All @@ -8,4 +10,8 @@ impl Subscription {
pub fn id(self) -> u64 {
self.0
}

pub fn stream(&self, ndb: &Ndb) -> SubscriptionStream {
SubscriptionStream::new(ndb.clone(), *self)
}
}

0 comments on commit 91cff7b

Please sign in to comment.