diff --git a/linera-core/src/client.rs b/linera-core/src/client.rs index 44a6fb8b009..65a7bc30a24 100644 --- a/linera-core/src/client.rs +++ b/linera-core/src/client.rs @@ -16,10 +16,9 @@ use crate::{ }, }; use futures::{ - channel::oneshot, future, lock::Mutex, - stream::{self, FuturesUnordered, StreamExt}, + stream::{self, AbortHandle, FuturesUnordered, StreamExt}, }; use linera_base::{ abi::{Abi, ContractAbi}, @@ -605,7 +604,7 @@ where async fn update_streams( this: &Arc>, - senders: &mut HashMap>, + senders: &mut HashMap, ) -> Result<(), ChainClientError> where P: Send + 'static, @@ -617,30 +616,28 @@ where (guard.chain_id, nodes, guard.node_client.clone()) }; // Drop removed validators. - senders.retain(|name, _| nodes.contains_key(name)); + senders.retain(|name, abort| { + if !nodes.contains_key(name) { + abort.abort(); + } + !abort.is_aborted() + }); // Add tasks for new validators. for (name, mut node) in nodes { let hash_map::Entry::Vacant(entry) = senders.entry(name) else { continue; }; - let stream = match node.subscribe(vec![chain_id]).await { - Err(e) => { - info!("Could not connect to validator {name}: {e:?}"); + let (mut stream, abort) = match node.subscribe(vec![chain_id]).await { + Err(error) => { + info!(?error, "Could not connect to validator {name}"); continue; } - Ok(stream) => stream, + Ok(stream) => stream::abortable(stream), }; let this = this.clone(); let local_node = local_node.clone(); - let (sender, receiver) = oneshot::channel(); - // Calling tokio_stream::StreamExt::merge explicitly because it would conflict with the - // the futures_util::StreamExt trait. - let mut cancelable_stream = tokio_stream::StreamExt::merge( - stream.map(Some), - stream::once(receiver).map(|_| None), - ); tokio::spawn(async move { - while let Some(Some(notification)) = cancelable_stream.next().await { + while let Some(notification) = stream.next().await { Self::process_notification( this.clone(), name, @@ -651,7 +648,7 @@ where .await; } }); - entry.insert(sender); + entry.insert(abort); } Ok(()) } diff --git a/linera-rpc/src/transport.rs b/linera-rpc/src/transport.rs index 9c475da87fc..e82566ff58c 100644 --- a/linera-rpc/src/transport.rs +++ b/linera-rpc/src/transport.rs @@ -7,7 +7,11 @@ use crate::{codec, codec::Codec, RpcMessage}; use async_trait::async_trait; -use futures::{future, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; +use futures::{ + future, + stream::{self, AbortHandle, AbortRegistration, Abortable}, + Sink, SinkExt, Stream, StreamExt, TryStreamExt, +}; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, io, net::ToSocketAddrs, sync::Arc}; use structopt::clap::arg_enum; @@ -60,7 +64,7 @@ pub trait MessageHandler: Clone { /// The result of spawning a server is oneshot channel to kill it and a handle to track completion. pub struct ServerHandle { - pub complete: futures::channel::oneshot::Sender<()>, + pub abort: AbortHandle, pub handle: tokio::task::JoinHandle>, } @@ -72,7 +76,7 @@ impl ServerHandle { } pub async fn kill(self) -> Result<(), std::io::Error> { - self.complete.send(()).unwrap(); + self.abort.abort(); self.handle.await??; Ok(()) } @@ -141,18 +145,18 @@ impl TransportProtocol { where S: MessageHandler + Send + 'static, { - let (complete, receiver) = futures::channel::oneshot::channel(); + let (abort, registration) = AbortHandle::new_pair(); let handle = match self { Self::Udp => { let socket = UdpSocket::bind(&address).await?; - tokio::spawn(Self::run_udp_server(socket, state, receiver)) + tokio::spawn(Self::run_udp_server(socket, state, registration)) } Self::Tcp => { let listener = TcpListener::bind(address).await?; - tokio::spawn(Self::run_tcp_server(listener, state, receiver)) + tokio::spawn(Self::run_tcp_server(listener, state, registration)) } }; - Ok(ServerHandle { complete, handle }) + Ok(ServerHandle { abort, handle }) } } @@ -189,31 +193,25 @@ impl TransportProtocol { async fn run_udp_server( socket: UdpSocket, state: S, - mut exit_future: futures::channel::oneshot::Receiver<()>, + registration: AbortRegistration, ) -> Result<(), std::io::Error> where S: MessageHandler + Send + 'static, { - let (udp_sink, mut udp_stream) = UdpFramed::new(socket, Codec).split(); + let (udp_sink, udp_stream) = UdpFramed::new(socket, Codec).split(); + let mut udp_stream = Abortable::new(udp_stream, registration); let udp_sink = Arc::new(Mutex::new(udp_sink)); // Track the latest tasks for a given peer. This is used to return answers in the // same order as the queries. let mut previous_tasks = HashMap::new(); - loop { - let (message, peer) = match future::select(exit_future, udp_stream.next()).await { - future::Either::Left(_) => break, - future::Either::Right((value, new_exit_future)) => { - exit_future = new_exit_future; - match value { - Some(Ok(value)) => value, - Some(Err(codec::Error::Io(io_error))) => return Err(io_error), - Some(Err(other_error)) => { - warn!("Received an invalid message: {other_error}"); - continue; - } - None => return Err(std::io::ErrorKind::UnexpectedEof.into()), - } + while let Some(value) = udp_stream.next().await { + let (message, peer) = match value { + Ok(value) => value, + Err(codec::Error::Io(io_error)) => return Err(io_error), + Err(other_error) => { + warn!("Received an invalid message: {other_error}"); + continue; } }; @@ -296,19 +294,18 @@ impl TransportProtocol { async fn run_tcp_server( listener: TcpListener, state: S, - mut exit_future: futures::channel::oneshot::Receiver<()>, + registration: AbortRegistration, ) -> Result<(), std::io::Error> where S: MessageHandler + Send + 'static, { - loop { - let (socket, _) = match future::select(exit_future, Box::pin(listener.accept())).await { - future::Either::Left(_) => break, - future::Either::Right((value, new_exit_future)) => { - exit_future = new_exit_future; - value? - } - }; + let accept_stream = stream::try_unfold(listener, |listener| async move { + let (socket, _) = listener.accept().await?; + Ok::<_, io::Error>(Some((socket, listener))) + }); + let mut accept_stream = Box::pin(Abortable::new(accept_stream, registration)); + while let Some(value) = accept_stream.next().await { + let socket = value?; let mut handler = state.clone(); tokio::spawn(async move { let mut transport = Framed::new(socket, Codec);