Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for draining publish events before shutting down. #32

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ readme = "README.md"
description = "An MQTT 3.1.1 client written in Rust, using async functions and tokio."
repository = "https://github.com/fluffysquirrels/mqtt-async-client-rs"

[[example]]
name = "mqttc"

[dependencies]
bytes = "0.4.0"
futures-core = "0.3.1"
Expand Down
101 changes: 68 additions & 33 deletions src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ use tokio::{
error::Elapsed,
Instant,
timeout,
},
}, task::JoinHandle,
};
#[cfg(feature = "tls")]
use tokio_rustls::{self, webpki::DNSNameRef, TlsConnector};
Expand Down Expand Up @@ -158,6 +158,9 @@ struct IoTaskHandle {

/// Signal to the IO task to shutdown. Shared with IoTask.
halt: Arc<AtomicBool>,

/// Handle to the async task
join_handle: Option<JoinHandle<()>>,
}

/// The state held by the IO task, a long-running tokio future. The IO
Expand Down Expand Up @@ -275,20 +278,21 @@ impl Client {
let (tx_recv_published, rx_recv_published) =
mpsc::channel::<Result<Packet>>(self.options.packet_buffer_len);
let halt = Arc::new(AtomicBool::new(false));
self.io_task_handle = Some(IoTaskHandle {
tx_io_requests,
rx_recv_published,
halt: halt.clone(),
});
let io = IoTask {
options: self.options.clone(),
rx_io_requests,
tx_recv_published,
state: IoTaskState::Disconnected,
subscriptions: BTreeMap::new(),
halt,
halt: halt.clone(),
};
self.options.runtime.spawn(io.run());
let join_handle = self.options.runtime.spawn(io.run());
self.io_task_handle = Some(IoTaskHandle {
tx_io_requests,
rx_recv_published,
halt,
join_handle: Some(join_handle),
});
Ok(())
}

Expand Down Expand Up @@ -479,11 +483,13 @@ impl Client {
}
}

async fn shutdown(&mut self) -> Result <()> {
let c = self.check_io_task()?;
c.halt.store(true, Ordering::SeqCst);
async fn shutdown(&mut self) -> Result<()> {
self.write_request(IoType::ShutdownConnection, None).await?;
self.io_task_handle = None;
let mut c = self.take_io_task()?;
c.halt.store(true, Ordering::SeqCst);
if let Some(h) = c.join_handle.take() {
h.await.map_err(Error::from_std_err)?;
}
Ok(())
}

Expand Down Expand Up @@ -522,6 +528,12 @@ impl Client {
None => Err("No IO task, did you call connect?".into()),
}
}
fn take_io_task(&mut self) -> Result<IoTaskHandle> {
match self.io_task_handle.take() {
Some(h) => Ok(h),
None => Err("No IO task, did you call connect?".into()),
}
}

fn check_io_task(&self) -> Result<&IoTaskHandle> {
match self.io_task_handle {
Expand Down Expand Up @@ -687,15 +699,18 @@ impl IoTask {
async fn run(mut self) {
loop {
if self.halt.load(Ordering::SeqCst) {
debug!("IoTask: draining by request.");
self.drain().await.unwrap();
self.shutdown_conn().await;
debug!("IoTask: halting by request.");
self.state = IoTaskState::Halted;
return;
}

match self.state {
IoTaskState::Halted => return,
IoTaskState::Disconnected =>
IoTaskState::Halted => {
debug!("IoTask: halting");
return;
}
IoTaskState::Disconnected => {
match Self::try_connect(&mut self).await {
Err(e) => {
error!("IoTask: Error connecting: {}", e);
Expand All @@ -712,20 +727,23 @@ impl IoTask {
error!("IoTask: Error replaying subscriptions on reconnect: {}",
e);
}
},
},
IoTaskState::Connected(_) =>
match Self::run_once_connected(&mut self).await {
Err(Error::Disconnected) => {
info!("IoTask: Disconnected, resetting state");
self.state = IoTaskState::Disconnected;
},
Err(e) => {
error!("IoTask: Quitting run loop due to error: {}", e);
return;
},
_ => {},
},
}
}
}
IoTaskState::Connected(_) => match Self::run_once_connected(&mut self).await {
Err(Error::Disconnected) => {
info!("IoTask: Disconnected, resetting state");
self.state = IoTaskState::Disconnected;
}
Err(Error::ZeroRead) => {
// Nothing to do,
}
Err(e) => {
error!("IoTask: Quitting run loop due to error: {}", e);
return;
}
_ => {}
},
}
}
}
Expand Down Expand Up @@ -907,6 +925,25 @@ impl IoTask {
}
}
}
async fn drain(&mut self) -> Result<()> {
// Do not accept any more io requests
self.rx_io_requests.close();
loop {
let req = self.rx_io_requests.recv().await;
match req {
None => {
// Sender closed.
debug!("IoTask: Req stream closed, shutting down.");
return Ok(());
}
Some(req) => match self.handle_io_req(req).await {
Err(Error::Disconnected) => {}
Err(e) => return Err(e),
Ok(_) => {}
},
}
}
}

async fn handle_read(&mut self, read: Result<Packet>) -> Result<()> {
let c = match self.state {
Expand Down Expand Up @@ -1104,9 +1141,7 @@ impl IoTask {
let nread = stream.read(&mut read_buf[*read_bufn..readlen]).await?;
*read_bufn += nread;
if nread == 0 {
// Socket disconnected
error!("IoTask: Socket disconnected");
return Err(Error::Disconnected);
return Err(Error::ZeroRead);
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub type Result<T> = std::result::Result<T, Error>;
pub enum Error {
/// The client is disconnected.
Disconnected,
/// The client read zero bytes from the stream.
ZeroRead,

/// An error represented by an implementation of std::error::Error.
StdError(Box<dyn std::error::Error + Send + Sync>),
Expand All @@ -33,6 +35,7 @@ impl Display for Error {
fn fmt(&self, f: &mut Formatter) -> std::result::Result<(), fmt::Error> {
match self {
Error::Disconnected => write!(f, "Disconnected"),
Error::ZeroRead => write!(f, "ZeroRead"),
Error::StdError(e) => write!(f, "{}", e),
Error::String(s) => write!(f, "{}", s),
Error::_NonExhaustive => panic!("Not reachable"),
Expand Down