diff --git a/agentwire/Cargo.toml b/agentwire/Cargo.toml new file mode 100644 index 0000000..288bf24 --- /dev/null +++ b/agentwire/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "agentwire" +version = "0.0.1" +authors = ["Valentine Valyaeff "] +description = """ +A framework for asynchronous message-passing agents. +""" +publish = false + +edition.workspace = true +[features] +sandbox-network = [] + +[dependencies.agentwire-macros] +version = "=0.0.1" +path = "macros" + +[dependencies] +close_fds = "0.3.2" +futures = "0.3" +libc = "0.2.93" +nix = { version = "0.26.2", default-features = false, features = ["signal", "fs", "mman", "sched"] } +rkyv = "0.7.40" +shell-words = "1.1.0" +thiserror = "1.0.61" +tokio = { version = "1", features = ["rt-multi-thread", "process", "sync", "time", "io-util"] } +tracing = "0.1" + +[dev-dependencies] +tokio = { version = "1", features = ["macros"] } diff --git a/agentwire/macros/Cargo.toml b/agentwire/macros/Cargo.toml new file mode 100644 index 0000000..5bf5c52 --- /dev/null +++ b/agentwire/macros/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "agentwire-macros" +version = "0.0.1" +edition.workspace = true +authors = ["Valentine Valyaeff "] +description = """ +Procedural macros for agentwire. +""" +publish = false + +[lib] +proc-macro = true + +[dependencies] +heck = "0.5.0" +proc-macro2 = "1.0.79" +quote = "1.0.35" +syn = { version = "2.0.55", features = ["extra-traits"] } diff --git a/agentwire/macros/src/broker.rs b/agentwire/macros/src/broker.rs new file mode 100644 index 0000000..c509a92 --- /dev/null +++ b/agentwire/macros/src/broker.rs @@ -0,0 +1,335 @@ +use heck::ToSnakeCase as _; +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use std::collections::HashSet; +use syn::{ + parse::{Parse, ParseStream, Result}, + parse_macro_input, + punctuated::{Pair, Punctuated}, + Data, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, Ident, Path, Token, +}; + +#[derive(PartialEq, Eq, Hash)] +enum AgentAttr { + Task, + Thread, + Process, + Init, + InitAsync, + Logger(Expr), +} + +impl Parse for AgentAttr { + fn parse(input: ParseStream) -> Result { + let ident = input.parse::()?; + match ident.to_string().as_str() { + "task" => Ok(Self::Task), + "thread" => Ok(Self::Thread), + "process" => Ok(Self::Process), + "init" => Ok(Self::Init), + "init_async" => Ok(Self::InitAsync), + "logger" => { + input.parse::()?; + Ok(Self::Logger(input.parse()?)) + } + ident => panic!("Unknown #[agent] option: {ident}"), + } + } +} + +#[derive(PartialEq, Eq, Hash)] +enum BrokerAttr { + Plan(Path), + Error(Path), + PollExtra, +} + +impl Parse for BrokerAttr { + fn parse(input: ParseStream) -> Result { + let ident = input.parse::()?; + match ident.to_string().as_str() { + "plan" => { + input.parse::()?; + Ok(Self::Plan(input.parse()?)) + } + "error" => { + input.parse::()?; + Ok(Self::Error(input.parse()?)) + } + "poll_extra" => Ok(Self::PollExtra), + ident => panic!("Unknown #[broker] option: {ident}"), + } + } +} + +#[allow(clippy::too_many_lines)] +pub fn proc_macro_derive(input: TokenStream) -> TokenStream { + let DeriveInput { attrs, ident, data, .. } = parse_macro_input!(input); + let Data::Struct(DataStruct { fields, .. }) = data else { panic!("must be a struct") }; + let Fields::Named(FieldsNamed { named: fields, .. }) = fields else { + panic!("must have named fields") + }; + + let broker_attrs = attrs + .iter() + .find(|attr| attr.path().is_ident("broker")) + .expect("must have a `#[broker]` attribute") + .parse_args_with(Punctuated::::parse_terminated) + .expect("failed to parse `broker` attribute") + .into_pairs() + .map(Pair::into_value) + .collect::>(); + let broker_plan = broker_attrs + .iter() + .find_map(|attr| if let BrokerAttr::Plan(expr) = attr { Some(expr) } else { None }) + .expect("#[broker] attribute must set a `plan`"); + let broker_error = broker_attrs + .iter() + .find_map(|attr| if let BrokerAttr::Error(expr) = attr { Some(expr) } else { None }) + .expect("#[broker] attribute must set an `error`"); + + let agent_fields = fields.iter().filter_map(|field| { + field.attrs.iter().find(|attr| attr.path().is_ident("agent")).map(|attr| { + let attrs = attr + .parse_args_with(Punctuated::::parse_terminated) + .expect("failed to parse `agent` attribute"); + (field, attrs.into_pairs().map(Pair::into_value).collect::>()) + }) + }); + + let constructor_name = format_ident!("new_{}", ident.to_string().to_snake_case()); + let constructor_fields = agent_fields + .clone() + .map(|(Field { ident, .. }, _)| quote!(#ident: ::agentwire::agent::Cell::Vacant)); + let constructor = quote! { + macro_rules! #constructor_name { + ($($tokens:tt)*) => { + #ident { + #(#constructor_fields,)* + $($tokens)* + } + }; + } + }; + + let run_fut_name = format_ident!("Run{}", ident); + let run_handlers = agent_fields.clone().map(|(field, _)| { + let ident = field.ident.as_ref().unwrap(); + let handler = format_ident!("handle_{}", ident); + quote! { + if let Some(port) = fut.broker.#ident.enabled() { + loop { + match ::futures::StreamExt::poll_next_unpin(port, cx) { + ::std::task::Poll::Ready(Some(output)) if output.source_ts > fence => { + match fut.broker.#handler(fut.plan, output) { + ::std::result::Result::Ok(::agentwire::BrokerFlow::Break) => { + return ::std::task::Poll::Ready(::std::result::Result::Ok(())); + } + ::std::result::Result::Ok(::agentwire::BrokerFlow::Continue) => { + continue 'outer; + } + ::std::result::Result::Err(err) => { + return ::std::task::Poll::Ready( + ::std::result::Result::Err( + ::agentwire::BrokerError::Handler( + ::std::stringify!(#ident), + err, + ), + ), + ); + } + } + } + ::std::task::Poll::Ready(::std::option::Option::Some(_)) => { + continue; + } + ::std::task::Poll::Ready(::std::option::Option::None) => { + return ::std::task::Poll::Ready( + ::std::result::Result::Err( + ::agentwire::BrokerError::AgentTerminated( + ::std::stringify!(#ident), + ), + ), + ); + } + ::std::task::Poll::Pending => { + break; + } + } + } + } + } + }); + let poll_extra = broker_attrs.contains(&BrokerAttr::PollExtra).then(|| { + quote! { + match fut.broker.poll_extra(fut.plan, cx, fence) { + ::std::result::Result::Ok(::std::option::Option::Some(poll)) => { + break poll.map(Ok); + } + ::std::result::Result::Ok(::std::option::Option::None) => { + continue; + } + ::std::result::Result::Err(err) => { + return ::std::task::Poll::Ready(::std::result::Result::Err( + ::agentwire::BrokerError::PollExtra(err), + )); + } + } + } + }); + let run = quote! { + #[allow(missing_docs)] + pub struct #run_fut_name<'a> { + broker: &'a mut #ident, + plan: &'a mut dyn #broker_plan, + fence: ::std::time::Instant, + } + + impl ::futures::future::Future for #run_fut_name<'_> { + type Output = ::std::result::Result<(), ::agentwire::BrokerError<#broker_error>>; + + fn poll( + mut self: ::std::pin::Pin<&mut Self>, + cx: &mut ::std::task::Context<'_>, + ) -> ::std::task::Poll { + let fence = self.fence; + let fut = self.as_mut().get_mut(); + 'outer: loop { + #(#run_handlers)* + #poll_extra + } + } + } + + impl #ident { + #[allow(missing_docs)] + pub fn run<'a>(&'a mut self, plan: &'a mut dyn #broker_plan) -> #run_fut_name<'a> { + Self::run_with_fence(self, plan, ::std::time::Instant::now()) + } + + #[allow(missing_docs)] + pub fn run_with_fence<'a>( + &'a mut self, + plan: &'a mut dyn #broker_plan, + fence: ::std::time::Instant, + ) -> #run_fut_name<'a> { + #run_fut_name { + broker: self, + plan, + fence, + } + } + } + }; + + let methods = agent_fields.clone().map(|(field, attrs)| { + let ident = field.ident.as_ref().unwrap(); + let enable = format_ident!("enable_{}", ident); + let try_enable = format_ident!("try_enable_{}", ident); + let disable = format_ident!("disable_{}", ident); + let init = format_ident!("init_{}", ident); + let (init, init_async) = if attrs.contains(&AgentAttr::InitAsync) { + let init = quote! { + match self.#init().await { + ::std::result::Result::Ok(agent) => agent, + ::std::result::Result::Err(err) => { + return ::std::result::Result::Err( + ::agentwire::BrokerError::Init(::std::stringify!(#ident), err), + ); + } + } + }; + (init, quote!(async)) + } else if attrs.contains(&AgentAttr::Init) { + (quote!(self.#init()), quote!()) + } else { + (quote!(Default::default()), quote!()) + }; + let constructor = if attrs.contains(&AgentAttr::Process) { + let logger = if let Some(logger) = attrs + .iter() + .find_map(|attr| if let AgentAttr::Logger(expr) = attr { Some(expr) } else { None }) + { + quote!(#logger) + } else { + quote!(::agentwire::agent::process::default_logger) + }; + quote!(::agentwire::agent::Process::spawn_process(#init, #logger)) + } else if attrs.contains(&AgentAttr::Thread) { + quote! { + match ::agentwire::agent::Thread::spawn_thread(#init) { + ::std::result::Result::Ok(cell) => cell, + ::std::result::Result::Err(err) => { + return ::std::result::Result::Err( + ::agentwire::BrokerError::SpawnThread(::std::stringify!(#ident), err) + ); + } + } + } + } else if attrs.contains(&AgentAttr::Task) { + quote!(::agentwire::agent::Task::spawn_task(#init)) + } else { + panic!("must have `task`, `thread`, or `process` tag"); + }; + + quote! { + #[allow(missing_docs)] + pub #init_async fn #enable( + &mut self, + ) -> ::std::result::Result<(), ::agentwire::BrokerError<#broker_error>> { + match ::std::mem::replace(&mut self.#ident, ::agentwire::agent::Cell::Vacant) { + ::agentwire::agent::Cell::Vacant => { + self.#ident = ::agentwire::agent::Cell::Enabled(#constructor); + } + ::agentwire::agent::Cell::Enabled(agent) + | ::agentwire::agent::Cell::Disabled(agent) => { + self.#ident = ::agentwire::agent::Cell::Enabled(agent); + } + } + ::std::result::Result::Ok(()) + } + + #[allow(missing_docs)] + pub fn #try_enable(&mut self) { + match ::std::mem::replace(&mut self.#ident, ::agentwire::agent::Cell::Vacant) { + ::agentwire::agent::Cell::Vacant => {} + ::agentwire::agent::Cell::Enabled(agent) + | ::agentwire::agent::Cell::Disabled(agent) => { + self.#ident = ::agentwire::agent::Cell::Enabled(agent); + } + } + } + + #[allow(missing_docs)] + pub fn #disable(&mut self) { + match ::std::mem::replace(&mut self.#ident, ::agentwire::agent::Cell::Vacant) { + ::agentwire::agent::Cell::Vacant => {} + ::agentwire::agent::Cell::Enabled(agent) + | ::agentwire::agent::Cell::Disabled(agent) => { + self.#ident = ::agentwire::agent::Cell::Disabled(agent); + } + } + } + } + }); + + let disable_agents = agent_fields.map(|(field, _)| { + let disable = format_ident!("disable_{}", field.ident.as_ref().unwrap()); + quote!(#disable) + }); + + let expanded = quote! { + #constructor + #run + + impl #ident { + #(#methods)* + + #[allow(missing_docs)] + pub fn disable_agents(&mut self) { + #(self.#disable_agents();)* + } + } + }; + expanded.into() +} diff --git a/agentwire/macros/src/lib.rs b/agentwire/macros/src/lib.rs new file mode 100644 index 0000000..9c60011 --- /dev/null +++ b/agentwire/macros/src/lib.rs @@ -0,0 +1,21 @@ +//! Procedural macros for agentwire. + +#![warn(unsafe_op_in_unsafe_fn)] +#![warn(clippy::pedantic)] + +extern crate proc_macro; + +mod broker; +mod test; + +use proc_macro::TokenStream; + +#[proc_macro_derive(Broker, attributes(broker, agent))] +pub fn derive_broker(input: TokenStream) -> TokenStream { + broker::proc_macro_derive(input) +} + +#[proc_macro_attribute] +pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream { + test::proc_macro_attribute(attr, item) +} diff --git a/agentwire/macros/src/test.rs b/agentwire/macros/src/test.rs new file mode 100644 index 0000000..fe2c09a --- /dev/null +++ b/agentwire/macros/src/test.rs @@ -0,0 +1,69 @@ +use proc_macro::TokenStream; +use quote::quote; +use std::mem::take; +use syn::{ + parse::{Parse, ParseStream, Result}, + parse_macro_input, + punctuated::Punctuated, + Expr, Ident, ItemFn, LitStr, Token, +}; + +enum TestAttr { + Init(Expr), + Timeout(Expr), +} + +impl Parse for TestAttr { + fn parse(input: ParseStream) -> Result { + let ident = input.parse::()?; + match ident.to_string().as_str() { + "init" => { + input.parse::()?; + Ok(Self::Init(input.parse()?)) + } + "timeout" => { + input.parse::()?; + Ok(Self::Timeout(input.parse()?)) + } + ident => panic!("Unknown option: {ident}"), + } + } +} + +pub fn proc_macro_attribute(attr: TokenStream, item: TokenStream) -> TokenStream { + let test_attrs = + parse_macro_input!(attr with Punctuated::::parse_terminated); + let init = test_attrs + .iter() + .find_map(|attr| if let TestAttr::Init(expr) = attr { Some(quote!(#expr)) } else { None }) + .unwrap_or_else(|| quote!(|| {})); + let timeout = test_attrs + .iter() + .find_map( + |attr| { + if let TestAttr::Timeout(expr) = attr { Some(quote!(#expr)) } else { None } + }, + ) + .unwrap_or_else(|| quote!(::agentwire::testing_rt::DEFAULT_TIMEOUT)); + + let ItemFn { attrs, vis, mut sig, block } = parse_macro_input!(item as ItemFn); + let test_name = LitStr::new(&sig.ident.to_string(), sig.ident.span()); + assert!(take(&mut sig.asyncness).is_some(), "Test function must be async"); + + let expanded = quote! { + #(#attrs)* + #[test] + #vis #sig { + struct TestId; + let test_id = ::std::any::TypeId::of::(); + ::agentwire::testing_rt::run_broker_test( + ::std::stringify!(#test_name), + &::std::format!("{test_id:?}"), + ::std::time::Duration::from_millis(#timeout), + #init, + ::std::boxed::Box::pin(async move #block), + ) + } + }; + expanded.into() +} diff --git a/agentwire/src/agent/mod.rs b/agentwire/src/agent/mod.rs new file mode 100644 index 0000000..3f4b5e7 --- /dev/null +++ b/agentwire/src/agent/mod.rs @@ -0,0 +1,111 @@ +//! Agent module. +//! +//! # Examples +//! +//! ``` +//! # #[tokio::main] async fn main() { +//! use agentwire::{ +//! agent::{self, Agent, Task as _}, +//! port::{self, Port}, +//! }; +//! use futures::{ +//! channel::mpsc::{self, SendError}, +//! prelude::*, +//! }; +//! +//! /// An agent that receives numbers, multiplies them by 2, and sends them +//! /// back. +//! struct Doubler; +//! +//! impl Port for Doubler { +//! type Input = u32; +//! type Output = u32; +//! +//! const INPUT_CAPACITY: usize = 0; +//! const OUTPUT_CAPACITY: usize = 0; +//! } +//! +//! impl Agent for Doubler { +//! const NAME: &'static str = "doubler"; +//! } +//! +//! impl agent::Task for Doubler { +//! type Error = SendError; +//! +//! async fn run(self, mut port: port::Inner) -> Result<(), Self::Error> { +//! while let Some(x) = port.next().await { +//! port.send(x.chain(x.value * 2)).await?; +//! } +//! Ok(()) +//! } +//! } +//! +//! let (mut doubler, _kill) = Doubler.spawn_task(); +//! +//! // Send an input message to the agent. +//! doubler.send(port::Input::new(3)).await; +//! // Receive an output message from the agent. +//! let output = doubler.next().await; +//! assert_eq!(output.unwrap().value, 6); +//! # } +//! ``` + +pub mod process; + +mod task; +mod thread; + +pub use self::{process::Process, task::Task, thread::Thread}; + +use crate::port::{self, Port}; +use futures::prelude::*; +use std::{mem::replace, pin::Pin}; + +/// Abstract agent. +pub trait Agent: Port + Sized + 'static { + /// Name of the agent. Must be unique. + const NAME: &'static str; +} + +/// Future to kill an agent. +pub type Kill = Pin + Send>>; + +/// Agent cell inside a broker. +pub enum Cell { + /// Agent is not initialized. + Vacant, + /// Agent is initialized and enabled. + Enabled((port::Outer, Kill)), + /// Agent is initialized but disabled. + Disabled((port::Outer, Kill)), +} + +impl Cell { + /// Returns `Some(port)` if the agent is enabled, otherwise returns `None`. + pub fn enabled(&mut self) -> Option<&mut port::Outer> { + match self { + Self::Vacant | Self::Disabled(_) => None, + Self::Enabled((ref mut port, _kill)) => Some(port), + } + } + + /// Returns `true` if the agent is enabled. + #[must_use] + pub fn is_enabled(&self) -> bool { + matches!(self, Self::Enabled(_)) + } + + /// Returns `true` if the agent is initialized. + #[must_use] + pub fn is_initialized(&self) -> bool { + !matches!(self, Self::Vacant) + } + + /// Kills the agent. + pub async fn kill(&mut self) { + match replace(self, Self::Vacant) { + Self::Enabled((_port, kill)) | Self::Disabled((_port, kill)) => kill.await, + Self::Vacant => {} + } + } +} diff --git a/agentwire/src/agent/process.rs b/agentwire/src/agent/process.rs new file mode 100644 index 0000000..8337340 --- /dev/null +++ b/agentwire/src/agent/process.rs @@ -0,0 +1,350 @@ +//! Process-based agents. + +use super::{Agent, Kill}; +use crate::{ + port::{self, SharedPort, SharedSerializer}, + spawn_named_thread, +}; +use close_fds::close_open_fds; +use futures::{future::Either, prelude::*}; +use nix::{ + errno::Errno, + sched::{unshare, CloneFlags}, + sys::signal::{self, Signal}, + unistd::Pid, +}; +use rkyv::{de::deserializers::SharedDeserializeMap, Archive, Deserialize, Infallible, Serialize}; +use std::{ + env, + error::Error, + fmt::Debug, + io, + os::{ + fd::{AsRawFd as _, FromRawFd as _, OwnedFd, RawFd}, + unix::process::{parent_id, ExitStatusExt as _}, + }, + pin::pin, + process::{self, Stdio}, + sync::atomic::{AtomicBool, Ordering}, +}; +use thiserror::Error; +use tokio::{ + io::{AsyncBufReadExt as _, BufReader}, + process::{ChildStderr, ChildStdout, Command}, + runtime, + sync::oneshot, + task, +}; + +/// Environment variable to pass extra arguments to the agent process. +pub const ARGS_ENV: &str = "AGENTWIRE_PROCESS_ARGS"; + +const SHMEM_ENV: &str = "AGENTWIRE_PROCESS_SHMEM"; +const PARENT_PID_ENV: &str = "AGENTWIRE_PROCESS_PARENT_PID"; + +static INIT_PROCESSES: AtomicBool = AtomicBool::new(false); + +/// Error returned by [`Process::call`]. +#[derive(Error, Debug)] +pub enum CallError { + /// Error returned by the agent. + #[error("agent: {0}")] + Agent(T), + /// Error initializing the shared memory. + #[error("shared memory: {0}")] + SharedMemory(Errno), +} + +/// Exit strategy returned from [`Process::exit_strategy`]. +#[derive(Clone, Copy, Default, Debug)] +pub enum ExitStrategy { + /// Close the port without restarting the agent. + Close, + /// Keep the port open and restart the agent. + Restart, + /// Keep the port open, restart the agent, and retry the latest input. + #[default] + Retry, +} + +/// Additional settings for starting a new process. +pub trait Initializer: Send { + /// File descriptors to keep open when starting a new process. + #[must_use] + fn keep_file_descriptors(&self) -> Vec; + + /// Additional environment variables for the process. + #[must_use] + fn envs(&self) -> Vec<(String, String)>; +} + +/// Default initializer with no additional settings. +pub struct DefaultInitializer; + +impl Initializer for DefaultInitializer { + fn keep_file_descriptors(&self) -> Vec { + Vec::new() + } + + fn envs(&self) -> Vec<(String, String)> { + Vec::new() + } +} + +/// Agent running on a dedicated OS process. +pub trait Process +where + Self: Agent + + SharedPort + + Clone + + Send + + Debug + + Archive + + for<'a> Serialize>, + ::Archived: Deserialize, + Self::Input: Archive + for<'a> Serialize>, + Self::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + /// Error type returned by the agent. + type Error: Debug; + + /// Runs the agent event-loop inside a dedicated OS thread. + fn run(self, port: port::RemoteInner) -> Result<(), Self::Error>; + + /// Spawns a new process running the agent event-loop and returns a handle + /// for bi-directional communication with the agent. + /// + /// # Panics + /// + /// If [`init`] hasn't been called yet. + fn spawn_process(self, logger: F) -> (port::Outer, Kill) + where + F: Fn(&'static str, ChildStdout, ChildStderr) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + { + assert!( + INIT_PROCESSES.load(Ordering::Relaxed), + "process-based agents are not initialized (missing call to \ + `agentwire::agent::process::init`)" + ); + let (inner, outer) = port::new(); + let (send_kill_tx, send_kill_rx) = oneshot::channel(); + let (wait_kill_tx, wait_kill_rx) = oneshot::channel(); + let kill = async move { + let _ = send_kill_tx.send(()); + wait_kill_rx.await.unwrap(); + tracing::info!("Process agent {} killed", Self::NAME); + }; + let spawn_process = spawn_process_impl(self, inner, send_kill_rx, wait_kill_tx, logger); + spawn_named_thread(format!("proc-ipc-{}", Self::NAME), || { + let rt = runtime::Builder::new_current_thread().enable_all().build().unwrap(); + rt.block_on(task::LocalSet::new().run_until(spawn_process)); + }); + (outer, kill.boxed()) + } + + /// Connects to the shared memory and calls the [`run`](Self::run) method. + fn call(shmem: OwnedFd) -> Result<(), CallError> { + let mut inner = port::RemoteInner::::from_shared_memory(shmem) + .map_err(CallError::SharedMemory)?; + let agent = inner.init_state().deserialize(&mut Infallible).unwrap(); + agent.run(inner).map_err(CallError::Agent) + } + + /// When the agent process terminates, this method decides how to proceed. + /// See [`ExitStrategy`] for available options. + #[must_use] + fn exit_strategy(_code: Option, _signal: Option) -> ExitStrategy { + ExitStrategy::default() + } + + /// Additional settings for starting a new process. + #[must_use] + fn initializer() -> impl Initializer { + DefaultInitializer + } +} + +/// Initializes process-based agents. +/// +/// This function must be called as early in the program lifetime as possible. +/// Everything before this function call gets duplicated for each process-based +/// agent. +pub fn init(call_process_agent: impl FnOnce(&str, OwnedFd) -> Result<(), Box>) { + match (env::var(SHMEM_ENV), env::var(PARENT_PID_ENV)) { + (Ok(shmem), Ok(parent_pid)) => { + let result = unsafe { libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGKILL) }; + if result == -1 { + eprintln!( + "Failed to set the parent death signal: {:#?}", + io::Error::last_os_error() + ); + process::exit(1); + } + if parent_id() != parent_pid.parse::().unwrap() { + // The parent exited before the above `prctl` call. + process::exit(1); + } + let shmem_fd = unsafe { + OwnedFd::from_raw_fd( + shmem.parse::().expect("shared memory file descriptor to be an integer"), + ) + }; + // Agent's name is the first argument. + let argv0 = std::env::args().next().expect("argv[0] is not set"); + let name = argv0 + .strip_prefix("proc-") + .expect("mega-agent process name should start with 'proc-'"); + match call_process_agent(name, shmem_fd) { + Ok(()) => tracing::warn!("Agent {name} exited"), + Err(err) => tracing::error!("Agent {name} exited with an error: {err:#?}"), + } + process::exit(1); + } + (Err(_), Err(_)) => { + INIT_PROCESSES.store(true, Ordering::Relaxed); + } + (shmem, parent_pid) => { + panic!( + "Inconsistent state of the following environment variables: \ + {SHMEM_ENV}={shmem:?}, {PARENT_PID_ENV}={parent_pid:?}, " + ); + } + } +} + +/// Creates a default process agent logger. +pub async fn default_logger(agent_name: &'static str, stdout: ChildStdout, stderr: ChildStderr) { + let mut stdout = BufReader::new(stdout).lines(); + let mut stderr = BufReader::new(stderr).lines(); + loop { + match future::select(pin!(stdout.next_line()), pin!(stderr.next_line())).await { + Either::Left((Ok(Some(line)), _)) => { + tracing::info!("[{agent_name}] {line}"); + } + Either::Right((Ok(Some(line)), _)) => { + tracing::info!("[{agent_name}] {line}"); + } + Either::Left((Ok(None), _)) => { + tracing::warn!("[{agent_name}] closed"); + break; + } + Either::Right((Ok(None), _)) => { + tracing::warn!("[{agent_name}] closed"); + break; + } + Either::Left((Err(err), _)) => { + tracing::error!("[{agent_name}] {err:#?}"); + break; + } + Either::Right((Err(err), _)) => { + tracing::error!("[{agent_name}] {err:#?}"); + break; + } + } + } +} + +async fn spawn_process_impl( + init_state: T, + mut inner: port::Inner, + mut send_kill_rx: oneshot::Receiver<()>, + wait_kill_tx: oneshot::Sender<()>, + logger: F, +) where + F: Fn(&'static str, ChildStdout, ChildStderr) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + ::Archived: Deserialize, + T::Input: Archive + for<'a> Serialize>, + T::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + let mut recovered_inputs = Vec::new(); + loop { + let (shmem_fd, close) = inner + .into_shared_memory(T::NAME, &init_state, recovered_inputs) + .expect("couldn't initialize shared memory"); + let exe = env::current_exe().expect("couldn't determine current executable file"); + + let initializer = T::initializer(); + let mut child_fds = initializer.keep_file_descriptors(); + child_fds.push(shmem_fd.as_raw_fd()); + let mut child = unsafe { + Command::new(exe) + .arg0(format!("proc-{}", T::NAME)) + .args( + env::var(ARGS_ENV) + .map(|args| shell_words::split(&args).expect("invalid process arguments")) + .unwrap_or_default(), + ) + .envs(initializer.envs()) + .env(SHMEM_ENV, shmem_fd.as_raw_fd().to_string()) + .env(PARENT_PID_ENV, process::id().to_string()) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .pre_exec(sandbox_agent) + .pre_exec(move || { + close_open_fds(libc::STDERR_FILENO + 1, &child_fds); + Ok(()) + }) + .spawn() + .expect("failed to spawn a sub-process") + }; + drop(shmem_fd); + drop(initializer); + let pid = Pid::from_raw(child.id().unwrap().try_into().unwrap()); + task::spawn(logger(T::NAME, child.stdout.take().unwrap(), child.stderr.take().unwrap())); + tracing::info!("Process agent {} spawned with PID: {}", T::NAME, pid.as_raw()); + match future::select(Box::pin(child.wait()), &mut send_kill_rx).await { + Either::Left((status, _)) => { + let status = status.expect("failed to run a sub-process"); + let (code, signal) = (status.code(), status.signal()); + if signal.is_some_and(|signal| signal == libc::SIGINT) { + tracing::warn!("Process agent {} exited on Ctrl-C", T::NAME); + break; + } + let exit_strategy = T::exit_strategy(code, signal); + tracing::info!( + "Process agent {} exited with code {code:?} and signal {signal:?}, proceeding \ + with {exit_strategy:?}", + T::NAME + ); + (inner, recovered_inputs) = + close.await.expect("shared memory deinitialization failure"); + match exit_strategy { + ExitStrategy::Close => { + let _ = wait_kill_tx.send(()); + break; + } + ExitStrategy::Restart => { + recovered_inputs.clear(); + } + ExitStrategy::Retry => {} + } + } + Either::Right((_kill, wait)) => { + signal::kill(pid, Signal::SIGKILL) + .expect("failed to send SIGKILL to a sub-process"); + wait.await.expect("failed to kill a sub-process"); + close.await.expect("shared memory deinitialization failure"); + let _ = wait_kill_tx.send(()); + break; + } + }; + } +} + +fn sandbox_agent() -> std::io::Result<()> { + #[allow(unused_mut)] + let mut flags = CloneFlags::CLONE_NEWUSER | CloneFlags::CLONE_NEWIPC; + #[cfg(feature = "sandbox-network")] + { + flags |= CloneFlags::CLONE_NEWNET; + } + match unshare(flags) { + Ok(()) => Ok(()), + Err(err) => Err(err.into()), + } +} diff --git a/agentwire/src/agent/task.rs b/agentwire/src/agent/task.rs new file mode 100644 index 0000000..37954bb --- /dev/null +++ b/agentwire/src/agent/task.rs @@ -0,0 +1,32 @@ +use super::{Agent, Kill}; +use crate::port; +use futures::prelude::*; +use std::fmt::Debug; +use tokio::task; + +/// Agent running on a dedicated asynchronous task. +pub trait Task: Agent + Send { + /// Error type returned by the agent. + type Error: Debug; + + /// Runs the agent event-loop inside a dedicated asynchronous task. + fn run(self, port: port::Inner) -> impl Future> + Send; + + /// Spawns a new task running the agent event-loop and returns a handle for + /// bi-directional communication with the agent. + fn spawn_task(self) -> (port::Outer, Kill) { + let (inner, outer) = port::new(); + task::spawn(async move { + tracing::info!("Agent {} spawned", Self::NAME); + match self.run(inner).await { + Ok(()) => { + tracing::warn!("Task agent {} exited", Self::NAME); + } + Err(err) => { + tracing::error!("Task agent {} exited with error: {err:#?}", Self::NAME); + } + } + }); + (outer, future::pending().boxed()) + } +} diff --git a/agentwire/src/agent/thread.rs b/agentwire/src/agent/thread.rs new file mode 100644 index 0000000..8fd4367 --- /dev/null +++ b/agentwire/src/agent/thread.rs @@ -0,0 +1,31 @@ +use super::{Agent, Kill}; +use crate::{port, spawn_named_thread}; +use futures::prelude::*; +use std::{fmt::Debug, future, io}; + +/// Agent running on a dedicated OS thread. +pub trait Thread: Agent + Send { + /// Error type returned by the agent. + type Error: Debug; + + /// Runs the agent event-loop inside a dedicated OS thread. + fn run(self, port: port::Inner) -> Result<(), Self::Error>; + + /// Spawns a new thread running the agent event-loop and returns a handle for + /// bi-directional communication with the agent. + fn spawn_thread(self) -> io::Result<(port::Outer, Kill)> { + let (inner, outer) = port::new(); + spawn_named_thread(format!("thrd-{}", Self::NAME), move || { + tracing::info!("Agent {} spawned", Self::NAME); + match self.run(inner) { + Ok(()) => { + tracing::warn!("Thread agent {} exited", Self::NAME); + } + Err(err) => { + tracing::error!("Thread agent {} exited with error: {err:#?}", Self::NAME); + } + } + }); + Ok((outer, future::pending().boxed())) + } +} diff --git a/agentwire/src/lib.rs b/agentwire/src/lib.rs new file mode 100644 index 0000000..911f912 --- /dev/null +++ b/agentwire/src/lib.rs @@ -0,0 +1,350 @@ +//! A framework for asynchronous message-passing agents. +//! +//! There are three main components: +//! - **Agent** - a separate computation unit, which runs in its own isolated +//! task, thread or process. +//! - **Broker** - a manager of agents. It is responsible for spawning agents, +//! and for message passing between them. +//! - **Port** - a bi-directional communication channel between an agent and the +//! broker. +//! +//! # Agent +//! +//! An agent is a computation unit that runs concurrently with other agents. +//! It is a structure that implements [`Agent`], [`Port`](port::Port) and other +//! trais, depending on whether it is a task-based agent, thread-based agent, or +//! a process-based agent. +//! +//! Each agent defines its own input, output, and error types, and a `run` +//! method that is called when the agent is started. The agent structure defines +//! its initial state. +//! +//! See [`agent`] module for more details. +//! +//! # Port +//! +//! A port is a bi-directional communication channel between an agent and the +//! broker. It has an input and an output side. The input side is used by the +//! broker to send messages to the agent, and the output side is used by the +//! agent to send messages to the broker. +//! +//! When used for a process-based agent, the port works via shared memory, and +//! the serialization/deserialization is done using the `rkyv` library. +//! +//! See [`port`] module for more details. +//! +//! # Broker +//! +//! A broker is a manager of agents. It is responsible for spawning agents, +//! handling the agent messages, and running **plans**. A broker shouldn't run +//! any computationally expensive tasks, and should act only as a router between +//! agents. The agents shouldn't be connected to each other directly, but only +//! through the broker. The broker and the agents form a **star topology**. +//! +//! ```ignore +//! use agentwire::{agent, Broker}; +//! +//! #[derive(Broker)] +//! #[broker(plan = Plan, error = Error)] +//! struct MyBroker { +//! #[agent(task)] +//! foo: agent::Cell, +//! // non-agent fields can be added as well +//! bar: String, +//! } +//! +//! // A broker can be created using the `new_broker!` macro, passing the +//! // non-agent fields as arguments. +//! let my_broker = new_my_broker!(bar: "baz".to_string()); +//! ``` +//! +//! See [`Broker`] macro for the full list of supported options. +//! +//! Each broker defines its own **Plan** trait, with a handler for each agent. +//! +//! ```ignore +//! // It's advised to provide a default implementation for each handler +//! trait Plan { +//! // ... +//! +//! fn handle_foo( +//! &mut self, +//! broker: &mut Broker, +//! output: port::Output, +//! ) -> Result { +//! Ok(BrokerFlow::Continue) +//! } +//! +//! // ... +//! } +//! ``` +//! +//! A concrete plan can be defined by implementing the `Plan` trait for a +//! structure, and then calling the `Broker::run` method with the plan. +//! +//! ```ignore +//! struct MyPlan { +//! result: Option, +//! } +//! +//! // A concrete plan can implement a subset of handlers +//! impl Plan for MyPlan { +//! // ... +//! +//! fn handle_foo( +//! &mut self, +//! _broker: &mut Broker, +//! output: port::Output, +//! ) -> Result { +//! self.result = Some(output.value); +//! Ok(BrokerFlow::Break) +//! } +//! +//! // ... +//! } +//! +//! impl MyPlan { +//! // A run method can be defined to run the broker with the plan. +//! pub async fn run(mut self, broker: &mut Broker) -> Option { +//! // Enable needed agents. +//! broker.enable_foo()?; +//! // Run the broker until `BrokerFlow::Break` is returned from one of the handlers. +//! broker.run(&mut self).await?; +//! // Disable unneeded agents. +//! broker.disable_foo(); +//! // Return the result. +//! self.result +//! } +//! } +//! ``` +//! +//! # Process-based agents +//! +//! Process-based agents are agents that run inside their own separate +//! processes. They are isolated from the broker and other agents, and can be +//! used to run untrusted or unreliable code. +//! +//! If process-based agents are used, a special initialization method should be +//! called at the beginning of the program. It will branch the program into an +//! agent process when special environment variables are set. +//! +//! ```ignore +//! use agentwire::agent::Process as _; +//! +//! // NOTE: keep track of all process-based agents here! +//! fn call_process_agent(name: &str, fd: OwnedFd) -> Result<(), Box> { +//! match name { +//! "foo" => Foo::call(fd)?, +//! "bar" => Bar::call(fd)?, +//! _ => panic!("unregistered agent {name}"), +//! } +//! Ok(()) +//! } +//! +//! fn main() { +//! agentwire::agent::process::init(call_process_agent); +//! } +//! ``` +//! +//! # Testing +//! +//! The [`test`] macro is provided to simplify testing of brokers. See the macro +//! documentation for more details. + +#![warn(missing_docs, unsafe_op_in_unsafe_fn)] +#![warn(clippy::pedantic)] +#![allow(clippy::missing_errors_doc, clippy::missing_panics_doc)] + +pub mod agent; +pub mod port; +pub mod testing_rt; + +pub use agent::Agent; + +/// A macro for creating a broker test. +/// +/// # Examples +/// +/// ```ignore +/// #[agentwire::test( +/// // Custom initialization method. Required if the broker has process-based +/// // agents. +/// init = init, +/// // Custom timeout in milliseconds. Defaults to 60000. +/// timeout = 10000, +/// )] +/// async fn test_foo() { +/// let mut broker = new_broker!(); +/// let mut plan = MyPlan::new(); +/// broker.run(&mut plan).await.unwrap(); +/// } +/// +/// // If the broker has process-based agents, a custom initialization method +/// // should be provided. +/// fn init() { +/// agentwire::agent::process::init(|name, fd| match name { +/// "foo" => Ok(Foo::call(fd)?), +/// _ => panic!("unregistered agent {name}"), +/// }); +/// } +/// ``` +pub use agentwire_macros::test; +/// A derive macro for creating a broker. +/// +/// # Examples +/// +/// ```ignore +/// use agentwire::{agent, Broker, BrokerFlow}; +/// use futures::future::BoxFuture; +/// use std::task::{Context, Instant, Poll}; +/// use thiserror::Error; +/// +/// // Define the error type for the broker. +/// #[derive(Error, Debug)] +/// pub enum Error {} +/// +/// // Define the plan trait for the broker. +/// pub trait Plan { +/// fn handle_foo( +/// &mut self, +/// broker: &mut Broker, +/// output: port::Output, +/// ) -> Result { +/// Ok(BrokerFlow::Continue) +/// } +/// } +/// +/// // Define the broker structure. +/// #[derive(Broker)] +/// #[broker( +/// plan = Plan, // Plan trait for the broker (required) +/// error = Error, // Error type used by the generated `run` method (required) +/// poll_extra, // Call `poll_extra` method in the generated `run` method (optional) +/// )] +/// pub struct MyBroker { +/// // Define the agents. Each agent should be annotated with the `agent` +/// // attribute, followed by the agent type (`task`, `thread`, `process`). +/// #[agent( +/// // The agent is task-based +/// task, +/// // The agent is thread-based +/// thread, +/// // The agent is process-based +/// process, +/// // The agent has a custom initialization method (instead of using +/// // `Default`) +/// init, +/// // The agent has a custom asynchronous initialization method (instead +/// // of using `Default`) +/// init_async, +/// // The process-agent has a custom logger +/// logger = self.process_logger().await, +/// )] +/// foo: agent::Cell, +/// // non-agent fields can be added as well +/// bar: String, +/// } +/// +/// impl MyBroker { +/// // Implement the `init_foo` method if the `init` option is enabled. +/// fn init_foo(&mut self) -> Foo { +/// Foo {} +/// } +/// +/// // Implement the asynchronous `init_foo` method if the `init_async` +/// // option is enabled. +/// async fn init_foo(&mut self) -> Result { +/// Ok(Foo {}) +/// } +/// +/// // Implement the handler method for the `foo` agent. +/// fn handle_foo( +/// &mut self, +/// plan: &mut dyn Plan, +/// output: port::Output, +/// ) -> Result { +/// plan.handle_foo(self, output) +/// } +/// +/// // Implement the `poll_extra` method if it's enabled. +/// fn poll_extra( +/// &mut self, +/// plan: &mut dyn Plan, +/// cx: &mut Context<'_>, +/// fence: Instant, +/// ) -> Result>> { +/// Ok(Some(Poll::Pending)) +/// } +/// +/// // Implement a custom logger for process-based agents. +/// async fn process_logger( +/// &self, +/// ) -> impl Fn(&'static str, ChildStdout, ChildStderr) -> BoxFuture<()> + Send + 'static +/// { +/// move |agent_name, stdout, stderr| { +/// Box::pin(agentwire::agent::process::default_logger(agent_name, stdout, stderr)) +/// } +/// } +/// } +/// +/// // `new_my_broker!` macro is generated by the `Broker` macro. It takes the +/// // non-agent fields as arguments. +/// let my_broker = new_my_broker!(bar: "baz".to_string()); +/// ``` +pub use agentwire_macros::Broker; + +use std::{ffi::CString, fmt::Display, io, thread}; +use thiserror::Error; + +/// Used to tell a broker whether it should exit early or go on as usual. +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum BrokerFlow { + /// Continue managing agents. + Continue, + /// Stops the broker returning control to the caller. + Break, +} + +/// The type of error that can occur in a broker. +#[derive(Error, Debug)] +pub enum BrokerError { + /// An agent initialization error. + #[error("agent {0} initialization: {1}")] + Init(&'static str, T), + /// An agent spawning error. + #[error("agent {0} thread spawning: {1}")] + SpawnThread(&'static str, io::Error), + /// An agent handler error. + #[error("agent {0} handler: {1}")] + Handler(&'static str, T), + /// `poll_extra` method error. + #[error("poll_extra: {0}")] + PollExtra(T), + /// An agent has terminated. + #[error("agent {0} terminated")] + AgentTerminated(&'static str), +} + +fn spawn_named_thread(name: impl Into, f: F) -> thread::JoinHandle +where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, +{ + let name = name.into(); + thread::Builder::new() + .name(name.clone()) + .spawn(move || { + if let Ok(title) = CString::new(name.as_bytes()) { + let result = unsafe { libc::prctl(libc::PR_SET_NAME, title.as_ptr(), 0, 0, 0) }; + if result == -1 { + eprintln!( + "failed to set thread name to '{name}': {:#?}", + io::Error::last_os_error() + ); + } + } + f() + }) + .expect("failed to spawn thread") +} diff --git a/agentwire/src/port.rs b/agentwire/src/port.rs new file mode 100644 index 0000000..6908ba2 --- /dev/null +++ b/agentwire/src/port.rs @@ -0,0 +1,929 @@ +//! Bi-directional channel for a computation unit. +//! +//! There are two kinds of ports: internal and shared. +//! +//! # Internal ports +//! +//! Internal port is used when the agent is located in the same process as the +//! broker (task-based and thread-based agents). +//! +//! ```ignore +//! use agentwire::{Agent, Port}; +//! +//! struct Foo; +//! +//! impl Agent for Foo { +//! const NAME: &'static str = "foo"; +//! } +//! +//! impl Port for Foo { +//! type Input = Input; +//! type Output = Output; +//! +//! // Set to `0` to not buffer the input data. +//! const INPUT_CAPACITY: usize = 0; +//! // Set to `0` to not buffer the output data. +//! const OUTPUT_CAPACITY: usize = 0; +//! } +//! +//! enum Input { +//! // .. +//! } +//! +//! enum Output { +//! // .. +//! } +//! ``` +//! +//! # Shared ports +//! +//! Shared port is used when the agent is located in a separate process. The +//! shared memory is used to transfer the data between the processes. +//! +//! A shared port must define the buffer sizes for the initial state, input +//! messages, and output messages. The following example sets the sizes for +//! simple types. If a type contains dynamic data, e.g. vectors or strings, then +//! the buffer size should be set to the maximum possible size of the data. +//! +//! ```ignore +//! use agentwire::{Agent, Port, SharedPort}; +//! use rkyv::{Archive, Deserialize, Serialize}; +//! +//! #[derive(Archive, Serialize, Deserialize)] +//! struct Foo { +//! // .. +//! } +//! +//! impl Agent for Foo { +//! const NAME: &'static str = "foo"; +//! } +//! +//! impl Port for Foo { +//! type Input = Input; +//! type Output = Output; +//! +//! // Set to `0` to not buffer the input data. +//! const INPUT_CAPACITY: usize = 0; +//! // Set to `0` to not buffer the output data. +//! const OUTPUT_CAPACITY: usize = 0; +//! } +//! +//! impl SharedPort for Foo { +//! const SERIALIZED_INIT_SIZE: usize = +//! size_of::() + size_of::<::Archived>(); +//! const SERIALIZED_INPUT_SIZE: usize = +//! size_of::() + size_of::<::Archived>(); +//! const SERIALIZED_OUTPUT_SIZE: usize = +//! size_of::() + size_of::<::Archived>(); +//! } +//! +//! #[derive(Archive, Serialize, Deserialize)] +//! enum Input { +//! // .. +//! } +//! +//! #[derive(Archive, Serialize, Deserialize)] +//! enum Output { +//! // .. +//! } +//! ``` + +use futures::{ + channel::{ + mpsc::{self, SendError}, + oneshot, + }, + future::{select, Either}, + prelude::*, + select_biased, + stream::FusedStream, +}; +use libc::{c_int, c_uint, sem_t}; +use nix::{ + errno::Errno, + sys::{ + memfd::{memfd_create, MemFdCreateFlag}, + mman::{mmap, munmap, MapFlags, ProtFlags}, + }, + unistd::ftruncate, +}; +use rkyv::{ + de::deserializers::SharedDeserializeMap, + ser::{ + serializers::{ + AllocScratch, BufferSerializer, CompositeSerializer, FallbackScratch, HeapScratch, + SharedSerializeMap, + }, + Serializer, + }, + Archive, Deserialize, Infallible, Serialize, +}; +use std::{ + cmp::max, + ffi::{CString, NulError}, + fmt::Debug, + io, + marker::PhantomData, + mem, + num::NonZeroUsize, + os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}, + pin::Pin, + ptr, slice, + task::{Context, Poll}, + time::Instant, +}; +use thiserror::Error; +use tokio::task; + +const SCRATCH_SIZE: usize = 1024; + +/// Error occured during shared memory creation. +#[derive(Error, Debug)] +pub enum CreateSharedMemoryError { + /// Invalid shared memory name. + #[error("invalid name: {0}")] + InvalidName(NulError), + /// Error occured during `memfd_create`. + #[error("memfd_create: {0}")] + MemfdCreate(Errno), + /// Error occured during `ftruncate`. + #[error("ftruncate: {0}")] + Ftruncate(Errno), + #[error("mmap: {0}")] + /// Error occured during `mmap`. + Mmap(Errno), + /// Error occured during semaphore initialization. + #[error("sem_init: {0}")] + SemInit(io::Error), +} + +/// Error occured during shared memory destruction. +#[derive(Error, Debug)] +pub enum DestroySharedMemoryError { + /// Error occured during `munmap`. + #[error("munmap: {0}")] + Munmap(Errno), + /// Error occured during semaphore destruction. + #[error("sem_destroy: {0}")] + SemDestroy(io::Error), +} + +/// Error returned by [`Outer::send_unjam`]. +#[derive(Error, Debug)] +pub enum SendUnjamError { + /// Error occured during message sending. + #[error("send: {0}")] + Send(#[from] SendError), + /// Port is closed. + #[error("port is closed")] + Closed, +} + +/// Bi-directional channel description. +pub trait Port: 'static { + /// Input channel message type. + /// + /// Set to `!` if the agent doesn't have input, e.g. a raw sensor. + type Input: Send + Debug; + + /// Output channel message type. + /// + /// Set to `!` if the agent doesn't have output, e.g. a raw actuator. + type Output: Send + Debug; + + /// Input channel capacity. + /// + /// Set to `0` if the input data should to be as fresh as possible. + const INPUT_CAPACITY: usize; + + /// Output channel capacity. + /// + /// Set to `0` if the output data should to be as fresh as possible. + const OUTPUT_CAPACITY: usize; +} + +/// Shared memory serializer. +pub type SharedSerializer<'a> = CompositeSerializer< + BufferSerializer<&'a mut [u8]>, + FallbackScratch, AllocScratch>, + SharedSerializeMap, +>; + +/// Bi-directional channel description in shared memory. +#[allow(clippy::module_name_repetitions)] +pub trait SharedPort: Port +where + Self::Input: Archive + for<'a> Serialize>, + Self::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + /// Buffer size for input messages. Must be at least `size_of::()` + /// for a zero-sized input. + const SERIALIZED_INPUT_SIZE: usize; + + /// Buffer size for output messages. Must be at least `size_of::()` + /// for a zero-sized output. + const SERIALIZED_OUTPUT_SIZE: usize; + + /// Buffer size for initial agent state. Must be at least + /// `size_of::()` for a zero-sized state. + const SERIALIZED_INIT_SIZE: usize; +} + +/// Input message. +#[derive(Debug)] +pub struct Input { + /// Input value. + pub value: T::Input, + /// Source data timestamp. + pub source_ts: Instant, +} + +/// Archived input message. +pub struct ArchivedInput<'a, T: Port> +where + T::Input: Archive, +{ + /// Archived input value. + pub value: &'a ::Archived, + /// Source data timestamp. + pub source_ts: Instant, +} + +/// Output message. +#[derive(Debug)] +pub struct Output { + /// Output value. + pub value: T::Output, + /// Source data timestamp. + pub source_ts: Instant, +} + +/// A handle for bi-directional communication for the outside of the computation +/// unit. The type implements both [`Sink`] and [`Stream`] for the input and the +/// output channels respectively. +pub struct Outer { + /// Sender channel for the computation unit input. + pub tx: OuterTx, + /// Receiver channel for the computation unit output. + pub rx: OuterRx, +} + +/// A handle for bi-directional communication for the inside of the computation +/// unit. The type implements both [`Sink`] and [`Stream`] for the input and the +/// output channels respectively. +pub struct Inner { + /// Sender channel for the computation unit output. + pub tx: InnerTx, + /// Receiver channel for the computation unit input. + pub rx: InnerRx, +} + +/// A handle for bi-directional communication for the inside of the computation +/// unit, which is located in another process. +pub struct RemoteInner +where + T: SharedPort + Debug + Archive + for<'a> Serialize>, + ::Archived: Deserialize, + T::Input: Archive + for<'a> Serialize>, + T::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + shared_memory: *mut SharedMemory, + scratch: Option, AllocScratch>>, +} + +/// Sender channel for the computation unit input. +pub type OuterTx = mpsc::Sender>; + +/// Receiver channel for the computation unit output. +pub type OuterRx = mpsc::Receiver>; + +/// Sender channel for the computation unit output. +pub type InnerTx = mpsc::Sender>; + +/// Receiver channel for the computation unit input. +pub type InnerRx = mpsc::Receiver>; + +type InitialInputs = Vec<(Box<[u8]>, Instant)>; + +/// Creates a new bi-directional channel. +#[must_use] +pub fn new() -> (Inner, Outer) { + let (input_tx, input_rx) = mpsc::channel(T::INPUT_CAPACITY); + let (output_tx, output_rx) = mpsc::channel(T::OUTPUT_CAPACITY); + let inner = Inner { tx: output_tx, rx: input_rx }; + let outer = Outer { tx: input_tx, rx: output_rx }; + (inner, outer) +} + +impl Input { + /// Creates a new input value with the source timestamp of now. + pub fn new(value: T::Input) -> Self { + Self { value, source_ts: Instant::now() } + } + + /// Creates a new input value with the source timestamp of the original + /// input. + pub fn derive(&self, value: O::Input) -> Input { + Input { value, source_ts: self.source_ts } + } + + /// Creates a new output value with the source timestamp of the input. + pub fn chain(&self, value: T::Output) -> Output { + Output { value, source_ts: self.source_ts } + } + + /// Returns a closure, which creates a new output value with the source + /// timestamp of the input. + pub fn chain_fn(&self) -> impl Fn(T::Output) -> Output { + let source_ts = self.source_ts; + move |value| Output { value, source_ts } + } +} + +impl ArchivedInput<'_, T> +where + T::Input: Archive, +{ + /// Creates a new output value with the source timestamp of the input. + pub fn chain(&self, value: T::Output) -> Output { + Output { value, source_ts: self.source_ts } + } + + /// Returns a closure, which creates a new output value with the source + /// timestamp of the input. + pub fn chain_fn(&self) -> impl Fn(T::Output) -> Output { + let source_ts = self.source_ts; + move |value| Output { value, source_ts } + } +} + +impl Output { + /// Creates a new output value with the source timestamp of now. + pub fn new(value: T::Output) -> Self { + Self { value, source_ts: Instant::now() } + } + + /// Creates a new output value with the source timestamp of the original + /// output. + pub fn derive(&self, value: O::Output) -> Output { + Output { value, source_ts: self.source_ts } + } + + /// Returns a closure, which creates a new output value with the source + /// timestamp of the original output. + pub fn derive_fn(&self) -> impl Fn(O::Output) -> Output { + let source_ts = self.source_ts; + move |value| Output { value, source_ts } + } + + /// Creates a new input value with the source timestamp of the output. + pub fn chain(&self, value: O::Input) -> Input { + Input { value, source_ts: self.source_ts } + } + + /// Returns a closure, which creates a new input value with the source + /// timestamp of the output. + pub fn chain_fn(&self) -> impl Fn(O::Input) -> Input { + let source_ts = self.source_ts; + move |value| Input { value, source_ts } + } +} + +impl Outer { + /// Sends a message avoiding jams. Reading a message from the queue if + /// necessary. + /// + /// This is for situations where the agent may be blocked by sending a + /// message to the broker, but the broker is not listening to new messages + /// from the agent. Instead the broker sends a message to the agent and + /// blocks until it's received by the agent. + #[allow(clippy::mut_mut)] // triggered by `select!` internals + pub async fn send_unjam(&mut self, message: Input) -> Result<(), SendUnjamError> { + let mut send = self.tx.send(message).fuse(); + let mut recv = self.rx.next(); + loop { + select_biased! { + result = send => break Ok(result?), + item = recv => match item { + Some(item) => drop(item), + None => break Err(SendUnjamError::Closed), + } + } + } + } +} + +impl Stream for Outer { + type Item = Output; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.rx).poll_next(cx) + } +} + +impl FusedStream for Outer { + fn is_terminated(&self) -> bool { + self.rx.is_terminated() + } +} + +impl Sink> for Outer { + type Error = SendError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Input) -> Result<(), Self::Error> { + Pin::new(&mut self.tx).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx).poll_close(cx) + } +} + +impl Stream for Inner { + type Item = Input; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.rx).poll_next(cx) + } +} + +impl FusedStream for Inner { + fn is_terminated(&self) -> bool { + self.rx.is_terminated() + } +} + +impl Sink> for Inner { + type Error = SendError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Output) -> Result<(), Self::Error> { + Pin::new(&mut self.tx).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.tx).poll_close(cx) + } +} + +// This is a header of a shared memory. Right after the header, there is a raw +// data buffer. On initialization it contains the initial agent state. After +// initialization it contains the following content in the specific order: +// +// 1. Input buffer 0 +// 2. Input buffer 1 +// 3. Output buffer +struct SharedMemory +where + T: SharedPort + Debug + Archive + for<'a> Serialize>, + ::Archived: Deserialize, + T::Input: Archive + for<'a> Serialize>, + T::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + input_ts: [Instant; 2], + input_tx: sem_t, + input_rx: sem_t, + input_count: usize, + input_index: usize, + output_ts: Instant, + output_tx: sem_t, + output_rx: sem_t, + _marker: PhantomData, +} + +impl SharedMemory +where + T: SharedPort + Debug + Archive + for<'a> Serialize>, + ::Archived: Deserialize, + T::Input: Archive + for<'a> Serialize>, + T::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + fn size_of() -> NonZeroUsize { + let size = mem::size_of::() + + max( + mem::size_of::() + mem::size_of::(), + T::SERIALIZED_INPUT_SIZE * 2 + T::SERIALIZED_OUTPUT_SIZE, + ); + NonZeroUsize::new(size).expect("to always be positive") + } + + unsafe fn create(name: &str) -> Result<(*mut Self, OwnedFd), CreateSharedMemoryError> { + let size = Self::size_of(); + let name = CString::new(name).map_err(CreateSharedMemoryError::InvalidName)?; + let raw_fd = memfd_create(&name, MemFdCreateFlag::empty()) + .map_err(CreateSharedMemoryError::MemfdCreate)? as RawFd; + let fd = unsafe { OwnedFd::from_raw_fd(raw_fd) }; + let len = size.get().try_into().expect("shared memory size is extremely large"); + ftruncate(fd.as_raw_fd(), len).map_err(CreateSharedMemoryError::Ftruncate)?; + let ptr = unsafe { + mmap( + None, + size, + ProtFlags::PROT_READ | ProtFlags::PROT_WRITE, + MapFlags::MAP_SHARED, + fd.as_raw_fd(), + 0, + ) + .map_err(CreateSharedMemoryError::Mmap)? + .cast::() + }; + unsafe { + sem_init(&mut (*ptr).input_tx, 1, 0).map_err(CreateSharedMemoryError::SemInit)?; + sem_init(&mut (*ptr).input_rx, 1, 0).map_err(CreateSharedMemoryError::SemInit)?; + sem_init(&mut (*ptr).output_tx, 1, 1).map_err(CreateSharedMemoryError::SemInit)?; + sem_init(&mut (*ptr).output_rx, 1, 0).map_err(CreateSharedMemoryError::SemInit)?; + (*ptr).input_count = 0; + (*ptr).input_index = 0; + } + Ok((ptr, fd)) + } + + unsafe fn from_fd(fd: OwnedFd) -> Result<*mut Self, Errno> { + let ptr = unsafe { + mmap( + None, + Self::size_of(), + ProtFlags::PROT_READ | ProtFlags::PROT_WRITE, + MapFlags::MAP_SHARED, + fd.as_raw_fd(), + 0, + )? + .cast::() + }; + drop(fd); + Ok(ptr) + } + + unsafe fn destroy(ptr: *mut Self) -> Result<(), DestroySharedMemoryError> { + unsafe { + sem_destroy(&mut (*ptr).input_tx).map_err(DestroySharedMemoryError::SemDestroy)?; + sem_destroy(&mut (*ptr).input_rx).map_err(DestroySharedMemoryError::SemDestroy)?; + sem_destroy(&mut (*ptr).output_tx).map_err(DestroySharedMemoryError::SemDestroy)?; + sem_destroy(&mut (*ptr).output_rx).map_err(DestroySharedMemoryError::SemDestroy)?; + munmap(ptr.cast(), Self::size_of().get()).map_err(DestroySharedMemoryError::Munmap)?; + } + Ok(()) + } + + unsafe fn init_state(&mut self) -> &mut [u8] { + unsafe { + slice::from_raw_parts_mut( + ptr::addr_of_mut!(*self).add(1).cast::(), + T::SERIALIZED_INIT_SIZE, + ) + } + } + + unsafe fn input(&mut self, n: usize) -> &mut [u8] { + unsafe { + slice::from_raw_parts_mut( + ptr::addr_of_mut!(*self).add(1).cast::().add(T::SERIALIZED_INPUT_SIZE * n), + T::SERIALIZED_INPUT_SIZE, + ) + } + } + + unsafe fn output(&mut self) -> &mut [u8] { + unsafe { + slice::from_raw_parts_mut( + ptr::addr_of_mut!(*self).add(1).cast::().add(T::SERIALIZED_INPUT_SIZE * 2), + T::SERIALIZED_OUTPUT_SIZE, + ) + } + } +} + +impl Inner +where + T: SharedPort + Debug + Archive + for<'a> Serialize>, + ::Archived: Deserialize, + T::Input: Archive + for<'a> Serialize>, + T::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + /// Sets up shared memory for this channel. + pub fn into_shared_memory( + self, + name: &str, + init_state: &T, + initial_inputs: InitialInputs, + ) -> Result< + (OwnedFd, impl Future>), + CreateSharedMemoryError, + > { + let Self { tx, rx } = self; + let (ptr, fd) = unsafe { SharedMemory::::create(name)? }; + let addr = ptr as usize; + let (stop_tx_tx, stop_tx_rx) = oneshot::channel(); + let (stop_rx_tx, stop_rx_rx) = oneshot::channel(); + set_init_state(addr, init_state); + let tx_task = spawn_shared_tx_task(tx, addr, stop_tx_rx); + let rx_task = spawn_shared_rx_task(rx, addr, stop_rx_rx, initial_inputs); + let close = async move { + let _ = stop_tx_tx.send(()); + let _ = stop_rx_tx.send(()); + let tx = tx_task.await.unwrap(); + let (rx, mut inputs) = rx_task.await.unwrap(); + unsafe { + let shared_memory = addr as *mut SharedMemory; + assert!((*shared_memory).input_count <= 2); + for mut i in 0..(*shared_memory).input_count { + if (*shared_memory).input_count == 2 && (*shared_memory).input_index == 0 { + i = (i + 1) % 2; + } + let input = Box::from(&*(*shared_memory).input(i)); + let input_ts = (*shared_memory).input_ts[i]; + inputs.push((input, input_ts)); + } + SharedMemory::destroy(shared_memory)?; + Ok((Self { tx, rx }, inputs)) + } + }; + Ok((fd, close)) + } +} + +impl RemoteInner +where + T: SharedPort + Debug + Archive + for<'a> Serialize>, + ::Archived: Deserialize, + T::Input: Archive + for<'a> Serialize>, + T::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + /// Creates a channel from the shared memory. + pub fn from_shared_memory(shmem_fd: OwnedFd) -> Result { + Ok(RemoteInner { + shared_memory: unsafe { SharedMemory::::from_fd(shmem_fd)? }, + scratch: Some(FallbackScratch::default()), + }) + } + + /// Reads the initial state. + #[allow(clippy::missing_panics_doc)] + pub fn init_state(&mut self) -> &::Archived { + unsafe { + let init_state = deserialize_message::((*self.shared_memory).init_state()); + sem_post(&mut (*self.shared_memory).input_tx).expect("semaphore failure"); + init_state + } + } + + /// Waits for a value on the receiver half. + #[allow(clippy::missing_panics_doc)] + pub fn recv(&mut self) -> ArchivedInput<'_, T> { + unsafe { + sem_wait(&mut (*self.shared_memory).input_rx).expect("semaphore failure"); + let input_index = 1 - (*self.shared_memory).input_index; + let value = deserialize_message::((*self.shared_memory).input(input_index)); + let source_ts = (*self.shared_memory).input_ts[input_index]; + sem_post(&mut (*self.shared_memory).input_tx).expect("semaphore failure"); + ArchivedInput { value, source_ts } + } + } + + /// Tries to receive a value on the receiver half. This function doesn't + /// block and returns `None` if the channel is empty. + #[allow(clippy::missing_panics_doc)] + pub fn try_recv(&mut self) -> Option> { + unsafe { + if sem_getvalue(&mut (*self.shared_memory).input_rx).expect("semaphore failure") > 0 { + Some(self.recv()) + } else { + None + } + } + } + + /// Sends a value on this channel. + #[allow(clippy::missing_panics_doc)] + pub fn send(&mut self, output: &Output) { + unsafe { + sem_wait(&mut (*self.shared_memory).output_tx).expect("semaphore failure"); + serialize_message((*self.shared_memory).output(), &mut self.scratch, &output.value); + (*self.shared_memory).output_ts = output.source_ts; + sem_post(&mut (*self.shared_memory).output_rx).expect("semaphore failure"); + } + } + + /// Tries to send a value on this channel. This function doesn't block and + /// do nothing if the channel is full (in which case it returns `false`). + #[allow(clippy::missing_panics_doc)] + pub fn try_send(&mut self, output: &Output) -> bool { + unsafe { + if sem_getvalue(&mut (*self.shared_memory).output_tx).expect("semaphore failure") > 0 { + self.send(output); + true + } else { + false + } + } + } +} + +fn serialize_message( + buf: &mut [u8], + scratch: &mut Option, AllocScratch>>, + value: &T, +) where + T: Archive + for<'a> Serialize> + Debug, +{ + let mut serializer = CompositeSerializer::new( + BufferSerializer::new(&mut buf[mem::size_of::()..]), + scratch.take().unwrap(), + SharedSerializeMap::new(), // reuse of this map doesn't work + ); + serializer.serialize_value(value).expect("failed to serialize an IPC message"); + let size = serializer.pos(); + let (_, c, _) = serializer.into_components(); + buf[..mem::size_of::()].copy_from_slice(&size.to_ne_bytes()); + *scratch = Some(c); +} + +unsafe fn deserialize_message(buf: &[u8]) -> &T::Archived +where + T: Archive + for<'a> Serialize>, +{ + let size = usize::from_ne_bytes(buf[..mem::size_of::()].try_into().unwrap()); + let bytes = &buf[mem::size_of::()..mem::size_of::() + size]; + unsafe { rkyv::archived_root::(bytes) } +} + +fn set_init_state(addr: usize, init_state: &T) +where + T: SharedPort + Debug + Archive + for<'a> Serialize>, + ::Archived: Deserialize, + T::Input: Archive + for<'a> Serialize>, + T::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + let mut scratch = Some(FallbackScratch::default()); + unsafe { + let shared_memory = addr as *mut SharedMemory; + serialize_message((*shared_memory).init_state(), &mut scratch, init_state); + } +} + +fn spawn_shared_tx_task( + mut tx: InnerTx, + addr: usize, + mut stop_tx_rx: oneshot::Receiver<()>, +) -> task::JoinHandle> +where + T: SharedPort + Debug + Archive + for<'a> Serialize>, + ::Archived: Deserialize, + T::Input: Archive + for<'a> Serialize>, + T::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + task::spawn_local(async move { + let spawn_sem_wait = || { + task::spawn_blocking(move || unsafe { + let shared_memory = addr as *mut SharedMemory; + sem_wait(&mut (*shared_memory).output_rx).expect("semaphore failure"); + }) + }; + let mut sem_wait = spawn_sem_wait(); + loop { + if let Either::Left((_, sem_wait)) = select(&mut stop_tx_rx, sem_wait).await { + unsafe { + let shared_memory = addr as *mut SharedMemory; + sem_post(&mut (*shared_memory).output_rx).expect("semaphore failure"); + } + sem_wait.await.unwrap(); + break; + } + let (value, source_ts) = unsafe { + let shared_memory = addr as *mut SharedMemory; + let archived = deserialize_message::((*shared_memory).output()); + // Reuse of `SharedDeserializeMap` doesn't work + let value = archived.deserialize(&mut SharedDeserializeMap::new()).unwrap(); + let source_ts = (*shared_memory).output_ts; + sem_post(&mut (*shared_memory).output_tx).expect("semaphore failure"); + (value, source_ts) + }; + let mut send = tx.feed(Output { value, source_ts }); + match select(&mut stop_tx_rx, &mut send).await { + Either::Left((_, _)) | Either::Right((Err(_), _)) => break, + Either::Right((Ok(result), _)) => result, + } + sem_wait = spawn_sem_wait(); + } + tx + }) +} + +fn spawn_shared_rx_task( + mut rx: InnerRx, + addr: usize, + mut stop_rx_rx: oneshot::Receiver<()>, + mut initial_inputs: InitialInputs, +) -> task::JoinHandle<(InnerRx, InitialInputs)> +where + T: SharedPort + Debug + Archive + for<'a> Serialize>, + ::Archived: Deserialize, + T::Input: Archive + for<'a> Serialize>, + T::Output: Archive + for<'a> Serialize>, + ::Archived: Deserialize, +{ + task::spawn_local(async move { + let spawn_sem_wait = || { + task::spawn_blocking(move || unsafe { + let shared_memory = addr as *mut SharedMemory; + sem_wait(&mut (*shared_memory).input_tx).expect("semaphore failure"); + }) + }; + let mut sem_wait = spawn_sem_wait(); + let mut scratch = Some(FallbackScratch::default()); + loop { + if let Either::Left((_, sem_wait)) = select(&mut stop_rx_rx, sem_wait).await { + unsafe { + let shared_memory = addr as *mut SharedMemory; + sem_post(&mut (*shared_memory).input_tx).expect("semaphore failure"); + } + sem_wait.await.unwrap(); + break; + } + let input = if let Some((input, input_ts)) = initial_inputs.pop() { + Either::Left((input, input_ts)) + } else { + match select(&mut stop_rx_rx, rx.next()).await { + Either::Left((_, _)) | Either::Right((None, _)) => break, + Either::Right((Some(input), _)) => Either::Right(input), + } + }; + unsafe { + let shared_memory = addr as *mut SharedMemory; + let input_index = (*shared_memory).input_index; + (*shared_memory).input_count = ((*shared_memory).input_count + 1).min(2); + (*shared_memory).input_index = ((*shared_memory).input_index + 1) % 2; + match input { + Either::Left((input, input_ts)) => { + ptr::copy_nonoverlapping::( + input.as_ptr(), + (*shared_memory).input(input_index).as_mut_ptr(), + input.len(), + ); + (*shared_memory).input_ts[input_index] = input_ts; + } + Either::Right(input) => { + serialize_message( + (*shared_memory).input(input_index), + &mut scratch, + &input.value, + ); + (*shared_memory).input_ts[input_index] = input.source_ts; + } + } + sem_post(&mut (*shared_memory).input_rx).expect("semaphore failure"); + } + sem_wait = spawn_sem_wait(); + } + (rx, initial_inputs) + }) +} + +unsafe fn sem_init(sem: *mut sem_t, pshared: c_int, value: c_uint) -> io::Result<()> { + let result = unsafe { libc::sem_init(sem, pshared, value) }; + if result == -1 { Err(io::Error::last_os_error()) } else { Ok(()) } +} + +unsafe fn sem_destroy(sem: *mut sem_t) -> io::Result<()> { + let result = unsafe { libc::sem_destroy(sem) }; + if result == -1 { Err(io::Error::last_os_error()) } else { Ok(()) } +} + +unsafe fn sem_post(sem: *mut sem_t) -> io::Result<()> { + let result = unsafe { libc::sem_post(sem) }; + if result == -1 { Err(io::Error::last_os_error()) } else { Ok(()) } +} + +unsafe fn sem_getvalue(sem: *mut sem_t) -> io::Result { + let mut value = 0; + let result = unsafe { libc::sem_getvalue(sem, &mut value) }; + if result == -1 { Err(io::Error::last_os_error()) } else { Ok(value) } +} + +unsafe fn sem_wait(sem: *mut sem_t) -> io::Result<()> { + let result = unsafe { libc::sem_wait(sem) }; + if result == -1 { Err(io::Error::last_os_error()) } else { Ok(()) } +} diff --git a/agentwire/src/testing_rt.rs b/agentwire/src/testing_rt.rs new file mode 100644 index 0000000..14713d8 --- /dev/null +++ b/agentwire/src/testing_rt.rs @@ -0,0 +1,78 @@ +//! Testing runtime. + +#![doc(hidden)] + +use crate::agent; +use futures::prelude::*; +use std::{ + env, + panic::{catch_unwind, AssertUnwindSafe}, + pin::Pin, + process, + time::Duration, +}; +use tokio::{process::Command, runtime, time}; + +/// Name of the environment variable used to pass the test ID. +pub const BROKER_TEST_ID_ENV: &str = "AGENTWIRE_BROKER_TEST_ID"; + +/// Default timeout for broker tests. +pub const DEFAULT_TIMEOUT: u64 = 60_000; + +/// Runs a broker test. +pub fn run_broker_test( + test_name: &str, + test_id: &str, + timeout: Duration, + init: impl FnOnce(), + f: Pin>>, +) { + let test_id = format!("{test_id:?}"); + if env::var(BROKER_TEST_ID_ENV).map_or(false, |var| var == test_id) { + let result = catch_unwind(AssertUnwindSafe(|| { + init(); + tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap().block_on(f); + })); + process::exit(result.is_err().into()); + } + let mut test_runner_args = env::args(); + let mut child_args = Vec::new(); + while let Some(arg) = test_runner_args.next() { + match arg.as_str() { + "--bench" + | "--exclude-should-panic" + | "--force-run-in-process" + | "--ignored" + | "--include-ignored" + | "--show-output" + | "--test" => { + child_args.push(arg); + } + "--color" | "-Z" => { + child_args.push(arg); + if let Some(arg) = test_runner_args.next() { + child_args.push(arg); + } + } + _ => {} + } + } + child_args.push("--quiet".into()); + child_args.push("--test-threads".into()); + child_args.push("1".into()); + child_args.push("--nocapture".into()); + child_args.push("--exact".into()); + child_args.push("--".into()); + child_args.push(test_name.into()); + let result = + runtime::Builder::new_current_thread().enable_all().build().unwrap().block_on(async { + let mut child = Command::new(env::current_exe().unwrap()) + .args(&child_args) + .env(BROKER_TEST_ID_ENV, test_id) + .env(agent::process::ARGS_ENV, shell_words::join(&child_args)) + .spawn() + .unwrap(); + time::timeout(timeout, child.wait()).await.expect("timeouted").unwrap() + }); + assert!(result.success(), "test failed"); +} diff --git a/agentwire/tests/process.rs b/agentwire/tests/process.rs new file mode 100644 index 0000000..43a24a3 --- /dev/null +++ b/agentwire/tests/process.rs @@ -0,0 +1,111 @@ +use agentwire::{ + agent::{self, Process as _}, + port::{self, Port, SharedPort}, + Agent, Broker, BrokerFlow, +}; +use futures::prelude::*; +use rkyv::{Archive, Deserialize, Serialize}; +use std::{mem::size_of, time::Instant}; +use thiserror::Error; + +#[derive(Clone, Default, Archive, Serialize, Deserialize, Debug)] +struct Doubler; + +impl Port for Doubler { + type Input = u32; + type Output = u32; + + const INPUT_CAPACITY: usize = 0; + const OUTPUT_CAPACITY: usize = 0; +} + +impl SharedPort for Doubler { + const SERIALIZED_INIT_SIZE: usize = + size_of::() + size_of::<::Archived>(); + const SERIALIZED_INPUT_SIZE: usize = + size_of::() + size_of::<::Archived>(); + const SERIALIZED_OUTPUT_SIZE: usize = + size_of::() + size_of::<::Archived>(); +} + +impl Agent for Doubler { + const NAME: &'static str = "doubler"; +} + +#[derive(Error, Debug)] +pub enum DoublerError {} + +impl agent::Process for Doubler { + type Error = DoublerError; + + fn run(self, mut port: port::RemoteInner) -> Result<(), Self::Error> { + loop { + let input = port.recv(); + let output = input.chain(input.value * 2); + port.send(&output); + } + } +} + +#[derive(Error, Debug)] +pub enum Error {} + +trait Plan { + fn handle_doubler( + &mut self, + broker: &mut Broker, + output: port::Output, + ) -> Result; +} + +#[derive(Broker)] +#[broker(plan = Plan, error = Error)] +struct Broker { + #[agent(process)] + doubler: agent::Cell, +} + +impl Broker { + fn handle_doubler( + &mut self, + plan: &mut dyn Plan, + output: port::Output, + ) -> Result { + plan.handle_doubler(self, output) + } +} + +fn init() { + agent::process::init(|name, fd| match name { + "doubler" => Ok(Doubler::call(fd)?), + _ => panic!("unregistered agent {name}"), + }); +} + +#[agentwire::test(init = init)] +async fn test_process() { + struct TestPlan { + result: Option, + } + impl Plan for TestPlan { + fn handle_doubler( + &mut self, + _broker: &mut Broker, + output: port::Output, + ) -> Result { + self.result = Some(output.value); + Ok(BrokerFlow::Break) + } + } + + let mut broker = new_broker!(); + let mut plan = TestPlan { result: None }; + broker.enable_doubler().unwrap(); + + let fence = Instant::now(); + broker.doubler.enabled().unwrap().send(port::Input::new(3)).await.unwrap(); + broker.run_with_fence(&mut plan, fence).await.unwrap(); + + broker.disable_doubler(); + assert_eq!(plan.result, Some(6)); +} diff --git a/agentwire/tests/task.rs b/agentwire/tests/task.rs new file mode 100644 index 0000000..c55e62c --- /dev/null +++ b/agentwire/tests/task.rs @@ -0,0 +1,90 @@ +use agentwire::{ + agent, + port::{self, Port}, + Agent, Broker, BrokerFlow, +}; +use futures::{channel::mpsc::SendError, prelude::*}; +use std::time::Instant; +use thiserror::Error; + +#[derive(Default)] +struct Doubler; + +impl Port for Doubler { + type Input = u32; + type Output = u32; + + const INPUT_CAPACITY: usize = 0; + const OUTPUT_CAPACITY: usize = 0; +} + +impl Agent for Doubler { + const NAME: &'static str = "doubler"; +} + +impl agent::Task for Doubler { + type Error = SendError; + + async fn run(self, mut port: port::Inner) -> Result<(), Self::Error> { + while let Some(x) = port.next().await { + port.send(x.chain(x.value * 2)).await?; + } + Ok(()) + } +} + +#[derive(Error, Debug)] +pub enum Error {} + +trait Plan { + fn handle_doubler( + &mut self, + broker: &mut Broker, + output: port::Output, + ) -> Result; +} + +#[derive(Broker)] +#[broker(plan = Plan, error = Error)] +struct Broker { + #[agent(task)] + doubler: agent::Cell, +} + +impl Broker { + fn handle_doubler( + &mut self, + plan: &mut dyn Plan, + output: port::Output, + ) -> Result { + plan.handle_doubler(self, output) + } +} + +#[agentwire::test] +async fn test_task() { + struct TestPlan { + result: Option, + } + impl Plan for TestPlan { + fn handle_doubler( + &mut self, + _broker: &mut Broker, + output: port::Output, + ) -> Result { + self.result = Some(output.value); + Ok(BrokerFlow::Break) + } + } + + let mut broker = new_broker!(); + let mut plan = TestPlan { result: None }; + broker.enable_doubler().unwrap(); + + let fence = Instant::now(); + broker.doubler.enabled().unwrap().send(port::Input::new(3)).await.unwrap(); + broker.run_with_fence(&mut plan, fence).await.unwrap(); + + broker.disable_doubler(); + assert_eq!(plan.result, Some(6)); +} diff --git a/agentwire/tests/thread.rs b/agentwire/tests/thread.rs new file mode 100644 index 0000000..ce9c989 --- /dev/null +++ b/agentwire/tests/thread.rs @@ -0,0 +1,100 @@ +use agentwire::{ + agent, + port::{self, Port}, + Agent, Broker, BrokerFlow, +}; +use futures::{channel::mpsc::SendError, prelude::*}; +use std::{io, time::Instant}; +use thiserror::Error; +use tokio::runtime; + +#[derive(Default)] +struct Doubler; + +impl Port for Doubler { + type Input = u32; + type Output = u32; + + const INPUT_CAPACITY: usize = 0; + const OUTPUT_CAPACITY: usize = 0; +} + +impl Agent for Doubler { + const NAME: &'static str = "doubler"; +} + +#[derive(Error, Debug)] +pub enum DoublerError { + #[error("tokio runtime error")] + Runtime(#[from] io::Error), + #[error("send error")] + Send(#[from] SendError), +} + +impl agent::Thread for Doubler { + type Error = DoublerError; + + fn run(self, mut port: port::Inner) -> Result<(), Self::Error> { + let rt = runtime::Builder::new_current_thread().enable_all().build()?; + while let Some(x) = rt.block_on(port.next()) { + rt.block_on(port.send(x.chain(x.value * 2)))?; + } + Ok(()) + } +} + +#[derive(Error, Debug)] +pub enum Error {} + +trait Plan { + fn handle_doubler( + &mut self, + broker: &mut Broker, + output: port::Output, + ) -> Result; +} + +#[derive(Broker)] +#[broker(plan = Plan, error = Error)] +struct Broker { + #[agent(thread)] + doubler: agent::Cell, +} + +impl Broker { + fn handle_doubler( + &mut self, + plan: &mut dyn Plan, + output: port::Output, + ) -> Result { + plan.handle_doubler(self, output) + } +} + +#[agentwire::test] +async fn test_thread() { + struct TestPlan { + result: Option, + } + impl Plan for TestPlan { + fn handle_doubler( + &mut self, + _broker: &mut Broker, + output: port::Output, + ) -> Result { + self.result = Some(output.value); + Ok(BrokerFlow::Break) + } + } + + let mut broker = new_broker!(); + let mut plan = TestPlan { result: None }; + broker.enable_doubler().unwrap(); + + let fence = Instant::now(); + broker.doubler.enabled().unwrap().send(port::Input::new(3)).await.unwrap(); + broker.run_with_fence(&mut plan, fence).await.unwrap(); + + broker.disable_doubler(); + assert_eq!(plan.result, Some(6)); +}