From 2d49fd55aa671b05c90ea94ad87699fcb8307668 Mon Sep 17 00:00:00 2001 From: Jaco van den Bergh Date: Fri, 1 Nov 2024 08:19:11 +0200 Subject: [PATCH] Wait for background workers to finish current jobs before quitting (#860) * wait for background workers --- Cargo.toml | 1 + src/bgworker/mod.rs | 31 +++++++++++++----- src/bgworker/skq.rs | 12 ++++--- src/boot.rs | 77 ++++++++++++++++++++++++++++++++------------- 4 files changed, 88 insertions(+), 33 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5f1acd055..ea8c63ca5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ sea-orm = { version = "1.1.0", features = [ ], optional = true } tokio = { version = "1.33.0", default-features = false } +tokio-util = "0.7.10" # the rest serde = { workspace = true } diff --git a/src/bgworker/mod.rs b/src/bgworker/mod.rs index 8a1e0a199..0e5754d42 100644 --- a/src/bgworker/mod.rs +++ b/src/bgworker/mod.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use async_trait::async_trait; use serde::Serialize; +use tokio_util::sync::CancellationToken; use tracing::{debug, error}; #[cfg(feature = "bg_pg")] pub mod pg; @@ -20,6 +21,7 @@ pub enum Queue { Redis( bb8::Pool, Arc>, + CancellationToken, ), #[cfg(feature = "bg_pg")] Postgres( @@ -45,7 +47,7 @@ impl Queue { debug!(worker = class, "job enqueue"); match self { #[cfg(feature = "bg_redis")] - Self::Redis(pool, _) => { + Self::Redis(pool, _, _) => { skq::enqueue(pool, class, queue, args).await?; } #[cfg(feature = "bg_pg")] @@ -80,7 +82,7 @@ impl Queue { debug!(worker = W::class_name(), "register worker"); match self { #[cfg(feature = "bg_redis")] - Self::Redis(_, p) => { + Self::Redis(_, p, _) => { let mut p = p.lock().await; p.register(skq::SidekiqBackgroundWorker::new(worker)); } @@ -103,7 +105,7 @@ impl Queue { debug!("running background jobs"); match self { #[cfg(feature = "bg_redis")] - Self::Redis(_, p) => { + Self::Redis(_, p, _) => { p.lock().await.clone().run().await; } #[cfg(feature = "bg_pg")] @@ -133,7 +135,7 @@ impl Queue { debug!("workers setup"); match self { #[cfg(feature = "bg_redis")] - Self::Redis(_, _) => {} + Self::Redis(_, _, _) => {} #[cfg(feature = "bg_pg")] Self::Postgres(pool, _, _) => { pg::initialize_database(pool).await.map_err(Box::from)?; @@ -152,7 +154,7 @@ impl Queue { debug!("clearing job queues"); match self { #[cfg(feature = "bg_redis")] - Self::Redis(pool, _) => { + Self::Redis(pool, _, _) => { skq::clear(pool).await?; } #[cfg(feature = "bg_pg")] @@ -173,7 +175,7 @@ impl Queue { debug!("job queue ping requested"); match self { #[cfg(feature = "bg_redis")] - Self::Redis(pool, _) => { + Self::Redis(pool, _, _) => { skq::ping(pool).await?; } #[cfg(feature = "bg_pg")] @@ -189,12 +191,27 @@ impl Queue { pub fn describe(&self) -> String { match self { #[cfg(feature = "bg_redis")] - Self::Redis(_, _) => "redis queue".to_string(), + Self::Redis(_, _, _) => "redis queue".to_string(), #[cfg(feature = "bg_pg")] Self::Postgres(_, _, _) => "postgres queue".to_string(), _ => "no queue".to_string(), } } + + /// # Errors + /// + /// Does not currently return an error, but the postgres or other future queue implementations + /// might, so using Result here as return type. + pub fn shutdown(&self) -> Result<()> { + println!("waiting for running jobs to finish..."); + match self { + #[cfg(feature = "bg_redis")] + Self::Redis(_, _, cancellation_token) => cancellation_token.cancel(), + _ => {} + } + + Ok(()) + } } #[async_trait] diff --git a/src/bgworker/skq.rs b/src/bgworker/skq.rs index 2b8bab9ab..9ec5b77cd 100644 --- a/src/bgworker/skq.rs +++ b/src/bgworker/skq.rs @@ -118,12 +118,14 @@ pub async fn create_provider(qcfg: &RedisQueueConfig) -> Result { let manager = RedisConnectionManager::new(qcfg.uri.clone())?; let redis = Pool::builder().build(manager).await?; let queues = get_queues(&qcfg.queues); + let processor = Processor::new(redis.clone(), queues) + .with_config(ProcessorConfig::default().num_workers(qcfg.num_workers as usize)); + let cancellation_token = processor.get_cancellation_token(); + Ok(Queue::Redis( - redis.clone(), - Arc::new(tokio::sync::Mutex::new( - Processor::new(redis, queues) - .with_config(ProcessorConfig::default().num_workers(qcfg.num_workers as usize)), - )), + redis, + Arc::new(tokio::sync::Mutex::new(processor)), + cancellation_token, )) } diff --git a/src/boot.rs b/src/boot.rs index f93cdba96..e663d44b4 100644 --- a/src/boot.rs +++ b/src/boot.rs @@ -6,7 +6,8 @@ use std::path::PathBuf; use axum::Router; #[cfg(feature = "with-db")] use sea_orm_migration::MigratorTrait; -use tokio::signal; +use tokio::task::JoinHandle; +use tokio::{select, signal}; use tracing::{debug, error, info, warn}; #[cfg(feature = "with-db")] @@ -87,31 +88,29 @@ pub async fn start( H::serve(router, &app_context).await?; } (Some(router), true) => { - debug!("note: worker is run in-process (tokio spawn)"); - if app_context.config.workers.mode == WorkerMode::BackgroundQueue { - if let Some(queue) = &app_context.queue_provider { - let cloned_queue = queue.clone(); - tokio::spawn(async move { - let res = cloned_queue.run().await; - if res.is_err() { - error!( - err = res.unwrap_err().to_string(), - "error while running worker" - ); - } - }); - } else { - return Err(Error::QueueProviderMissing); - } - } + let handle = if app_context.config.workers.mode == WorkerMode::BackgroundQueue { + Some(start_queue_worker(&app_context)?) + } else { + None + }; H::serve(router, &app_context).await?; + + if let Some(handle) = handle { + shutdown_and_await_queue_worker(&app_context, handle).await?; + } } (None, true) => { - if let Some(queue) = &app_context.queue_provider { - queue.run().await?; + let handle = if app_context.config.workers.mode == WorkerMode::BackgroundQueue { + Some(start_queue_worker(&app_context)?) } else { - return Err(Error::QueueProviderMissing); + None + }; + + shutdown_signal().await; + + if let Some(handle) = handle { + shutdown_and_await_queue_worker(&app_context, handle).await?; } } _ => {} @@ -119,6 +118,42 @@ pub async fn start( Ok(()) } +fn start_queue_worker(app_context: &AppContext) -> Result> { + debug!("note: worker is run in-process (tokio spawn)"); + + if let Some(queue) = &app_context.queue_provider { + let cloned_queue = queue.clone(); + let handle = tokio::spawn(async move { + let res = cloned_queue.run().await; + if res.is_err() { + error!( + err = res.unwrap_err().to_string(), + "error while running worker" + ); + } + }); + return Ok(handle); + } + + Err(Error::QueueProviderMissing) +} + +async fn shutdown_and_await_queue_worker( + app_context: &AppContext, + handle: JoinHandle<()>, +) -> Result<(), Error> { + if let Some(queue) = &app_context.queue_provider { + queue.shutdown()?; + } + + println!("press ctrl-c again to force quit"); + select! { + _ = handle => {} + () = shutdown_signal() => {} + } + Ok(()) +} + /// Run task /// /// # Errors