Skip to content

Commit

Permalink
Use abortable streams instead of select. (#1292)
Browse files Browse the repository at this point in the history
  • Loading branch information
afck authored Nov 27, 2023
1 parent 738d183 commit f2310b4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 49 deletions.
31 changes: 14 additions & 17 deletions linera-core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ use crate::{
},
};
use futures::{
channel::oneshot,
future,
lock::Mutex,
stream::{self, FuturesUnordered, StreamExt},
stream::{self, AbortHandle, FuturesUnordered, StreamExt},
};
use linera_base::{
abi::{Abi, ContractAbi},
Expand Down Expand Up @@ -605,7 +604,7 @@ where

async fn update_streams(
this: &Arc<Mutex<Self>>,
senders: &mut HashMap<ValidatorName, oneshot::Sender<()>>,
senders: &mut HashMap<ValidatorName, AbortHandle>,
) -> Result<(), ChainClientError>
where
P: Send + 'static,
Expand All @@ -617,30 +616,28 @@ where
(guard.chain_id, nodes, guard.node_client.clone())
};
// Drop removed validators.
senders.retain(|name, _| nodes.contains_key(name));
senders.retain(|name, abort| {
if !nodes.contains_key(name) {
abort.abort();
}
!abort.is_aborted()
});
// Add tasks for new validators.
for (name, mut node) in nodes {
let hash_map::Entry::Vacant(entry) = senders.entry(name) else {
continue;
};
let stream = match node.subscribe(vec![chain_id]).await {
Err(e) => {
info!("Could not connect to validator {name}: {e:?}");
let (mut stream, abort) = match node.subscribe(vec![chain_id]).await {
Err(error) => {
info!(?error, "Could not connect to validator {name}");
continue;
}
Ok(stream) => stream,
Ok(stream) => stream::abortable(stream),
};
let this = this.clone();
let local_node = local_node.clone();
let (sender, receiver) = oneshot::channel();
// Calling tokio_stream::StreamExt::merge explicitly because it would conflict with the
// the futures_util::StreamExt trait.
let mut cancelable_stream = tokio_stream::StreamExt::merge(
stream.map(Some),
stream::once(receiver).map(|_| None),
);
tokio::spawn(async move {
while let Some(Some(notification)) = cancelable_stream.next().await {
while let Some(notification) = stream.next().await {
Self::process_notification(
this.clone(),
name,
Expand All @@ -651,7 +648,7 @@ where
.await;
}
});
entry.insert(sender);
entry.insert(abort);
}
Ok(())
}
Expand Down
61 changes: 29 additions & 32 deletions linera-rpc/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

use crate::{codec, codec::Codec, RpcMessage};
use async_trait::async_trait;
use futures::{future, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use futures::{
future,
stream::{self, AbortHandle, AbortRegistration, Abortable},
Sink, SinkExt, Stream, StreamExt, TryStreamExt,
};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, io, net::ToSocketAddrs, sync::Arc};
use structopt::clap::arg_enum;
Expand Down Expand Up @@ -60,7 +64,7 @@ pub trait MessageHandler: Clone {

/// The result of spawning a server is oneshot channel to kill it and a handle to track completion.
pub struct ServerHandle {
pub complete: futures::channel::oneshot::Sender<()>,
pub abort: AbortHandle,
pub handle: tokio::task::JoinHandle<Result<(), std::io::Error>>,
}

Expand All @@ -72,7 +76,7 @@ impl ServerHandle {
}

pub async fn kill(self) -> Result<(), std::io::Error> {
self.complete.send(()).unwrap();
self.abort.abort();
self.handle.await??;
Ok(())
}
Expand Down Expand Up @@ -141,18 +145,18 @@ impl TransportProtocol {
where
S: MessageHandler + Send + 'static,
{
let (complete, receiver) = futures::channel::oneshot::channel();
let (abort, registration) = AbortHandle::new_pair();
let handle = match self {
Self::Udp => {
let socket = UdpSocket::bind(&address).await?;
tokio::spawn(Self::run_udp_server(socket, state, receiver))
tokio::spawn(Self::run_udp_server(socket, state, registration))
}
Self::Tcp => {
let listener = TcpListener::bind(address).await?;
tokio::spawn(Self::run_tcp_server(listener, state, receiver))
tokio::spawn(Self::run_tcp_server(listener, state, registration))
}
};
Ok(ServerHandle { complete, handle })
Ok(ServerHandle { abort, handle })
}
}

Expand Down Expand Up @@ -189,31 +193,25 @@ impl TransportProtocol {
async fn run_udp_server<S>(
socket: UdpSocket,
state: S,
mut exit_future: futures::channel::oneshot::Receiver<()>,
registration: AbortRegistration,
) -> Result<(), std::io::Error>
where
S: MessageHandler + Send + 'static,
{
let (udp_sink, mut udp_stream) = UdpFramed::new(socket, Codec).split();
let (udp_sink, udp_stream) = UdpFramed::new(socket, Codec).split();
let mut udp_stream = Abortable::new(udp_stream, registration);
let udp_sink = Arc::new(Mutex::new(udp_sink));
// Track the latest tasks for a given peer. This is used to return answers in the
// same order as the queries.
let mut previous_tasks = HashMap::new();

loop {
let (message, peer) = match future::select(exit_future, udp_stream.next()).await {
future::Either::Left(_) => break,
future::Either::Right((value, new_exit_future)) => {
exit_future = new_exit_future;
match value {
Some(Ok(value)) => value,
Some(Err(codec::Error::Io(io_error))) => return Err(io_error),
Some(Err(other_error)) => {
warn!("Received an invalid message: {other_error}");
continue;
}
None => return Err(std::io::ErrorKind::UnexpectedEof.into()),
}
while let Some(value) = udp_stream.next().await {
let (message, peer) = match value {
Ok(value) => value,
Err(codec::Error::Io(io_error)) => return Err(io_error),
Err(other_error) => {
warn!("Received an invalid message: {other_error}");
continue;
}
};

Expand Down Expand Up @@ -296,19 +294,18 @@ impl TransportProtocol {
async fn run_tcp_server<S>(
listener: TcpListener,
state: S,
mut exit_future: futures::channel::oneshot::Receiver<()>,
registration: AbortRegistration,
) -> Result<(), std::io::Error>
where
S: MessageHandler + Send + 'static,
{
loop {
let (socket, _) = match future::select(exit_future, Box::pin(listener.accept())).await {
future::Either::Left(_) => break,
future::Either::Right((value, new_exit_future)) => {
exit_future = new_exit_future;
value?
}
};
let accept_stream = stream::try_unfold(listener, |listener| async move {
let (socket, _) = listener.accept().await?;
Ok::<_, io::Error>(Some((socket, listener)))
});
let mut accept_stream = Box::pin(Abortable::new(accept_stream, registration));
while let Some(value) = accept_stream.next().await {
let socket = value?;
let mut handler = state.clone();
tokio::spawn(async move {
let mut transport = Framed::new(socket, Codec);
Expand Down

0 comments on commit f2310b4

Please sign in to comment.