From 9f234cda4688600954a3d9cf6409a32c9e407bbd Mon Sep 17 00:00:00 2001 From: Slavik Date: Mon, 22 Apr 2024 17:45:28 +0200 Subject: [PATCH] fix(heartbeat): stop heartbeat if connection is not active --- src/channel.rs | 9 +++++++++ src/channels.rs | 23 ++++++++++++++++++++++- src/heartbeat.rs | 8 +++++++- src/socket_state.rs | 6 ++++++ 4 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/channel.rs b/src/channel.rs index a92b05c9..a900dc29 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -275,6 +275,11 @@ impl Channel { self.waker.wake() } + fn readable(&self) { + trace!(channel=%self.id, "readable"); + self.waker.readable() + } + fn assert_channel0(&self, class_id: Identifier, method_id: Identifier) -> Result<()> { if self.id == 0 { Ok(()) @@ -374,6 +379,7 @@ impl Channel { trace!(channel=%self.id, "send_frame"); self.frames.push(self.id, frame, resolver, expected_reply); self.wake(); + self.readable(); } async fn send_method_frame_with_body( @@ -405,6 +411,7 @@ impl Channel { trace!(channel=%self.id, "send_frames"); let promise = self.frames.push_frames(frames); self.wake(); + self.readable(); promise.await?; Ok(publisher_confirms_result .unwrap_or_else(|| PublisherConfirm::not_requested(self.returned_messages.clone()))) @@ -842,6 +849,8 @@ impl Channel { ) -> Result<()> { self.connection_status.unblock(); self.wake(); + self.readable(); + Ok(()) } diff --git a/src/channels.rs b/src/channels.rs index a7194c7b..7150d1dc 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -154,10 +154,29 @@ impl Channels { .all(|c| c.status().flow()) } - pub(crate) fn send_heartbeat(&self) { + pub(crate) fn send_heartbeat(&self) -> Result<()> { debug!("send heartbeat"); if let Some(channel0) = self.get(0) { + debug!("connection status: {:?}", self.connection_status.state()); + + if !self.connection_status.connected() { + let error = AMQPError::new( + AMQPHardError::FRAMEERROR.into(), + "heartbeat frame was not received on channel 0".into(), + ); + + self.internal_rpc.register_internal_future(async move { + channel0 + .connection_close(error.get_id(), error.get_message().as_str(), 0, 0) + .await + }); + + return Err(Error::InvalidConnectionState( + self.connection_status.state(), + )); + } + let (promise, resolver) = Promise::new(); if level_enabled!(Level::TRACE) { @@ -167,6 +186,8 @@ impl Channels { channel0.send_frame(AMQPFrame::Heartbeat(0), resolver, None); self.internal_rpc.register_internal_future(promise); } + + Ok(()) } pub(crate) fn handle_frame(&self, f: AMQPFrame) -> Result<()> { diff --git a/src/heartbeat.rs b/src/heartbeat.rs index f6086369..e9e99db3 100644 --- a/src/heartbeat.rs +++ b/src/heartbeat.rs @@ -7,6 +7,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; +use tracing::error; #[derive(Clone)] pub struct Heartbeat { @@ -86,7 +87,12 @@ impl Inner { .unwrap_or_else(|| { // Update last_write so that if we cannot write to the socket yet, we don't enqueue countless heartbeats self.update_last_write(); - channels.send_heartbeat(); + + if let Err(err) = channels.send_heartbeat() { + self.timeout = None; + error!("Failed to send heartbeat: {}", err); + } + timeout }) }) diff --git a/src/socket_state.rs b/src/socket_state.rs index 848d0065..198ea66c 100644 --- a/src/socket_state.rs +++ b/src/socket_state.rs @@ -105,7 +105,13 @@ impl SocketStateHandle { let _ = self.sender.send(event); } + /// Wake the socket up pub fn wake(&self) { self.send(SocketEvent::Wake); } + + /// Notify that the socket is readable + pub fn readable(&self) { + self.send(SocketEvent::Readable); + } }