diff --git a/roles/tests-integration/tests/common/mod.rs b/roles/tests-integration/tests/common/mod.rs index ed9be47a5e..3d5b4b6c76 100644 --- a/roles/tests-integration/tests/common/mod.rs +++ b/roles/tests-integration/tests/common/mod.rs @@ -275,7 +275,7 @@ pub async fn start_sniffer(upstream: SocketAddr, downstream: SocketAddr) -> Snif sniffer } -pub async fn start_poolsv2( +pub async fn start_pool( listening_address: Option, coinbase_outputs: Option>, template_provider_address: Option, @@ -288,9 +288,7 @@ pub async fn start_poolsv2( let pool = test_pool.pool.clone(); let pool_clone = pool.clone(); tokio::task::spawn(async move { - let ret = pool_clone.start().await; - dbg!(&ret); - assert!(ret.is_ok()); + assert!(pool_clone.start().await.is_ok()); }); tokio::time::sleep(std::time::Duration::from_secs(1)).await; pool diff --git a/roles/tests-integration/tests/common/sniffer/mod.rs b/roles/tests-integration/tests/common/sniffer/mod.rs index 750d4caf9a..fffa494dc6 100644 --- a/roles/tests-integration/tests/common/sniffer/mod.rs +++ b/roles/tests-integration/tests/common/sniffer/mod.rs @@ -49,8 +49,8 @@ impl Sniffer { Self { downstream, upstream, - downstream_messages: MessagesAggregator::new(Role::Downstream), - upstream_messages: MessagesAggregator::new(Role::Upstream), + downstream_messages: MessagesAggregator::new(), + upstream_messages: MessagesAggregator::new(), } } @@ -74,12 +74,51 @@ impl Sniffer { }; } - pub fn downstream_state(&self, message: ExpectMessage) -> bool { - self.downstream_messages.current_state(message) + pub fn downstream_state(&self) -> Option> { + self.downstream_messages.current_state() } - pub fn upstream_state(&self, message: ExpectMessage) -> bool { - self.upstream_messages.current_state(message) + pub fn upstream_state(&self) -> Option> { + self.upstream_messages.current_state() + } + + pub fn expect_downstream_setup_connection(&self) -> bool { + match self.downstream_state() { + Some(PoolMessages::Common(CommonMessages::SetupConnection(..))) => true, + _ => false, + } + } + + pub fn expect_downstream_coinbase_output_data_size(&self) -> bool { + match self.downstream_state() { + Some(PoolMessages::TemplateDistribution( + TemplateDistribution::CoinbaseOutputDataSize(..), + )) => true, + _ => false, + } + } + + pub fn expect_upstream_setup_connection_success(&self) -> bool { + match self.upstream_state() { + Some(PoolMessages::Common(CommonMessages::SetupConnectionSuccess(..))) => true, + _ => false, + } + } + + pub fn expect_upstream_new_template(&self) -> bool { + match self.upstream_state() { + Some(PoolMessages::TemplateDistribution(TemplateDistribution::NewTemplate(..))) => true, + _ => false, + } + } + + pub fn expect_upstream_set_new_prev_hash(&self) -> bool { + match self.upstream_state() { + Some(PoolMessages::TemplateDistribution(TemplateDistribution::SetNewPrevHash(..))) => { + true + } + _ => false, + } } async fn create_downstream( @@ -309,20 +348,12 @@ type MsgType = u8; #[derive(Debug, Clone)] struct MessagesAggregator { messages: Arc)>>>, - role: Role, -} - -#[derive(Debug, Clone)] -enum Role { - Upstream, - Downstream, } impl MessagesAggregator { - pub fn new(role: Role) -> Self { + pub fn new() -> Self { Self { messages: Arc::new(Mutex::new(VecDeque::new())), - role, } } @@ -338,61 +369,20 @@ impl MessagesAggregator { .unwrap() } - pub fn current_state(&self, expected_message: ExpectMessage) -> bool { + pub fn current_state(&self) -> Option> { // remove first element in vecqueue and compare it with expected message let is_state = self .messages .safe_lock(|messages| { let mut cloned = messages.clone(); if let Some((_msg_type, msg)) = cloned.pop_front() { - let msg = ExpectMessage::from(msg); - if expected_message == msg { - *messages = cloned; - true - } else { - false - } + *messages = cloned; + Some(msg) } else { - false + None } }) .unwrap(); is_state } } - -#[derive(Clone, PartialEq)] -pub enum ExpectMessage { - SetupConnection, - SetupConnectionSuccess, - SetupConnectionError, - CoinbaseOutputDataSize, - NewTemplate, - SetNewPrevHash, -} - -impl From> for ExpectMessage { - fn from(m: PoolMessages<'static>) -> Self { - match m { - PoolMessages::Common(CommonMessages::SetupConnection(_)) => { - ExpectMessage::SetupConnection - } - PoolMessages::Common(CommonMessages::SetupConnectionSuccess(_)) => { - ExpectMessage::SetupConnectionSuccess - } - PoolMessages::Common(CommonMessages::SetupConnectionError(_)) => { - ExpectMessage::SetupConnectionError - } - PoolMessages::TemplateDistribution(TemplateDistribution::CoinbaseOutputDataSize(_)) => { - ExpectMessage::CoinbaseOutputDataSize - } - PoolMessages::TemplateDistribution(TemplateDistribution::NewTemplate(_)) => { - ExpectMessage::NewTemplate - } - PoolMessages::TemplateDistribution(TemplateDistribution::SetNewPrevHash(_)) => { - ExpectMessage::SetNewPrevHash - } - _ => unimplemented!(), - } - } -} diff --git a/roles/tests-integration/tests/pool_integration.rs b/roles/tests-integration/tests/pool_integration.rs index 6743449f0d..78260f7652 100644 --- a/roles/tests-integration/tests/pool_integration.rs +++ b/roles/tests-integration/tests/pool_integration.rs @@ -1,5 +1,3 @@ -use common::sniffer::ExpectMessage; - mod common; #[tokio::test] @@ -9,10 +7,10 @@ async fn success_pool_template_provider_connection() { let pool_addr = common::get_available_address(); let _tp = common::start_template_provider(tp_addr.port()).await; let sniffer = common::start_sniffer(tp_addr, sniffer_addr).await; - let _pool = common::start_poolsv2(Some(pool_addr), None, Some(sniffer_addr)).await; - assert!(sniffer.downstream_state(ExpectMessage::SetupConnection)); - assert!(sniffer.upstream_state(ExpectMessage::SetupConnectionSuccess)); - assert!(sniffer.downstream_state(ExpectMessage::CoinbaseOutputDataSize)); - assert!(sniffer.upstream_state(ExpectMessage::NewTemplate)); - assert!(sniffer.upstream_state(ExpectMessage::SetNewPrevHash)); + let _ = common::start_pool(Some(pool_addr), None, Some(sniffer_addr)).await; + assert!(sniffer.expect_downstream_setup_connection()); + assert!(sniffer.expect_upstream_setup_connection_success()); + assert!(sniffer.expect_downstream_coinbase_output_data_size()); + assert!(sniffer.expect_upstream_new_template()); + assert!(sniffer.expect_upstream_set_new_prev_hash()); }