diff --git a/src/worker/mod.rs b/src/worker/mod.rs index 3057b3cf..1e76196e 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -10,7 +10,6 @@ use std::time::Duration; use std::{error::Error as StdError, sync::atomic::AtomicUsize}; use tokio::io::{AsyncBufRead, AsyncWrite}; use tokio::net::TcpStream; -use tokio::sync::Mutex; use tokio::task::{spawn_blocking, AbortHandle, JoinError, JoinSet}; use tokio::time::sleep as tokio_sleep; @@ -165,7 +164,10 @@ pub struct Worker { terminated: bool, forever: bool, shutdown_timeout: Option, - shutdown_signal: Option>>, + + // NOTE: this is always `Some` if `self.terminated == false` whenever any `pub` function + // on this type returns. it is `Some(std::future::pending())` if no shutdown signaler is set. + shutdown_signal: Option, } impl Worker { @@ -198,7 +200,9 @@ impl Worker { terminated: false, forever: false, shutdown_timeout, - shutdown_signal: shutdown_signal.map(|signal| Arc::new(Mutex::new(signal))), + shutdown_signal: Some( + shutdown_signal.unwrap_or_else(|| Box::pin(std::future::pending())), + ), } } } @@ -360,7 +364,7 @@ impl< terminated: self.terminated, forever: self.forever, shutdown_timeout: self.shutdown_timeout, - shutdown_signal: None, + shutdown_signal: Some(Box::pin(std::future::pending())), }) } @@ -429,15 +433,20 @@ impl< .await?; } - let maybe_shutdown_signal = self.shutdown_signal.clone(); + // the only place `shutdown_signal` is set to `None` is when we `.take()` it here. + // later on, we maintain the invariant that either `self.terminated` is set to `true` OR we + // restore `shutdown_signal` to `Some` (such as if the heartbeat future fails). in the + // former case, we'll never hit this `.take()` again due to the `assert` above, and in the + // latter case the `take()` will yet again succeed. + let mut shutdown_signal = self + .shutdown_signal + .take() + .expect("see shutdown_signal comment"); let maybe_shutdown_timeout = self.shutdown_timeout; let report = tokio::select! { // A signal from the user space received. - _ = async { - let signal = maybe_shutdown_signal.unwrap(); - signal.lock().await.as_mut().await; - }, if maybe_shutdown_signal.is_some() => { + _ = &mut shutdown_signal => { let nrunning = tokio::select! { _ = async { tokio_sleep(maybe_shutdown_timeout.unwrap()).await; }, if maybe_shutdown_timeout.is_some() => { 0 @@ -462,6 +471,8 @@ impl< // note that if it is an error from heartbeat(), the worker will _not_ be marked as // terminated and _can_ be restarted self.terminated = exit.is_ok(); + // restore shutdown signal since it has not resolved + self.shutdown_signal = Some(shutdown_signal); if let Ok(true) = exit { let nrunning = self.force_fail_all_workers("terminated").await;