diff --git a/Cargo.toml b/Cargo.toml index 3198cc4..ed003fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,9 @@ readme = "README.md" description = "An MQTT 3.1.1 client written in Rust, using async functions and tokio." repository = "https://github.com/fluffysquirrels/mqtt-async-client-rs" +[[example]] +name = "mqttc" + [dependencies] bytes = "0.4.0" futures-core = "0.3.1" diff --git a/src/client/client.rs b/src/client/client.rs index 10e8157..65c74cb 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -69,7 +69,7 @@ use tokio::{ error::Elapsed, Instant, timeout, - }, + }, task::JoinHandle, }; #[cfg(feature = "tls")] use tokio_rustls::{self, webpki::DNSNameRef, TlsConnector}; @@ -158,6 +158,9 @@ struct IoTaskHandle { /// Signal to the IO task to shutdown. Shared with IoTask. halt: Arc, + + /// Handle to the async task + join_handle: Option>, } /// The state held by the IO task, a long-running tokio future. The IO @@ -275,20 +278,21 @@ impl Client { let (tx_recv_published, rx_recv_published) = mpsc::channel::>(self.options.packet_buffer_len); let halt = Arc::new(AtomicBool::new(false)); - self.io_task_handle = Some(IoTaskHandle { - tx_io_requests, - rx_recv_published, - halt: halt.clone(), - }); let io = IoTask { options: self.options.clone(), rx_io_requests, tx_recv_published, state: IoTaskState::Disconnected, subscriptions: BTreeMap::new(), - halt, + halt: halt.clone(), }; - self.options.runtime.spawn(io.run()); + let join_handle = self.options.runtime.spawn(io.run()); + self.io_task_handle = Some(IoTaskHandle { + tx_io_requests, + rx_recv_published, + halt, + join_handle: Some(join_handle), + }); Ok(()) } @@ -479,11 +483,13 @@ impl Client { } } - async fn shutdown(&mut self) -> Result <()> { - let c = self.check_io_task()?; - c.halt.store(true, Ordering::SeqCst); + async fn shutdown(&mut self) -> Result<()> { self.write_request(IoType::ShutdownConnection, None).await?; - self.io_task_handle = None; + let mut c = self.take_io_task()?; + c.halt.store(true, Ordering::SeqCst); + if let Some(h) = c.join_handle.take() { + h.await.map_err(Error::from_std_err)?; + } Ok(()) } @@ -522,6 +528,12 @@ impl Client { None => Err("No IO task, did you call connect?".into()), } } + fn take_io_task(&mut self) -> Result { + match self.io_task_handle.take() { + Some(h) => Ok(h), + None => Err("No IO task, did you call connect?".into()), + } + } fn check_io_task(&self) -> Result<&IoTaskHandle> { match self.io_task_handle { @@ -687,15 +699,18 @@ impl IoTask { async fn run(mut self) { loop { if self.halt.load(Ordering::SeqCst) { + debug!("IoTask: draining by request."); + self.drain().await.unwrap(); self.shutdown_conn().await; - debug!("IoTask: halting by request."); self.state = IoTaskState::Halted; - return; } match self.state { - IoTaskState::Halted => return, - IoTaskState::Disconnected => + IoTaskState::Halted => { + debug!("IoTask: halting"); + return; + } + IoTaskState::Disconnected => { match Self::try_connect(&mut self).await { Err(e) => { error!("IoTask: Error connecting: {}", e); @@ -712,20 +727,23 @@ impl IoTask { error!("IoTask: Error replaying subscriptions on reconnect: {}", e); } - }, - }, - IoTaskState::Connected(_) => - match Self::run_once_connected(&mut self).await { - Err(Error::Disconnected) => { - info!("IoTask: Disconnected, resetting state"); - self.state = IoTaskState::Disconnected; - }, - Err(e) => { - error!("IoTask: Quitting run loop due to error: {}", e); - return; - }, - _ => {}, - }, + } + } + } + IoTaskState::Connected(_) => match Self::run_once_connected(&mut self).await { + Err(Error::Disconnected) => { + info!("IoTask: Disconnected, resetting state"); + self.state = IoTaskState::Disconnected; + } + Err(Error::ZeroRead) => { + // Nothing to do, + } + Err(e) => { + error!("IoTask: Quitting run loop due to error: {}", e); + return; + } + _ => {} + }, } } } @@ -907,6 +925,25 @@ impl IoTask { } } } + async fn drain(&mut self) -> Result<()> { + // Do not accept any more io requests + self.rx_io_requests.close(); + loop { + let req = self.rx_io_requests.recv().await; + match req { + None => { + // Sender closed. + debug!("IoTask: Req stream closed, shutting down."); + return Ok(()); + } + Some(req) => match self.handle_io_req(req).await { + Err(Error::Disconnected) => {} + Err(e) => return Err(e), + Ok(_) => {} + }, + } + } + } async fn handle_read(&mut self, read: Result) -> Result<()> { let c = match self.state { @@ -1104,9 +1141,7 @@ impl IoTask { let nread = stream.read(&mut read_buf[*read_bufn..readlen]).await?; *read_bufn += nread; if nread == 0 { - // Socket disconnected - error!("IoTask: Socket disconnected"); - return Err(Error::Disconnected); + return Err(Error::ZeroRead); } } } diff --git a/src/error.rs b/src/error.rs index 85655ba..857a68e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,8 @@ pub type Result = std::result::Result; pub enum Error { /// The client is disconnected. Disconnected, + /// The client read zero bytes from the stream. + ZeroRead, /// An error represented by an implementation of std::error::Error. StdError(Box), @@ -33,6 +35,7 @@ impl Display for Error { fn fmt(&self, f: &mut Formatter) -> std::result::Result<(), fmt::Error> { match self { Error::Disconnected => write!(f, "Disconnected"), + Error::ZeroRead => write!(f, "ZeroRead"), Error::StdError(e) => write!(f, "{}", e), Error::String(s) => write!(f, "{}", s), Error::_NonExhaustive => panic!("Not reachable"),