From 91cff7bf6dda2a9acfd21201f7d227fc8927dac0 Mon Sep 17 00:00:00 2001 From: William Casarin Date: Mon, 9 Dec 2024 16:39:37 -0800 Subject: [PATCH] async: adding efficient, poll-based stream support This is a much more efficient, polling-based stream implementation that doesn't rely on horrible things like spawning threads just to do async. Changelog-Added: Add async stream support Signed-off-by: William Casarin --- Cargo.toml | 1 + nostrdb | 2 +- src/config.rs | 12 ++--- src/future.rs | 87 +++++++++++++++++++++++++++++++ src/lib.rs | 5 ++ src/ndb.rs | 124 +++++++++++++++++++++++++++++--------------- src/subscription.rs | 8 ++- 7 files changed, 186 insertions(+), 53 deletions(-) create mode 100644 src/future.rs diff --git a/Cargo.toml b/Cargo.toml index 0b3b285..a7da8e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ bindgen = [] [dependencies] flatbuffers = "23.5.26" libc = "0.2.151" +futures = "0.3.31" thiserror = "2.0.3" tokio = { version = "1", features = ["rt-multi-thread", "macros"] } tracing = "0.1.40" diff --git a/nostrdb b/nostrdb index 3260fa1..3d471ea 160000 --- a/nostrdb +++ b/nostrdb @@ -1 +1 @@ -Subproject commit 3260fa14639cf2adfec69b6a2bb000047f038e18 +Subproject commit 3d471ea97a1bd904b2b7ceae936e9f40f2bb7a80 diff --git a/src/config.rs b/src/config.rs index 6c5d889..cf0f566 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,8 +3,6 @@ use crate::bindings; #[derive(Copy, Clone)] pub struct Config { pub config: bindings::ndb_config, - // We add a flag to know if we've installed a Rust closure so we can clean it up in Drop. - is_rust_closure: bool, } impl Default for Config { @@ -29,11 +27,7 @@ impl Config { bindings::ndb_default_config(&mut config); } - let is_rust_closure = false; - Config { - config, - is_rust_closure, - } + Config { config } } // @@ -54,7 +48,8 @@ impl Config { self } - /// Set a callback for when we have + /// Set a callback to be notified on updated subscriptions. The function + /// will be called with the corresponsing subscription id. pub fn set_sub_callback(mut self, closure: F) -> Self where F: FnMut(u64) + 'static, @@ -67,7 +62,6 @@ impl Config { self.config.sub_cb = Some(sub_callback_trampoline); self.config.sub_cb_ctx = ctx_ptr; - self.is_rust_closure = true; self } diff --git a/src/future.rs b/src/future.rs new file mode 100644 index 0000000..08d3c61 --- /dev/null +++ b/src/future.rs @@ -0,0 +1,87 @@ +use crate::{Ndb, NoteKey, Subscription}; + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures::Stream; + +/// Used to track query futures +#[derive(Debug, Clone)] +pub(crate) struct SubscriptionState { + pub ready: bool, + pub done: bool, + pub waker: Option, +} + +/// A subscription that you can .await on. This can enables very clean +/// integration into Rust's async state machinery. +pub struct SubscriptionStream { + // some handle or state + // e.g., a reference to a non-blocking API or a shared atomic state + ndb: Ndb, + sub_id: Subscription, + max_notes: u32, +} + +impl SubscriptionStream { + pub fn new(ndb: Ndb, sub_id: Subscription) -> Self { + // Most of the time we only want to fetch a few things. If expecting + // lots of data, use `set_max_notes_per_await` + let max_notes = 32; + SubscriptionStream { + ndb, + sub_id, + max_notes, + } + } + + pub fn notes_per_await(mut self, max_notes: u32) -> Self { + self.max_notes = max_notes; + self + } + + pub fn sub_id(&self) -> Subscription { + self.sub_id + } +} + +impl Drop for SubscriptionStream { + fn drop(&mut self) { + // Perform cleanup here, like removing the subscription from the global map + let mut map = self.ndb.subs.lock().unwrap(); + map.remove(&self.sub_id); + } +} + +impl Stream for SubscriptionStream { + type Item = Vec; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let pinned = std::pin::pin!(self); + let me = pinned.as_ref().get_ref(); + let mut map = me.ndb.subs.lock().unwrap(); + let sub_state = map.entry(me.sub_id).or_insert(SubscriptionState { + ready: false, + done: false, + waker: None, + }); + + // we've unsubscribed + if sub_state.done { + return Poll::Ready(None); + } + + if sub_state.ready { + // Reset ready, fetch notes + sub_state.ready = false; + let notes = me.ndb.poll_for_notes(me.sub_id, me.max_notes); + return Poll::Ready(Some(notes)); + } + + // Not ready yet, store waker + sub_state.waker = Some(cx.waker().clone()); + std::task::Poll::Pending + } +} diff --git a/src/lib.rs b/src/lib.rs index 0b9d15e..aa465c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,9 @@ mod bindings; mod ndb_profile; mod block; + +mod future; + mod config; mod error; mod filter; @@ -30,6 +33,8 @@ pub use block::{Block, BlockType, Blocks, Mention}; pub use config::Config; pub use error::{Error, FilterError}; pub use filter::{Filter, FilterBuilder}; +pub(crate) use future::SubscriptionState; +pub use future::SubscriptionStream; pub use ndb::Ndb; pub use ndb_profile::{NdbProfile, NdbProfileRecord}; pub use ndb_str::{NdbStr, NdbStrVariant}; diff --git a/src/ndb.rs b/src/ndb.rs index c5ef1f4..520a87f 100644 --- a/src/ndb.rs +++ b/src/ndb.rs @@ -3,22 +3,20 @@ use std::ptr; use crate::{ bindings, Blocks, Config, Error, Filter, Note, NoteKey, ProfileKey, ProfileRecord, QueryResult, - Result, Subscription, Transaction, + Result, Subscription, SubscriptionState, SubscriptionStream, Transaction, }; +use futures::StreamExt; +use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::fs; use std::os::raw::c_int; use std::path::Path; -use std::sync::Arc; -use tokio::task; // Make sure to import the task module +use std::sync::{Arc, Mutex}; use tracing::debug; #[derive(Debug)] struct NdbRef { ndb: *mut bindings::ndb, - - /// Have we configured a rust closure for our callback? If so we need - /// to clean that up when this is dropped - has_rust_closure: bool, rust_cb_ctx: *mut ::std::os::raw::c_void, } @@ -34,7 +32,7 @@ impl Drop for NdbRef { unsafe { bindings::ndb_destroy(self.ndb); - if self.has_rust_closure && !self.rust_cb_ctx.is_null() { + if !self.rust_cb_ctx.is_null() { // Rebuild the Box from the raw pointer and drop it. let _ = Box::from_raw(self.rust_cb_ctx as *mut Box); } @@ -42,10 +40,15 @@ impl Drop for NdbRef { } } +type SubMap = HashMap; + /// A nostrdb context. Construct one of these with [Ndb::new]. #[derive(Debug, Clone)] pub struct Ndb { refs: Arc, + + /// Track query future states + pub(crate) subs: Arc>, } impl Ndb { @@ -65,7 +68,30 @@ impl Ndb { let min_mapsize = 1024 * 1024 * 512; let mut mapsize = config.config.mapsize; - let mut config = *config; + let config = *config; + + let prev_callback = config.config.sub_cb; + let prev_callback_ctx = config.config.sub_cb_ctx; + let subs = Arc::new(Mutex::new(SubMap::default())); + let subs_clone = subs.clone(); + + // We need to register our own callback so that we can wake + // query futures + let mut config = config.set_sub_callback(move |sub_id: u64| { + let mut map = subs_clone.lock().unwrap(); + if let Some(s) = map.get_mut(&Subscription::new(sub_id)) { + s.ready = true; + if let Some(w) = s.waker.take() { + w.wake(); + } + } + + if let Some(pcb) = prev_callback { + unsafe { + pcb(prev_callback_ctx, sub_id); + }; + } + }); let result = loop { let result = @@ -90,15 +116,10 @@ impl Ndb { return Err(Error::DbOpenFailed); } - let has_rust_closure = !config.config.sub_cb_ctx.is_null(); let rust_cb_ctx = config.config.sub_cb_ctx; - let refs = Arc::new(NdbRef { - ndb, - has_rust_closure, - rust_cb_ctx, - }); + let refs = Arc::new(NdbRef { ndb, rust_cb_ctx }); - Ok(Ndb { refs }) + Ok(Ndb { refs, subs }) } /// Ingest a relay-sent event in the form `["EVENT","subid", {"id:"...}]` @@ -155,9 +176,17 @@ impl Ndb { unsafe { bindings::ndb_num_subscriptions(self.as_ptr()) as u32 } } - pub fn unsubscribe(&self, sub: Subscription) -> Result<()> { + pub fn unsubscribe(&mut self, sub: Subscription) -> Result<()> { let r = unsafe { bindings::ndb_unsubscribe(self.as_ptr(), sub.id()) }; + // mark the subscription as done if it exists in our stream map + { + let mut map = self.subs.lock().unwrap(); + if let Entry::Occupied(mut entry) = map.entry(sub) { + entry.get_mut().done = true; + } + } + if r == 0 { Err(Error::SubscriptionError) } else { @@ -204,32 +233,11 @@ impl Ndb { sub_id: Subscription, max_notes: u32, ) -> Result> { - let ndb = self.clone(); - let handle = task::spawn_blocking(move || { - let mut vec: Vec = vec![]; - vec.reserve_exact(max_notes as usize); - let res = unsafe { - bindings::ndb_wait_for_notes( - ndb.as_ptr(), - sub_id.id(), - vec.as_mut_ptr(), - max_notes as c_int, - ) - }; - if res == 0 { - Err(Error::SubscriptionError) - } else { - unsafe { - vec.set_len(res as usize); - }; - Ok(vec) - } - }); + let mut stream = SubscriptionStream::new(self.clone(), sub_id).notes_per_await(max_notes); - match handle.await { - Ok(Ok(res)) => Ok(res.into_iter().map(NoteKey::new).collect()), - Ok(Err(err)) => Err(err), - Err(_) => Err(Error::SubscriptionError), + match stream.next().await { + Some(res) => Ok(res), + None => Err(Error::SubscriptionError), } } @@ -527,4 +535,36 @@ mod tests { // we should definitely clean this up... especially on windows test_util::cleanup_db(&db); } + + #[tokio::test] + async fn test_stream() { + let db = "target/testdbs/test_callback"; + test_util::cleanup_db(&db); + + { + let ndb = Ndb::new(db, &Config::new()).expect("ndb"); + + let filter = Filter::new().kinds(vec![1]).build(); + let filters = vec![filter]; + + let mut sub = ndb.subscribe(&filters).expect("sub_id").stream(&ndb); + + let res = sub.next(); + + ndb.process_event(r#"["EVENT","b",{"id": "702555e52e82cc24ad517ba78c21879f6e47a7c0692b9b20df147916ae8731a3","pubkey": "32bf915904bfde2d136ba45dde32c88f4aca863783999faea2e847a8fafd2f15","created_at": 1702675561,"kind": 1,"tags": [],"content": "hello, world","sig": "2275c5f5417abfd644b7bc74f0388d70feb5d08b6f90fa18655dda5c95d013bfbc5258ea77c05b7e40e0ee51d8a2efa931dc7a0ec1db4c0a94519762c6625675"}]"#).expect("process ok"); + + let res = res.await.expect("await ok"); + assert_eq!(res, vec![NoteKey::new(1)]); + + let txn = Transaction::new(&ndb).expect("txn"); + let res = ndb.query(&txn, &filters, 1).expect("query ok"); + assert_eq!(res.len(), 1); + assert_eq!( + hex::encode(res[0].note.id()), + "702555e52e82cc24ad517ba78c21879f6e47a7c0692b9b20df147916ae8731a3" + ); + } + + test_util::cleanup_db(&db); + } } diff --git a/src/subscription.rs b/src/subscription.rs index 8e77d6a..905642b 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -1,4 +1,6 @@ -#[derive(Debug, Clone, Copy, Eq, PartialEq)] +use crate::{Ndb, SubscriptionStream}; + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub struct Subscription(u64); impl Subscription { @@ -8,4 +10,8 @@ impl Subscription { pub fn id(self) -> u64 { self.0 } + + pub fn stream(&self, ndb: &Ndb) -> SubscriptionStream { + SubscriptionStream::new(ndb.clone(), *self) + } }