From 73ae509d2b521e7aed9a3c725df8b8205153c26b Mon Sep 17 00:00:00 2001 From: Marc-Antoine Perennou Date: Tue, 1 Oct 2024 22:55:16 +0200 Subject: [PATCH] simplify waiting for recovery Signed-off-by: Marc-Antoine Perennou --- examples/p.rs | 20 ++++++++------------ src/acker.rs | 11 +++++++---- src/channel.rs | 18 +++++++++--------- src/channel_recovery_context.rs | 5 +++-- src/channels.rs | 4 ++-- src/error.rs | 27 ++++++++++++++++++--------- src/io_loop.rs | 2 +- 7 files changed, 48 insertions(+), 39 deletions(-) diff --git a/examples/p.rs b/examples/p.rs index 2bde3f4d..7abb29fa 100644 --- a/examples/p.rs +++ b/examples/p.rs @@ -1,6 +1,5 @@ use lapin::{ - options::*, types::FieldTable, BasicProperties, ChannelState, Connection, ConnectionProperties, - Error, + options::*, types::FieldTable, BasicProperties, Connection, ConnectionProperties, }; use tracing::info; @@ -85,16 +84,13 @@ fn main() { } Err(err) => { println!("GOT ERROR"); - match err { - Error::InvalidChannelState(ChannelState::Reconnecting, Some(notifier)) => { - notifier.await - } - err => { - if !err.is_amqp_soft_error() { - panic!("{}", err); - } - errors += 1; - } + let (soft, notifier) = err.is_amqp_soft_error(); + if !soft { + panic!("{}", err); + } + errors += 1; + if let Some(notifier) = notifier { + notifier.await } } } diff --git a/src/acker.rs b/src/acker.rs index 401426b0..53494c13 100644 --- a/src/acker.rs +++ b/src/acker.rs @@ -78,10 +78,13 @@ impl Acker { async fn rpc)>(&self, f: F) -> Result<()> { if self.used.swap(true, Ordering::SeqCst) { - return Err(Error::ProtocolError(AMQPError::new( - AMQPSoftError::PRECONDITIONFAILED.into(), - "Attempted to use an already used Acker".into(), - ))); + return Err(Error::ProtocolError( + AMQPError::new( + AMQPSoftError::PRECONDITIONFAILED.into(), + "Attempted to use an already used Acker".into(), + ), + None, + )); } if let Some(error) = self.error.as_ref() { error.check()?; diff --git a/src/channel.rs b/src/channel.rs index 92880f61..da7fdd93 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -295,7 +295,7 @@ impl Channel { class_id, method_id, ); - Err(Error::ProtocolError(error)) + Err(Error::ProtocolError(error, None)) } } @@ -426,7 +426,7 @@ impl Channel { class_id, method_id, ); - Err(Error::ProtocolError(error)) + Err(Error::ProtocolError(error, None)) } pub(crate) fn handle_content_header_frame( @@ -465,7 +465,7 @@ impl Channel { class_id, 0, ); - let error = Error::ProtocolError(error); + let error = Error::ProtocolError(error, None); self.set_connection_error(error.clone()); Err(error) }, @@ -534,7 +534,7 @@ impl Channel { ) .await }); - Err(Error::ProtocolError(err)) + Err(Error::ProtocolError(err, None)) } fn before_connection_start_ok( @@ -553,7 +553,7 @@ impl Channel { } fn on_connection_close_ok_sent(&self, error: Error) { - if let Error::ProtocolError(_) = error { + if let Error::ProtocolError(_, _) = error { self.internal_rpc.set_connection_error(error); } else { self.internal_rpc.set_connection_closed(error); @@ -573,7 +573,7 @@ impl Channel { fn on_channel_close_ok_sent(&self, error: Option) { if !self.recovery_config.auto_recover_channels - || !error.as_ref().map_or(false, Error::is_amqp_soft_error) + || !error.as_ref().map_or(false, |e| e.is_amqp_soft_error().0) { self.set_closed( error @@ -822,7 +822,7 @@ impl Channel { ?error, "Connection closed", ); - Error::ProtocolError(error) + Error::ProtocolError(error, None) }) .unwrap_or_else(|error| { error!(%error); @@ -911,13 +911,13 @@ impl Channel { channel=%self.id, ?method, ?error, "Channel closed" ); - Error::ProtocolError(error) + Error::ProtocolError(error, None) }); match ( self.recovery_config.auto_recover_channels, error.clone().ok(), ) { - (true, Some(error)) if error.is_amqp_soft_error() => { + (true, Some(error)) if error.is_amqp_soft_error().0 => { self.status.set_reconnecting(error) } (_, err) => self.set_closing(err), diff --git a/src/channel_recovery_context.rs b/src/channel_recovery_context.rs index dd9c5a6b..936123f1 100644 --- a/src/channel_recovery_context.rs +++ b/src/channel_recovery_context.rs @@ -14,10 +14,11 @@ pub(crate) struct ChannelRecoveryContext { impl ChannelRecoveryContext { pub(crate) fn new(cause: Error) -> Self { + let notifier = Notifier::default(); Self { - cause, + cause: cause.with_notifier(notifier.clone()), expected_replies: None, - notifier: Notifier::default(), + notifier, } } diff --git a/src/channels.rs b/src/channels.rs index 06481dd6..630fd706 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -225,7 +225,7 @@ impl Channels { .await }); } - return Err(Error::ProtocolError(error)); + return Err(Error::ProtocolError(error, None)); } } AMQPFrame::Header(channel_id, class_id, header) => { @@ -248,7 +248,7 @@ impl Channels { .await }); } - return Err(Error::ProtocolError(error)); + return Err(Error::ProtocolError(error, None)); } else { self.handle_content_header_frame( channel_id, diff --git a/src/error.rs b/src/error.rs index 0265dcb7..1528a4c8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,7 +27,7 @@ pub enum Error { IOError(Arc), ParsingError(ParserError), - ProtocolError(AMQPError), + ProtocolError(AMQPError, Option), SerialisationError(Arc), MissingHeartbeatError, @@ -53,23 +53,30 @@ impl Error { } } - pub fn is_amqp_soft_error(&self) -> bool { - if let Error::ProtocolError(e) = self { + pub fn is_amqp_soft_error(&self) -> (bool, Option) { + if let Error::ProtocolError(e, notifier) = self { if let AMQPErrorKind::Soft(_) = e.kind() { - return true; + return (true, notifier.clone()); } } - false + (false, None) } pub fn is_amqp_hard_error(&self) -> bool { - if let Error::ProtocolError(e) = self { + if let Error::ProtocolError(e, _) = self { if let AMQPErrorKind::Hard(_) = e.kind() { return true; } } false } + + pub(crate) fn with_notifier(self, notifier: Notifier) -> Self { + match self { + Self::ProtocolError(err, _) => Self::ProtocolError(err, Some(notifier)), + err => err, + } + } } impl fmt::Display for Error { @@ -91,7 +98,7 @@ impl fmt::Display for Error { Error::IOError(e) => write!(f, "IO error: {}", e), Error::ParsingError(e) => write!(f, "failed to parse: {}", e), - Error::ProtocolError(e) => write!(f, "protocol error: {}", e), + Error::ProtocolError(e, _) => write!(f, "protocol error: {}", e), Error::SerialisationError(e) => write!(f, "failed to serialise: {}", e), Error::MissingHeartbeatError => { @@ -119,7 +126,7 @@ impl error::Error for Error { match self { Error::IOError(e) => Some(&**e), Error::ParsingError(e) => Some(e), - Error::ProtocolError(e) => Some(e), + Error::ProtocolError(e, _) => Some(e), Error::SerialisationError(e) => Some(&**e), _ => None, } @@ -156,7 +163,9 @@ impl PartialEq for Error { false } (ParsingError(left_inner), ParsingError(right_inner)) => left_inner == right_inner, - (ProtocolError(left_inner), ProtocolError(right_inner)) => left_inner == right_inner, + (ProtocolError(left_inner, _), ProtocolError(right_inner, _)) => { + left_inner == right_inner + } (SerialisationError(_), SerialisationError(_)) => { error!("Unable to compare lapin::Error::SerialisationError"); false diff --git a/src/io_loop.rs b/src/io_loop.rs index 33113a66..6de68fed 100644 --- a/src/io_loop.rs +++ b/src/io_loop.rs @@ -440,7 +440,7 @@ impl IoLoop { 0, 0, ); - self.critical_error(Error::ProtocolError(error))?; + self.critical_error(Error::ProtocolError(error, None))?; } self.receive_buffer.consume(consumed); Ok(Some(f))