diff --git a/roles/mining-proxy/src/lib/downstream_mining.rs b/roles/mining-proxy/src/lib/downstream_mining.rs index d4e9bcb174..2b38760e1a 100644 --- a/roles/mining-proxy/src/lib/downstream_mining.rs +++ b/roles/mining-proxy/src/lib/downstream_mining.rs @@ -155,7 +155,7 @@ impl DownstreamMiningNodeStatus { use core::convert::TryInto; use std::sync::Arc; -use tokio::task; +use tokio::{sync::oneshot::Receiver as TokioReceiver, task}; impl PartialEq for DownstreamMiningNode { fn eq(&self, other: &Self) -> bool { @@ -487,40 +487,49 @@ use network_helpers_sv2::plain_connection_tokio::PlainConnection; use std::net::SocketAddr; use tokio::net::TcpListener; -pub async fn listen_for_downstream_mining(address: SocketAddr) { +pub async fn listen_for_downstream_mining(address: SocketAddr, mut shutdown_rx: TokioReceiver<()>) { info!("Listening for downstream mining connections on {}", address); - let listner = TcpListener::bind(address).await.unwrap(); + let listener = TcpListener::bind(address).await.unwrap(); let mut ids = roles_logic_sv2::utils::Id::new(); - while let Ok((stream, _)) = listner.accept().await { - let (receiver, sender): (Receiver, Sender) = - PlainConnection::new(stream).await; - let node = DownstreamMiningNode::new(receiver, sender, ids.next()); - - task::spawn(async move { - let mut incoming: StdFrame = node.receiver.recv().await.unwrap().try_into().unwrap(); - let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); - let routing_logic = super::get_common_routing_logic(); - let node = Arc::new(Mutex::new(node)); - - // Call handle_setup_connection or fail - match DownstreamMiningNode::handle_message_common( - node.clone(), - message_type, - payload, - routing_logic, - ) { - Ok(SendToCommon::RelayNewMessageToRemote(_, message)) => { - let message = match message { - roles_logic_sv2::parsers::CommonMessages::SetupConnectionSuccess(m) => m, + let mut should_continue = true; + while should_continue { + tokio::select! { + Ok((stream,_)) = listener.accept() => { + let (receiver, sender): (Receiver, Sender) = + PlainConnection::new(stream).await; + let node = DownstreamMiningNode::new(receiver, sender, ids.next()); + + task::spawn(async move { + let mut incoming: StdFrame = node.receiver.recv().await.unwrap().try_into().unwrap(); + let message_type = incoming.get_header().unwrap().msg_type(); + let payload = incoming.payload(); + let routing_logic = super::get_common_routing_logic(); + let node = Arc::new(Mutex::new(node)); + + // Call handle_setup_connection or fail + match DownstreamMiningNode::handle_message_common( + node.clone(), + message_type, + payload, + routing_logic, + ) { + Ok(SendToCommon::RelayNewMessageToRemote(_, message)) => { + let message = match message { + roles_logic_sv2::parsers::CommonMessages::SetupConnectionSuccess(m) => m, + _ => panic!(), + }; + DownstreamMiningNode::start(node, message).await + } _ => panic!(), - }; - DownstreamMiningNode::start(node, message).await - } - _ => panic!(), + } + }); + }, + _ = &mut shutdown_rx => { + info!("Closing listener on {}", address); + should_continue = false; } - }); + } } } diff --git a/roles/mining-proxy/src/main.rs b/roles/mining-proxy/src/main.rs index 7ad23a616e..ab3f45bbab 100644 --- a/roles/mining-proxy/src/main.rs +++ b/roles/mining-proxy/src/main.rs @@ -23,6 +23,7 @@ mod lib; use lib::Config; use roles_logic_sv2::utils::{GroupId, Mutex}; use std::{net::SocketAddr, sync::Arc}; +use tokio::sync::oneshot; use tracing::{error, info, warn}; mod args { @@ -138,8 +139,17 @@ async fn main() { info!("PROXY INITIALIZED"); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + tokio::select! { - _ = lib::downstream_mining::listen_for_downstream_mining(socket) => warn!("Downstream mining exited"), - _ = tokio::signal::ctrl_c() => info!("Interrupt received"), + _ = lib::downstream_mining::listen_for_downstream_mining(socket, shutdown_rx) => { + warn!("Downstream mining listener exited prematurely"); + }, + _ = tokio::signal::ctrl_c() => { + let _ = shutdown_tx.send(()); + info!("Interrupt received"); + } } + + info!("Shutdown done"); }