diff --git a/Cargo.toml b/Cargo.toml index 5dad60e..ff2faf3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,17 +8,27 @@ homepage = "https://github.com/jlizen/compute-heavy-future-executor" rust-version = "1.70" exclude = ["/.github", "/examples", "/scripts"] readme = "README.md" -description = "Additional executor patterns for use handling compute-bounded, blocking futures." +description = "Additional executor patterns for handling compute-bounded, blocking futures." categories = ["asynchronous"] +[features] +tokio = ["tokio/rt"] +tokio_block_in_place = ["tokio", "tokio/rt-multi-thread"] +secondary_tokio_runtime = ["tokio", "tokio/rt-multi-thread", "dep:libc", "dep:num_cpus"] + [dependencies] +libc = { version = "0.2.168", optional = true } log = "0.4.22" +num_cpus = { version = "1.0", optional = true } +tokio = { version = "1.0", features = ["macros", "sync"] } + +[dev-dependencies] +tokio = { version = "1.0", features = ["full"]} +futures-util = "0.3.31" + +[package.metadata.docs.rs] +all-features = true -[target.'cfg(compute_heavy_executor_tokio)'.dependencies] -tokio = { version = "1.0", features = ["rt", "rt-multi-thread", "macros", "sync"]} -libc = { version = "0.2.168"} -num_cpus = { version = "1.0"} [lints.rust] -# calling libraries can use the convention of `cfg(compute_heavy_executor)` to enable usage of this crate -unexpected_cfgs = { level = "warn", check-cfg = ['cfg(compute_heavy_executor)', 'cfg(compute_heavy_executor_tokio)'] } \ No newline at end of file +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tokio_unstable)'] } diff --git a/README.md b/README.md index eec64dc..484c00a 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,2 @@ # compute-heavy-future-executor -Experimental crate for adding special handling for frequently blocking futures +Experimental crate that adds additional executor patterns to use with frequently blocking futures. diff --git a/src/block_in_place.rs b/src/block_in_place.rs new file mode 100644 index 0000000..fd62710 --- /dev/null +++ b/src/block_in_place.rs @@ -0,0 +1,45 @@ +use crate::{ + concurrency_limit::ConcurrencyLimit, + error::{Error, InvalidConfig}, + ComputeHeavyFutureExecutor, +}; + +use tokio::runtime::{Handle, RuntimeFlavor}; + +pub(crate) struct BlockInPlaceExecutor { + concurrency_limit: ConcurrencyLimit, +} + +impl BlockInPlaceExecutor { + pub(crate) fn new(max_concurrency: Option) -> Result { + match Handle::current().runtime_flavor() { + RuntimeFlavor::MultiThread => Ok(()), + #[cfg(tokio_unstable)] + RuntimeFlavor::MultiThreadAlt => Ok(()), + flavor => Err(Error::InvalidConfig(InvalidConfig { + field: "current tokio runtime flavor", + received: format!("{flavor:#?}"), + expected: "MultiThread", + }))?, + }?; + + Ok(Self { + concurrency_limit: ConcurrencyLimit::new(max_concurrency), + }) + } +} + +impl ComputeHeavyFutureExecutor for BlockInPlaceExecutor { + async fn execute(&self, fut: F) -> Result + where + F: std::future::Future + Send + 'static, + O: Send + 'static, + { + let _permit = self.concurrency_limit.acquire_permit().await; + + Ok(tokio::task::block_in_place(move || { + tokio::runtime::Handle::current().block_on(async { fut.await }) + })) + // permit implicitly drops + } +} diff --git a/src/concurrency_limit.rs b/src/concurrency_limit.rs new file mode 100644 index 0000000..1b9281e --- /dev/null +++ b/src/concurrency_limit.rs @@ -0,0 +1,42 @@ +use std::sync::Arc; + +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +use crate::error::Error; + +/// Wrapper around semaphore that turns it into a non-op if no limit is provided +/// or semaphore channel is closed +pub(crate) struct ConcurrencyLimit { + semaphore: Option>, +} + +impl ConcurrencyLimit { + /// Accepts none in case no concurrency + pub(crate) fn new(limit: Option) -> Self { + let semaphore = limit.map(|limit| Arc::new(Semaphore::new(limit))); + + Self { semaphore } + } + + /// Waits on a permit to the semaphore if configured, otherwise immediately returns. + /// + /// Internally turns errors into a no-op (`None`) and outputs log lines. + pub(crate) async fn acquire_permit(&self) -> Option { + match self.semaphore.clone() { + Some(semaphore) => { + match semaphore + .acquire_owned() + .await + .map_err(|err| Error::Semaphore(err)) + { + Ok(permit) => Some(permit), + Err(err) => { + log::error!("failed to acquire permit: {err}"); + None + } + } + } + None => None, + } + } +} diff --git a/src/current_context.rs b/src/current_context.rs new file mode 100644 index 0000000..6badb65 --- /dev/null +++ b/src/current_context.rs @@ -0,0 +1,27 @@ +use crate::{concurrency_limit::ConcurrencyLimit, error::Error, ComputeHeavyFutureExecutor}; + +pub(crate) struct CurrentContextExecutor { + concurrency_limit: ConcurrencyLimit, +} + +impl CurrentContextExecutor { + pub(crate) fn new(max_concurrency: Option) -> Self { + Self { + concurrency_limit: ConcurrencyLimit::new(max_concurrency), + } + } +} + +impl ComputeHeavyFutureExecutor for CurrentContextExecutor { + async fn execute(&self, fut: F) -> Result + where + F: std::future::Future + Send + 'static, + O: Send + 'static, + { + let _permit = self.concurrency_limit.acquire_permit().await; + + Ok(fut.await) + + // implicit permit drop + } +} diff --git a/src/custom_executor.rs b/src/custom_executor.rs new file mode 100644 index 0000000..300c432 --- /dev/null +++ b/src/custom_executor.rs @@ -0,0 +1,55 @@ +use std::{future::Future, pin::Pin}; + +use crate::{ + concurrency_limit::ConcurrencyLimit, error::Error, make_future_cancellable, + ComputeHeavyFutureExecutor, +}; + +/// A closure that accepts an arbitrary future and polls it to completion +/// via its preferred strategy. +pub type CustomExecutorClosure = Box< + dyn Fn( + Pin + Send + 'static>>, + ) -> Box< + dyn Future>> + + Send + + 'static, + > + Send + + Sync, +>; + +pub(crate) struct CustomExecutor { + closure: CustomExecutorClosure, + concurrency_limit: ConcurrencyLimit, +} + +impl CustomExecutor { + pub(crate) fn new(closure: CustomExecutorClosure, max_concurrency: Option) -> Self { + Self { + closure, + concurrency_limit: ConcurrencyLimit::new(max_concurrency), + } + } +} + +impl ComputeHeavyFutureExecutor for CustomExecutor { + async fn execute(&self, fut: F) -> Result + where + F: Future + Send + 'static, + O: Send + 'static, + { + let _permit = self.concurrency_limit.acquire_permit().await; + + let (wrapped_future, rx) = make_future_cancellable(fut); + + // if our custom executor future resolves to an error, we know it will never send + // the response so we immediately return + if let Err(err) = Box::into_pin((self.closure)(Box::pin(wrapped_future))).await { + return Err(Error::BoxError(err)); + } + + rx.await.map_err(|err| Error::RecvError(err)) + + // permit implicitly drops + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..af43323 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,45 @@ +use core::fmt; + +use crate::ExecutorStrategy; + +#[non_exhaustive] +#[derive(Debug)] +pub enum Error { + AlreadyInitialized(ExecutorStrategy), + InvalidConfig(InvalidConfig), + RecvError(tokio::sync::oneshot::error::RecvError), + Semaphore(tokio::sync::AcquireError), + BoxError(Box), + #[cfg(feature = "tokio")] + JoinError(tokio::task::JoinError), +} + +#[derive(Debug)] +pub struct InvalidConfig { + pub field: &'static str, + pub received: String, + pub expected: &'static str, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::AlreadyInitialized(strategy) => write!( + f, + "global strategy is already initialzed with strategy: {strategy:#?}" + ), + Error::InvalidConfig(err) => write!(f, "invalid config: {err:#?}"), + Error::BoxError(err) => write!(f, "custom executor error: {err}"), + Error::RecvError(err) => write!(f, "error in custom executor response channel: {err}"), + Error::Semaphore(err) => write!( + f, + "concurrency limiter semaphore channel is closed, continuing: {err}" + ), + #[cfg(feature = "tokio")] + Error::JoinError(err) => write!( + f, + "error joining tokio handle in spawn_blocking executor: {err}" + ), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index e69de29..db6a430 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -0,0 +1,577 @@ +#[cfg(feature = "tokio_block_in_place")] +mod block_in_place; +mod concurrency_limit; +mod current_context; +mod custom_executor; +pub mod error; +#[cfg(feature = "secondary_tokio_runtime")] +mod secondary_tokio_runtime; +#[cfg(feature = "tokio")] +mod spawn_blocking; + +pub use custom_executor::CustomExecutorClosure; +pub use error::Error; +#[cfg(feature = "secondary_tokio_runtime")] +pub use secondary_tokio_runtime::SecondaryTokioRuntimeStrategyBuilder; + +#[cfg(feature = "tokio_block_in_place")] +use block_in_place::BlockInPlaceExecutor; +use current_context::CurrentContextExecutor; +use custom_executor::CustomExecutor; +#[cfg(feature = "secondary_tokio_runtime")] +use secondary_tokio_runtime::SecondaryTokioRuntimeExecutor; +#[cfg(feature = "tokio")] +use spawn_blocking::SpawnBlockingExecutor; + +use std::{fmt::Debug, future::Future, sync::OnceLock}; + +use tokio::{select, sync::oneshot::Receiver}; + +// TODO: module docs, explain the point of this library, give some samples + +/// Initialize a builder to set the global compute heavy future +/// executor strategy. +#[must_use = "doesn't do anything unless used"] +pub fn global_strategy_builder() -> GlobalStrategyBuilder { + GlobalStrategyBuilder::default() +} + +/// Get the currently initialized strategy, or the default strategy for the +/// current feature and runtime type in case no strategy has been loaded. +pub fn global_strategy() -> CurrentStrategy { + match COMPUTE_HEAVY_FUTURE_EXECUTOR_STRATEGY.get() { + Some(strategy) => CurrentStrategy::Initialized(strategy.into()), + None => CurrentStrategy::Default(<&ExecutorStrategyImpl>::default().into()), + } +} + +#[must_use = "doesn't do anything unless used"] +#[derive(Default)] +pub struct GlobalStrategyBuilder { + max_concurrency: Option, +} + +impl GlobalStrategyBuilder { + /// Set the max number of simultaneous futures processed by this executor. + /// + /// If this number is exceeded, the futures sent to + /// [`execute_compute_heavy_future()`] will sleep until a permit + /// can be acquired. + /// + /// ## Default + /// No maximum concurrency + /// + /// # Example + /// + /// ``` + /// use compute_heavy_future_executor::global_strategy_builder; + /// + /// # async fn run() { + /// global_strategy_builder() + /// .max_concurrency(10) + /// .initialize_current_context() + /// .unwrap(); + /// # } + pub fn max_concurrency(self, max_task_concurrency: usize) -> Self { + Self { + max_concurrency: Some(max_task_concurrency), + ..self + } + } + + /// Initializes a new global strategy to wait in the current context. + /// + /// This is effectively a non-op wrapper that adds no special handling for the future besides optional concurrency control. + /// This is the default if the `tokio` feature is disabled. + /// + /// # Cancellation + /// Yes, the future is dropped if the caller drops the returned future from + ///[`execute_compute_heavy_future()`]. + /// + /// Note that it will only be dropped across yield points in the case of long-blocking futures. + /// + /// ## Error + /// Returns an error if the global strategy is already initialized. + /// It can only be initialized once. + /// + /// # Example + /// + /// ``` + /// use compute_heavy_future_executor::global_strategy_builder; + /// use compute_heavy_future_executor::execute_compute_heavy_future; + /// + /// # async fn run() { + /// global_strategy_builder().initialize_current_context().unwrap(); + /// + /// let future = async { + /// std::thread::sleep(std::time::Duration::from_millis(50)); + /// 5 + /// }; + /// + /// let res = execute_compute_heavy_future(future).await.unwrap(); + /// assert_eq!(res, 5); + /// # } + /// ``` + pub fn initialize_current_context(self) -> Result<(), Error> { + let strategy = + ExecutorStrategyImpl::CurrentContext(CurrentContextExecutor::new(self.max_concurrency)); + set_strategy(strategy) + } + + /// Initializes a new global strategy to execute futures by blocking on them inside the + /// tokio blocking threadpool. This is the default strategy if none is explicitly initialized, + /// if the `tokio` feature is enabled. + /// + /// By default, tokio will spin up a blocking thread + /// per task, which may be more than your count of CPU cores, depending on runtime config. + /// + /// If you expect many concurrent cpu-heavy futures, consider limiting your blocking + /// tokio threadpool size. + /// Or, you can use a heavier weight strategy like [`initialize_secondary_tokio_runtime()`]. + /// + /// # Cancellation + /// Yes, the future is dropped if the caller drops the returned future + /// from [`execute_compute_heavy_future()`]. + /// + /// Note that it will only be dropped across yield points in the case of long-blocking futures. + /// + /// ## Error + /// Returns an error if the global strategy is already initialized. + /// It can only be initialized once. + /// + /// # Example + /// + /// ``` + /// use compute_heavy_future_executor::global_strategy_builder; + /// use compute_heavy_future_executor::execute_compute_heavy_future; + /// + /// # async fn run() { + /// global_strategy_builder().initialize_spawn_blocking().unwrap(); + /// + /// let future = async { + /// std::thread::sleep(std::time::Duration::from_millis(50)); + /// 5 + /// }; + /// + /// let res = execute_compute_heavy_future(future).await.unwrap(); + /// assert_eq!(res, 5); + /// # } + /// ``` + #[cfg(feature = "tokio")] + pub fn initialize_spawn_blocking(self) -> Result<(), Error> { + let strategy = + ExecutorStrategyImpl::SpawnBlocking(SpawnBlockingExecutor::new(self.max_concurrency)); + set_strategy(strategy) + } + + /// Initializes a new global strategy to execute futures by calling tokio::task::block_in_place + /// on the current tokio worker thread. This evicts other tasks on same worker thread to + /// avoid blocking them. + /// + /// This approach can starve your executor of worker threads if called with too many + /// concurrent cpu-heavy futures. + /// + /// If you expect many concurrent cpu-heavy futures, consider a + /// heavier weight strategy like [`initialize_secondary_tokio_runtime()`]. + /// + /// # Cancellation + /// No, this strategy does not allow futures to be cancelled. + /// + /// ## Error + /// Returns an error if called from a context besides a tokio multithreaded runtime. + /// + /// Returns an error if the global strategy is already initialized. + /// It can only be initialized once. + /// + /// # Example + /// + /// ``` + /// use compute_heavy_future_executor::global_strategy_builder; + /// use compute_heavy_future_executor::execute_compute_heavy_future; + /// + /// # async fn run() { + /// global_strategy_builder().initialize_block_in_place().unwrap(); + /// + /// let future = async { + /// std::thread::sleep(std::time::Duration::from_millis(50)); + /// 5 + /// }; + /// + /// let res = execute_compute_heavy_future(future).await.unwrap(); + /// assert_eq!(res, 5); + /// # } + /// ``` + #[cfg(feature = "tokio_block_in_place")] + pub fn initialize_block_in_place(self) -> Result<(), Error> { + let strategy = + ExecutorStrategyImpl::BlockInPlace(BlockInPlaceExecutor::new(self.max_concurrency)?); + set_strategy(strategy) + } + + /// Initializes a new global strategy that spins up a secondary background tokio runtime + /// that executes futures on lower priority worker threads. + /// + /// This uses certain defaults, listed below. To modify these defaults, + /// instead use [`secondary_tokio_runtime_builder()`] + /// + /// # Defaults + /// ## Thread niceness + /// The thread niceness for the secondary runtime's worker threads, + /// which on linux is used to increase or lower relative + /// OS scheduling priority. + /// + /// Default: 10 + /// + /// ## Thread count + /// The count of worker threads in the secondary tokio runtime. + /// + /// Default: CPU core count + /// + /// ## Channel size + /// The buffer size of the channel used to spawn tasks + /// in the background executor. + /// + /// Default: 10 + /// + /// ## Max task concurrency + /// The max number of simultaneous background tasks running + /// + /// Default: no limit + /// + /// # Cancellation + /// Yes, the future is dropped if the caller drops the returned future + /// from [`execute_compute_heavy_future()`]. + /// + /// Note that it will only be dropped across yield points in the case of long-blocking futures. + /// + /// ## Error + /// Returns an error if the global strategy is already initialized. + /// It can only be initialized once. + /// + /// # Example + /// + /// ``` + /// use compute_heavy_future_executor::global_strategy_builder; + /// use compute_heavy_future_executor::execute_compute_heavy_future; + /// + /// # async fn run() { + /// global_strategy_builder().initialize_secondary_tokio_runtime().unwrap(); + /// + /// let future = async { + /// std::thread::sleep(std::time::Duration::from_millis(50)); + /// 5 + /// }; + /// + /// let res = execute_compute_heavy_future(future).await.unwrap(); + /// assert_eq!(res, 5); + /// # } + /// ``` + #[cfg(feature = "secondary_tokio_runtime")] + pub fn initialize_secondary_tokio_runtime(self) -> Result<(), Error> { + self.secondary_tokio_runtime_builder().initialize() + } + + /// Creates a [`SecondaryTokioRuntimeStrategyBuilder`] for a customized secondary tokio runtime strategy. + /// + /// Subsequent calls on the returned builder allow modifying defaults. + /// + /// The returned builder will require calling [`SecondaryTokioRuntimeStrategyBuilder::initialize()`] to + /// ultimately load the strategy. + /// + /// # Cancellation + /// Yes, the future is dropped if the caller drops the returned future + /// from [`execute_compute_heavy_future()`]. + /// + /// Note that it will only be dropped across yield points in the case of long-blocking futures. + /// + /// # Example + /// + /// ``` + /// use compute_heavy_future_executor::global_strategy_builder; + /// use compute_heavy_future_executor::execute_compute_heavy_future; + /// + /// # async fn run() { + /// global_strategy_builder() + /// .secondary_tokio_runtime_builder() + /// .niceness(1) + /// .thread_count(2) + /// .channel_size(3) + /// .max_concurrency(4) + /// .initialize() + /// .unwrap(); + /// + /// let future = async { + /// std::thread::sleep(std::time::Duration::from_millis(50)); + /// 5 + /// }; + /// + /// let res = execute_compute_heavy_future(future).await.unwrap(); + /// assert_eq!(res, 5); + /// # } + /// ``` + #[cfg(feature = "secondary_tokio_runtime")] + #[must_use = "doesn't do anything unless used"] + pub fn secondary_tokio_runtime_builder(self) -> SecondaryTokioRuntimeStrategyBuilder { + SecondaryTokioRuntimeStrategyBuilder::new(self.max_concurrency) + } + + /// Accepts a closure that will poll an arbitrary feature to completion. + /// + /// Intended for injecting arbitrary runtimes/strategies or customizing existing ones. + /// + /// # Cancellation + /// Yes, the closure's returned future is dropped if the caller drops the returned future from [`execute_compute_heavy_future()`]. + /// Note that it will only be dropped across yield points in the case of long-blocking futures. + /// + /// ## Error + /// Returns an error if the global strategy is already initialized. + /// It can only be initialized once. + /// + /// # Example + /// + /// ``` + /// use compute_heavy_future_executor::global_strategy_builder; + /// use compute_heavy_future_executor::execute_compute_heavy_future; + /// use compute_heavy_future_executor::CustomExecutorClosure; + /// + /// // this isn't actually a good strategy, to be clear + /// # async fn run() { + /// let closure: CustomExecutorClosure = Box::new(|fut| { + /// Box::new( + /// async move { + /// tokio::task::spawn(async move { fut.await }) + /// .await + /// .map_err(|err| err.into()) + /// } + /// ) + /// }); + /// + /// global_strategy_builder().initialize_custom_executor(closure).unwrap(); + /// + /// let future = async { + /// std::thread::sleep(std::time::Duration::from_millis(50)); + /// 5 + /// }; + /// + /// let res = execute_compute_heavy_future(future).await.unwrap(); + /// assert_eq!(res, 5); + /// # } + /// + /// ``` + pub fn initialize_custom_executor(self, closure: CustomExecutorClosure) -> Result<(), Error> { + let strategy = ExecutorStrategyImpl::CustomExecutor(CustomExecutor::new( + closure, + self.max_concurrency, + )); + set_strategy(strategy) + } +} + +pub(crate) fn set_strategy(strategy: ExecutorStrategyImpl) -> Result<(), Error> { + COMPUTE_HEAVY_FUTURE_EXECUTOR_STRATEGY + .set(strategy) + .map_err(|_| { + Error::AlreadyInitialized(COMPUTE_HEAVY_FUTURE_EXECUTOR_STRATEGY.get().unwrap().into()) + })?; + + log::info!( + "initialized compute-heavy future executor strategy - {:#?}", + global_strategy() + ); + + Ok(()) +} +trait ComputeHeavyFutureExecutor { + /// Accepts a future and returns its result + async fn execute(&self, fut: F) -> Result + where + F: Future + Send + 'static, + O: Send + 'static; +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum CurrentStrategy { + Default(ExecutorStrategy), + Initialized(ExecutorStrategy), +} + +#[non_exhaustive] +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ExecutorStrategy { + /// A non-op strategy that awaits in the current context + CurrentContext, + /// User-provided closure + CustomExecutor, + /// tokio task::spawn_blocking + #[cfg(feature = "tokio")] + SpawnBlocking, + /// tokio task::block_in_place + #[cfg(feature = "tokio_block_in_place")] + BlockInPlace, + #[cfg(feature = "secondary_tokio_runtime")] + /// Spin up a second, lower-priority tokio runtime + /// that communicates via channels + SecondaryTokioRuntime, +} + +impl From<&ExecutorStrategyImpl> for ExecutorStrategy { + fn from(value: &ExecutorStrategyImpl) -> Self { + match value { + ExecutorStrategyImpl::CurrentContext(_) => Self::CurrentContext, + ExecutorStrategyImpl::CustomExecutor(_) => Self::CustomExecutor, + #[cfg(feature = "tokio")] + ExecutorStrategyImpl::SpawnBlocking(_) => Self::SpawnBlocking, + #[cfg(feature = "tokio_block_in_place")] + ExecutorStrategyImpl::BlockInPlace(_) => Self::BlockInPlace, + #[cfg(feature = "secondary_tokio_runtime")] + ExecutorStrategyImpl::SecondaryTokioRuntime(_) => Self::SecondaryTokioRuntime, + } + } +} + +/// The stored strategy used to spawn compute-heavy futures. +static COMPUTE_HEAVY_FUTURE_EXECUTOR_STRATEGY: OnceLock = OnceLock::new(); + +#[non_exhaustive] +enum ExecutorStrategyImpl { + /// A non-op strategy that awaits in the current context + CurrentContext(CurrentContextExecutor), + /// User-provided closure + CustomExecutor(CustomExecutor), + /// tokio task::spawn_blocking + #[cfg(feature = "tokio")] + SpawnBlocking(SpawnBlockingExecutor), + /// tokio task::block_in_place + #[cfg(feature = "tokio_block_in_place")] + BlockInPlace(BlockInPlaceExecutor), + #[cfg(feature = "secondary_tokio_runtime")] + /// Spin up a second, lower-priority tokio runtime + /// that communicates via channels + SecondaryTokioRuntime(SecondaryTokioRuntimeExecutor), +} + +impl ComputeHeavyFutureExecutor for ExecutorStrategyImpl { + async fn execute(&self, fut: F) -> Result + where + F: Future + Send + 'static, + O: Send + 'static, + { + match self { + ExecutorStrategyImpl::CurrentContext(executor) => executor.execute(fut).await, + ExecutorStrategyImpl::CustomExecutor(executor) => executor.execute(fut).await, + #[cfg(feature = "tokio")] + ExecutorStrategyImpl::SpawnBlocking(executor) => executor.execute(fut).await, + #[cfg(feature = "tokio_block_in_place")] + ExecutorStrategyImpl::BlockInPlace(executor) => executor.execute(fut).await, + #[cfg(feature = "secondary_tokio_runtime")] + ExecutorStrategyImpl::SecondaryTokioRuntime(executor) => executor.execute(fut).await, + } + } +} + +/// The fallback strategy used in case no strategy is explicitly set +static DEFAULT_COMPUTE_HEAVY_FUTURE_EXECUTOR_STRATEGY: OnceLock = + OnceLock::new(); + +impl Default for &ExecutorStrategyImpl { + fn default() -> Self { + &DEFAULT_COMPUTE_HEAVY_FUTURE_EXECUTOR_STRATEGY.get_or_init(|| { + #[cfg(feature = "tokio")] + { + log::info!("Defaulting to SpawnBlocking strategy for compute-heavy future executor \ + until a strategy is initialized"); + + ExecutorStrategyImpl::SpawnBlocking(SpawnBlockingExecutor::new(None)) + } + + #[cfg(not(feature = "tokio"))] + { + log::warn!("Defaulting to CurrentContext (non-op) strategy for compute-heavy future executor \ + until a strategy is initialized."); + ExecutorStrategyImpl::CurrentContext(CurrentContextExecutor::new(None)) + } + }) + } +} + +/// Spawn a future to the configured compute-heavy executor and wait on its output. +/// +/// # Strategy selection +/// +/// If no strategy is configured, this library will fall back to the following defaults: +/// - no `tokio`` feature - current context +/// - all other cases - spawn blocking +/// +/// You can override these defaults by initializing a strategy via [`global_strategy_builder()`] +/// and [`GlobalStrategyBuilder`]. +/// +/// # Cancellation +/// +/// Most strategies will cancel the input future, if the caller drops the returned future, +/// with the following exception: +/// - the block in place strategy never cancels the future (until the executor is shut down) +/// +/// # Example +/// +/// ``` +/// # async fn run() { +/// use compute_heavy_future_executor::execute_compute_heavy_future; +/// +/// let future = async { +/// std::thread::sleep(std::time::Duration::from_millis(50)); +/// 5 +/// }; +/// +/// let res = execute_compute_heavy_future(future).await.unwrap(); +/// assert_eq!(res, 5); +/// # } +/// +/// ``` +/// +pub async fn execute_compute_heavy_future(fut: F) -> Result +where + F: Future + Send + 'static, + R: Send + 'static, +{ + let executor = COMPUTE_HEAVY_FUTURE_EXECUTOR_STRATEGY + .get() + .unwrap_or_else(|| <&ExecutorStrategyImpl>::default()); + match executor { + ExecutorStrategyImpl::CurrentContext(executor) => executor.execute(fut).await, + ExecutorStrategyImpl::CustomExecutor(executor) => executor.execute(fut).await, + #[cfg(feature = "tokio_block_in_place")] + ExecutorStrategyImpl::BlockInPlace(executor) => executor.execute(fut).await, + #[cfg(feature = "tokio")] + ExecutorStrategyImpl::SpawnBlocking(executor) => executor.execute(fut).await, + #[cfg(feature = "secondary_tokio_runtime")] + ExecutorStrategyImpl::SecondaryTokioRuntime(executor) => executor.execute(fut).await, + } +} + +pub fn make_future_cancellable(fut: F) -> (impl Future, Receiver) +where + F: std::future::Future + Send + 'static, + O: Send + 'static, +{ + let (mut tx, rx) = tokio::sync::oneshot::channel(); + let wrapped_future = async { + select! { + // if tx is closed, we always want to poll that future first, + // so we don't need to add rng + biased; + + _ = tx.closed() => { + // receiver already dropped, don't need to do anything + // cancel the background future + }, + result = fut => { + // if this fails, the receiver already dropped, so we don't need to do anything + let _ = tx.send(result); + } + } + }; + + (wrapped_future, rx) +} + +// tests are in /tests/ to allow separate initialization of oncelock across processes when using default cargo test runner diff --git a/src/secondary_tokio_runtime.rs b/src/secondary_tokio_runtime.rs new file mode 100644 index 0000000..318418c --- /dev/null +++ b/src/secondary_tokio_runtime.rs @@ -0,0 +1,221 @@ +use std::{future::Future, pin::Pin}; + +use tokio::sync::mpsc::Sender; + +use crate::{ + concurrency_limit::ConcurrencyLimit, + error::{Error, InvalidConfig}, + make_future_cancellable, set_strategy, ComputeHeavyFutureExecutor, ExecutorStrategyImpl, +}; + +const DEFAULT_NICENESS: i8 = 10; +const DEFAULT_CHANNEL_SIZE: usize = 10; + +fn default_thread_count() -> usize { + num_cpus::get() +} + +/// Extention of [`GlobalStrategyBuilder`] for a customized secondary tokio runtime strategy. +/// +/// Requires calling [`SecondaryTokioRuntimeStrategyBuilder::initialize()`] to +/// initialize the strategy. +/// +/// # Example +/// +/// ``` +/// use compute_heavy_future_executor::global_strategy_builder; +/// use compute_heavy_future_executor::execute_compute_heavy_future; +/// +/// # async fn run() { +/// global_strategy_builder() +/// .secondary_tokio_runtime_builder() +/// .niceness(1) +/// .thread_count(2) +/// .channel_size(3) +/// .max_concurrency(4) +/// .initialize() +/// .unwrap(); +/// # } +/// ``` +#[must_use = "doesn't do anything unless used"] +#[derive(Default)] +pub struct SecondaryTokioRuntimeStrategyBuilder { + niceness: Option, + thread_count: Option, + channel_size: Option, + // passed down from the parent `GlobalStrategy` builder, not modified internally + max_concurrency: Option, +} + +impl SecondaryTokioRuntimeStrategyBuilder { + pub(crate) fn new(max_concurrency: Option) -> Self { + Self { + max_concurrency, + ..Default::default() + } + } +} + +impl SecondaryTokioRuntimeStrategyBuilder { + /// Set the thread niceness for the secondary runtime's worker threads, + /// which on linux is used to increase or lower relative + /// OS scheduling priority. + /// + /// Allowed values are -20..=19 + /// + /// ## Default + /// + /// The default value is 10. + pub fn niceness(self, niceness: i8) -> Self { + Self { + niceness: Some(niceness), + ..self + } + } + + /// Set the count of worker threads in the secondary tokio runtime. + /// + /// ## Default + /// + /// The default value is the number of cpu cores + pub fn thread_count(self, thread_count: usize) -> Self { + Self { + thread_count: Some(thread_count), + ..self + } + } + + /// Set the buffer size of the channel used to spawn tasks + /// in the background executor. + /// + /// ## Default + /// + /// The default value is 10 + pub fn channel_size(self, channel_size: usize) -> Self { + Self { + channel_size: Some(channel_size), + ..self + } + } + + /// Set the max number of simultaneous futures processed by this executor. + /// + /// Yes, the future is dropped if the caller drops the returned future from + ///[`execute_compute_heavy_future()`]. + /// + /// ## Default + /// No maximum concurrency + pub fn max_concurrency(self, max_task_concurrency: usize) -> Self { + Self { + max_concurrency: Some(max_task_concurrency), + ..self + } + } + + pub fn initialize(self) -> Result<(), Error> { + let niceness = self.niceness.unwrap_or(DEFAULT_NICENESS); + + // please https://github.com/rust-lang/rfcs/issues/671 + if !(-20..=19).contains(&niceness) { + return Err(Error::InvalidConfig(InvalidConfig { + field: "niceness", + received: niceness.to_string(), + expected: "-20..=19", + })); + } + + let thread_count = self.thread_count.unwrap_or_else(|| default_thread_count()); + let channel_size = self.channel_size.unwrap_or(DEFAULT_CHANNEL_SIZE); + + let executor = SecondaryTokioRuntimeExecutor::new( + niceness, + thread_count, + channel_size, + self.max_concurrency, + ); + + set_strategy(ExecutorStrategyImpl::SecondaryTokioRuntime(executor)) + } +} + +type BackgroundFuture = Pin + Send + 'static>>; + +pub(crate) struct SecondaryTokioRuntimeExecutor { + tx: Sender, + concurrency_limit: ConcurrencyLimit, +} + +impl SecondaryTokioRuntimeExecutor { + pub(crate) fn new( + niceness: i8, + thread_count: usize, + channel_size: usize, + max_concurrency: Option, + ) -> Self { + // channel is only for routing work to new task::spawn so should be very quick + let (tx, mut rx) = tokio::sync::mpsc::channel(channel_size); + + std::thread::Builder::new() + .name("compute-heavy-executor".to_string()) + .spawn(move || { + let rt = tokio::runtime::Builder::new_multi_thread() + .thread_name("compute-heavy-executor-pool-thread") + .worker_threads(thread_count) + .on_thread_start(move || unsafe { + // Reduce thread pool thread niceness, so they are lower priority + // than the foreground executor and don't interfere with I/O tasks + #[cfg(target_os = "linux")] + { + *libc::__errno_location() = 0; + if libc::nice(niceness.into()) == -1 && *libc::__errno_location() != 0 { + let error = std::io::Error::last_os_error(); + log::error!("failed to set threadpool niceness of secondary compute-heavy tokio executor: {}", error); + } + } + }) + .enable_all() + .build() + .unwrap_or_else(|e| panic!("cpu heavy runtime failed_to_initialize: {}", e)); + + rt.block_on(async { + log::debug!("starting to process work on secondary compute-heavy tokio executor"); + + while let Some(work) = rx.recv().await { + tokio::task::spawn(async move { + work.await + }); + } + }); + log::warn!("exiting secondary compute heavy tokio runtime because foreground channel closed"); + }) + .unwrap_or_else(|e| panic!("secondary compute-heavy runtime thread failed_to_initialize: {}", e)); + + Self { + tx, + concurrency_limit: ConcurrencyLimit::new(max_concurrency), + } + } +} + +impl ComputeHeavyFutureExecutor for SecondaryTokioRuntimeExecutor { + async fn execute(&self, fut: F) -> Result + where + F: std::future::Future + Send + 'static, + O: Send + 'static, + { + let _permit = self.concurrency_limit.acquire_permit().await; + + let (wrapped_future, rx) = make_future_cancellable(fut); + + match self.tx.send(Box::pin(wrapped_future)).await { + Ok(_) => (), + Err(err) => { + panic!("secondary compute-heavy runtime channel cannot be reached: {err}") + } + } + + rx.await.map_err(|err| Error::RecvError(err)) + + // permit implicitly drops + } +} diff --git a/src/spawn_blocking.rs b/src/spawn_blocking.rs new file mode 100644 index 0000000..fee12a6 --- /dev/null +++ b/src/spawn_blocking.rs @@ -0,0 +1,38 @@ +use crate::{ + concurrency_limit::ConcurrencyLimit, error::Error, make_future_cancellable, + ComputeHeavyFutureExecutor, +}; + +pub(crate) struct SpawnBlockingExecutor { + concurrency_limit: ConcurrencyLimit, +} + +impl SpawnBlockingExecutor { + pub(crate) fn new(max_concurrency: Option) -> Self { + let concurrency_limit = ConcurrencyLimit::new(max_concurrency); + + Self { concurrency_limit } + } +} + +impl ComputeHeavyFutureExecutor for SpawnBlockingExecutor { + async fn execute(&self, fut: F) -> Result + where + F: std::future::Future + Send + 'static, + O: Send + 'static, + { + let _permit = self.concurrency_limit.acquire_permit().await; + + let (wrapped_future, rx) = make_future_cancellable(fut); + + if let Err(err) = tokio::task::spawn_blocking(move || { + tokio::runtime::Handle::current().block_on(wrapped_future) + }) + .await + { + return Err(Error::JoinError(err)); + } + + rx.await.map_err(|err| Error::RecvError(err)) + } +} diff --git a/tests/block_in_place_strategy.rs b/tests/block_in_place_strategy.rs new file mode 100644 index 0000000..c425260 --- /dev/null +++ b/tests/block_in_place_strategy.rs @@ -0,0 +1,58 @@ +#[cfg(feature = "tokio_block_in_place")] +mod test { + use std::time::Duration; + + use compute_heavy_future_executor::{ + execute_compute_heavy_future, global_strategy, global_strategy_builder, CurrentStrategy, + ExecutorStrategy, + }; + use futures_util::future::join_all; + + fn initialize() { + // we are racing all tests against the single oncelock + let _ = global_strategy_builder() + .max_concurrency(3) + .initialize_block_in_place(); + } + + #[cfg(feature = "tokio_block_in_place")] + #[tokio::test(flavor = "multi_thread")] + async fn block_in_place_strategy() { + initialize(); + + let future = async { 5 }; + + let res = execute_compute_heavy_future(future).await.unwrap(); + assert_eq!(res, 5); + + assert_eq!( + global_strategy(), + CurrentStrategy::Initialized(ExecutorStrategy::BlockInPlace) + ); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 10)] + async fn block_in_place_concurrency() { + initialize(); + + let start = std::time::Instant::now(); + + let mut futures = Vec::new(); + + for _ in 0..5 { + let future = async move { std::thread::sleep(Duration::from_millis(15)) }; + // we need to spawn here since otherwise block in place will cancel other futures inside the same task, + // ref https://docs.rs/tokio/latest/tokio/task/fn.block_in_place.html + let future = + tokio::task::spawn(async move { execute_compute_heavy_future(future).await }); + futures.push(future); + } + + join_all(futures).await; + + let elapsed_millis = start.elapsed().as_millis(); + assert!(elapsed_millis < 50, "futures did not run concurrently"); + + assert!(elapsed_millis > 20, "futures exceeded max concurrency"); + } +} diff --git a/tests/block_in_place_wrong_runtime.rs b/tests/block_in_place_wrong_runtime.rs new file mode 100644 index 0000000..dec7b72 --- /dev/null +++ b/tests/block_in_place_wrong_runtime.rs @@ -0,0 +1,9 @@ +#[cfg(feature = "tokio_block_in_place")] +#[tokio::test] +async fn block_in_place_wrong_runtime() { + use compute_heavy_future_executor::{global_strategy_builder, Error}; + + let res = global_strategy_builder().initialize_block_in_place(); + + assert!(matches!(res, Err(Error::InvalidConfig(_)))); +} diff --git a/tests/current_context_default.rs b/tests/current_context_default.rs new file mode 100644 index 0000000..3985b67 --- /dev/null +++ b/tests/current_context_default.rs @@ -0,0 +1,19 @@ +#[cfg(not(feature = "tokio"))] +#[tokio::test] +async fn default_to_current_context_tokio_single_threaded() { + use compute_heavy_future_executor::{ + execute_compute_heavy_future, global_strategy, CurrentStrategy, ExecutorStrategy, + }; + + // this is a tokio test but we haven't enabled the tokio config flag + + let future = async { 5 }; + + let res = execute_compute_heavy_future(future).await.unwrap(); + assert_eq!(res, 5); + + assert_eq!( + global_strategy(), + CurrentStrategy::Default(ExecutorStrategy::CurrentContext) + ); +} diff --git a/tests/current_context_strategy.rs b/tests/current_context_strategy.rs new file mode 100644 index 0000000..2669d82 --- /dev/null +++ b/tests/current_context_strategy.rs @@ -0,0 +1,78 @@ +use std::time::Duration; + +use compute_heavy_future_executor::{ + execute_compute_heavy_future, global_strategy, global_strategy_builder, CurrentStrategy, + ExecutorStrategy, +}; +use futures_util::future::join_all; +use tokio::select; + +fn initialize() { + // we are racing all tests against the single oncelock + let _ = global_strategy_builder() + .max_concurrency(3) + .initialize_current_context(); +} + +#[tokio::test] +async fn current_context_strategy() { + initialize(); + + let future = async { 5 }; + + let res = execute_compute_heavy_future(future).await.unwrap(); + assert_eq!(res, 5); + + assert_eq!( + global_strategy(), + CurrentStrategy::Initialized(ExecutorStrategy::CurrentContext) + ); +} + +#[tokio::test] +async fn current_context_cancellable() { + initialize(); + + let (tx, mut rx) = tokio::sync::oneshot::channel::<()>(); + let future = async move { + { + tokio::time::sleep(Duration::from_secs(60)).await; + let _ = tx.send(()); + } + }; + + select! { + _ = tokio::time::sleep(Duration::from_millis(4)) => { }, + _ = execute_compute_heavy_future(future) => {} + } + + tokio::time::sleep(Duration::from_millis(8)).await; + + // future should have been cancelled when spawn compute heavy future was dropped + assert_eq!( + rx.try_recv(), + Err(tokio::sync::oneshot::error::TryRecvError::Closed) + ); +} + +#[tokio::test] +async fn current_context_concurrency() { + initialize(); + + let start = std::time::Instant::now(); + + let mut futures = Vec::new(); + + for _ in 0..5 { + // can't use std::thread::sleep because this is all in the same thread + let future = async move { tokio::time::sleep(Duration::from_millis(15)).await }; + futures.push(execute_compute_heavy_future(future)); + } + + join_all(futures).await; + + let elapsed_millis = start.elapsed().as_millis(); + assert!(elapsed_millis < 50, "futures did not run concurrently"); + + assert!(elapsed_millis > 20, "futures exceeded max concurrency"); +} diff --git a/tests/custom_executor_simple.rs b/tests/custom_executor_simple.rs new file mode 100644 index 0000000..5999adc --- /dev/null +++ b/tests/custom_executor_simple.rs @@ -0,0 +1,28 @@ +use compute_heavy_future_executor::{ + execute_compute_heavy_future, global_strategy, global_strategy_builder, CurrentStrategy, + CustomExecutorClosure, ExecutorStrategy, +}; + +#[tokio::test] +async fn custom_executor_simple() { + let closure: CustomExecutorClosure = Box::new(|fut| { + Box::new(async move { + let res = fut.await; + Ok(res) + }) + }); + + global_strategy_builder() + .initialize_custom_executor(closure) + .unwrap(); + + let future = async { 5 }; + + let res = execute_compute_heavy_future(future).await.unwrap(); + assert_eq!(res, 5); + + assert_eq!( + global_strategy(), + CurrentStrategy::Initialized(ExecutorStrategy::CustomExecutor) + ); +} diff --git a/tests/custom_executor_strategy.rs b/tests/custom_executor_strategy.rs new file mode 100644 index 0000000..210f5bf --- /dev/null +++ b/tests/custom_executor_strategy.rs @@ -0,0 +1,80 @@ +use std::time::Duration; + +use compute_heavy_future_executor::{ + execute_compute_heavy_future, global_strategy_builder, CustomExecutorClosure, +}; +use futures_util::future::join_all; +use tokio::select; + +fn initialize() { + let closure: CustomExecutorClosure = Box::new(|fut| { + Box::new(async move { + tokio::task::spawn(async move { fut.await }) + .await + .map_err(|err| err.into()) + }) + }); + + // we are racing all tests against the single oncelock + let _ = global_strategy_builder() + .max_concurrency(3) + .initialize_custom_executor(closure); +} + +#[tokio::test] +async fn custom_executor_strategy() { + initialize(); + + let future = async { 5 }; + + let res = execute_compute_heavy_future(future).await.unwrap(); + assert_eq!(res, 5); +} + +#[tokio::test] +async fn custom_executor_concurrency() { + initialize(); + + let start = std::time::Instant::now(); + + let mut futures = Vec::new(); + + for _ in 0..5 { + // can't use std::thread::sleep because this is all in the same thread + let future = async move { tokio::time::sleep(Duration::from_millis(15)).await }; + futures.push(execute_compute_heavy_future(future)); + } + + join_all(futures).await; + + let elapsed_millis = start.elapsed().as_millis(); + assert!(elapsed_millis < 50, "futures did not run concurrently"); + + assert!(elapsed_millis > 20, "futures exceeded max concurrency"); +} + +#[tokio::test] +async fn custom_executor_cancellable() { + initialize(); + + let (tx, mut rx) = tokio::sync::oneshot::channel::<()>(); + let future = async move { + { + tokio::time::sleep(Duration::from_secs(60)).await; + let _ = tx.send(()); + } + }; + + select! { + _ = tokio::time::sleep(Duration::from_millis(4)) => { }, + _ = execute_compute_heavy_future(future) => {} + } + + tokio::time::sleep(Duration::from_millis(8)).await; + + // future should have been cancelled when spawn compute heavy future was dropped + assert_eq!( + rx.try_recv(), + Err(tokio::sync::oneshot::error::TryRecvError::Closed) + ); +} diff --git a/tests/multiple_initialize_err.rs b/tests/multiple_initialize_err.rs new file mode 100644 index 0000000..5ad3e2d --- /dev/null +++ b/tests/multiple_initialize_err.rs @@ -0,0 +1,13 @@ +use compute_heavy_future_executor::{error::Error, global_strategy_builder}; + +#[test] +fn multiple_initialize_err() { + global_strategy_builder() + .initialize_current_context() + .unwrap(); + + assert!(matches!( + global_strategy_builder().initialize_current_context(), + Err(Error::AlreadyInitialized(_)) + )); +} diff --git a/tests/multiple_initialize_err_with_secondary_runtime_builder.rs b/tests/multiple_initialize_err_with_secondary_runtime_builder.rs new file mode 100644 index 0000000..5ec6f1b --- /dev/null +++ b/tests/multiple_initialize_err_with_secondary_runtime_builder.rs @@ -0,0 +1,16 @@ +#[cfg(feature = "secondary_tokio_runtime")] +#[test] +fn multiple_initialize_err_with_secondary_runtime_builder() { + use compute_heavy_future_executor::{error::Error, global_strategy_builder}; + + let builder = global_strategy_builder().secondary_tokio_runtime_builder(); // not yet initialized + + global_strategy_builder() + .initialize_current_context() + .unwrap(); + + assert!(matches!( + builder.initialize(), + Err(Error::AlreadyInitialized(_)) + )); +} diff --git a/tests/secondary_tokio_builder_allowed_config.rs b/tests/secondary_tokio_builder_allowed_config.rs new file mode 100644 index 0000000..c04b7da --- /dev/null +++ b/tests/secondary_tokio_builder_allowed_config.rs @@ -0,0 +1,19 @@ +#[cfg(feature = "secondary_tokio_runtime")] +#[tokio::test] +async fn secondary_tokio_runtime_builder_allowed_config() { + use compute_heavy_future_executor::{execute_compute_heavy_future, global_strategy_builder}; + + global_strategy_builder() + .max_concurrency(5) + .secondary_tokio_runtime_builder() + .channel_size(10) + .niceness(5) + .thread_count(2) + .initialize() + .unwrap(); + + let future = async { 5 }; + + let res = execute_compute_heavy_future(future).await.unwrap(); + assert_eq!(res, 5); +} diff --git a/tests/secondary_tokio_builder_disallowed_config.rs b/tests/secondary_tokio_builder_disallowed_config.rs new file mode 100644 index 0000000..51b512a --- /dev/null +++ b/tests/secondary_tokio_builder_disallowed_config.rs @@ -0,0 +1,14 @@ +#[cfg(feature = "secondary_tokio_runtime")] +#[tokio::test] +#[should_panic] +async fn secondary_tokio_runtime_builder_disallowed_config() { + use compute_heavy_future_executor::{error::Error, global_strategy_builder}; + + let res = global_strategy_builder() + .secondary_tokio_runtime_builder() + .channel_size(10) + .niceness(5) + .initialize(); + + assert!(matches!(res, Err(Error::InvalidConfig(_)))); +} diff --git a/tests/secondary_tokio_strategy.rs b/tests/secondary_tokio_strategy.rs new file mode 100644 index 0000000..c0379f5 --- /dev/null +++ b/tests/secondary_tokio_strategy.rs @@ -0,0 +1,81 @@ +#[cfg(feature = "secondary_tokio_runtime")] +mod test { + use std::time::Duration; + + use futures_util::future::join_all; + use tokio::select; + + use compute_heavy_future_executor::{ + execute_compute_heavy_future, global_strategy, global_strategy_builder, CurrentStrategy, + ExecutorStrategy, + }; + + fn initialize() { + // we are racing all tests against the single oncelock + let _ = global_strategy_builder() + .max_concurrency(3) + .initialize_secondary_tokio_runtime(); + } + + #[tokio::test] + async fn secondary_tokio_runtime_strategy() { + initialize(); + + let future = async { 5 }; + + let res = execute_compute_heavy_future(future).await.unwrap(); + assert_eq!(res, 5); + + assert_eq!( + global_strategy(), + CurrentStrategy::Initialized(ExecutorStrategy::SecondaryTokioRuntime) + ); + } + + #[tokio::test] + async fn secondary_tokio_runtime_concurrency() { + initialize(); + + let start = std::time::Instant::now(); + + let mut futures = Vec::new(); + + for _ in 0..5 { + let future = async move { std::thread::sleep(Duration::from_millis(15)) }; + futures.push(execute_compute_heavy_future(future)); + } + + join_all(futures).await; + + let elapsed_millis = start.elapsed().as_millis(); + assert!(elapsed_millis < 50, "futures did not run concurrently"); + + assert!(elapsed_millis > 20, "futures exceeded max concurrency"); + } + + #[tokio::test] + async fn secondary_tokio_runtime_strategy_cancel_safe() { + initialize(); + + let (tx, mut rx) = tokio::sync::oneshot::channel::<()>(); + let future = async move { + { + tokio::time::sleep(Duration::from_secs(60)).await; + let _ = tx.send(()); + } + }; + + select! { + _ = tokio::time::sleep(Duration::from_millis(4)) => { }, + _ = execute_compute_heavy_future(future) => {} + } + + tokio::time::sleep(Duration::from_millis(8)).await; + + // future should have been cancelled when spawn compute heavy future was dropped + assert_eq!( + rx.try_recv(), + Err(tokio::sync::oneshot::error::TryRecvError::Closed) + ); + } +} diff --git a/tests/spawn_blocking_default.rs b/tests/spawn_blocking_default.rs new file mode 100644 index 0000000..3214f99 --- /dev/null +++ b/tests/spawn_blocking_default.rs @@ -0,0 +1,17 @@ +#[cfg(feature = "tokio")] +#[tokio::test] +async fn spawn_blocking_strategy() { + use compute_heavy_future_executor::{ + execute_compute_heavy_future, global_strategy, CurrentStrategy, ExecutorStrategy, + }; + + let future = async { 5 }; + + let res = execute_compute_heavy_future(future).await.unwrap(); + assert_eq!(res, 5); + + assert_eq!( + global_strategy(), + CurrentStrategy::Default(ExecutorStrategy::SpawnBlocking) + ); +} diff --git a/tests/spawn_blocking_strategy.rs b/tests/spawn_blocking_strategy.rs new file mode 100644 index 0000000..b44d05f --- /dev/null +++ b/tests/spawn_blocking_strategy.rs @@ -0,0 +1,80 @@ +#[cfg(feature = "tokio")] +mod test { + use std::time::Duration; + + use futures_util::future::join_all; + use tokio::select; + + use compute_heavy_future_executor::{ + execute_compute_heavy_future, global_strategy, global_strategy_builder, CurrentStrategy, + ExecutorStrategy, + }; + + fn initialize() { + // we are racing all tests against the single oncelock + let _ = global_strategy_builder() + .max_concurrency(3) + .initialize_spawn_blocking(); + } + + #[tokio::test] + async fn spawn_blocking_strategy() { + initialize(); + + let future = async { 5 }; + + let res = execute_compute_heavy_future(future).await.unwrap(); + assert_eq!(res, 5); + + assert_eq!( + global_strategy(), + CurrentStrategy::Initialized(ExecutorStrategy::SpawnBlocking) + ); + } + + #[tokio::test] + async fn spawn_blocking_concurrency() { + initialize(); + let start = std::time::Instant::now(); + + let mut futures = Vec::new(); + + for _ in 0..5 { + let future = async move { std::thread::sleep(Duration::from_millis(15)) }; + futures.push(execute_compute_heavy_future(future)); + } + + join_all(futures).await; + + let elapsed_millis = start.elapsed().as_millis(); + assert!(elapsed_millis < 60, "futures did not run concurrently"); + + assert!(elapsed_millis > 20, "futures exceeded max concurrency"); + } + + #[tokio::test] + async fn spawn_blocking_strategy_cancellable() { + initialize(); + + let (tx, mut rx) = tokio::sync::oneshot::channel::<()>(); + let future = async move { + { + tokio::time::sleep(Duration::from_secs(60)).await; + let _ = tx.send(()); + } + }; + + select! { + _ = tokio::time::sleep(Duration::from_millis(4)) => { }, + _ = execute_compute_heavy_future(future) => {} + } + + tokio::time::sleep(Duration::from_millis(8)).await; + + // future should have been cancelled when spawn compute heavy future was dropped + assert_eq!( + rx.try_recv(), + Err(tokio::sync::oneshot::error::TryRecvError::Closed) + ); + } +}