diff --git a/.gitignore b/.gitignore index e7c54e24a..c81a5cf2e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ Cargo.lock /.idea/ *.iml +*~ # Ignore generated protobuf files src/protos/*.rs diff --git a/client/Cargo.toml b/client/Cargo.toml index 065ac6ade..32392f37d 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -23,6 +23,7 @@ once_cell = "1.13" opentelemetry = { workspace = true, features = ["metrics"] } parking_lot = "0.12" prost-types = "0.11" +slotmap = "1.0" thiserror = "1.0" tokio = "1.1" tonic = { workspace = true, features = ["tls", "tls-roots"] } diff --git a/client/src/lib.rs b/client/src/lib.rs index 71a4e1350..b634c6087 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -6,10 +6,10 @@ #[macro_use] extern crate tracing; - mod metrics; mod raw; mod retry; +mod worker_registry; mod workflow_handle; pub use crate::retry::{CallType, RetryClient, RETRYABLE_ERROR_CODES}; @@ -23,6 +23,7 @@ pub use temporal_sdk_core_protos::temporal::api::{ }, }; pub use tonic; +pub use worker_registry::{Slot, SlotManager, SlotProvider, WorkerKey}; pub use workflow_handle::{WorkflowExecutionInfo, WorkflowExecutionResult}; use crate::{ @@ -280,6 +281,7 @@ pub struct ConfiguredClient { headers: Arc>>, /// Capabilities as read from the `get_system_info` RPC call made on client connection capabilities: Option, + workers: Arc, } impl ConfiguredClient { @@ -299,6 +301,11 @@ impl ConfiguredClient { pub fn capabilities(&self) -> Option<&get_system_info_response::Capabilities> { self.capabilities.as_ref() } + + /// Returns a cloned reference to a registry with workers using this client instance + pub fn workers(&self) -> Arc { + self.workers.clone() + } } // The configured client is effectively a "smart" (dumb) pointer @@ -377,6 +384,7 @@ impl ClientOptions { client: TemporalServiceClient::new(svc), options: Arc::new(self.clone()), capabilities: None, + workers: Arc::new(SlotManager::new()), }; match client .get_system_info(GetSystemInfoRequest::default()) @@ -974,6 +982,10 @@ pub struct WorkflowOptions { /// Optionally associate extra search attributes with a workflow pub search_attributes: Option>, + + /// Optionally enable Eager Workflow Start, a latency optimization using local workers + /// NOTE: Experimental and incompatible with versioning with BuildIDs + pub enable_eager_workflow_start: bool, } #[async_trait::async_trait] @@ -988,7 +1000,7 @@ impl WorkflowClientTrait for Client { options: WorkflowOptions, ) -> Result { Ok(WorkflowService::start_workflow_execution( - &mut self.inner.client.clone(), + &mut self.inner.clone(), StartWorkflowExecutionRequest { namespace: self.namespace.clone(), input: input.into_payloads(), @@ -1006,10 +1018,11 @@ impl WorkflowClientTrait for Client { workflow_execution_timeout: options .execution_timeout .and_then(|d| d.try_into().ok()), - workflow_run_timeout: options.execution_timeout.and_then(|d| d.try_into().ok()), + workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()), workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()), search_attributes: options.search_attributes.and_then(|d| d.try_into().ok()), cron_schedule: options.cron_schedule.unwrap_or_default(), + request_eager_execution: options.enable_eager_workflow_start, ..Default::default() }, ) @@ -1169,9 +1182,7 @@ impl WorkflowClientTrait for Client { workflow_execution_timeout: workflow_options .execution_timeout .and_then(|d| d.try_into().ok()), - workflow_run_timeout: workflow_options - .execution_timeout - .and_then(|d| d.try_into().ok()), + workflow_run_timeout: workflow_options.run_timeout.and_then(|d| d.try_into().ok()), workflow_task_timeout: workflow_options .task_timeout .and_then(|d| d.try_into().ok()), diff --git a/client/src/raw.rs b/client/src/raw.rs index ad260f922..0874c8e2c 100644 --- a/client/src/raw.rs +++ b/client/src/raw.rs @@ -5,10 +5,12 @@ use crate::{ metrics::{namespace_kv, task_queue_kv}, raw::sealed::RawClientLike, + worker_registry::{Slot, SlotManager}, Client, ConfiguredClient, InterceptedMetricsSvc, RetryClient, TemporalServiceClient, LONG_POLL_TIMEOUT, }; use futures::{future::BoxFuture, FutureExt, TryFutureExt}; +use std::sync::Arc; use temporal_sdk_core_api::telemetry::metrics::MetricKeyValue; use temporal_sdk_core_protos::{ grpc::health::v1::{health_client::HealthClient, *}, @@ -55,6 +57,9 @@ pub(super) mod sealed { /// Return a mutable ref to the health service client instance fn health_client_mut(&mut self) -> &mut HealthClient; + /// Return a registry with workers using this client instance + fn get_workers_info(&self) -> Option>; + async fn call( &mut self, _call_name: &'static str, @@ -111,6 +116,10 @@ where self.get_client_mut().health_client_mut() } + fn get_workers_info(&self) -> Option> { + self.get_client().get_workers_info() + } + async fn call( &mut self, call_name: &'static str, @@ -173,6 +182,10 @@ where fn health_client_mut(&mut self) -> &mut HealthClient { self.health_svc_mut() } + + fn get_workers_info(&self) -> Option> { + None + } } impl RawClientLike for ConfiguredClient> @@ -216,6 +229,10 @@ where fn health_client_mut(&mut self) -> &mut HealthClient { self.client.health_client_mut() } + + fn get_workers_info(&self) -> Option> { + Some(self.workers()) + } } impl RawClientLike for Client { @@ -252,6 +269,10 @@ impl RawClientLike for Client { fn health_client_mut(&mut self) -> &mut HealthClient { self.inner.health_client_mut() } + + fn get_workers_info(&self) -> Option> { + self.inner.get_workers_info() + } } /// Helper for cloning a tonic request as long as the inner message may be cloned. @@ -356,10 +377,29 @@ macro_rules! proxy { self.call(stringify!($method), fact, request.into_request()) } }; + ($client_type:tt, $client_meth:ident, $method:ident, $req:ty, $resp:ty, + $closure_before:expr, $closure_after:expr) => { + #[doc = concat!("See [", stringify!($client_type), "::", stringify!($method), "]")] + fn $method( + &mut self, + request: impl tonic::IntoRequest<$req>, + ) -> BoxFuture, tonic::Status>> { + #[allow(unused_mut)] + let fact = |c: &mut Self, mut req: tonic::Request<$req>| { + let data = type_closure_two_arg(&mut req, c.get_workers_info().unwrap(), + $closure_before); + let mut c = c.$client_meth().clone(); + async move { + type_closure_two_arg(c.$method(req).await, data, $closure_after) + }.boxed() + }; + self.call(stringify!($method), fact, request.into_request()) + } + }; } macro_rules! proxier { ( $trait_name:ident; $impl_list_name:ident; $client_type:tt; $client_meth:ident; - $(($method:ident, $req:ty, $resp:ty $(, $closure:expr)? );)* ) => { + $(($method:ident, $req:ty, $resp:ty $(, $closure:expr $(, $closure_after:expr)?)? );)* ) => { #[cfg(test)] const $impl_list_name: &'static [&'static str] = &[$(stringify!($method)),*]; /// Trait version of the generated client with modifications to attach appropriate metric @@ -377,7 +417,8 @@ macro_rules! proxier { as tonic::codegen::Body>::Error: Into + Send, { $( - proxy!($client_type, $client_meth, $method, $req, $resp $(,$closure)*); + proxy!($client_type, $client_meth, $method, $req, $resp + $(,$closure $(,$closure_after)*)*); )* } }; @@ -388,6 +429,10 @@ fn type_closure_arg(arg: T, f: impl FnOnce(T) -> R) -> R { f(arg) } +fn type_closure_two_arg(arg1: R, arg2: T, f: impl FnOnce(R, T) -> S) -> S { + f(arg1, arg2) +} + proxier! { WorkflowService; ALL_IMPLEMENTED_WORKFLOW_SERVICE_RPCS; WorkflowServiceClient; workflow_client_mut; ( @@ -435,10 +480,35 @@ proxier! { start_workflow_execution, StartWorkflowExecutionRequest, StartWorkflowExecutionResponse, - |r| { + |r, workers| { + let mut slot: Option> = None; let mut labels = AttachMetricLabels::namespace(r.get_ref().namespace.clone()); labels.task_q(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); + let req_mut = r.get_mut(); + if req_mut.request_eager_execution { + let namespace = req_mut.namespace.clone(); + let task_queue = req_mut.task_queue.clone().unwrap().name.clone(); + match workers.try_reserve_wft_slot(namespace, task_queue) { + Some(s) => slot = Some(s), + None => req_mut.request_eager_execution = false + } + } + slot + }, + |resp, slot| { + if let Some(mut s) = slot { + if let Ok(response) = resp.as_ref() { + if let Some(task) = response.get_ref().clone().eager_workflow_task { + if let Err(e) = s.schedule_wft(task) { + // This is a latency issue, i.e., the client does not need to handle + // this error, because the WFT will be retried after a timeout. + warn!(details = ?e, "Eager workflow task rejected by worker."); + } + } + } + } + resp } ); ( diff --git a/client/src/worker_registry/mod.rs b/client/src/worker_registry/mod.rs new file mode 100644 index 000000000..302c622b0 --- /dev/null +++ b/client/src/worker_registry/mod.rs @@ -0,0 +1,264 @@ +//! This module enables the tracking of workers that are associated with a client instance. +//! This is needed to implement Eager Workflow Start, a latency optimization in which the client, +//! after reserving a slot, directly forwards a WFT to a local worker. + +use parking_lot::RwLock; +use slotmap::SlotMap; +use std::collections::{hash_map::Entry::Vacant, HashMap}; + +use temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollWorkflowTaskQueueResponse; + +slotmap::new_key_type! { + /// Registration key for a worker + pub struct WorkerKey; +} + +/// This trait is implemented by an object associated with a worker, which provides WFT processing slots. +#[cfg_attr(test, mockall::automock)] +pub trait SlotProvider: std::fmt::Debug { + /// The namespace for the WFTs that it can process. + fn namespace(&self) -> &str; + /// The task queue this provider listens to. + fn task_queue(&self) -> &str; + /// Try to reserve a slot on this worker. + fn try_reserve_wft_slot(&self) -> Option>; +} + +/// This trait represents a slot reserved for processing a WFT by a worker. +#[cfg_attr(test, mockall::automock)] +pub trait Slot { + /// Consumes this slot by dispatching a WFT to its worker. This can only be called once. + fn schedule_wft( + self: Box, + task: PollWorkflowTaskQueueResponse, + ) -> Result<(), anyhow::Error>; +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +struct SlotKey { + namespace: String, + task_queue: String, +} + +impl SlotKey { + fn new(namespace: String, task_queue: String) -> SlotKey { + SlotKey { + namespace, + task_queue, + } + } +} + +/// This is an inner class for [SlotManager] needed to hide the mutex. +#[derive(Default, Debug)] +struct SlotManagerImpl { + /// Maps keys, i.e., namespace#task_queue, to provider. + providers: HashMap>, + /// Maps ids to keys in `providers`. + index: SlotMap, +} + +impl SlotManagerImpl { + /// Factory method. + fn new() -> Self { + Self { + index: Default::default(), + providers: Default::default(), + } + } + + fn try_reserve_wft_slot( + &self, + namespace: String, + task_queue: String, + ) -> Option> { + let key = SlotKey::new(namespace, task_queue); + if let Some(p) = self.providers.get(&key) { + if let Some(slot) = p.try_reserve_wft_slot() { + return Some(slot); + } + } + None + } + + fn register(&mut self, provider: Box) -> Option { + let key = SlotKey::new( + provider.namespace().to_string(), + provider.task_queue().to_string(), + ); + if let Vacant(p) = self.providers.entry(key.clone()) { + p.insert(provider); + Some(self.index.insert(key)) + } else { + warn!("Ignoring registration for worker: {key:?}."); + None + } + } + + fn unregister(&mut self, id: WorkerKey) { + if let Some(key) = self.index.remove(id) { + self.providers.remove(&key); + } + } + + #[cfg(test)] + fn num_providers(&self) -> (usize, usize) { + (self.index.len(), self.providers.len()) + } +} + +/// Enables local workers to made themselves visible to a shared client instance. +/// There can only be one worker registered per namespace+queue_name+client, others will get ignored. +/// It also provides a convenient method to find compatible slots within the collection. +#[derive(Default, Debug)] +pub struct SlotManager { + manager: RwLock, +} + +impl SlotManager { + /// Factory method. + pub fn new() -> Self { + Self { + manager: RwLock::new(SlotManagerImpl::new()), + } + } + + /// Try to reserve a compatible processing slot in any of the registered workers. + pub(crate) fn try_reserve_wft_slot( + &self, + namespace: String, + task_queue: String, + ) -> Option> { + self.manager + .read() + .try_reserve_wft_slot(namespace, task_queue) + } + + /// Register a local worker that can provide WFT processing slots. + pub fn register(&self, provider: Box) -> Option { + self.manager.write().register(provider) + } + + /// Unregister a provider, typically when its worker starts shutdown. + pub fn unregister(&self, id: WorkerKey) { + self.manager.write().unregister(id) + } + + #[cfg(test)] + /// Returns (num_providers, num_buckets), where a bucket key is namespace+task_queue. + /// There is only one provider per bucket so `num_providers` should be equal to `num_buckets`. + pub fn num_providers(&self) -> (usize, usize) { + self.manager.read().num_providers() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn new_mock_slot(with_error: bool) -> Box { + let mut mock_slot = MockSlot::new(); + if with_error { + mock_slot + .expect_schedule_wft() + .returning(|_| Err(anyhow::anyhow!("Changed my mind"))); + } else { + mock_slot.expect_schedule_wft().returning(|_| Ok(())); + } + Box::new(mock_slot) + } + + fn new_mock_provider( + namespace: String, + task_queue: String, + with_error: bool, + no_slots: bool, + ) -> MockSlotProvider { + let mut mock_provider = MockSlotProvider::new(); + mock_provider + .expect_try_reserve_wft_slot() + .returning(move || { + if no_slots { + None + } else { + Some(new_mock_slot(with_error)) + } + }); + mock_provider.expect_namespace().return_const(namespace); + mock_provider.expect_task_queue().return_const(task_queue); + mock_provider + } + + #[test] + fn registry_respects_registration_order() { + let mock_provider1 = + new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false); + let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true); + + let manager = SlotManager::new(); + let some_slots = manager.register(Box::new(mock_provider1)); + let no_slots = manager.register(Box::new(mock_provider2)); + assert!(no_slots.is_none()); + + let mut found = 0; + for _ in 0..10 { + if manager + .try_reserve_wft_slot("foo".to_string(), "bar_q".to_string()) + .is_some() + { + found += 1; + } + } + assert_eq!(found, 10); + assert_eq!((1, 1), manager.num_providers()); + + manager.unregister(some_slots.unwrap()); + assert_eq!((0, 0), manager.num_providers()); + + let mock_provider1 = + new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false); + let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true); + + let no_slots = manager.register(Box::new(mock_provider2)); + let some_slots = manager.register(Box::new(mock_provider1)); + assert!(some_slots.is_none()); + + let mut not_found = 0; + for _ in 0..10 { + if manager + .try_reserve_wft_slot("foo".to_string(), "bar_q".to_string()) + .is_none() + { + not_found += 1; + } + } + assert_eq!(not_found, 10); + assert_eq!((1, 1), manager.num_providers()); + manager.unregister(no_slots.unwrap()); + assert_eq!((0, 0), manager.num_providers()); + } + + #[test] + fn registry_keeps_one_provider_per_namespace() { + let manager = SlotManager::new(); + let mut worker_keys = vec![]; + for i in 0..10 { + let namespace = format!("myId{}", i % 3); + let mock_provider = new_mock_provider(namespace, "bar_q".to_string(), false, false); + worker_keys.push(manager.register(Box::new(mock_provider))); + } + assert_eq!((3, 3), manager.num_providers()); + + let count = worker_keys + .iter() + .filter(|key| key.is_some()) + .fold(0, |count, key| { + manager.unregister(key.unwrap()); + // Should be idempotent + manager.unregister(key.unwrap()); + count + 1 + }); + assert_eq!(3, count); + assert_eq!((0, 0), manager.num_providers()); + } +} diff --git a/core/src/worker/client.rs b/core/src/worker/client.rs index 9d288c926..25bd3bb92 100644 --- a/core/src/worker/client.rs +++ b/core/src/worker/client.rs @@ -1,8 +1,8 @@ //! Worker-specific client needs pub(crate) mod mocks; - -use temporal_client::{Client, RetryClient, WorkflowService}; +use std::sync::Arc; +use temporal_client::{Client, RetryClient, SlotManager, WorkflowService}; use temporal_sdk_core_protos::{ coresdk::workflow_commands::QueryResult, temporal::api::{ @@ -144,6 +144,7 @@ pub(crate) trait WorkerClient: Sync + Send { #[allow(clippy::needless_lifetimes)] // Clippy is wrong here fn capabilities<'a>(&'a self) -> Option<&'a get_system_info_response::Capabilities>; + fn workers(&self) -> Arc; } #[async_trait::async_trait] @@ -382,6 +383,10 @@ impl WorkerClient for WorkerClientBag { fn capabilities(&self) -> Option<&Capabilities> { self.client.get_client().inner().capabilities() } + + fn workers(&self) -> Arc { + self.client.get_client().inner().workers() + } } /// A version of [RespondWorkflowTaskCompletedRequest] that will finish being filled out by the diff --git a/core/src/worker/client/mocks.rs b/core/src/worker/client/mocks.rs index f92ec9f67..3864e3c02 100644 --- a/core/src/worker/client/mocks.rs +++ b/core/src/worker/client/mocks.rs @@ -1,5 +1,12 @@ use super::*; use futures::Future; +use lazy_static::lazy_static; +use std::sync::Arc; +use temporal_client::SlotManager; + +lazy_static! { + pub(crate) static ref DEFAULT_WORKERS_REGISTRY: Arc = Arc::new(SlotManager::new()); +} pub(crate) static DEFAULT_TEST_CAPABILITIES: &Capabilities = &Capabilities { signal_and_query_header: true, @@ -19,6 +26,8 @@ pub(crate) fn mock_workflow_client() -> MockWorkerClient { let mut r = MockWorkerClient::new(); r.expect_capabilities() .returning(|| Some(DEFAULT_TEST_CAPABILITIES)); + r.expect_workers() + .returning(|| DEFAULT_WORKERS_REGISTRY.clone()); r } @@ -27,6 +36,8 @@ pub(crate) fn mock_manual_workflow_client() -> MockManualWorkerClient { let mut r = MockManualWorkerClient::new(); r.expect_capabilities() .returning(|| Some(DEFAULT_TEST_CAPABILITIES)); + r.expect_workers() + .returning(|| DEFAULT_WORKERS_REGISTRY.clone()); r } @@ -103,5 +114,7 @@ mockall::mock! { where 'a: 'b, Self: 'b; fn capabilities(&self) -> Option<&'static get_system_info_response::Capabilities>; + + fn workers(&self) -> Arc; } } diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 9ea2ccbf2..cb040d44f 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -1,5 +1,6 @@ mod activities; pub(crate) mod client; +mod slot_provider; mod workflow; pub use temporal_sdk_core_api::worker::{WorkerConfig, WorkerConfigBuilder}; @@ -12,6 +13,8 @@ pub(crate) use activities::{ }; pub(crate) use workflow::{wft_poller::new_wft_poller, LEGACY_QUERY_ID}; +use temporal_client::WorkerKey; + use crate::{ abstractions::{dbg_panic, MeteredSemaphore}, errors::CompleteWfError, @@ -34,6 +37,8 @@ use crate::{ ActivityHeartbeat, CompleteActivityError, PollActivityError, PollWfError, WorkerTrait, }; use activities::WorkerActivityTasks; +use futures_util::stream; +use slot_provider::SlotProvider; use std::{ convert::TryInto, future, @@ -58,6 +63,7 @@ use temporal_sdk_core_protos::{ TaskToken, }; use tokio::sync::mpsc::unbounded_channel; +use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; #[cfg(test)] @@ -75,7 +81,8 @@ use { pub struct Worker { config: WorkerConfig, wf_client: Arc, - + /// Registration key to enable eager workflow start for this worker + worker_key: Option, /// Manages all workflows and WFT processing workflows: Workflows, /// Manages activity tasks for this worker/task queue @@ -170,7 +177,11 @@ impl WorkerTrait for Worker { ); } self.shutdown_token.cancel(); - // First, we want to stop polling of both activity and workflow tasks + // First, disable Eager Workflow Start + if let Some(key) = self.worker_key { + self.wf_client.workers().unregister(key); + } + // Second, we want to stop polling of both activity and workflow tasks if let Some(atm) = self.at_task_mgr.as_ref() { atm.initiate_shutdown(); } @@ -240,7 +251,7 @@ impl Worker { metrics.with_new_attrs([activity_worker_type()]), MetricsContext::available_task_slots, )); - + let (external_wft_tx, external_wft_rx) = unbounded_channel(); let (wft_stream, act_poller) = match task_pollers { TaskPollers::Real => { let max_nonsticky_polls = if sticky_queue_name.is_some() { @@ -302,6 +313,8 @@ impl Worker { sticky_queue_poller, )); let wft_stream = new_wft_poller(wf_task_poll_buffer, metrics.clone()); + let wft_stream = + stream::select(wft_stream, UnboundedReceiverStream::new(external_wft_rx)); #[cfg(test)] let wft_stream = wft_stream.left_stream(); (wft_stream, act_poll_buffer) @@ -353,7 +366,15 @@ impl Worker { info!("Activity polling is disabled for this worker"); }; let la_sink = LAReqSink::new(local_act_mgr.clone(), config.wf_state_inputs.clone()); + let provider = SlotProvider::new( + config.namespace.clone(), + config.task_queue.clone(), + wft_semaphore.clone(), + external_wft_tx, + ); + let worker_key = client.workers().register(Box::new(provider)); Self { + worker_key, wf_client: client.clone(), workflows: Workflows::new( build_wf_basics( diff --git a/core/src/worker/slot_provider.rs b/core/src/worker/slot_provider.rs new file mode 100644 index 000000000..bd9cbf13d --- /dev/null +++ b/core/src/worker/slot_provider.rs @@ -0,0 +1,175 @@ +//! This module implements traits defined in the client to dispatch a +//! WFT to a worker bypassing the server. +//! This enables latency optimizations such as Eager Workflow Start. + +use crate::{ + abstractions::{MeteredSemaphore, OwnedMeteredSemPermit}, + protosext::ValidPollWFTQResponse, + worker::workflow::wft_poller::validate_wft, +}; + +use std::sync::Arc; +use temporal_client::{Slot as SlotTrait, SlotProvider as SlotProviderTrait}; +use temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollWorkflowTaskQueueResponse; +use tokio::sync::mpsc::UnboundedSender; +use tonic::Status; + +type WFTStreamSender = + UnboundedSender>; + +pub struct Slot { + permit: OwnedMeteredSemPermit, + external_wft_tx: WFTStreamSender, +} + +impl Slot { + fn new(permit: OwnedMeteredSemPermit, external_wft_tx: WFTStreamSender) -> Self { + Self { + permit, + external_wft_tx, + } + } +} + +impl SlotTrait for Slot { + fn schedule_wft( + self: Box, + task: PollWorkflowTaskQueueResponse, + ) -> Result<(), anyhow::Error> { + let wft = validate_wft(task)?; + self.external_wft_tx.send(Ok((wft, self.permit)))?; + Ok(()) + } +} + +#[derive(derive_more::DebugCustom)] +#[debug(fmt = "SlotProvider {{ namespace:{namespace}, task_queue: {task_queue} }}")] +pub struct SlotProvider { + namespace: String, + task_queue: String, + wft_semaphore: Arc, + external_wft_tx: WFTStreamSender, +} + +impl SlotProvider { + pub(crate) fn new( + namespace: String, + task_queue: String, + wft_semaphore: Arc, + external_wft_tx: WFTStreamSender, + ) -> Self { + Self { + namespace, + task_queue, + wft_semaphore, + external_wft_tx, + } + } +} + +impl SlotProviderTrait for SlotProvider { + fn namespace(&self) -> &str { + &self.namespace + } + fn task_queue(&self) -> &str { + &self.task_queue + } + fn try_reserve_wft_slot(&self) -> Option> { + match self.wft_semaphore.try_acquire_owned().ok() { + Some(permit) => Some(Box::new(Slot::new(permit, self.external_wft_tx.clone()))), + None => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use temporal_sdk_core_protos::temporal::api::{ + common::v1::{WorkflowExecution, WorkflowType}, + history::v1::History, + taskqueue::v1::TaskQueue, + }; + use tokio::sync::mpsc::unbounded_channel; + + // make validate_wft() happy + fn new_validatable_response() -> PollWorkflowTaskQueueResponse { + PollWorkflowTaskQueueResponse { + workflow_execution_task_queue: Some(TaskQueue::default()), + workflow_execution: Some(WorkflowExecution::default()), + workflow_type: Some(WorkflowType::default()), + history: Some(History::default()), + ..Default::default() + } + } + + #[tokio::test] + async fn slot_propagates_through_channel() { + let wft_semaphore = Arc::new(MeteredSemaphore::new( + 2, + crate::MetricsContext::no_op(), + |_, _| {}, + )); + let (external_wft_tx, mut external_wft_rx) = unbounded_channel(); + + let provider = SlotProvider::new( + "my_namespace".to_string(), + "my_queue".to_string(), + wft_semaphore, + external_wft_tx, + ); + + let slot = provider + .try_reserve_wft_slot() + .expect("failed to reserver slot"); + let p = slot.schedule_wft(new_validatable_response()); + assert!(p.is_ok()); + assert!(external_wft_rx.recv().await.is_some()); + } + + #[tokio::test] + async fn channel_closes_when_provider_drops() { + let (external_wft_tx, mut external_wft_rx) = unbounded_channel(); + { + let external_wft_tx = external_wft_tx; + let wft_semaphore = Arc::new(MeteredSemaphore::new( + 2, + crate::MetricsContext::no_op(), + |_, _| {}, + )); + let provider = SlotProvider::new( + "my_namespace".to_string(), + "my_queue".to_string(), + wft_semaphore, + external_wft_tx, + ); + assert!(provider.try_reserve_wft_slot().is_some()); + } + assert!(external_wft_rx.recv().await.is_none()); + } + + #[tokio::test] + async fn unused_slots_reclaimed() { + let wft_semaphore = Arc::new(MeteredSemaphore::new( + 2, + crate::MetricsContext::no_op(), + |_, _| {}, + )); + { + let wft_semaphore = wft_semaphore.clone(); + let (external_wft_tx, _) = unbounded_channel(); + let provider = SlotProvider::new( + "my_namespace".to_string(), + "my_queue".to_string(), + wft_semaphore.clone(), + external_wft_tx, + ); + let slot = provider.try_reserve_wft_slot(); + assert!(slot.is_some()); + assert_eq!(wft_semaphore.available_permits(), 1); + // drop slot without using it + } + assert_eq!(wft_semaphore.available_permits(), 2); + } +} diff --git a/etc/dynamic-config.yaml b/etc/dynamic-config.yaml index d69c4149b..751957200 100644 --- a/etc/dynamic-config.yaml +++ b/etc/dynamic-config.yaml @@ -1,5 +1,7 @@ system.enableActivityLocalDispatch: - value: true +system.enableEagerWorkflowStart: + - value: true frontend.workerVersioningWorkflowAPIs: - value: true frontend.workerVersioningDataAPIs: diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index d8ba89e1a..f4b97ec3e 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -9,6 +9,7 @@ pub mod interceptors; pub mod wf_input_saver; pub mod workflows; +use anyhow::Context; pub use temporal_sdk_core::replay::HistoryForReplay; use crate::{ @@ -55,7 +56,10 @@ use temporal_sdk_core_protos::{ }, workflow_completion::WorkflowActivationCompletion, }, - temporal::api::{common::v1::Payload, history::v1::History}, + temporal::api::{ + common::v1::Payload, history::v1::History, + workflowservice::v1::StartWorkflowExecutionResponse, + }, DEFAULT_ACTIVITY_TYPE, }; use tokio::sync::{mpsc::unbounded_channel, OnceCell}; @@ -254,6 +258,23 @@ impl CoreWfStarter { .unwrap() } + pub async fn eager_start_with_worker( + &self, + wf_name: impl Into, + worker: &mut TestWorker, + ) -> StartWorkflowExecutionResponse { + assert!(self.workflow_options.enable_eager_workflow_start); + worker + .eager_submit_wf( + self.task_queue_name.clone(), + wf_name.into(), + vec![], + self.workflow_options.clone(), + ) + .await + .unwrap() + } + pub async fn start_wf_with_id(&self, workflow_id: String) -> String { let iw = self.initted_worker.get().expect( "Worker must be initted before starting a workflow.\ @@ -455,6 +476,39 @@ impl TestWorker { } } + /// Similar to `submit_wf` but checking that the server returns the first + /// workflow task in the client response. + /// Note that this does not guarantee that the worker will execute this task eagerly. + pub async fn eager_submit_wf( + &self, + workflow_id: impl Into, + workflow_type: impl Into, + input: Vec, + options: WorkflowOptions, + ) -> Result { + let c = self.client.as_ref().context("client needed for eager wf")?; + let wfid = workflow_id.into(); + let res = c + .start_workflow( + input, + self.inner.task_queue().to_string(), + wfid.clone(), + workflow_type.into(), + None, + options, + ) + .await?; + res.eager_workflow_task + .as_ref() + .context("no eager workflow task")?; + self.started_workflows.lock().push(WorkflowExecutionInfo { + namespace: c.namespace().to_string(), + workflow_id: wfid, + run_id: Some(res.run_id.clone()), + }); + Ok(res) + } + /// Runs until all expected workflows have completed pub async fn run_until_done(&mut self) -> Result<(), anyhow::Error> { self.run_until_done_intercepted(Option::::None) diff --git a/tests/integ_tests/workflow_tests.rs b/tests/integ_tests/workflow_tests.rs index 83b3670c6..aa958c62e 100644 --- a/tests/integ_tests/workflow_tests.rs +++ b/tests/integ_tests/workflow_tests.rs @@ -5,6 +5,7 @@ mod cancel_wf; mod child_workflows; mod continue_as_new; mod determinism; +mod eager; mod local_activities; mod modify_wf_properties; mod patches; diff --git a/tests/integ_tests/workflow_tests/eager.rs b/tests/integ_tests/workflow_tests/eager.rs new file mode 100644 index 000000000..5002d7449 --- /dev/null +++ b/tests/integ_tests/workflow_tests/eager.rs @@ -0,0 +1,61 @@ +use std::time::Duration; +use temporal_client::{WorkflowClientTrait, WorkflowExecutionInfo}; +use temporal_sdk::{WfContext, WorkflowResult}; +use temporal_sdk_core_test_utils::{get_integ_server_options, CoreWfStarter, NAMESPACE}; + +pub async fn eager_wf(_context: WfContext) -> WorkflowResult<()> { + Ok(().into()) +} + +#[tokio::test] +async fn eager_wf_start() { + let wf_name = "eager_wf_start"; + let mut starter = CoreWfStarter::new(wf_name); + starter.workflow_options.enable_eager_workflow_start = true; + // hang the test if eager task dispatch failed + starter.workflow_options.task_timeout = Some(Duration::from_secs(1500)); + starter.no_remote_activities(); + let mut worker = starter.worker().await; + worker.register_wf(wf_name.to_owned(), eager_wf); + starter.eager_start_with_worker(wf_name, &mut worker).await; + worker.run_until_done().await.unwrap(); +} + +#[tokio::test] +async fn eager_wf_start_different_clients() { + let wf_name = "eager_wf_start"; + let mut starter = CoreWfStarter::new(wf_name); + starter.workflow_options.enable_eager_workflow_start = true; + // hang the test if wf task needs retry + starter.workflow_options.task_timeout = Some(Duration::from_secs(1500)); + starter.no_remote_activities(); + let mut worker = starter.worker().await; + worker.register_wf(wf_name.to_owned(), eager_wf); + + let client = get_integ_server_options() + .connect(NAMESPACE.to_string(), None, None) + .await + .expect("Should connect"); + let w = starter.get_worker().await; + let res = client + .start_workflow( + vec![], + w.get_config().task_queue.clone(), // task_queue + wf_name.to_string(), // workflow_id + wf_name.to_string(), // workflow_type + None, + starter.workflow_options.clone(), + ) + .await + .unwrap(); + // different clients means no eager_wf_start. + assert!(res.eager_workflow_task.is_none()); + + //wf task delivered through default path + worker.started_workflows.lock().push(WorkflowExecutionInfo { + namespace: NAMESPACE.to_string(), + workflow_id: wf_name.to_string(), + run_id: Some(res.run_id.clone()), + }); + worker.run_until_done().await.unwrap(); +} diff --git a/tests/runner.rs b/tests/runner.rs index 1a258fa9c..d0a2dc94f 100644 --- a/tests/runner.rs +++ b/tests/runner.rs @@ -70,6 +70,11 @@ async fn main() -> Result<(), anyhow::Error> { ServerKind::TemporalCLI => { let config = TemporalDevServerConfigBuilder::default() .exe(default_cached_download()) + // TODO: Delete when temporalCLI enables it by default. + .extra_args(vec![ + "--dynamic-config-value".to_string(), + "system.enableEagerWorkflowStart=true".to_string(), + ]) .build()?; println!("Using temporal CLI"); (