diff --git a/utils/message-generator/src/executor.rs b/utils/message-generator/src/executor.rs index f31991eca2..22843e03a2 100644 --- a/utils/message-generator/src/executor.rs +++ b/utils/message-generator/src/executor.rs @@ -199,35 +199,21 @@ impl Executor { result ); - // If the connection should drop at this point then let's just break the loop - // Can't do anything else after the connection drops. - if *result == ActionResult::CloseConnection { - info!( - "Waiting 1 sec to make sure that remote have time to close the connection" - ); - tokio::time::sleep(std::time::Duration::from_millis(1000)).await; - recv.recv() - .await - .expect_err("Expecting the connection to be closed: wasn't"); - success = true; - break; - } - - let message = match recv.recv().await { - Ok(message) => message, - Err(_) => { - success = false; - error!("Connection closed before receiving the message"); - break; - } - }; - - let mut message: Sv2Frame, _> = message.try_into().unwrap(); - debug!("RECV {:#?}", message); - let header = message.get_header().unwrap(); - let payload = message.payload(); match result { ActionResult::MatchMessageType(message_type) => { + let message = match recv.recv().await { + Ok(message) => message, + Err(_) => { + success = false; + error!("Connection closed before receiving the message"); + break; + } + }; + + let message: Sv2Frame, _> = message.try_into().unwrap(); + debug!("RECV {:#?}", message); + let header = message.get_header().unwrap(); + if header.msg_type() != *message_type { error!( "WRONG MESSAGE TYPE expected: {} received: {}", @@ -245,6 +231,20 @@ impl Executor { message_type, field_data, // Vec<(String, Sv2Type)> )) => { + let message = match recv.recv().await { + Ok(message) => message, + Err(_) => { + success = false; + error!("Connection closed before receiving the message"); + break; + } + }; + + let mut message: Sv2Frame, _> = + message.try_into().unwrap(); + debug!("RECV {:#?}", message); + let header = message.get_header().unwrap(); + let payload = message.payload(); if subprotocol.as_str() == "CommonMessages" { match (header.msg_type(), payload).try_into() { Ok(roles_logic_sv2::parsers::CommonMessages::SetupConnection(m)) => { @@ -532,6 +532,20 @@ impl Executor { message_type: _, fields, } => { + let message = match recv.recv().await { + Ok(message) => message, + Err(_) => { + success = false; + error!("Connection closed before receiving the message"); + break; + } + }; + + let mut message: Sv2Frame, _> = + message.try_into().unwrap(); + debug!("RECV {:#?}", message); + let header = message.get_header().unwrap(); + let payload = message.payload(); if subprotocol.as_str() == "CommonMessages" { match (header.msg_type(), payload).try_into() { Ok(parsers::CommonMessages::SetupConnection(m)) => { @@ -730,6 +744,19 @@ impl Executor { }; } ActionResult::MatchMessageLen(message_len) => { + let message = match recv.recv().await { + Ok(message) => message, + Err(_) => { + success = false; + error!("Connection closed before receiving the message"); + break; + } + }; + + let mut message: Sv2Frame, _> = + message.try_into().unwrap(); + debug!("RECV {:#?}", message); + let payload = message.payload(); if payload.len() != *message_len { error!( "WRONG MESSAGE len expected: {} received: {}", @@ -741,6 +768,18 @@ impl Executor { } } ActionResult::MatchExtensionType(ext_type) => { + let message = match recv.recv().await { + Ok(message) => message, + Err(_) => { + success = false; + error!("Connection closed before receiving the message"); + break; + } + }; + + let message: Sv2Frame, _> = message.try_into().unwrap(); + debug!("RECV {:#?}", message); + let header = message.get_header().unwrap(); if header.ext_type() != *ext_type { error!( "WRONG EXTENSION TYPE expected: {} received: {}", @@ -752,7 +791,26 @@ impl Executor { } } ActionResult::CloseConnection => { - todo!() + info!( + "Waiting 1 sec to make sure that remote has time to close the connection" + ); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + if !recv.is_closed() { + error!("Expected connection to close, but it didn't. Test failed."); + success = false; + break; + } + } + ActionResult::SustainConnection => { + info!( + "Waiting 1 sec to make sure that remote has time to close the connection" + ); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + if recv.is_closed() { + error!("Expected connection to sustain, but it didn't. Test failed."); + success = false; + break; + } } ActionResult::None => todo!(), } diff --git a/utils/message-generator/src/main.rs b/utils/message-generator/src/main.rs index e2633ed23a..327d50cbc7 100644 --- a/utils/message-generator/src/main.rs +++ b/utils/message-generator/src/main.rs @@ -191,6 +191,7 @@ enum ActionResult { MatchMessageLen(usize), MatchExtensionType(u16), CloseConnection, + SustainConnection, None, } @@ -225,6 +226,7 @@ impl std::fmt::Display for ActionResult { write!(f, "MatchExtensionType: {}", extension_type) } ActionResult::CloseConnection => write!(f, "Close connection"), + ActionResult::SustainConnection => write!(f, "Sustain connection"), ActionResult::GetMessageField { subprotocol, fields, diff --git a/utils/message-generator/src/parser/actions.rs b/utils/message-generator/src/parser/actions.rs index ce84c7adf1..23bf6188be 100644 --- a/utils/message-generator/src/parser/actions.rs +++ b/utils/message-generator/src/parser/actions.rs @@ -91,6 +91,7 @@ impl Sv2ActionParser { "close_connection" => { action_results.push(ActionResult::CloseConnection); } + "sustain_connection" => action_results.push(ActionResult::SustainConnection), "none" => { action_results.push(ActionResult::None); }