diff --git a/roles/Cargo.lock b/roles/Cargo.lock index 3209c1ddf7..66a459744f 100644 --- a/roles/Cargo.lock +++ b/roles/Cargo.lock @@ -1351,6 +1351,7 @@ dependencies = [ name = "integration-test" version = "0.1.0" dependencies = [ + "async-channel 1.9.0", "binary_sv2", "bitcoind", "codec_sv2", diff --git a/roles/tests-integration/Cargo.toml b/roles/tests-integration/Cargo.toml index a54b5172bf..f7b4767d38 100644 --- a/roles/tests-integration/Cargo.toml +++ b/roles/tests-integration/Cargo.toml @@ -10,8 +10,7 @@ minreq = { version = "2.12.0", features = ["https"] } once_cell = "1.19.0" tar = "0.4.41" tokio = { version="1.36.0",features = ["full","tracing"] } -# demand-easy-sv2 = { version = "0.6.0", git = "https://github.com/demand-open-source/demand-easy-sv2.git" } - +async-channel = "1.5.1" codec_sv2 = { path = "../../protocols/v2/codec-sv2", features = ["noise_sv2"] } const_sv2 = { path = "../../protocols/v2/const-sv2" } binary_sv2 = { path = "../../protocols/v2/binary-sv2/binary-sv2" } diff --git a/roles/tests-integration/tests/common/mod.rs b/roles/tests-integration/tests/common/mod.rs index 62d16b086f..ed9be47a5e 100644 --- a/roles/tests-integration/tests/common/mod.rs +++ b/roles/tests-integration/tests/common/mod.rs @@ -1,9 +1,11 @@ +pub mod sniffer; + use bitcoind::{bitcoincore_rpc::RpcApi, BitcoinD, Conf}; use flate2::read::GzDecoder; use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey}; use once_cell::sync::Lazy; use pool_sv2::PoolSv2; -use proxy::SV2Proxy; +use sniffer::Sniffer; use std::{ collections::HashSet, env, @@ -16,8 +18,6 @@ use std::{ }; use tar::Archive; -pub mod proxy; - // prevents get_available_port from ever returning the same port twice static UNIQUE_PORTS: Lazy>> = Lazy::new(|| Mutex::new(HashSet::new())); @@ -266,13 +266,13 @@ pub async fn start_template_provider(tp_port: u16) -> TemplateProvider { template_provider } -pub async fn start_proxy(upstream: SocketAddr, downstream: SocketAddr) -> SV2Proxy { - let proxy = SV2Proxy::new(); - let proxy_clone = proxy.clone(); +pub async fn start_sniffer(upstream: SocketAddr, downstream: SocketAddr) -> Sniffer { + let sniffer = Sniffer::new(downstream, upstream).await; + let sniffer_clone = sniffer.clone(); tokio::spawn(async move { - proxy_clone.start(upstream, downstream).await; + sniffer_clone.start().await; }); - proxy + sniffer } pub async fn start_poolsv2( @@ -288,7 +288,9 @@ pub async fn start_poolsv2( let pool = test_pool.pool.clone(); let pool_clone = pool.clone(); tokio::task::spawn(async move { - assert!(pool_clone.start().await.is_ok()); + let ret = pool_clone.start().await; + dbg!(&ret); + assert!(ret.is_ok()); }); tokio::time::sleep(std::time::Duration::from_secs(1)).await; pool diff --git a/roles/tests-integration/tests/common/proxy/connection.rs b/roles/tests-integration/tests/common/proxy/connection.rs deleted file mode 100644 index f04eef9a5c..0000000000 --- a/roles/tests-integration/tests/common/proxy/connection.rs +++ /dev/null @@ -1,627 +0,0 @@ -use std::{sync::Arc, time::Duration}; -use tokio::sync::mpsc::{channel, Receiver, Sender}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpStream, - task::{self, AbortHandle}, -}; - -use codec_sv2::{HandshakeRole, StandardEitherFrame, StandardNoiseDecoder}; - -use binary_sv2::{Deserialize, GetSize, Serialize}; -use tracing::{debug, error}; - -use codec_sv2::HandShakeFrame; -use const_sv2::{ - INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_SIZE, RESPONDER_EXPECTED_HANDSHAKE_MESSAGE_SIZE, -}; -use futures::lock::Mutex; -use std::{convert::TryInto, sync::atomic::AtomicBool}; - -use super::error::ProxyError; - -trait SetState { - async fn set_state(self_: Arc>, state: codec_sv2::State); -} - -async fn initialize_as_downstream< - 'a, - Message: Serialize + Deserialize<'a> + GetSize, ->( - connection: Arc>, - role: HandshakeRole, - sender_outgoing: &mut Sender>, - receiver_incoming: &mut Receiver>, -) -> Result<(), ProxyError> { - let mut state = codec_sv2::State::initialized(role); - - // Create and send first handshake message - let first_message = state.step_0()?; - sender_outgoing - .send(first_message.into()) - .await - .map_err(|_| ProxyError::SendError)?; - - // Receive and deserialize second handshake message - let second_message = receiver_incoming - .recv() - .await - .ok_or(ProxyError::RecvError)?; - let second_message: HandShakeFrame = second_message - .try_into() - .map_err(|_| ProxyError::HandshakeRemoteInvalidMessage)?; - let second_message: [u8; INITIATOR_EXPECTED_HANDSHAKE_MESSAGE_SIZE] = second_message - .get_payload_when_handshaking() - .try_into() - .map_err(|_| ProxyError::HandshakeRemoteInvalidMessage)?; - - // Create and send thirth handshake message - let transport_mode = state.step_2(second_message)?; - - Connection::set_state(connection, transport_mode).await; - while !TRANSPORT_READY.load(std::sync::atomic::Ordering::SeqCst) { - std::hint::spin_loop() - } - Ok(()) -} - -async fn initialize_as_upstream<'a, Message: Serialize + Deserialize<'a> + GetSize>( - connection: Arc>, - role: HandshakeRole, - sender_outgoing: &mut Sender>, - receiver_incoming: &mut Receiver>, -) -> Result<(), ProxyError> { - let mut state = codec_sv2::State::initialized(role); - - // Receive and deserialize first handshake message - let first_message: HandShakeFrame = receiver_incoming - .recv() - .await - .ok_or(ProxyError::RecvError)? - .try_into() - .map_err(|_| ProxyError::HandshakeRemoteInvalidMessage)?; - let first_message: [u8; RESPONDER_EXPECTED_HANDSHAKE_MESSAGE_SIZE] = first_message - .get_payload_when_handshaking() - .try_into() - .map_err(|_| ProxyError::HandshakeRemoteInvalidMessage)?; - - // Create and send second handshake message - let (second_message, transport_mode) = state.step_1(first_message)?; - HANDSHAKE_READY.store(false, std::sync::atomic::Ordering::SeqCst); - sender_outgoing.send(second_message.into()).await?; - - // This sets the state to Handshake state - this prompts the task above to move the state - // to transport mode so that the next incoming message will be decoded correctly - // It is important to do this directly before sending the fourth message - // T::set_state(connection, transport_mode).await; - Connection::set_state(connection, transport_mode).await; - while !TRANSPORT_READY.load(std::sync::atomic::Ordering::SeqCst) { - std::hint::spin_loop() - } - - Ok(()) -} - -static HANDSHAKE_READY: AtomicBool = AtomicBool::new(false); -static TRANSPORT_READY: AtomicBool = AtomicBool::new(false); - -#[derive(Debug)] -pub struct Connection { - pub state: codec_sv2::State, -} - -impl SetState for Connection { - async fn set_state(self_: Arc>, state: codec_sv2::State) { - loop { - if HANDSHAKE_READY.load(std::sync::atomic::Ordering::SeqCst) { - if let Some(mut connection) = self_.try_lock() { - connection.state = state; - TRANSPORT_READY.store(true, std::sync::atomic::Ordering::Relaxed); - break; - }; - } - task::yield_now().await; - } - } -} - -impl Connection { - #[allow(clippy::new_ret_no_self)] - pub async fn new<'a, Message: Serialize + Deserialize<'a> + GetSize + Send + 'static>( - stream: TcpStream, - role: HandshakeRole, - ) -> Result< - ( - Receiver>, - Sender>, - AbortHandle, - AbortHandle, - ), - ProxyError, - > { - let address = stream.peer_addr().map_err(|_| ProxyError::SocketClosed)?; - - let (mut reader, mut writer) = stream.into_split(); - - let (sender_incoming, mut receiver_incoming): ( - Sender>, - Receiver>, - ) = channel(10); - let (mut sender_outgoing, mut receiver_outgoing): ( - Sender>, - Receiver>, - ) = channel(10); - - let state = codec_sv2::State::not_initialized(&role); - - let connection = Arc::new(Mutex::new(Self { state })); - - let cloned1 = connection.clone(); - let cloned2 = connection.clone(); - - // RECEIVE AND PARSE INCOMING MESSAGES FROM TCP STREAM - let recv_task = task::spawn(async move { - let mut decoder = StandardNoiseDecoder::::new(); - - loop { - let writable = decoder.writable(); - match reader.read_exact(writable).await { - Ok(_) => { - let mut connection = cloned1.lock().await; - let decoded = decoder.next_frame(&mut connection.state); - drop(connection); - - match decoded { - Ok(x) => { - if sender_incoming.send(x).await.is_err() { - error!("Shutting down noise stream reader!"); - task::yield_now().await; - break; - } - } - Err(e) => { - if let codec_sv2::Error::MissingBytes(_) = e { - } else { - error!("Shutting down noise stream reader! {:#?}", e); - task::yield_now().await; - break; - } - } - } - } - Err(e) => { - error!( - "Disconnected from client while reading : {} - {}", - e, &address - ); - task::yield_now().await; - break; - } - } - } - drop(sender_incoming); - drop(cloned1); - drop(reader); - let mut times = 0; - while !decoder.droppable() { - tokio::time::sleep(Duration::from_secs(5)).await; - if times >= 10 { - error!("Irrecoverable error impossible to free decoder"); - std::process::exit(1); - } - times += 1; - } - }); - - // ENCODE AND SEND INCOMING MESSAGES TO TCP STREAM - let send_task = task::spawn(async move { - let mut encoder = codec_sv2::NoiseEncoder::::new(); - - loop { - let received = receiver_outgoing.recv().await; - - match received { - Some(frame) => { - let mut connection = cloned2.lock().await; - match encoder.encode(frame, &mut connection.state) { - Ok(b) => { - drop(connection); - - let b = b.as_ref(); - - match (writer).write_all(b).await { - Ok(_) => (), - Err(e) => { - let _ = writer.shutdown().await; - // Just fail and force to reinitialize everything - error!( - "Disconnecting from client due to error writing: {} - {}", - e, &address - ); - task::yield_now().await; - break; - } - } - } - Err(e) => { - error!( - "Disconnecting from client due to error encoding: {} - {}", - e, &address - ); - drop(connection); - task::yield_now().await; - break; - } - }; - } - None => { - // Just fail and force to reinitialize everything - let _ = writer.shutdown().await; - error!( - "Disconnecting from client due to error receiving from: {}", - &address - ); - task::yield_now().await; - break; - } - }; - HANDSHAKE_READY.store(true, std::sync::atomic::Ordering::Relaxed); - } - receiver_outgoing.close(); - drop(receiver_outgoing); - drop(cloned2); - drop(writer); - let mut times = 0; - while !encoder.droppable() { - tokio::time::sleep(Duration::from_secs(5)).await; - if times >= 10 { - error!("Irrecoverable error impossible to free encoder"); - std::process::exit(1); - } - times += 1; - } - }); - - // DO THE NOISE HANDSHAKE - match role { - HandshakeRole::Initiator(_) => { - debug!("Initializing as downstream for - {}", &address); - initialize_as_downstream( - connection.clone(), - role, - &mut sender_outgoing, - &mut receiver_incoming, - ) - .await? - } - HandshakeRole::Responder(_) => { - debug!("Initializing as upstream for - {}", &address); - initialize_as_upstream( - connection.clone(), - role, - &mut sender_outgoing, - &mut receiver_incoming, - ) - .await? - } - }; - debug!("Noise handshake complete - {}", &address); - Ok(( - receiver_incoming, - sender_outgoing, - recv_task.abort_handle(), - send_task.abort_handle(), - )) - } -} - -#[derive(Debug)] -pub struct NewConnection { - pub state: codec_sv2::State, - pub role: HandshakeRole, - pub sender_incoming: Sender>, - pub receiver_incoming: Receiver>, - pub sender_outgoing: Sender>, - pub receiver_outgoing: Receiver>, -} - -impl + GetSize + Send + 'static> NewConnection { - pub async fn new(role: HandshakeRole) -> Self { - let (sender_incoming, receiver_incoming): ( - Sender>, - Receiver>, - ) = channel(10); - let (sender_outgoing, receiver_outgoing): ( - Sender>, - Receiver>, - ) = channel(10); - Self { - state: codec_sv2::State::not_initialized(&role), - role, - sender_incoming, - receiver_incoming, - sender_outgoing, - receiver_outgoing, - } - } - - pub async fn on_new_stream(mut self, stream: TcpStream) { - let (mut reader, _writer) = stream.into_split(); - - // RECEIVE AND PARSE INCOMING MESSAGES FROM TCP STREAM - let _recv_task = task::spawn(async move { - let mut decoder = StandardNoiseDecoder::::new(); - loop { - let writable = decoder.writable(); - match reader.read_exact(writable).await { - Ok(_) => { - let decoded = decoder.next_frame(&mut self.state); - - match decoded { - Ok(x) => { - if self.sender_incoming.send(x).await.is_err() { - error!("Shutting down noise stream reader!"); - task::yield_now().await; - break; - } - } - Err(e) => { - if let codec_sv2::Error::MissingBytes(_) = e { - } else { - error!("Shutting down noise stream reader! {:#?}", e); - task::yield_now().await; - break; - } - } - } - } - Err(_e) => { - task::yield_now().await; - break; - } - } - } - drop(reader); - let mut times = 0; - while !decoder.droppable() { - tokio::time::sleep(Duration::from_secs(5)).await; - if times >= 10 { - error!("Irrecoverable error impossible to free decoder"); - std::process::exit(1); - } - times += 1; - } - }); - - // // ENCODE AND SEND INCOMING MESSAGES TO TCP STREAM - // let send_task = task::spawn(async move { - // let mut encoder = codec_sv2::NoiseEncoder::::new(); - - // loop { - // let received = receiver_outgoing.recv().await; - - // match received { - // Some(frame) => { - // let mut connection = cloned2.lock().await; - // match encoder.encode(frame, &mut connection.state) { - // Ok(b) => { - // drop(connection); - - // let b = b.as_ref(); - - // match (writer).write_all(b).await { - // Ok(_) => (), - // Err(e) => { - // let _ = writer.shutdown().await; - // // Just fail and force to reinitialize everything - // task::yield_now().await; - // break; - // } - // } - // } - // Err(e) => { - // drop(connection); - // task::yield_now().await; - // break; - // } - // }; - // } - // None => { - // // Just fail and force to reinitialize everything - // let _ = writer.shutdown().await; - // task::yield_now().await; - // break; - // } - // }; - // HANDSHAKE_READY.store(true, std::sync::atomic::Ordering::Relaxed); - // } - // receiver_outgoing.close(); - // drop(receiver_outgoing); - // drop(cloned2); - // drop(writer); - // let mut times = 0; - // while !encoder.droppable() { - // tokio::time::sleep(Duration::from_secs(5)).await; - // if times >= 10 { - // error!("Irrecoverable error impossible to free encoder"); - // std::process::exit(1); - // } - // times += 1; - // } - // }); - - // // DO THE NOISE HANDSHAKE - // // match role { - // // HandshakeRole::Initiator(_) => { - // // initialize_as_downstream( - // // connection.clone(), - // // role, - // // &mut sender_outgoing, - // // &mut receiver_incoming, - // // ) - // // .await? - // // } - // // HandshakeRole::Responder(_) => { - // // initialize_as_upstream( - // // connection.clone(), - // // role, - // // &mut sender_outgoing, - // // &mut receiver_incoming, - // // ) - // // .await? - // // } - // // }; - - } - - // #[allow(clippy::new_ret_no_self)] - // pub async fn new_a<'a, Message: Serialize + Deserialize<'a> + GetSize + Send + 'static>( - // self, - // stream: TcpStream, - // ) -> Result< - // ( - // Receiver>, - // Sender>, - // AbortHandle, - // AbortHandle, - // ), - // ProxyError, - // > { - // let (mut reader, mut writer) = stream.into_split(); - - // // RECEIVE AND PARSE INCOMING MESSAGES FROM TCP STREAM - // let recv_task = task::spawn(async move { - // let mut decoder = StandardNoiseDecoder::::new(); - - // loop { - // let writable = decoder.writable(); - // match reader.read_exact(writable).await { - // Ok(_) => { - // let mut connection = cloned1.lock().await; - // let decoded = decoder.next_frame(&mut connection.state); - // drop(connection); - - // match decoded { - // Ok(x) => { - // if sender_incoming.send(x).await.is_err() { - // error!("Shutting down noise stream reader!"); - // task::yield_now().await; - // break; - // } - // } - // Err(e) => { - // if let codec_sv2::Error::MissingBytes(_) = e { - // } else { - // error!("Shutting down noise stream reader! {:#?}", e); - // task::yield_now().await; - // break; - // } - // } - // } - // } - // Err(e) => { - // task::yield_now().await; - // break; - // } - // } - // } - // drop(sender_incoming); - // drop(cloned1); - // drop(reader); - // let mut times = 0; - // while !decoder.droppable() { - // tokio::time::sleep(Duration::from_secs(5)).await; - // if times >= 10 { - // error!("Irrecoverable error impossible to free decoder"); - // std::process::exit(1); - // } - // times += 1; - // } - // }); - - // // ENCODE AND SEND INCOMING MESSAGES TO TCP STREAM - // let send_task = task::spawn(async move { - // let mut encoder = codec_sv2::NoiseEncoder::::new(); - - // loop { - // let received = receiver_outgoing.recv().await; - - // match received { - // Some(frame) => { - // let mut connection = cloned2.lock().await; - // match encoder.encode(frame, &mut connection.state) { - // Ok(b) => { - // drop(connection); - - // let b = b.as_ref(); - - // match (writer).write_all(b).await { - // Ok(_) => (), - // Err(e) => { - // let _ = writer.shutdown().await; - // // Just fail and force to reinitialize everything - // task::yield_now().await; - // break; - // } - // } - // } - // Err(e) => { - // drop(connection); - // task::yield_now().await; - // break; - // } - // }; - // } - // None => { - // // Just fail and force to reinitialize everything - // let _ = writer.shutdown().await; - // task::yield_now().await; - // break; - // } - // }; - // HANDSHAKE_READY.store(true, std::sync::atomic::Ordering::Relaxed); - // } - // receiver_outgoing.close(); - // drop(receiver_outgoing); - // drop(cloned2); - // drop(writer); - // let mut times = 0; - // while !encoder.droppable() { - // tokio::time::sleep(Duration::from_secs(5)).await; - // if times >= 10 { - // error!("Irrecoverable error impossible to free encoder"); - // std::process::exit(1); - // } - // times += 1; - // } - // }); - - // // DO THE NOISE HANDSHAKE - // // match role { - // // HandshakeRole::Initiator(_) => { - // // initialize_as_downstream( - // // connection.clone(), - // // role, - // // &mut sender_outgoing, - // // &mut receiver_incoming, - // // ) - // // .await? - // // } - // // HandshakeRole::Responder(_) => { - // // initialize_as_upstream( - // // connection.clone(), - // // role, - // // &mut sender_outgoing, - // // &mut receiver_incoming, - // // ) - // // .await? - // // } - // // }; - // Ok(( - // receiver_incoming, - // sender_outgoing, - // recv_task.abort_handle(), - // send_task.abort_handle(), - // )) - // } -} diff --git a/roles/tests-integration/tests/common/proxy/mod.rs b/roles/tests-integration/tests/common/proxy/mod.rs deleted file mode 100644 index 82a70648a3..0000000000 --- a/roles/tests-integration/tests/common/proxy/mod.rs +++ /dev/null @@ -1,489 +0,0 @@ -mod connection; -mod error; - -use std::{net::SocketAddr, sync::Arc}; - -use codec_sv2::{HandshakeRole, Initiator, Responder, StandardEitherFrame, StandardSv2Frame}; -use connection::Connection; -use error::ProxyError; -use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey}; -use roles_logic_sv2::{ - parsers::{ - CommonMessages, - JobDeclaration::{ - AllocateMiningJobToken, AllocateMiningJobTokenSuccess, DeclareMiningJob, - DeclareMiningJobError, DeclareMiningJobSuccess, IdentifyTransactions, - IdentifyTransactionsSuccess, ProvideMissingTransactions, - ProvideMissingTransactionsSuccess, SubmitSolution, - }, - PoolMessages, - TemplateDistribution::{self, CoinbaseOutputDataSize}, - }, - utils::Mutex, -}; -use tokio::{ - net::{TcpListener, TcpStream}, - select, - sync::mpsc::{channel, Receiver, Sender}, -}; - -use codec_sv2::framing_sv2::framing::Frame as EitherFrame; - -type MessageFrame = StandardEitherFrame>; - -#[derive(Debug, Clone)] -pub struct SV2Proxy { - aggregator: MessagesAggregator, -} - -impl SV2Proxy { - pub fn new() -> Self { - let aggregator = MessagesAggregator::new(); - Self { aggregator } - } - - pub async fn start(&self, upstream: SocketAddr, downstream: SocketAddr) { - let mut builder = ProxyBuilder::new(); - builder - .try_add_client(Self::wait_for_client(downstream).await) - .await - .expect("Impossible to add client to the ProxyBuilder") - .try_add_server( - TcpStream::connect(upstream) - .await - .expect("Impossible to connect to server"), - ) - .await - .expect("Impossible to add server to the ProxyBuilder"); - self.listen_for_all_messages(&mut builder); - let err = builder - .start() - .await - .expect("Impossible to build the Client"); - eprintln!("{err:?}"); - } - - pub fn contains(&self, message: ExpectMessage, role: Role) -> bool { - self.aggregator.contains(message, role) - } - - fn listen_for_all_messages(&self, builder: &mut ProxyBuilder) { - for n in 0..44 { - let mut server_recv = builder.add_handler(Role::Server, n); - let mut client_recv = builder.add_handler(Role::Client, n); - let aggregator = self.aggregator.clone(); - tokio::spawn(async move { - tokio::select! { - m = server_recv.recv() => { - aggregator.add_message(m.unwrap(), Role::Server); - } - m = client_recv.recv() => { - aggregator.add_message(m.unwrap(), Role::Client); - } - } - }); - } - } - - async fn wait_for_client(client: SocketAddr) -> TcpStream { - let listner = TcpListener::bind(client) - .await - .expect("Impossible to listen on given address"); - if let Ok((stream, _)) = listner.accept().await { - stream - } else { - panic!("Impossible to accept dowsntream connetion") - } - } -} - -#[derive(Clone, PartialEq)] -pub enum ExpectMessage { - SetupConnection, - SetupConnectionSuccess, - SetupConnectionError, -} - -impl From> for ExpectMessage { - fn from(m: PoolMessages<'static>) -> Self { - match m { - PoolMessages::Common(CommonMessages::SetupConnection(_)) => { - ExpectMessage::SetupConnection - } - PoolMessages::Common(CommonMessages::SetupConnectionSuccess(_)) => { - ExpectMessage::SetupConnectionSuccess - } - PoolMessages::Common(CommonMessages::SetupConnectionError(_)) => { - ExpectMessage::SetupConnectionError - } - _ => unimplemented!(), - } - } -} - -pub struct ProxyBuilder { - from_client: Option>, - to_client: Option>, - from_server: Option>, - to_server: Option>, - cert_validity: u64, - proxy_pub_key: Secp256k1PublicKey, - proxy_sec_key: Secp256k1SecretKey, - server_auth_key: Option, - handlers: Vec, -} - -impl ProxyBuilder { - pub fn new() -> Self { - Self { - from_client: None, - to_client: None, - from_server: None, - to_server: None, - cert_validity: 10000, - proxy_pub_key: "9auqWEzQDVyd2oe1JVGFLMLHZtCo2FFqZwtKA5gd9xbuEu7PH72" - .to_string() - .parse() - .expect("Invalid default pub key"), - proxy_sec_key: "mkDLTBBRxdBv998612qipDYoTK3YUrqLe8uWw7gu3iXbSrn2n" - .to_string() - .parse() - .expect("Invalid default sec key"), - server_auth_key: None, - handlers: vec![], - } - } - - pub async fn try_add_client(&mut self, stream: TcpStream) -> Result<&mut Self, ProxyError> { - if self.from_client.is_none() && self.to_client.is_none() { - let auth_pub_k_as_bytes = self.proxy_pub_key.into_bytes(); - let auth_prv_k_as_bytes = self.proxy_sec_key.into_bytes(); - let responder = Responder::from_authority_kp( - &auth_pub_k_as_bytes, - &auth_prv_k_as_bytes, - std::time::Duration::from_secs(self.cert_validity), - ) - .expect("invalid key pair"); - - if let Ok((receiver_from_client, send_to_client, _, _)) = - Connection::new::<'static, PoolMessages<'static>>( - stream, - HandshakeRole::Responder(responder), - ) - .await - { - self.from_client = Some(receiver_from_client); - self.to_client = Some(send_to_client); - Ok(self) - } else { - Err(ProxyError::ImpossibleToCompleteHandShakeWithDownstream) - } - } else { - Err(ProxyError::CanNotHaveMoreThan1Client) - } - } - - pub async fn try_add_server(&mut self, stream: TcpStream) -> Result<&mut Self, ProxyError> { - if self.from_server.is_none() && self.to_server.is_none() { - let initiator = match self.server_auth_key { - Some(key) => Initiator::from_raw_k(key.into_bytes()) - .expect("Pub key is already checked for validity"), - None => Initiator::without_pk().expect("This fn call can not fail"), - }; - - if let Ok((receiver_from_client, send_to_client, _, _)) = - Connection::new::<'static, PoolMessages<'static>>( - stream, - HandshakeRole::Initiator(initiator), - ) - .await - { - self.from_server = Some(receiver_from_client); - self.to_server = Some(send_to_client); - Ok(self) - } else { - Err(ProxyError::ImpossibleToCompleteHandShakeWithUpstream) - } - } else { - Err(ProxyError::CanNotHaveMoreThan1Server) - } - } - - pub fn add_handler( - &mut self, - expect_from: Role, - message_type: u8, - ) -> Receiver> { - let (s, r) = channel(3); - let channel = MessageChannel { - message_type, - expect_from, - receiver: None, - sender: s, - }; - self.handlers.push(channel); - r - } - - pub async fn start(self) -> Result<(), ProxyError> { - if let (Some(from_client), Some(to_client), Some(from_server), Some(to_server)) = ( - self.from_client, - self.to_client, - self.from_server, - self.to_server, - ) { - let mut client_handlers = vec![]; - let mut server_handlers = vec![]; - for handler in self.handlers { - match handler.expect_from { - Role::Client => client_handlers.push(handler), - Role::Server => server_handlers.push(handler), - } - } - select! { - r = Self::recv_from_down_send_to_up(from_client, to_server, client_handlers) => r, - r = Self::recv_from_up_send_to_down(from_server, to_client, server_handlers) => r, - } - } else { - Err(ProxyError::IncompleteBuilder) - } - } - - async fn recv_from_down_send_to_up( - mut recv: Receiver, - send: Sender, - mut handlers: Vec, - ) -> Result<(), ProxyError> { - while let Some(mut frame) = recv.recv().await { - let mut send_original_frame_upstream = true; - for handler in handlers.iter_mut() { - if let Some(frame) = handler.on_message(&mut frame).await { - send_original_frame_upstream = false; - if send.send(frame).await.is_err() { - return Err(ProxyError::UpstreamClosed); - }; - } - } - if send_original_frame_upstream && send.send(frame).await.is_err() { - return Err(ProxyError::UpstreamClosed); - }; - } - Err(ProxyError::DownstreamClosed) - } - - async fn recv_from_up_send_to_down( - mut recv: Receiver, - send: Sender, - mut handlers: Vec, - ) -> Result<(), ProxyError> { - while let Some(mut frame) = recv.recv().await { - let mut send_original_frame_upstream = true; - for handler in handlers.iter_mut() { - if let Some(frame) = handler.on_message(&mut frame).await { - send_original_frame_upstream = false; - if send.send(frame).await.is_err() { - return Err(ProxyError::DownstreamClosed); - }; - } - } - if send_original_frame_upstream && send.send(frame).await.is_err() { - return Err(ProxyError::DownstreamClosed); - }; - } - Err(ProxyError::UpstreamClosed) - } -} - -#[derive(PartialEq, Clone, Debug)] -pub enum Role { - Client, - Server, -} - -impl std::fmt::Display for Role { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::Client => write!(f, "Client"), - Role::Server => write!(f, "Server"), - } - } -} - -#[derive(Debug, Clone)] -struct MessagesAggregator { - messages: Arc, Role)>>>, -} - -impl MessagesAggregator { - pub fn new() -> Self { - Self { - messages: Arc::new(Mutex::new(vec![])), - } - } - - pub fn add_message(&self, message: PoolMessages<'static>, role: Role) { - self.messages - .safe_lock(|messages| { - let mut new_state = messages.clone(); - new_state.push((message, role)); - *messages = new_state - }) - .unwrap(); - } - - pub fn contains(&self, message: ExpectMessage, role: Role) -> bool { - let messages = self - .messages - .safe_lock(|messages| messages.clone()) - .unwrap(); - let messages = messages - .iter() - .map(|(m, r)| (ExpectMessage::from(m.clone()), r.clone())); - for m in messages { - if m.0 == message && m.1 == role { - return true; - } - } - false - } -} - -struct MessageChannel { - pub message_type: u8, - pub expect_from: Role, - pub receiver: Option>>, - pub sender: Sender>, -} - -impl MessageChannel { - pub async fn on_message(&mut self, frame: &mut MessageFrame) -> Option { - let (mt, message) = self.message_from_frame(frame); - if mt == self.message_type { - if self.sender.send(message).await.is_err() { - panic!("Impossible to send message to message handler, for: {mt}"); - }; - if let Some(receiver) = &mut self.receiver { - if let Some(message) = receiver.recv().await { - let frame: StandardSv2Frame> = message - .try_into() - .expect("A message can always be converted in a frame"); - Some(frame.into()) - } else { - panic!("Impossible to receive message from message handler, for: {mt}"); - } - } else { - None - } - } else { - None - } - } - - fn message_from_frame(&self, frame: &mut MessageFrame) -> (u8, PoolMessages<'static>) { - match frame { - EitherFrame::Sv2(frame) => { - if let Some(header) = frame.get_header() { - let message_type = header.msg_type(); - let mut payload = frame.payload().to_vec(); - let message: Result, _> = - (message_type, payload.as_mut_slice()).try_into(); - match message { - Ok(message) => { - let message = Self::into_static(message); - (message_type, message) - } - _ => { - panic!( - "Received frame with invalid payload or message type: {frame:?}" - ); - } - } - } else { - panic!("Received frame with invalid header: {frame:?}"); - } - } - EitherFrame::HandShake(f) => { - panic!("Received unexpected handshake frame: {f:?}"); - } - } - } - - pub fn into_static(m: PoolMessages<'_>) -> PoolMessages<'static> { - match m { - PoolMessages::Mining(m) => PoolMessages::Mining(m.into_static()), - PoolMessages::Common(m) => match m { - CommonMessages::ChannelEndpointChanged(m) => { - PoolMessages::Common(CommonMessages::ChannelEndpointChanged(m.into_static())) - } - CommonMessages::SetupConnection(m) => { - PoolMessages::Common(CommonMessages::SetupConnection(m.into_static())) - } - CommonMessages::SetupConnectionError(m) => { - PoolMessages::Common(CommonMessages::SetupConnectionError(m.into_static())) - } - CommonMessages::SetupConnectionSuccess(m) => { - PoolMessages::Common(CommonMessages::SetupConnectionSuccess(m.into_static())) - } - }, - PoolMessages::JobDeclaration(m) => match m { - AllocateMiningJobToken(m) => { - PoolMessages::JobDeclaration(AllocateMiningJobToken(m.into_static())) - } - AllocateMiningJobTokenSuccess(m) => { - PoolMessages::JobDeclaration(AllocateMiningJobTokenSuccess(m.into_static())) - } - DeclareMiningJob(m) => { - PoolMessages::JobDeclaration(DeclareMiningJob(m.into_static())) - } - DeclareMiningJobError(m) => { - PoolMessages::JobDeclaration(DeclareMiningJobError(m.into_static())) - } - DeclareMiningJobSuccess(m) => { - PoolMessages::JobDeclaration(DeclareMiningJobSuccess(m.into_static())) - } - IdentifyTransactions(m) => { - PoolMessages::JobDeclaration(IdentifyTransactions(m.into_static())) - } - IdentifyTransactionsSuccess(m) => { - PoolMessages::JobDeclaration(IdentifyTransactionsSuccess(m.into_static())) - } - ProvideMissingTransactions(m) => { - PoolMessages::JobDeclaration(ProvideMissingTransactions(m.into_static())) - } - ProvideMissingTransactionsSuccess(m) => { - PoolMessages::JobDeclaration(ProvideMissingTransactionsSuccess(m.into_static())) - } - SubmitSolution(m) => PoolMessages::JobDeclaration(SubmitSolution(m.into_static())), - }, - PoolMessages::TemplateDistribution(m) => match m { - CoinbaseOutputDataSize(m) => { - PoolMessages::TemplateDistribution(CoinbaseOutputDataSize(m.into_static())) - } - TemplateDistribution::NewTemplate(m) => PoolMessages::TemplateDistribution( - TemplateDistribution::NewTemplate(m.into_static()), - ), - TemplateDistribution::RequestTransactionData(m) => { - PoolMessages::TemplateDistribution( - TemplateDistribution::RequestTransactionData(m.into_static()), - ) - } - TemplateDistribution::RequestTransactionDataError(m) => { - PoolMessages::TemplateDistribution( - TemplateDistribution::RequestTransactionDataError(m.into_static()), - ) - } - TemplateDistribution::RequestTransactionDataSuccess(m) => { - PoolMessages::TemplateDistribution( - TemplateDistribution::RequestTransactionDataSuccess(m.into_static()), - ) - } - TemplateDistribution::SetNewPrevHash(m) => PoolMessages::TemplateDistribution( - TemplateDistribution::SetNewPrevHash(m.into_static()), - ), - TemplateDistribution::SubmitSolution(m) => PoolMessages::TemplateDistribution( - TemplateDistribution::SubmitSolution(m.into_static()), - ), - }, - } - } -} diff --git a/roles/tests-integration/tests/common/proxy/error.rs b/roles/tests-integration/tests/common/sniffer/error.rs similarity index 100% rename from roles/tests-integration/tests/common/proxy/error.rs rename to roles/tests-integration/tests/common/sniffer/error.rs diff --git a/roles/tests-integration/tests/common/sniffer/mod.rs b/roles/tests-integration/tests/common/sniffer/mod.rs new file mode 100644 index 0000000000..111831ad75 --- /dev/null +++ b/roles/tests-integration/tests/common/sniffer/mod.rs @@ -0,0 +1,383 @@ +mod error; + +use std::{collections::VecDeque, net::SocketAddr, sync::Arc}; + +use async_channel::{Receiver, Sender}; +use codec_sv2::{HandshakeRole, Initiator, Responder, StandardEitherFrame}; +use error::ProxyError; +use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey}; +use network_helpers_sv2::noise_connection_tokio::Connection; +use roles_logic_sv2::{ + parsers::{ + CommonMessages, + JobDeclaration::{ + AllocateMiningJobToken, AllocateMiningJobTokenSuccess, DeclareMiningJob, + DeclareMiningJobError, DeclareMiningJobSuccess, IdentifyTransactions, + IdentifyTransactionsSuccess, ProvideMissingTransactions, + ProvideMissingTransactionsSuccess, SubmitSolution, + }, + PoolMessages, + TemplateDistribution::{self, CoinbaseOutputDataSize}, + }, + utils::Mutex, +}; +use tokio::{ + net::{TcpListener, TcpStream}, + select, +}; + +use codec_sv2::framing_sv2::framing::Frame as EitherFrame; + +type MessageFrame = StandardEitherFrame>; + +#[derive(Debug, Clone)] +pub struct Sniffer { + downstream: SocketAddr, + upstream: SocketAddr, + downstream_messages: MessagesAggregator, + upstream_messages: MessagesAggregator, +} + +impl Drop for Sniffer { + fn drop(&mut self) { + if !self.downstream_messages.is_empty() { + panic!("Downstream messages: {:?}", self.downstream_messages); + } + if !self.upstream_messages.is_empty() { + panic!("Upstream messages: {:?}", self.upstream_messages); + } + } +} + +impl Sniffer { + pub async fn new(downstream: SocketAddr, upstream: SocketAddr) -> Self { + Self { + downstream, + upstream, + downstream_messages: MessagesAggregator::new(Role::Downstream), + upstream_messages: MessagesAggregator::new(Role::Upstream), + } + } + + pub async fn start(self) { + let (downstream_receiver, downstream_sender) = + Self::create_downstream(Self::wait_for_client(self.downstream).await) + .await + .expect("Failed to create downstream"); + let (upstream_receiver, upstream_sender) = Self::create_upstream( + TcpStream::connect(self.upstream) + .await + .expect("Failed to connect to upstream"), + ) + .await + .expect("Failed to create upstream"); + let downstream_messages = self.downstream_messages.clone(); + let upstream_messages = self.upstream_messages.clone(); + let _ = select! { + r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages) => r, + r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages) => r, + }; + } + + pub fn downstream_state(&self, message: ExpectMessage) -> bool { + self.downstream_messages.current_state(message) + } + + pub fn upstream_state(&self, message: ExpectMessage) -> bool { + self.upstream_messages.current_state(message) + } + + async fn create_downstream( + stream: TcpStream, + ) -> Option<(Receiver, Sender)> { + let cert_validity = 10000; + let proxy_pub_key: Secp256k1PublicKey = + "9auqWEzQDVyd2oe1JVGFLMLHZtCo2FFqZwtKA5gd9xbuEu7PH72" + .to_string() + .parse() + .expect("Invalid default pub key"); + let proxy_sec_key: Secp256k1SecretKey = "mkDLTBBRxdBv998612qipDYoTK3YUrqLe8uWw7gu3iXbSrn2n" + .to_string() + .parse() + .expect("Invalid default pub key"); + let auth_pub_k_as_bytes = proxy_pub_key.into_bytes(); + let auth_prv_k_as_bytes = proxy_sec_key.into_bytes(); + let responder = Responder::from_authority_kp( + &auth_pub_k_as_bytes, + &auth_prv_k_as_bytes, + std::time::Duration::from_secs(cert_validity), + ) + .expect("invalid key pair"); + if let Ok((receiver_from_client, send_to_client, _, _)) = + Connection::new::<'static, PoolMessages<'static>>( + stream, + HandshakeRole::Responder(responder), + ) + .await + { + Some((receiver_from_client, send_to_client)) + } else { + None + } + } + + async fn create_upstream( + stream: TcpStream, + ) -> Option<(Receiver, Sender)> { + let initiator = Initiator::without_pk().expect("This fn call can not fail"); + if let Ok((receiver_from_client, send_to_client, _, _)) = + Connection::new::<'static, PoolMessages<'static>>( + stream, + HandshakeRole::Initiator(initiator), + ) + .await + { + Some((receiver_from_client, send_to_client)) + } else { + None + } + } + + async fn recv_from_down_send_to_up( + recv: Receiver, + send: Sender, + downstream_messages: MessagesAggregator, + ) -> Result<(), ProxyError> { + while let Ok(mut frame) = recv.recv().await { + let msg = Self::message_from_frame(&mut frame); + downstream_messages.add_message(msg.0, msg.1.clone()); + if send.send(frame).await.is_err() { + return Err(ProxyError::UpstreamClosed); + }; + } + Err(ProxyError::DownstreamClosed) + } + + async fn recv_from_up_send_to_down( + recv: Receiver, + send: Sender, + upstream_messages: MessagesAggregator, + ) -> Result<(), ProxyError> { + while let Ok(mut frame) = recv.recv().await { + let msg = Self::message_from_frame(&mut frame); + upstream_messages.add_message(msg.0, msg.1.clone()); + if send.send(frame).await.is_err() { + return Err(ProxyError::DownstreamClosed); + }; + } + Err(ProxyError::UpstreamClosed) + } + + fn message_from_frame(frame: &mut MessageFrame) -> (u8, PoolMessages<'static>) { + match frame { + EitherFrame::Sv2(frame) => { + if let Some(header) = frame.get_header() { + let message_type = header.msg_type(); + let mut payload = frame.payload().to_vec(); + let message: Result, _> = + (message_type, payload.as_mut_slice()).try_into(); + match message { + Ok(message) => { + let message = Self::into_static(message); + (message_type, message) + } + _ => { + panic!( + "Received frame with invalid payload or message type: {frame:?}" + ); + } + } + } else { + panic!("Received frame with invalid header: {frame:?}"); + } + } + EitherFrame::HandShake(f) => { + panic!("Received unexpected handshake frame: {f:?}"); + } + } + } + + fn into_static(m: PoolMessages<'_>) -> PoolMessages<'static> { + match m { + PoolMessages::Mining(m) => PoolMessages::Mining(m.into_static()), + PoolMessages::Common(m) => match m { + CommonMessages::ChannelEndpointChanged(m) => { + PoolMessages::Common(CommonMessages::ChannelEndpointChanged(m.into_static())) + } + CommonMessages::SetupConnection(m) => { + PoolMessages::Common(CommonMessages::SetupConnection(m.into_static())) + } + CommonMessages::SetupConnectionError(m) => { + PoolMessages::Common(CommonMessages::SetupConnectionError(m.into_static())) + } + CommonMessages::SetupConnectionSuccess(m) => { + PoolMessages::Common(CommonMessages::SetupConnectionSuccess(m.into_static())) + } + }, + PoolMessages::JobDeclaration(m) => match m { + AllocateMiningJobToken(m) => { + PoolMessages::JobDeclaration(AllocateMiningJobToken(m.into_static())) + } + AllocateMiningJobTokenSuccess(m) => { + PoolMessages::JobDeclaration(AllocateMiningJobTokenSuccess(m.into_static())) + } + DeclareMiningJob(m) => { + PoolMessages::JobDeclaration(DeclareMiningJob(m.into_static())) + } + DeclareMiningJobError(m) => { + PoolMessages::JobDeclaration(DeclareMiningJobError(m.into_static())) + } + DeclareMiningJobSuccess(m) => { + PoolMessages::JobDeclaration(DeclareMiningJobSuccess(m.into_static())) + } + IdentifyTransactions(m) => { + PoolMessages::JobDeclaration(IdentifyTransactions(m.into_static())) + } + IdentifyTransactionsSuccess(m) => { + PoolMessages::JobDeclaration(IdentifyTransactionsSuccess(m.into_static())) + } + ProvideMissingTransactions(m) => { + PoolMessages::JobDeclaration(ProvideMissingTransactions(m.into_static())) + } + ProvideMissingTransactionsSuccess(m) => { + PoolMessages::JobDeclaration(ProvideMissingTransactionsSuccess(m.into_static())) + } + SubmitSolution(m) => PoolMessages::JobDeclaration(SubmitSolution(m.into_static())), + }, + PoolMessages::TemplateDistribution(m) => match m { + CoinbaseOutputDataSize(m) => { + PoolMessages::TemplateDistribution(CoinbaseOutputDataSize(m.into_static())) + } + TemplateDistribution::NewTemplate(m) => PoolMessages::TemplateDistribution( + TemplateDistribution::NewTemplate(m.into_static()), + ), + TemplateDistribution::RequestTransactionData(m) => { + PoolMessages::TemplateDistribution( + TemplateDistribution::RequestTransactionData(m.into_static()), + ) + } + TemplateDistribution::RequestTransactionDataError(m) => { + PoolMessages::TemplateDistribution( + TemplateDistribution::RequestTransactionDataError(m.into_static()), + ) + } + TemplateDistribution::RequestTransactionDataSuccess(m) => { + PoolMessages::TemplateDistribution( + TemplateDistribution::RequestTransactionDataSuccess(m.into_static()), + ) + } + TemplateDistribution::SetNewPrevHash(m) => PoolMessages::TemplateDistribution( + TemplateDistribution::SetNewPrevHash(m.into_static()), + ), + TemplateDistribution::SubmitSolution(m) => PoolMessages::TemplateDistribution( + TemplateDistribution::SubmitSolution(m.into_static()), + ), + }, + } + } + + async fn wait_for_client(client: SocketAddr) -> TcpStream { + let listner = TcpListener::bind(client) + .await + .expect("Impossible to listen on given address"); + if let Ok((stream, _)) = listner.accept().await { + stream + } else { + panic!("Impossible to accept dowsntream connetion") + } + } +} + +type MsgType = u8; + +#[derive(Debug, Clone)] +struct MessagesAggregator { + messages: Arc)>>>, + role: Role, +} + +#[derive(Debug, Clone)] +enum Role { + Upstream, + Downstream, +} + +impl MessagesAggregator { + pub fn new(role: Role) -> Self { + Self { + messages: Arc::new(Mutex::new(VecDeque::new())), + role, + } + } + + pub fn add_message(&self, msg_type: MsgType, message: PoolMessages<'static>) { + self.messages + .safe_lock(|messages| messages.push_back((msg_type, message))) + .unwrap(); + } + + pub fn is_empty(&self) -> bool { + self.messages + .safe_lock(|messages| messages.is_empty()) + .unwrap() + } + + pub fn current_state(&self, expected_message: ExpectMessage) -> bool { + // remove first element in vecqueue and compare it with expected message + let is_state = self + .messages + .safe_lock(|messages| { + let mut cloned = messages.clone(); + if let Some((_msg_type, msg)) = cloned.pop_front() { + let msg = ExpectMessage::from(msg); + if expected_message == msg { + *messages = cloned; + true + } else { + false + } + } else { + false + } + }) + .unwrap(); + is_state + } +} + +#[derive(Clone, PartialEq)] +pub enum ExpectMessage { + SetupConnection, + SetupConnectionSuccess, + SetupConnectionError, + CoinbaseOutputDataSize, + NewTemplate, + SetNewPrevHash, +} + +impl From> for ExpectMessage { + fn from(m: PoolMessages<'static>) -> Self { + dbg!(&m); + match m { + PoolMessages::Common(CommonMessages::SetupConnection(_)) => { + ExpectMessage::SetupConnection + } + PoolMessages::Common(CommonMessages::SetupConnectionSuccess(_)) => { + ExpectMessage::SetupConnectionSuccess + } + PoolMessages::Common(CommonMessages::SetupConnectionError(_)) => { + ExpectMessage::SetupConnectionError + } + PoolMessages::TemplateDistribution(TemplateDistribution::CoinbaseOutputDataSize(_)) => { + ExpectMessage::CoinbaseOutputDataSize + } + PoolMessages::TemplateDistribution(TemplateDistribution::NewTemplate(_)) => { + ExpectMessage::NewTemplate + } + PoolMessages::TemplateDistribution(TemplateDistribution::SetNewPrevHash(_)) => { + ExpectMessage::SetNewPrevHash + } + _ => unimplemented!(), + } + } +} diff --git a/roles/tests-integration/tests/pool_integration.rs b/roles/tests-integration/tests/pool_integration.rs index b0d5e8b17b..a4fb837add 100644 --- a/roles/tests-integration/tests/pool_integration.rs +++ b/roles/tests-integration/tests/pool_integration.rs @@ -1,4 +1,4 @@ -use common::proxy::{ExpectMessage, Role}; +use common::sniffer::ExpectMessage; mod common; @@ -8,8 +8,11 @@ async fn success_pool_template_provider_connection() { let tp_addr = common::get_available_address(); let pool_addr = common::get_available_address(); let _tp = common::start_template_provider(tp_addr.port()).await; - let proxy = common::start_proxy(tp_addr, proxy_addr).await; + let proxy = common::start_sniffer(tp_addr, proxy_addr).await; let _pool = common::start_poolsv2(Some(pool_addr), None, Some(proxy_addr)).await; - assert!(proxy.contains(ExpectMessage::SetupConnection, Role::Client)); - assert!(proxy.contains(ExpectMessage::SetupConnectionSuccess, Role::Server)); + assert!(proxy.downstream_state(ExpectMessage::SetupConnection)); + assert!(proxy.upstream_state(ExpectMessage::SetupConnectionSuccess)); + assert!(proxy.downstream_state(ExpectMessage::CoinbaseOutputDataSize)); + assert!(proxy.upstream_state(ExpectMessage::NewTemplate)); + assert!(proxy.upstream_state(ExpectMessage::SetNewPrevHash)); }