From c64bc6d445fc35dd0135f3e3433a486895517094 Mon Sep 17 00:00:00 2001 From: imDema Date: Thu, 5 Oct 2023 11:00:13 +0200 Subject: [PATCH] Split network modules for sync and async --- Cargo.toml | 6 + examples/nexmark.rs | 2 +- src/network/mod.rs | 13 +- src/network/{ => sync}/demultiplexer.rs | 168 ---------------------- src/network/sync/mod.rs | 3 + src/network/sync/multiplexer.rs | 144 +++++++++++++++++++ src/network/{ => sync}/remote.rs | 106 -------------- src/network/tokio/demultiplexer.rs | 184 ++++++++++++++++++++++++ src/network/tokio/mod.rs | 3 + src/network/{ => tokio}/multiplexer.rs | 122 +--------------- src/network/tokio/remote.rs | 159 ++++++++++++++++++++ 11 files changed, 512 insertions(+), 398 deletions(-) rename src/network/{ => sync}/demultiplexer.rs (53%) create mode 100644 src/network/sync/mod.rs create mode 100644 src/network/sync/multiplexer.rs rename src/network/{ => sync}/remote.rs (60%) create mode 100644 src/network/tokio/demultiplexer.rs create mode 100644 src/network/tokio/mod.rs rename src/network/{ => tokio}/multiplexer.rs (57%) create mode 100644 src/network/tokio/remote.rs diff --git a/Cargo.toml b/Cargo.toml index 721d192b..49bf0b6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -145,3 +145,9 @@ harness = false lto = true strip = "symbols" # debug = 1 + +[profile.release-fast] +inherits = "release" +lto = true +codegen-units = 1 +panic = "abort" diff --git a/examples/nexmark.rs b/examples/nexmark.rs index cec92170..85ae8060 100644 --- a/examples/nexmark.rs +++ b/examples/nexmark.rs @@ -462,7 +462,7 @@ fn filter_bid(e: Event) -> Option { } fn main() { - env_logger::init(); + tracing_subscriber::fmt::init(); let (config, args) = EnvironmentConfig::from_args(); if args.len() != 2 { diff --git a/src/network/mod.rs b/src/network/mod.rs index e1d17d02..8a910c95 100644 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -8,10 +8,17 @@ pub(crate) use topology::*; use crate::operator::StreamElement; use crate::scheduler::{BlockId, HostId, ReplicaId}; -mod demultiplexer; -mod multiplexer; +#[cfg(feature = "async-tokio")] +mod tokio; +#[cfg(feature = "async-tokio")] +use tokio::*; + +#[cfg(not(feature = "async-tokio"))] +mod sync; +#[cfg(not(feature = "async-tokio"))] +use sync::*; + mod network_channel; -mod remote; mod topology; #[derive(Debug, Clone)] diff --git a/src/network/demultiplexer.rs b/src/network/sync/demultiplexer.rs similarity index 53% rename from src/network/demultiplexer.rs rename to src/network/sync/demultiplexer.rs index 036fc077..81d909ee 100644 --- a/src/network/demultiplexer.rs +++ b/src/network/sync/demultiplexer.rs @@ -1,15 +1,6 @@ -#[cfg(not(feature = "async-tokio"))] use std::net::{Shutdown, TcpListener, TcpStream}; -#[cfg(not(feature = "async-tokio"))] use std::thread::JoinHandle; -#[cfg(feature = "async-tokio")] -use tokio::io::AsyncWriteExt; -#[cfg(feature = "async-tokio")] -use tokio::net::{TcpListener, TcpStream}; -#[cfg(feature = "async-tokio")] -use tokio::task::JoinHandle; - use anyhow::anyhow; use std::collections::HashMap; use std::net::ToSocketAddrs; @@ -32,7 +23,6 @@ pub(crate) struct DemuxHandle { tx_senders: UnboundedSender<(ReceiverEndpoint, Sender>)>, } -#[cfg(not(feature = "async-tokio"))] impl DemuxHandle { /// Construct a new `DemultiplexingReceiver` for a block. /// @@ -70,7 +60,6 @@ impl DemuxHandle { } /// Bind the socket of this demultiplexer. -#[cfg(not(feature = "async-tokio"))] fn bind_remotes( coord: DemuxCoord, address: (String, u16), @@ -175,7 +164,6 @@ fn bind_remotes( /// + Return an enum, either Queued or Overflowed /// /// if overflowed send a yield request through a second channel -#[cfg(not(feature = "async-tokio"))] fn demux_thread( coord: DemuxCoord, senders: HashMap>>, @@ -199,159 +187,3 @@ fn demux_thread( let _ = stream.shutdown(Shutdown::Both); log::debug!("{} finished", coord); } - -#[cfg(feature = "async-tokio")] -impl DemuxHandle { - /// Construct a new `DemultiplexingReceiver` for a block. - /// - /// All the local replicas of this block should be registered to this demultiplexer. - /// `num_client` is the number of multiplexers that will connect to this demultiplexer. Since - /// the remote senders are all multiplexed this corresponds to the number of remote replicas in - /// the previous block (relative to the block this demultiplexer refers to). - pub fn new( - coord: DemuxCoord, - address: (String, u16), - num_clients: usize, - ) -> (Self, JoinHandle<()>) { - let (tx_senders, rx_senders) = channel::unbounded(); - - let join_handle = tokio::spawn(bind_remotes(coord, address, num_clients, rx_senders)); - (Self { coord, tx_senders }, join_handle) - } - - /// Register a local receiver to this demultiplexer. - pub fn register( - &mut self, - receiver_endpoint: ReceiverEndpoint, - sender: Sender>, - ) { - log::debug!( - "registering {} to the demultiplexer of {}", - receiver_endpoint, - self.coord - ); - self.tx_senders - .send((receiver_endpoint, sender)) - .unwrap_or_else(|_| panic!("register for {:?} failed", self.coord)) - } -} - -/// Bind the socket of this demultiplexer. -#[cfg(feature = "async-tokio")] -async fn bind_remotes( - coord: DemuxCoord, - address: (String, u16), - num_clients: usize, - rx_senders: UnboundedReceiver<(ReceiverEndpoint, Sender>)>, -) { - let address = (address.0.as_ref(), address.1); - let address: Vec<_> = address - .to_socket_addrs() - .map_err(|e| format!("Failed to get the address for {}: {:?}", coord, e)) - .unwrap() - .collect(); - - log::debug!("demux binding {}", address[0]); - let listener = TcpListener::bind(&*address) - .await - .map_err(|e| { - anyhow!( - "Failed to bind socket for {} at {:?}: {:?}", - coord, - address, - e - ) - }) - .unwrap(); - let address = listener - .local_addr() - .map(|a| a.to_string()) - .unwrap_or_else(|_| "unknown".to_string()); - info!( - "Remote receiver at {} is ready to accept {} connections to {}", - coord, num_clients, address - ); - - // the list of JoinHandle of all the spawned threads, including the demultiplexer one - let mut join_handles = vec![]; - let mut tx_broadcast = vec![]; - - let mut connected_clients = 0; - while connected_clients < num_clients { - let stream = listener.accept().await; - let (stream, peer_addr) = match stream { - Ok(stream) => stream, - Err(e) => { - warn!("Failed to accept incoming connection at {}: {:?}", coord, e); - continue; - } - }; - connected_clients += 1; - info!( - "Remote receiver at {} accepted a new connection from {} ({} / {})", - coord, peer_addr, connected_clients, num_clients - ); - - let (demux_tx, demux_rx) = flume::unbounded(); - let join_handle = tokio::spawn(async move { - let mut senders = HashMap::new(); - while let Ok((endpoint, sender)) = demux_rx.recv_async().await { - senders.insert(endpoint, sender); - } - log::debug!("demux got senders"); - demux_thread::(coord, senders, stream).await; - }); - join_handles.push(join_handle); - tx_broadcast.push(demux_tx); - } - log::debug!("All connection to {} started, waiting for senders", coord); - - // Broadcast senders - while let Ok(t) = rx_senders.recv() { - for tx in tx_broadcast.iter() { - tx.send(t.clone()).unwrap(); - } - } - drop(tx_broadcast); // Start all demuxes - for handle in join_handles { - handle.await.unwrap(); - } - log::debug!("all demuxes for {} finished", coord); -} - -/// Handle the connection with a remote sender. -/// -/// Will deserialize the message upon arrival and send to the corresponding recipient the -/// deserialized data. If the recipient is not yet known, it is waited until it registers. -/// -/// # Upgrade path -/// -/// Replace send with queue. -/// -/// The queue uses a hierarchical queue: -/// + First try to reserve and put the value in the fast queue -/// + If the fast queue is full, put in the slow (unbounded?) queue -/// + Return an enum, either Queued or Overflowed -/// -/// if overflowed send a yield request through a second channel -#[cfg(feature = "async-tokio")] -async fn demux_thread( - coord: DemuxCoord, - senders: HashMap>>, - mut stream: TcpStream, -) { - let address = stream - .peer_addr() - .map(|a| a.to_string()) - .unwrap_or_else(|_| "unknown".to_string()); - log::debug!("{} started", coord); - - while let Some((dest, message)) = remote_recv(coord, &mut stream, &address).await { - if let Err(e) = senders[&dest].send(message) { - warn!("demux failed to send message to {}: {:?}", dest, e); - } - } - - stream.shutdown().await.unwrap(); - log::debug!("{} finished", coord); -} diff --git a/src/network/sync/mod.rs b/src/network/sync/mod.rs new file mode 100644 index 00000000..ae5a84f4 --- /dev/null +++ b/src/network/sync/mod.rs @@ -0,0 +1,3 @@ +pub(super) mod demultiplexer; +pub(super) mod multiplexer; +pub(super) mod remote; diff --git a/src/network/sync/multiplexer.rs b/src/network/sync/multiplexer.rs new file mode 100644 index 00000000..65208cff --- /dev/null +++ b/src/network/sync/multiplexer.rs @@ -0,0 +1,144 @@ +use std::io::ErrorKind; +use std::time::Duration; + +use std::net::{Shutdown, TcpStream, ToSocketAddrs}; +use std::thread::{sleep, JoinHandle}; + +use crate::channel::{self, Receiver, Sender}; +use crate::network::remote::remote_send; +use crate::network::{DemuxCoord, NetworkMessage, ReceiverEndpoint}; +use crate::operator::ExchangeData; + +// +// use crate::channel::Selector; + +use crate::network::NetworkSender; + +/// Maximum number of attempts to make for connecting to a remote host. +const CONNECT_ATTEMPTS: usize = 32; +/// Timeout for connecting to a remote host. +const CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +/// To avoid spamming the connections, wait this timeout before trying again. If the connection +/// fails again this timeout will be doubled up to `RETRY_MAX_TIMEOUT`. +const RETRY_INITIAL_TIMEOUT: Duration = Duration::from_millis(8); +/// Maximum timeout between connection attempts. +const RETRY_MAX_TIMEOUT: Duration = Duration::from_secs(1); + +const MUX_CHANNEL_CAPACITY: usize = 10; +/// Like `NetworkSender`, but this should be used in a multiplexed channel (i.e. a remote one). +/// +/// The `ReceiverEndpoint` is sent alongside the actual message in order to demultiplex it. +#[derive(Debug)] +pub struct MultiplexingSender { + tx: Option)>>, +} + +impl MultiplexingSender { + pub fn new(coord: DemuxCoord, address: (String, u16)) -> (Self, JoinHandle<()>) { + let (tx, rx) = channel::bounded(MUX_CHANNEL_CAPACITY); + + let join_handle = std::thread::Builder::new() + .name(format!( + "mux-{}:{}-{}", + coord.coord.host_id, coord.prev_block_id, coord.coord.block_id + )) + .spawn(move || { + log::debug!( + "mux {coord} connecting to {}", + address.to_socket_addrs().unwrap().next().unwrap() + ); + let stream = connect_remote(coord, address); + + mux_thread::(coord, rx, stream); + }) + .unwrap(); + (Self { tx: Some(tx) }, join_handle) + } + + pub(crate) fn get_sender(&mut self, receiver_endpoint: ReceiverEndpoint) -> NetworkSender { + crate::network::mux_sender(receiver_endpoint, self.tx.as_ref().unwrap().clone()) + } +} + +/// Connect the sender to a remote channel located at the specified address. +/// +/// - At first the address is resolved to an actual address (DNS resolution) +/// - Then at most `CONNECT_ATTEMPTS` are performed, and an exponential backoff is used in case +/// of errors. +/// - If the connection cannot be established this function will panic. +fn connect_remote(coord: DemuxCoord, address: (String, u16)) -> TcpStream { + let socket_addrs: Vec<_> = address + .to_socket_addrs() + .map_err(|e| format!("Failed to get the address for {coord}: {e:?}",)) + .unwrap() + .collect(); + let mut retry_delay = RETRY_INITIAL_TIMEOUT; + for attempt in 1..=CONNECT_ATTEMPTS { + log::debug!( + "{} connecting to {:?} ({} attempt)", + coord, + socket_addrs, + attempt, + ); + + for address in socket_addrs.iter() { + match TcpStream::connect_timeout(address, CONNECT_TIMEOUT) { + Ok(stream) => { + return stream; + } + Err(err) => match err.kind() { + ErrorKind::TimedOut => { + log::debug!("{coord} timeout connecting to {address:?}"); + } + ErrorKind::ConnectionRefused => { + log::log!( + if attempt > 4 { + log::Level::Warn + } else { + log::Level::Debug + }, + "{coord} connection refused connecting to {address:?} ({attempt})" + ); + } + _ => { + log::warn!("{coord} failed to connect to {address:?}: {err:?}"); + } + }, + } + } + + log::debug!( + "{coord} retrying connection to {socket_addrs:?} in {}s", + retry_delay.as_secs_f32(), + ); + + sleep(retry_delay); + retry_delay = (2 * retry_delay).min(RETRY_MAX_TIMEOUT); + } + panic!("Failed to connect to remote {coord} at {address:?} after {CONNECT_ATTEMPTS} attempts",); +} + +fn mux_thread( + coord: DemuxCoord, + rx: Receiver<(ReceiverEndpoint, NetworkMessage)>, + mut stream: TcpStream, +) { + use std::io::Write; + + let address = stream + .peer_addr() + .map(|a| a.to_string()) + .unwrap_or_else(|_| "unknown".to_string()); + log::debug!("{} connected to {:?}", coord, address); + + // let mut w = std::io::BufWriter::new(&mut stream); + let mut w = &mut stream; + + while let Ok((dest, message)) = rx.recv() { + remote_send(message, dest, &mut w, &address); + } + + w.flush().unwrap(); + let _ = stream.shutdown(Shutdown::Both); + log::debug!("{} finished", coord); +} diff --git a/src/network/remote.rs b/src/network/sync/remote.rs similarity index 60% rename from src/network/remote.rs rename to src/network/sync/remote.rs index 8572d4d2..348c5f42 100644 --- a/src/network/remote.rs +++ b/src/network/sync/remote.rs @@ -3,8 +3,6 @@ use once_cell::sync::Lazy; use std::io::Read; #[cfg(not(feature = "async-tokio"))] use std::io::Write; -#[cfg(feature = "async-tokio")] -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use bincode::config::{FixintEncoding, RejectTrailing, WithOtherIntEncoding, WithOtherTrailing}; use bincode::{DefaultOptions, Options}; @@ -95,68 +93,6 @@ pub(crate) fn remote_send( ); } -/// Serialize and send a message to a remote socket. -/// -/// The network protocol works as follow: -/// - send a `MessageHeader` serialized with bincode with `FixintEncoding` -/// - send the message -#[cfg(feature = "async-tokio")] -pub(crate) async fn remote_send( - msg: NetworkMessage, - dest: ReceiverEndpoint, - writer: &mut W, - address: &str, -) { - let serialized_len = BINCODE_MSG_CONFIG - .serialized_size(&msg) - .unwrap_or_else(|e| { - panic!( - "Failed to compute serialized length of outgoing message to {}: {:?}", - dest, e - ) - }); - - let header = MessageHeader { - size: serialized_len.try_into().unwrap(), - replica_id: dest.coord.replica_id, - sender_block_id: dest.prev_block_id, - }; - - let mut buf = Vec::with_capacity(HEADER_SIZE + serialized_len as usize); - - BINCODE_HEADER_CONFIG - .serialize_into(&mut buf, &header) - .unwrap_or_else(|e| { - panic!( - "Failed to serialize header of message (was {} bytes) to {} at {}: {:?}", - serialized_len, dest, address, e - ) - }); - - BINCODE_MSG_CONFIG - .serialize_into(&mut buf, &msg) - .unwrap_or_else(|e| { - panic!( - "Failed to serialize message, {} bytes to {} at {}: {:?}", - serialized_len, dest, address, e - ) - }); - assert_eq!(buf.len(), HEADER_SIZE + serialized_len as usize); - - writer.write_all(buf.as_ref()).await.unwrap_or_else(|e| { - panic!( - "Failed to send message {} bytes to {} at {}: {:?}", - serialized_len, dest, address, e - ) - }); - - get_profiler().net_bytes_out( - msg.sender, - dest.coord, - HEADER_SIZE + serialized_len as usize, - ); -} - /// Receive a message from the remote channel. Returns `None` if there was a failure receiving the /// last message. /// @@ -203,48 +139,6 @@ pub(crate) fn remote_recv( Some((dest, msg)) } -#[cfg(feature = "async-tokio")] -pub(crate) async fn remote_recv( - coord: DemuxCoord, - reader: &mut R, - address: &str, -) -> Option<(ReceiverEndpoint, NetworkMessage)> { - let mut header = [0u8; HEADER_SIZE]; - match reader.read_exact(&mut header).await { - Ok(_) => {} - Err(e) => { - log::trace!( - "Failed to receive {} bytes of header to {} from {}: {:?}", - HEADER_SIZE, - coord, - address, - e - ); - return None; - } - } - let header: MessageHeader = BINCODE_HEADER_CONFIG - .deserialize(&header) - .expect("Malformed header"); - let mut buf = vec![0u8; header.size as usize]; - reader.read_exact(&mut buf).await.unwrap_or_else(|e| { - panic!( - "Failed to receive {} bytes to {} from {}: {:?}", - header.size, coord, address, e - ) - }); - let msg: NetworkMessage = BINCODE_MSG_CONFIG - .deserialize(buf.as_ref()) - .expect("Malformed message"); - - let dest = ReceiverEndpoint::new( - Coord::new(coord.coord.block_id, coord.coord.host_id, header.replica_id), - header.sender_block_id, - ); - get_profiler().net_bytes_in(msg.sender, dest.coord, HEADER_SIZE + header.size as usize); - Some((dest, msg)) -} - #[cfg(test)] mod tests { use bincode::Options; diff --git a/src/network/tokio/demultiplexer.rs b/src/network/tokio/demultiplexer.rs new file mode 100644 index 00000000..609d0e6e --- /dev/null +++ b/src/network/tokio/demultiplexer.rs @@ -0,0 +1,184 @@ +#[cfg(feature = "async-tokio")] +use tokio::io::AsyncWriteExt; +#[cfg(feature = "async-tokio")] +use tokio::net::{TcpListener, TcpStream}; +#[cfg(feature = "async-tokio")] +use tokio::task::JoinHandle; + +use anyhow::anyhow; +use std::collections::HashMap; +use std::net::ToSocketAddrs; + +use crate::channel::{self, Sender, UnboundedReceiver, UnboundedSender}; +use crate::network::remote::remote_recv; +use crate::network::{DemuxCoord, NetworkMessage, ReceiverEndpoint}; +use crate::operator::ExchangeData; + +/// Like `NetworkReceiver`, but this should be used in a multiplexed channel (i.e. a remote one). +/// +/// This receiver is handled in a separate thread that keeps track of the local registered receivers +/// and the open connections. The incoming messages are tagged with the receiver endpoint. Upon +/// arrival they are routed to the correct receiver according to the `ReceiverEndpoint` the message +/// is tagged with. +#[derive(Debug)] +pub(crate) struct DemuxHandle { + coord: DemuxCoord, + /// Tell the dem&ultiplexer that a new receiver is present, + tx_senders: UnboundedSender<(ReceiverEndpoint, Sender>)>, +} + +#[cfg(feature = "async-tokio")] +impl DemuxHandle { + /// Construct a new `DemultiplexingReceiver` for a block. + /// + /// All the local replicas of this block should be registered to this demultiplexer. + /// `num_client` is the number of multiplexers that will connect to this demultiplexer. Since + /// the remote senders are all multiplexed this corresponds to the number of remote replicas in + /// the previous block (relative to the block this demultiplexer refers to). + pub fn new( + coord: DemuxCoord, + address: (String, u16), + num_clients: usize, + ) -> (Self, JoinHandle<()>) { + let (tx_senders, rx_senders) = channel::unbounded(); + + let join_handle = tokio::spawn(bind_remotes(coord, address, num_clients, rx_senders)); + (Self { coord, tx_senders }, join_handle) + } + + /// Register a local receiver to this demultiplexer. + pub fn register( + &mut self, + receiver_endpoint: ReceiverEndpoint, + sender: Sender>, + ) { + log::debug!( + "registering {} to the demultiplexer of {}", + receiver_endpoint, + self.coord + ); + self.tx_senders + .send((receiver_endpoint, sender)) + .unwrap_or_else(|_| panic!("register for {:?} failed", self.coord)) + } +} + +/// Bind the socket of this demultiplexer. +#[cfg(feature = "async-tokio")] +async fn bind_remotes( + coord: DemuxCoord, + address: (String, u16), + num_clients: usize, + rx_senders: UnboundedReceiver<(ReceiverEndpoint, Sender>)>, +) { + let address = (address.0.as_ref(), address.1); + let address: Vec<_> = address + .to_socket_addrs() + .map_err(|e| format!("Failed to get the address for {}: {:?}", coord, e)) + .unwrap() + .collect(); + + log::debug!("demux binding {}", address[0]); + let listener = TcpListener::bind(&*address) + .await + .map_err(|e| { + anyhow!( + "Failed to bind socket for {} at {:?}: {:?}", + coord, + address, + e + ) + }) + .unwrap(); + let address = listener + .local_addr() + .map(|a| a.to_string()) + .unwrap_or_else(|_| "unknown".to_string()); + info!( + "Remote receiver at {} is ready to accept {} connections to {}", + coord, num_clients, address + ); + + // the list of JoinHandle of all the spawned threads, including the demultiplexer one + let mut join_handles = vec![]; + let mut tx_broadcast = vec![]; + + let mut connected_clients = 0; + while connected_clients < num_clients { + let stream = listener.accept().await; + let (stream, peer_addr) = match stream { + Ok(stream) => stream, + Err(e) => { + warn!("Failed to accept incoming connection at {}: {:?}", coord, e); + continue; + } + }; + connected_clients += 1; + info!( + "Remote receiver at {} accepted a new connection from {} ({} / {})", + coord, peer_addr, connected_clients, num_clients + ); + + let (demux_tx, demux_rx) = flume::unbounded(); + let join_handle = tokio::spawn(async move { + let mut senders = HashMap::new(); + while let Ok((endpoint, sender)) = demux_rx.recv_async().await { + senders.insert(endpoint, sender); + } + log::debug!("demux got senders"); + demux_thread::(coord, senders, stream).await; + }); + join_handles.push(join_handle); + tx_broadcast.push(demux_tx); + } + log::debug!("All connection to {} started, waiting for senders", coord); + + // Broadcast senders + while let Ok(t) = rx_senders.recv() { + for tx in tx_broadcast.iter() { + tx.send(t.clone()).unwrap(); + } + } + drop(tx_broadcast); // Start all demuxes + for handle in join_handles { + handle.await.unwrap(); + } + log::debug!("all demuxes for {} finished", coord); +} + +/// Handle the connection with a remote sender. +/// +/// Will deserialize the message upon arrival and send to the corresponding recipient the +/// deserialized data. If the recipient is not yet known, it is waited until it registers. +/// +/// # Upgrade path +/// +/// Replace send with queue. +/// +/// The queue uses a hierarchical queue: +/// + First try to reserve and put the value in the fast queue +/// + If the fast queue is full, put in the slow (unbounded?) queue +/// + Return an enum, either Queued or Overflowed +/// +/// if overflowed send a yield request through a second channel +#[cfg(feature = "async-tokio")] +async fn demux_thread( + coord: DemuxCoord, + senders: HashMap>>, + mut stream: TcpStream, +) { + let address = stream + .peer_addr() + .map(|a| a.to_string()) + .unwrap_or_else(|_| "unknown".to_string()); + log::debug!("{} started", coord); + + while let Some((dest, message)) = remote_recv(coord, &mut stream, &address).await { + if let Err(e) = senders[&dest].send(message) { + warn!("demux failed to send message to {}: {:?}", dest, e); + } + } + + stream.shutdown().await.unwrap(); + log::debug!("{} finished", coord); +} diff --git a/src/network/tokio/mod.rs b/src/network/tokio/mod.rs new file mode 100644 index 00000000..ae5a84f4 --- /dev/null +++ b/src/network/tokio/mod.rs @@ -0,0 +1,3 @@ +pub(super) mod demultiplexer; +pub(super) mod multiplexer; +pub(super) mod remote; diff --git a/src/network/multiplexer.rs b/src/network/tokio/multiplexer.rs similarity index 57% rename from src/network/multiplexer.rs rename to src/network/tokio/multiplexer.rs index 0b981a85..011dc00a 100644 --- a/src/network/multiplexer.rs +++ b/src/network/tokio/multiplexer.rs @@ -1,11 +1,6 @@ use std::io::ErrorKind; use std::time::Duration; -#[cfg(not(feature = "async-tokio"))] -use std::net::{Shutdown, TcpStream, ToSocketAddrs}; -#[cfg(not(feature = "async-tokio"))] -use std::thread::{sleep, JoinHandle}; - #[cfg(feature = "async-tokio")] use std::net::ToSocketAddrs; #[cfg(feature = "async-tokio")] @@ -23,7 +18,7 @@ use crate::operator::ExchangeData; // #[cfg(not(feature = "async-tokio"))] // use crate::channel::Selector; -use super::NetworkSender; +use crate::network::NetworkSender; /// Maximum number of attempts to make for connecting to a remote host. const CONNECT_ATTEMPTS: usize = 32; @@ -45,119 +40,6 @@ pub struct MultiplexingSender { tx: Option)>>, } -#[cfg(not(feature = "async-tokio"))] -impl MultiplexingSender { - pub fn new(coord: DemuxCoord, address: (String, u16)) -> (Self, JoinHandle<()>) { - let (tx, rx) = channel::bounded(MUX_CHANNEL_CAPACITY); - - let join_handle = std::thread::Builder::new() - .name(format!( - "mux-{}:{}-{}", - coord.coord.host_id, coord.prev_block_id, coord.coord.block_id - )) - .spawn(move || { - log::debug!( - "mux {coord} connecting to {}", - address.to_socket_addrs().unwrap().next().unwrap() - ); - let stream = connect_remote(coord, address); - - mux_thread::(coord, rx, stream); - }) - .unwrap(); - (Self { tx: Some(tx) }, join_handle) - } - - pub(crate) fn get_sender(&mut self, receiver_endpoint: ReceiverEndpoint) -> NetworkSender { - super::mux_sender(receiver_endpoint, self.tx.as_ref().unwrap().clone()) - } -} - -/// Connect the sender to a remote channel located at the specified address. -/// -/// - At first the address is resolved to an actual address (DNS resolution) -/// - Then at most `CONNECT_ATTEMPTS` are performed, and an exponential backoff is used in case -/// of errors. -/// - If the connection cannot be established this function will panic. -#[cfg(not(feature = "async-tokio"))] -fn connect_remote(coord: DemuxCoord, address: (String, u16)) -> TcpStream { - let socket_addrs: Vec<_> = address - .to_socket_addrs() - .map_err(|e| format!("Failed to get the address for {coord}: {e:?}",)) - .unwrap() - .collect(); - let mut retry_delay = RETRY_INITIAL_TIMEOUT; - for attempt in 1..=CONNECT_ATTEMPTS { - log::debug!( - "{} connecting to {:?} ({} attempt)", - coord, - socket_addrs, - attempt, - ); - - for address in socket_addrs.iter() { - match TcpStream::connect_timeout(address, CONNECT_TIMEOUT) { - Ok(stream) => { - return stream; - } - Err(err) => match err.kind() { - ErrorKind::TimedOut => { - log::debug!("{coord} timeout connecting to {address:?}"); - } - ErrorKind::ConnectionRefused => { - log::log!( - if attempt > 4 { - log::Level::Warn - } else { - log::Level::Debug - }, - "{coord} connection refused connecting to {address:?} ({attempt})" - ); - } - _ => { - log::warn!("{coord} failed to connect to {address:?}: {err:?}"); - } - }, - } - } - - log::debug!( - "{coord} retrying connection to {socket_addrs:?} in {}s", - retry_delay.as_secs_f32(), - ); - - sleep(retry_delay); - retry_delay = (2 * retry_delay).min(RETRY_MAX_TIMEOUT); - } - panic!("Failed to connect to remote {coord} at {address:?} after {CONNECT_ATTEMPTS} attempts",); -} - -#[cfg(not(feature = "async-tokio"))] -fn mux_thread( - coord: DemuxCoord, - rx: Receiver<(ReceiverEndpoint, NetworkMessage)>, - mut stream: TcpStream, -) { - use std::io::Write; - - let address = stream - .peer_addr() - .map(|a| a.to_string()) - .unwrap_or_else(|_| "unknown".to_string()); - log::debug!("{} connected to {:?}", coord, address); - - // let mut w = std::io::BufWriter::new(&mut stream); - let mut w = &mut stream; - - while let Ok((dest, message)) = rx.recv() { - remote_send(message, dest, &mut w, &address); - } - - w.flush().unwrap(); - let _ = stream.shutdown(Shutdown::Both); - log::debug!("{} finished", coord); -} - #[cfg(feature = "async-tokio")] impl MultiplexingSender { /// Construct a new `MultiplexingSender` for a block. @@ -188,7 +70,7 @@ impl MultiplexingSender { // } pub(crate) fn get_sender(&mut self, receiver_endpoint: ReceiverEndpoint) -> NetworkSender { - use super::mux_sender; + use crate::network::mux_sender; mux_sender(receiver_endpoint, self.tx.as_ref().unwrap().clone()) } } diff --git a/src/network/tokio/remote.rs b/src/network/tokio/remote.rs new file mode 100644 index 00000000..31a2210b --- /dev/null +++ b/src/network/tokio/remote.rs @@ -0,0 +1,159 @@ +use once_cell::sync::Lazy; +#[cfg(feature = "async-tokio")] +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use bincode::config::{FixintEncoding, RejectTrailing, WithOtherIntEncoding, WithOtherTrailing}; +use bincode::{DefaultOptions, Options}; +use serde::{Deserialize, Serialize}; + +use crate::network::{Coord, DemuxCoord, NetworkMessage, ReceiverEndpoint}; +use crate::operator::ExchangeData; +use crate::profiler::{get_profiler, Profiler}; +use crate::scheduler::BlockId; +use crate::scheduler::ReplicaId; + +/// Configuration of the header serializer: the integers must have a fixed length encoding. +static BINCODE_HEADER_CONFIG: Lazy< + WithOtherTrailing, RejectTrailing>, +> = Lazy::new(|| { + bincode::DefaultOptions::new() + .with_fixint_encoding() + .reject_trailing_bytes() +}); + +static BINCODE_MSG_CONFIG: Lazy = Lazy::new(bincode::DefaultOptions::new); + +pub(crate) const HEADER_SIZE: usize = 20; // std::mem::size_of::(); + +/// Header of a message sent before the actual message. +#[derive(Serialize, Deserialize, Default)] +struct MessageHeader { + /// The size of the actual message + size: u32, + /// The id of the replica this message is for. + replica_id: ReplicaId, + /// The id of the block that is sending the message. + sender_block_id: BlockId, +} + +/// Serialize and send a message to a remote socket. +/// +/// The network protocol works as follow: +/// - send a `MessageHeader` serialized with bincode with `FixintEncoding` +/// - send the message +#[cfg(feature = "async-tokio")] +pub(crate) async fn remote_send( + msg: NetworkMessage, + dest: ReceiverEndpoint, + writer: &mut W, + address: &str, +) { + let serialized_len = BINCODE_MSG_CONFIG + .serialized_size(&msg) + .unwrap_or_else(|e| { + panic!( + "Failed to compute serialized length of outgoing message to {}: {:?}", + dest, e + ) + }); + + let header = MessageHeader { + size: serialized_len.try_into().unwrap(), + replica_id: dest.coord.replica_id, + sender_block_id: dest.prev_block_id, + }; + + let mut buf = Vec::with_capacity(HEADER_SIZE + serialized_len as usize); + + BINCODE_HEADER_CONFIG + .serialize_into(&mut buf, &header) + .unwrap_or_else(|e| { + panic!( + "Failed to serialize header of message (was {} bytes) to {} at {}: {:?}", + serialized_len, dest, address, e + ) + }); + + BINCODE_MSG_CONFIG + .serialize_into(&mut buf, &msg) + .unwrap_or_else(|e| { + panic!( + "Failed to serialize message, {} bytes to {} at {}: {:?}", + serialized_len, dest, address, e + ) + }); + assert_eq!(buf.len(), HEADER_SIZE + serialized_len as usize); + + writer.write_all(buf.as_ref()).await.unwrap_or_else(|e| { + panic!( + "Failed to send message {} bytes to {} at {}: {:?}", + serialized_len, dest, address, e + ) + }); + + get_profiler().net_bytes_out( + msg.sender, + dest.coord, + HEADER_SIZE + serialized_len as usize, + ); +} + +#[cfg(feature = "async-tokio")] +pub(crate) async fn remote_recv( + coord: DemuxCoord, + reader: &mut R, + address: &str, +) -> Option<(ReceiverEndpoint, NetworkMessage)> { + let mut header = [0u8; HEADER_SIZE]; + match reader.read_exact(&mut header).await { + Ok(_) => {} + Err(e) => { + log::trace!( + "Failed to receive {} bytes of header to {} from {}: {:?}", + HEADER_SIZE, + coord, + address, + e + ); + return None; + } + } + let header: MessageHeader = BINCODE_HEADER_CONFIG + .deserialize(&header) + .expect("Malformed header"); + let mut buf = vec![0u8; header.size as usize]; + reader.read_exact(&mut buf).await.unwrap_or_else(|e| { + panic!( + "Failed to receive {} bytes to {} from {}: {:?}", + header.size, coord, address, e + ) + }); + let msg: NetworkMessage = BINCODE_MSG_CONFIG + .deserialize(buf.as_ref()) + .expect("Malformed message"); + + let dest = ReceiverEndpoint::new( + Coord::new(coord.coord.block_id, coord.coord.host_id, header.replica_id), + header.sender_block_id, + ); + get_profiler().net_bytes_in(msg.sender, dest.coord, HEADER_SIZE + header.size as usize); + Some((dest, msg)) +} + +#[cfg(test)] +mod tests { + use bincode::Options; + + use crate::network::remote::HEADER_SIZE; + + use super::{MessageHeader, BINCODE_HEADER_CONFIG}; + + #[test] + fn header_size() { + let computed_size = BINCODE_HEADER_CONFIG + .serialized_size(&MessageHeader::default()) + .unwrap(); + + assert_eq!(HEADER_SIZE as u64, computed_size); + } +}