Skip to content

Commit

Permalink
Split network modules for sync and async
Browse files Browse the repository at this point in the history
  • Loading branch information
imDema committed Oct 5, 2023
1 parent 08028f9 commit c64bc6d
Show file tree
Hide file tree
Showing 11 changed files with 512 additions and 398 deletions.
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,9 @@ harness = false
lto = true
strip = "symbols"
# debug = 1

[profile.release-fast]
inherits = "release"
lto = true
codegen-units = 1
panic = "abort"
2 changes: 1 addition & 1 deletion examples/nexmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ fn filter_bid(e: Event) -> Option<Bid> {
}

fn main() {
env_logger::init();
tracing_subscriber::fmt::init();

let (config, args) = EnvironmentConfig::from_args();
if args.len() != 2 {
Expand Down
13 changes: 10 additions & 3 deletions src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@ pub(crate) use topology::*;
use crate::operator::StreamElement;
use crate::scheduler::{BlockId, HostId, ReplicaId};

mod demultiplexer;
mod multiplexer;
#[cfg(feature = "async-tokio")]
mod tokio;
#[cfg(feature = "async-tokio")]
use tokio::*;

#[cfg(not(feature = "async-tokio"))]
mod sync;
#[cfg(not(feature = "async-tokio"))]
use sync::*;

mod network_channel;
mod remote;
mod topology;

#[derive(Debug, Clone)]
Expand Down
168 changes: 0 additions & 168 deletions src/network/demultiplexer.rs → src/network/sync/demultiplexer.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
#[cfg(not(feature = "async-tokio"))]
use std::net::{Shutdown, TcpListener, TcpStream};
#[cfg(not(feature = "async-tokio"))]
use std::thread::JoinHandle;

#[cfg(feature = "async-tokio")]
use tokio::io::AsyncWriteExt;
#[cfg(feature = "async-tokio")]
use tokio::net::{TcpListener, TcpStream};
#[cfg(feature = "async-tokio")]
use tokio::task::JoinHandle;

use anyhow::anyhow;
use std::collections::HashMap;
use std::net::ToSocketAddrs;
Expand All @@ -32,7 +23,6 @@ pub(crate) struct DemuxHandle<In: ExchangeData> {
tx_senders: UnboundedSender<(ReceiverEndpoint, Sender<NetworkMessage<In>>)>,
}

#[cfg(not(feature = "async-tokio"))]
impl<In: ExchangeData> DemuxHandle<In> {
/// Construct a new `DemultiplexingReceiver` for a block.
///
Expand Down Expand Up @@ -70,7 +60,6 @@ impl<In: ExchangeData> DemuxHandle<In> {
}

/// Bind the socket of this demultiplexer.
#[cfg(not(feature = "async-tokio"))]
fn bind_remotes<In: ExchangeData>(
coord: DemuxCoord,
address: (String, u16),
Expand Down Expand Up @@ -175,7 +164,6 @@ fn bind_remotes<In: ExchangeData>(
/// + Return an enum, either Queued or Overflowed
///
/// if overflowed send a yield request through a second channel
#[cfg(not(feature = "async-tokio"))]
fn demux_thread<In: ExchangeData>(
coord: DemuxCoord,
senders: HashMap<ReceiverEndpoint, Sender<NetworkMessage<In>>>,
Expand All @@ -199,159 +187,3 @@ fn demux_thread<In: ExchangeData>(
let _ = stream.shutdown(Shutdown::Both);
log::debug!("{} finished", coord);
}

#[cfg(feature = "async-tokio")]
impl<In: ExchangeData> DemuxHandle<In> {
/// Construct a new `DemultiplexingReceiver` for a block.
///
/// All the local replicas of this block should be registered to this demultiplexer.
/// `num_client` is the number of multiplexers that will connect to this demultiplexer. Since
/// the remote senders are all multiplexed this corresponds to the number of remote replicas in
/// the previous block (relative to the block this demultiplexer refers to).
pub fn new(
coord: DemuxCoord,
address: (String, u16),
num_clients: usize,
) -> (Self, JoinHandle<()>) {
let (tx_senders, rx_senders) = channel::unbounded();

let join_handle = tokio::spawn(bind_remotes(coord, address, num_clients, rx_senders));
(Self { coord, tx_senders }, join_handle)
}

/// Register a local receiver to this demultiplexer.
pub fn register(
&mut self,
receiver_endpoint: ReceiverEndpoint,
sender: Sender<NetworkMessage<In>>,
) {
log::debug!(
"registering {} to the demultiplexer of {}",
receiver_endpoint,
self.coord
);
self.tx_senders
.send((receiver_endpoint, sender))
.unwrap_or_else(|_| panic!("register for {:?} failed", self.coord))
}
}

/// Bind the socket of this demultiplexer.
#[cfg(feature = "async-tokio")]
async fn bind_remotes<In: ExchangeData>(
coord: DemuxCoord,
address: (String, u16),
num_clients: usize,
rx_senders: UnboundedReceiver<(ReceiverEndpoint, Sender<NetworkMessage<In>>)>,
) {
let address = (address.0.as_ref(), address.1);
let address: Vec<_> = address
.to_socket_addrs()
.map_err(|e| format!("Failed to get the address for {}: {:?}", coord, e))
.unwrap()
.collect();

log::debug!("demux binding {}", address[0]);
let listener = TcpListener::bind(&*address)
.await
.map_err(|e| {
anyhow!(
"Failed to bind socket for {} at {:?}: {:?}",
coord,
address,
e
)
})
.unwrap();
let address = listener
.local_addr()
.map(|a| a.to_string())
.unwrap_or_else(|_| "unknown".to_string());
info!(
"Remote receiver at {} is ready to accept {} connections to {}",
coord, num_clients, address
);

// the list of JoinHandle of all the spawned threads, including the demultiplexer one
let mut join_handles = vec![];
let mut tx_broadcast = vec![];

let mut connected_clients = 0;
while connected_clients < num_clients {
let stream = listener.accept().await;
let (stream, peer_addr) = match stream {
Ok(stream) => stream,
Err(e) => {
warn!("Failed to accept incoming connection at {}: {:?}", coord, e);
continue;
}
};
connected_clients += 1;
info!(
"Remote receiver at {} accepted a new connection from {} ({} / {})",
coord, peer_addr, connected_clients, num_clients
);

let (demux_tx, demux_rx) = flume::unbounded();
let join_handle = tokio::spawn(async move {
let mut senders = HashMap::new();
while let Ok((endpoint, sender)) = demux_rx.recv_async().await {
senders.insert(endpoint, sender);
}
log::debug!("demux got senders");
demux_thread::<In>(coord, senders, stream).await;
});
join_handles.push(join_handle);
tx_broadcast.push(demux_tx);
}
log::debug!("All connection to {} started, waiting for senders", coord);

// Broadcast senders
while let Ok(t) = rx_senders.recv() {
for tx in tx_broadcast.iter() {
tx.send(t.clone()).unwrap();
}
}
drop(tx_broadcast); // Start all demuxes
for handle in join_handles {
handle.await.unwrap();
}
log::debug!("all demuxes for {} finished", coord);
}

/// Handle the connection with a remote sender.
///
/// Will deserialize the message upon arrival and send to the corresponding recipient the
/// deserialized data. If the recipient is not yet known, it is waited until it registers.
///
/// # Upgrade path
///
/// Replace send with queue.
///
/// The queue uses a hierarchical queue:
/// + First try to reserve and put the value in the fast queue
/// + If the fast queue is full, put in the slow (unbounded?) queue
/// + Return an enum, either Queued or Overflowed
///
/// if overflowed send a yield request through a second channel
#[cfg(feature = "async-tokio")]
async fn demux_thread<In: ExchangeData>(
coord: DemuxCoord,
senders: HashMap<ReceiverEndpoint, Sender<NetworkMessage<In>>>,
mut stream: TcpStream,
) {
let address = stream
.peer_addr()
.map(|a| a.to_string())
.unwrap_or_else(|_| "unknown".to_string());
log::debug!("{} started", coord);

while let Some((dest, message)) = remote_recv(coord, &mut stream, &address).await {
if let Err(e) = senders[&dest].send(message) {
warn!("demux failed to send message to {}: {:?}", dest, e);
}
}

stream.shutdown().await.unwrap();
log::debug!("{} finished", coord);
}
3 changes: 3 additions & 0 deletions src/network/sync/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub(super) mod demultiplexer;
pub(super) mod multiplexer;
pub(super) mod remote;
144 changes: 144 additions & 0 deletions src/network/sync/multiplexer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
use std::io::ErrorKind;
use std::time::Duration;

use std::net::{Shutdown, TcpStream, ToSocketAddrs};
use std::thread::{sleep, JoinHandle};

use crate::channel::{self, Receiver, Sender};
use crate::network::remote::remote_send;
use crate::network::{DemuxCoord, NetworkMessage, ReceiverEndpoint};
use crate::operator::ExchangeData;

//
// use crate::channel::Selector;

use crate::network::NetworkSender;

/// Maximum number of attempts to make for connecting to a remote host.
const CONNECT_ATTEMPTS: usize = 32;
/// Timeout for connecting to a remote host.
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
/// To avoid spamming the connections, wait this timeout before trying again. If the connection
/// fails again this timeout will be doubled up to `RETRY_MAX_TIMEOUT`.
const RETRY_INITIAL_TIMEOUT: Duration = Duration::from_millis(8);
/// Maximum timeout between connection attempts.
const RETRY_MAX_TIMEOUT: Duration = Duration::from_secs(1);

const MUX_CHANNEL_CAPACITY: usize = 10;
/// Like `NetworkSender`, but this should be used in a multiplexed channel (i.e. a remote one).
///
/// The `ReceiverEndpoint` is sent alongside the actual message in order to demultiplex it.
#[derive(Debug)]
pub struct MultiplexingSender<Out: ExchangeData> {
tx: Option<Sender<(ReceiverEndpoint, NetworkMessage<Out>)>>,
}

impl<Out: ExchangeData> MultiplexingSender<Out> {
pub fn new(coord: DemuxCoord, address: (String, u16)) -> (Self, JoinHandle<()>) {
let (tx, rx) = channel::bounded(MUX_CHANNEL_CAPACITY);

let join_handle = std::thread::Builder::new()
.name(format!(
"mux-{}:{}-{}",
coord.coord.host_id, coord.prev_block_id, coord.coord.block_id
))
.spawn(move || {
log::debug!(
"mux {coord} connecting to {}",
address.to_socket_addrs().unwrap().next().unwrap()
);
let stream = connect_remote(coord, address);

mux_thread::<Out>(coord, rx, stream);
})
.unwrap();
(Self { tx: Some(tx) }, join_handle)
}

pub(crate) fn get_sender(&mut self, receiver_endpoint: ReceiverEndpoint) -> NetworkSender<Out> {
crate::network::mux_sender(receiver_endpoint, self.tx.as_ref().unwrap().clone())
}
}

/// Connect the sender to a remote channel located at the specified address.
///
/// - At first the address is resolved to an actual address (DNS resolution)
/// - Then at most `CONNECT_ATTEMPTS` are performed, and an exponential backoff is used in case
/// of errors.
/// - If the connection cannot be established this function will panic.
fn connect_remote(coord: DemuxCoord, address: (String, u16)) -> TcpStream {
let socket_addrs: Vec<_> = address
.to_socket_addrs()
.map_err(|e| format!("Failed to get the address for {coord}: {e:?}",))
.unwrap()
.collect();
let mut retry_delay = RETRY_INITIAL_TIMEOUT;
for attempt in 1..=CONNECT_ATTEMPTS {
log::debug!(
"{} connecting to {:?} ({} attempt)",
coord,
socket_addrs,
attempt,
);

for address in socket_addrs.iter() {
match TcpStream::connect_timeout(address, CONNECT_TIMEOUT) {
Ok(stream) => {
return stream;
}
Err(err) => match err.kind() {
ErrorKind::TimedOut => {
log::debug!("{coord} timeout connecting to {address:?}");
}
ErrorKind::ConnectionRefused => {
log::log!(
if attempt > 4 {
log::Level::Warn
} else {
log::Level::Debug
},
"{coord} connection refused connecting to {address:?} ({attempt})"
);
}
_ => {
log::warn!("{coord} failed to connect to {address:?}: {err:?}");
}
},
}
}

log::debug!(
"{coord} retrying connection to {socket_addrs:?} in {}s",
retry_delay.as_secs_f32(),
);

sleep(retry_delay);
retry_delay = (2 * retry_delay).min(RETRY_MAX_TIMEOUT);
}
panic!("Failed to connect to remote {coord} at {address:?} after {CONNECT_ATTEMPTS} attempts",);
}

fn mux_thread<Out: ExchangeData>(
coord: DemuxCoord,
rx: Receiver<(ReceiverEndpoint, NetworkMessage<Out>)>,
mut stream: TcpStream,
) {
use std::io::Write;

let address = stream
.peer_addr()
.map(|a| a.to_string())
.unwrap_or_else(|_| "unknown".to_string());
log::debug!("{} connected to {:?}", coord, address);

// let mut w = std::io::BufWriter::new(&mut stream);
let mut w = &mut stream;

while let Ok((dest, message)) = rx.recv() {
remote_send(message, dest, &mut w, &address);
}

w.flush().unwrap();
let _ = stream.shutdown(Shutdown::Both);
log::debug!("{} finished", coord);
}
Loading

0 comments on commit c64bc6d

Please sign in to comment.