From 6add451b5b052bd8c03fd4d85899b862c8b55cfa Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 31 Oct 2024 17:38:43 +0100 Subject: [PATCH] Actually consume jobs --- crates/storage-pg/src/queue/job.rs | 130 ++++++++- crates/storage/src/queue/job.rs | 48 +++- crates/tasks/src/email.rs | 162 ++++++----- crates/tasks/src/lib.rs | 20 +- crates/tasks/src/matrix.rs | 390 +++++++++++++------------- crates/tasks/src/new_queue.rs | 207 +++++++++++++- crates/tasks/src/recovery.rs | 198 +++++++------- crates/tasks/src/storage/from_row.rs | 70 ----- crates/tasks/src/storage/mod.rs | 14 - crates/tasks/src/storage/postgres.rs | 391 --------------------------- crates/tasks/src/user.rs | 211 +++++++-------- crates/tasks/src/utils.rs | 91 ------- 12 files changed, 848 insertions(+), 1084 deletions(-) delete mode 100644 crates/tasks/src/storage/from_row.rs delete mode 100644 crates/tasks/src/storage/mod.rs delete mode 100644 crates/tasks/src/storage/postgres.rs delete mode 100644 crates/tasks/src/utils.rs diff --git a/crates/storage-pg/src/queue/job.rs b/crates/storage-pg/src/queue/job.rs index 4b4433b50..90f8546a7 100644 --- a/crates/storage-pg/src/queue/job.rs +++ b/crates/storage-pg/src/queue/job.rs @@ -7,13 +7,16 @@ //! [`QueueJobRepository`]. use async_trait::async_trait; -use mas_storage::{queue::QueueJobRepository, Clock}; +use mas_storage::{ + queue::{Job, QueueJobRepository, Worker}, + Clock, +}; use rand::RngCore; use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{DatabaseError, ExecuteExt}; +use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt}; /// An implementation of [`QueueJobRepository`] for a PostgreSQL connection. pub struct PgQueueJobRepository<'c> { @@ -29,6 +32,37 @@ impl<'c> PgQueueJobRepository<'c> { } } +struct JobReservationResult { + queue_job_id: Uuid, + queue_name: String, + payload: serde_json::Value, + metadata: serde_json::Value, +} + +impl TryFrom for Job { + type Error = DatabaseInconsistencyError; + + fn try_from(value: JobReservationResult) -> Result { + let id = value.queue_job_id.into(); + let queue_name = value.queue_name; + let payload = value.payload; + + let metadata = serde_json::from_value(value.metadata).map_err(|e| { + DatabaseInconsistencyError::on("queue_jobs") + .column("metadata") + .row(id) + .source(e) + })?; + + Ok(Self { + id, + queue_name, + payload, + metadata, + }) + } +} + #[async_trait] impl<'c> QueueJobRepository for PgQueueJobRepository<'c> { type Error = DatabaseError; @@ -73,4 +107,96 @@ impl<'c> QueueJobRepository for PgQueueJobRepository<'c> { Ok(()) } + + #[tracing::instrument( + name = "db.queue_job.reserve", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn reserve( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + count: usize, + ) -> Result, Self::Error> { + let now = clock.now(); + let max_count = i64::try_from(count).unwrap_or(i64::MAX); + let queues: Vec = queues.iter().map(|&s| s.to_owned()).collect(); + let results = sqlx::query_as!( + JobReservationResult, + r#" + -- We first grab a few jobs that are available, + -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently + -- and we don't get multiple workers grabbing the same jobs + WITH locked_jobs AS ( + SELECT queue_job_id + FROM queue_jobs + WHERE + status = 'available' + AND queue_name = ANY($1) + ORDER BY queue_job_id ASC + LIMIT $2 + FOR UPDATE + SKIP LOCKED + ) + -- then we update the status of those jobs to 'running', returning the job details + UPDATE queue_jobs + SET status = 'running', started_at = $3, started_by = $4 + FROM locked_jobs + WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id + RETURNING + queue_jobs.queue_job_id, + queue_jobs.queue_name, + queue_jobs.payload, + queue_jobs.metadata + "#, + &queues, + max_count, + now, + Uuid::from(worker.id), + ) + .traced() + .fetch_all(&mut *self.conn) + .await?; + + let jobs = results + .into_iter() + .map(TryFrom::try_from) + .collect::, _>>()?; + + Ok(jobs) + } + + #[tracing::instrument( + name = "db.queue_job.mark_as_completed", + skip_all, + fields( + db.query.text, + job.id = %id, + ), + err, + )] + async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error> { + let now = clock.now(); + let res = sqlx::query!( + r#" + UPDATE queue_jobs + SET status = 'completed', completed_at = $1 + WHERE queue_job_id = $2 AND status = 'running' + "#, + now, + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + Ok(()) + } } diff --git a/crates/storage/src/queue/job.rs b/crates/storage/src/queue/job.rs index af538f628..77d4c2025 100644 --- a/crates/storage/src/queue/job.rs +++ b/crates/storage/src/queue/job.rs @@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize}; use tracing_opentelemetry::OpenTelemetrySpanExt; use ulid::Ulid; +use super::Worker; use crate::{repository_impl, Clock}; /// Represents a job in the job queue @@ -19,6 +20,9 @@ pub struct Job { /// The ID of the job pub id: Ulid, + /// The queue on which the job was placed + pub queue_name: String, + /// The payload of the job pub payload: serde_json::Value, @@ -27,7 +31,7 @@ pub struct Job { } /// Metadata stored alongside the job -#[derive(Serialize, Deserialize, Default)] +#[derive(Serialize, Deserialize, Default, Clone, Debug)] pub struct JobMetadata { #[serde(default)] trace_id: String, @@ -89,6 +93,38 @@ pub trait QueueJobRepository: Send + Sync { payload: serde_json::Value, metadata: serde_json::Value, ) -> Result<(), Self::Error>; + + /// Reserve multiple jobs from multiple queues + /// + /// # Parameters + /// + /// * `clock` - The clock used to generate timestamps + /// * `worker` - The worker that is reserving the jobs + /// * `queues` - The queues to reserve jobs from + /// * `count` - The number of jobs to reserve + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn reserve( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + count: usize, + ) -> Result, Self::Error>; + + /// Mark a job as completed + /// + /// # Parameters + /// + /// * `clock` - The clock used to generate timestamps + /// * `job` - The job to mark as completed + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails. + async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>; } repository_impl!(QueueJobRepository: @@ -100,6 +136,16 @@ repository_impl!(QueueJobRepository: payload: serde_json::Value, metadata: serde_json::Value, ) -> Result<(), Self::Error>; + + async fn reserve( + &mut self, + clock: &dyn Clock, + worker: &Worker, + queues: &[&str], + count: usize, + ) -> Result, Self::Error>; + + async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>; ); /// Extension trait for [`QueueJobRepository`] to help adding a job to the queue diff --git a/crates/tasks/src/email.rs b/crates/tasks/src/email.rs index a16ca29dc..3afbab8ce 100644 --- a/crates/tasks/src/email.rs +++ b/crates/tasks/src/email.rs @@ -5,97 +5,87 @@ // Please see LICENSE in the repository root for full details. use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use chrono::Duration; use mas_email::{Address, Mailbox}; use mas_i18n::locale; -use mas_storage::{job::JobWithSpanContext, queue::VerifyEmailJob}; +use mas_storage::queue::VerifyEmailJob; use mas_templates::{EmailVerificationContext, TemplateContext}; use rand::{distributions::Uniform, Rng}; use tracing::info; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; - -#[tracing::instrument( - name = "job.verify_email", - fields(user_email.id = %job.user_email_id()), - skip_all, - err(Debug), -)] -async fn verify_email( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let mut repo = state.repository().await?; - let mut rng = state.rng(); - let mailer = state.mailer(); - let clock = state.clock(); - - let language = job - .language() - .and_then(|l| l.parse().ok()) - .unwrap_or(locale!("en").into()); - - // Lookup the user email - let user_email = repo - .user_email() - .lookup(job.user_email_id()) - .await? - .context("User email not found")?; - - // Lookup the user associated with the email - let user = repo - .user() - .lookup(user_email.user_id) - .await? - .context("User not found")?; - - // Generate a verification code - let range = Uniform::::from(0..1_000_000); - let code = rng.sample(range); - let code = format!("{code:06}"); - - let address: Address = user_email.email.parse()?; - - // Save the verification code in the database - let verification = repo - .user_email() - .add_verification_code( - &mut rng, - &clock, - &user_email, - Duration::try_hours(8).unwrap(), - code, - ) - .await?; - - // And send the verification email - let mailbox = Mailbox::new(Some(user.username.clone()), address); - - let context = - EmailVerificationContext::new(user.clone(), verification.clone()).with_language(language); - - mailer.send_verification_email(mailbox, &context).await?; - - info!( - email.id = %user_email.id, - "Verification email sent" - ); - - repo.save().await?; - - Ok(()) -} - -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let verify_email_worker = - crate::build!(VerifyEmailJob => verify_email, suffix, state, storage_factory); - - monitor.register(verify_email_worker) +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; + +#[async_trait] +impl RunnableJob for VerifyEmailJob { + #[tracing::instrument( + name = "job.verify_email", + fields(user_email.id = %self.user_email_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let mailer = state.mailer(); + let clock = state.clock(); + + let language = self + .language() + .and_then(|l| l.parse().ok()) + .unwrap_or(locale!("en").into()); + + // Lookup the user email + let user_email = repo + .user_email() + .lookup(self.user_email_id()) + .await? + .context("User email not found")?; + + // Lookup the user associated with the email + let user = repo + .user() + .lookup(user_email.user_id) + .await? + .context("User not found")?; + + // Generate a verification code + let range = Uniform::::from(0..1_000_000); + let code = rng.sample(range); + let code = format!("{code:06}"); + + let address: Address = user_email.email.parse()?; + + // Save the verification code in the database + let verification = repo + .user_email() + .add_verification_code( + &mut rng, + &clock, + &user_email, + Duration::try_hours(8).unwrap(), + code, + ) + .await?; + + // And send the verification email + let mailbox = Mailbox::new(Some(user.username.clone()), address); + + let context = EmailVerificationContext::new(user.clone(), verification.clone()) + .with_language(language); + + mailer.send_verification_email(mailbox, &context).await?; + + info!( + email.id = %user_email.id, + "Verification email sent" + ); + + repo.save().await?; + + Ok(()) + } } diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index e56a082c7..ad2ede868 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -18,14 +18,13 @@ use rand::SeedableRng; use sqlx::{Pool, Postgres}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; +// TODO: we need to have a way to schedule recurring tasks // mod database; -// mod email; -// mod matrix; +mod email; +mod matrix; mod new_queue; -// mod recovery; -// mod storage; -// mod user; -// mod utils; +mod recovery; +mod user; #[derive(Clone)] struct State { @@ -111,6 +110,15 @@ pub async fn init( ); let mut worker = self::new_queue::QueueWorker::new(state, cancellation_token).await?; + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + worker.register_handler::(); + task_tracker.spawn(async move { if let Err(e) = worker.run().await { tracing::error!( diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index 3cc09b272..f4596c05f 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -7,239 +7,239 @@ use std::collections::HashSet; use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use mas_data_model::Device; use mas_matrix::ProvisionRequest; use mas_storage::{ compat::CompatSessionFilter, - job::{JobRepositoryExt as _, JobWithSpanContext}, oauth2::OAuth2SessionFilter, - queue::{DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, SyncDevicesJob}, + queue::{ + DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, QueueJobRepositoryExt as _, + SyncDevicesJob, + }, user::{UserEmailRepository, UserRepository}, Pagination, RepositoryAccess, }; use tracing::info; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; /// Job to provision a user on the Matrix homeserver. -/// This works by doing a PUT request to the /_synapse/admin/v2/users/{user_id} -/// endpoint. -#[tracing::instrument( - name = "job.provision_user" - fields(user.id = %job.user_id()), - skip_all, - err(Debug), -)] -async fn provision_user( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - let mxid = matrix.mxid(&user.username); - let emails = repo - .user_email() - .all(&user) - .await? - .into_iter() - .filter(|email| email.confirmed_at.is_some()) - .map(|email| email.email) - .collect(); - let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); - - if let Some(display_name) = job.display_name_to_set() { - request = request.set_displayname(display_name.to_owned()); - } +/// This works by doing a PUT request to the +/// /_synapse/admin/v2/users/{user_id} endpoint. +#[async_trait] +impl RunnableJob for ProvisionUserJob { + #[tracing::instrument( + name = "job.provision_user" + fields(user.id = %self.user_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let clock = state.clock(); + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + let mxid = matrix.mxid(&user.username); + let emails = repo + .user_email() + .all(&user) + .await? + .into_iter() + .filter(|email| email.confirmed_at.is_some()) + .map(|email| email.email) + .collect(); + let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); + + if let Some(display_name) = self.display_name_to_set() { + request = request.set_displayname(display_name.to_owned()); + } - let created = matrix.provision_user(&request).await?; + let created = matrix.provision_user(&request).await?; - if created { - info!(%user.id, %mxid, "User created"); - } else { - info!(%user.id, %mxid, "User updated"); - } + if created { + info!(%user.id, %mxid, "User created"); + } else { + info!(%user.id, %mxid, "User updated"); + } - // Schedule a device sync job - let sync_device_job = SyncDevicesJob::new(&user); - repo.job().schedule_job(sync_device_job).await?; + // Schedule a device sync job + let sync_device_job = SyncDevicesJob::new(&user); + repo.queue_job() + .schedule_job(&mut rng, &clock, sync_device_job) + .await?; - repo.save().await?; + repo.save().await?; - Ok(()) + Ok(()) + } } /// Job to provision a device on the Matrix homeserver. /// /// This job is deprecated and therefore just schedules a [`SyncDevicesJob`] -#[tracing::instrument( - name = "job.provision_device" - fields( - user.id = %job.user_id(), - device.id = %job.device_id(), - ), - skip_all, - err(Debug), -)] -async fn provision_device( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Schedule a device sync job - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; - - Ok(()) +#[async_trait] +impl RunnableJob for ProvisionDeviceJob { + #[tracing::instrument( + name = "job.provision_device" + fields( + user.id = %self.user_id(), + device.id = %self.device_id(), + ), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let mut repo = state.repository().await?; + let mut rng = state.rng(); + let clock = state.clock(); + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Schedule a device sync job + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; + + Ok(()) + } } /// Job to delete a device from a user's account. /// /// This job is deprecated and therefore just schedules a [`SyncDevicesJob`] -#[tracing::instrument( - name = "job.delete_device" - fields( - user.id = %job.user_id(), - device.id = %job.device_id(), - ), - skip_all, - err(Debug), -)] -async fn delete_device( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Schedule a device sync job - repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; - - Ok(()) +#[async_trait] +impl RunnableJob for DeleteDeviceJob { + #[tracing::instrument( + name = "job.delete_device" + fields( + user.id = %self.user_id(), + device.id = %self.device_id(), + ), + skip_all, + err(Debug), + )] + #[tracing::instrument( + name = "job.delete_device" + fields( + user.id = %self.user_id(), + device.id = %self.device_id(), + ), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let mut rng = state.rng(); + let clock = state.clock(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Schedule a device sync job + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; + + Ok(()) + } } /// Job to sync the list of devices of a user with the homeserver. -#[tracing::instrument( - name = "job.sync_devices", - fields(user.id = %job.user_id()), - skip_all, - err(Debug), -)] -async fn sync_devices( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Lock the user sync to make sure we don't get into a race condition - repo.user().acquire_lock_for_sync(&user).await?; - - let mut devices = HashSet::new(); - - // Cycle through all the compat sessions of the user, and grab the devices - let mut cursor = Pagination::first(100); - loop { - let page = repo - .compat_session() - .list( - CompatSessionFilter::new().for_user(&user).active_only(), - cursor, - ) - .await?; - - for (compat_session, _) in page.edges { - devices.insert(compat_session.device.as_str().to_owned()); - cursor = cursor.after(compat_session.id); - } +#[async_trait] +impl RunnableJob for SyncDevicesJob { + #[tracing::instrument( + name = "job.sync_devices", + fields(user.id = %self.user_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&user).await?; + + let mut devices = HashSet::new(); + + // Cycle through all the compat sessions of the user, and grab the devices + let mut cursor = Pagination::first(100); + loop { + let page = repo + .compat_session() + .list( + CompatSessionFilter::new().for_user(&user).active_only(), + cursor, + ) + .await?; + + for (compat_session, _) in page.edges { + devices.insert(compat_session.device.as_str().to_owned()); + cursor = cursor.after(compat_session.id); + } - if !page.has_next_page { - break; + if !page.has_next_page { + break; + } } - } - // Cycle though all the oauth2 sessions of the user, and grab the devices - let mut cursor = Pagination::first(100); - loop { - let page = repo - .oauth2_session() - .list( - OAuth2SessionFilter::new().for_user(&user).active_only(), - cursor, - ) - .await?; - - for oauth2_session in page.edges { - for scope in &*oauth2_session.scope { - if let Some(device) = Device::from_scope_token(scope) { - devices.insert(device.as_str().to_owned()); + // Cycle though all the oauth2 sessions of the user, and grab the devices + let mut cursor = Pagination::first(100); + loop { + let page = repo + .oauth2_session() + .list( + OAuth2SessionFilter::new().for_user(&user).active_only(), + cursor, + ) + .await?; + + for oauth2_session in page.edges { + for scope in &*oauth2_session.scope { + if let Some(device) = Device::from_scope_token(scope) { + devices.insert(device.as_str().to_owned()); + } } - } - cursor = cursor.after(oauth2_session.id); - } + cursor = cursor.after(oauth2_session.id); + } - if !page.has_next_page { - break; + if !page.has_next_page { + break; + } } - } - let mxid = matrix.mxid(&user.username); - matrix.sync_devices(&mxid, devices).await?; + let mxid = matrix.mxid(&user.username); + matrix.sync_devices(&mxid, devices).await?; - // We kept the connection until now, so that we still hold the lock on the user - // throughout the sync - repo.save().await?; + // We kept the connection until now, so that we still hold the lock on the user + // throughout the sync + repo.save().await?; - Ok(()) -} - -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let provision_user_worker = - crate::build!(ProvisionUserJob => provision_user, suffix, state, storage_factory); - let provision_device_worker = - crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory); - let delete_device_worker = - crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory); - let sync_devices_worker = - crate::build!(SyncDevicesJob => sync_devices, suffix, state, storage_factory); - - monitor - .register(provision_user_worker) - .register(provision_device_worker) - .register(delete_device_worker) - .register(sync_devices_worker) + Ok(()) + } } diff --git a/crates/tasks/src/new_queue.rs b/crates/tasks/src/new_queue.rs index f90b72011..6a0b94d8f 100644 --- a/crates/tasks/src/new_queue.rs +++ b/crates/tasks/src/new_queue.rs @@ -3,12 +3,12 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use chrono::{DateTime, Duration, Utc}; use mas_storage::{ - queue::{InsertableJob, Job, Worker}, + queue::{InsertableJob, Job, JobMetadata, Worker}, Clock, RepositoryAccess, RepositoryError, }; use mas_storage_pg::{DatabaseError, PgRepository}; @@ -20,12 +20,41 @@ use sqlx::{ Acquire, Either, }; use thiserror::Error; +use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; +use tracing::{Instrument as _, Span}; +use tracing_opentelemetry::OpenTelemetrySpanExt as _; +use ulid::Ulid; use crate::State; +type JobPayload = serde_json::Value; + +#[derive(Clone)] +pub struct JobContext { + pub id: Ulid, + pub metadata: JobMetadata, + pub queue_name: String, + pub cancellation_token: CancellationToken, +} + +impl JobContext { + pub fn span(&self) -> Span { + let span = tracing::info_span!( + parent: Span::none(), + "job.run", + job.id = %self.id, + job.queue_name = self.queue_name, + ); + + span.add_link(self.metadata.span_context()); + + span + } +} + pub trait FromJob { - fn from_job(job: &Job) -> Result + fn from_job(payload: JobPayload) -> Result where Self: Sized; } @@ -34,14 +63,14 @@ impl FromJob for T where T: DeserializeOwned, { - fn from_job(job: &Job) -> Result { - serde_json::from_value(job.payload.clone()).map_err(Into::into) + fn from_job(payload: JobPayload) -> Result { + serde_json::from_value(payload).map_err(Into::into) } } #[async_trait] pub trait RunnableJob: FromJob + Send + 'static { - async fn run(&self, state: &State) -> Result<(), anyhow::Error>; + async fn run(&self, state: &State, context: JobContext) -> Result<(), anyhow::Error>; } fn box_runnable_job(job: T) -> Box { @@ -79,7 +108,13 @@ pub enum QueueRunnerError { const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900); const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100); -type JobFactory = Box Box + Send>; +// How many jobs can we run concurrently +const MAX_CONCURRENT_JOBS: usize = 10; + +// How many jobs can we fetch at once +const MAX_JOBS_TO_FETCH: usize = 5; + +type JobFactory = Arc Box + Send + Sync>; pub struct QueueWorker { rng: ChaChaRng, @@ -89,7 +124,14 @@ pub struct QueueWorker { am_i_leader: bool, last_heartbeat: DateTime, cancellation_token: CancellationToken, + state: State, + running_jobs: JoinSet>, + job_contexts: HashMap, factories: HashMap<&'static str, JobFactory>, + + #[allow(clippy::type_complexity)] + last_join_result: + Option), tokio::task::JoinError>>, } impl QueueWorker { @@ -115,6 +157,12 @@ impl QueueWorker { .await .map_err(QueueRunnerError::SetupListener)?; + // We get notifications when a job is available on this channel + listener + .listen("queue_available") + .await + .map_err(QueueRunnerError::SetupListener)?; + let txn = listener .begin() .await @@ -139,14 +187,22 @@ impl QueueWorker { am_i_leader: false, last_heartbeat: now, cancellation_token, + state, + job_contexts: HashMap::new(), + running_jobs: JoinSet::new(), factories: HashMap::new(), + last_join_result: None, }) } pub fn register_handler(&mut self) -> &mut Self { - // TODO: error handling - let factory = |job: &Job| box_runnable_job(T::from_job(job).unwrap()); - self.factories.insert(T::QUEUE_NAME, Box::new(factory)); + // There is a potential panic here, which is fine as it's going to be caught + // within the job task + let factory = |payload: JobPayload| { + box_runnable_job(T::from_job(payload).expect("Failed to deserialize job")) + }; + + self.factories.insert(T::QUEUE_NAME, Arc::new(factory)); self } @@ -164,6 +220,7 @@ impl QueueWorker { async fn run_loop(&mut self) -> Result<(), QueueRunnerError> { self.wait_until_wakeup().await?; + // TODO: join all the jobs handles when shutting down if self.cancellation_token.is_cancelled() { return Ok(()); } @@ -214,6 +271,8 @@ impl QueueWorker { .sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION)); let wakeup_sleep = tokio::time::sleep(sleep_duration); + // TODO: add metrics to track the wake up reasons + tokio::select! { () = self.cancellation_token.cancelled() => { tracing::debug!("Woke up from cancellation"); @@ -223,6 +282,11 @@ impl QueueWorker { tracing::debug!("Woke up from sleep"); }, + Some(result) = self.running_jobs.join_next_with_id() => { + tracing::debug!("Joined job task"); + self.last_join_result = Some(result); + }, + notification = self.listener.recv() => { match notification { Ok(notification) => { @@ -281,6 +345,127 @@ impl QueueWorker { .try_get_leader_lease(&self.clock, &self.registration) .await?; + // Find any job task which finished + // If we got woken up by a join on the joinset, it will be stored in the + // last_join_result so that we don't loose it + + if self.last_join_result.is_none() { + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + + while let Some(result) = self.last_join_result.take() { + // TODO: add metrics to track the job status and the time it took + let context = match result { + Ok((id, Ok(()))) => { + // The job succeeded + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::info!( + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job completed" + ); + + context + } + Ok((id, Err(e))) => { + // The job failed + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::error!( + error = ?e, + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job failed" + ); + + // TODO: reschedule the job + + context + } + Err(e) => { + // The job crashed (or was cancelled) + let id = e.id(); + let context = self + .job_contexts + .remove(&id) + .expect("Job context not found"); + + tracing::error!( + error = &e as &dyn std::error::Error, + job.id = %context.id, + job.queue_name = %context.queue_name, + "Job crashed" + ); + + // TODO: reschedule the job + + context + } + }; + + repo.queue_job() + .mark_as_completed(&self.clock, context.id) + .await?; + + self.last_join_result = self.running_jobs.try_join_next_with_id(); + } + + // Compute how many jobs we should fetch at most + let max_jobs_to_fetch = MAX_CONCURRENT_JOBS + .saturating_sub(self.running_jobs.len()) + .max(MAX_JOBS_TO_FETCH); + + if max_jobs_to_fetch == 0 { + tracing::warn!("Internal job queue is full, not fetching any new jobs"); + } else { + // Grab a few jobs in the queue + let queues = self.factories.keys().copied().collect::>(); + let jobs = repo + .queue_job() + .reserve(&self.clock, &self.registration, &queues, 10) + .await?; + + for Job { + id, + queue_name, + payload, + metadata, + } in jobs + { + let cancellation_token = self.cancellation_token.child_token(); + let factory = self.factories.get(queue_name.as_str()).cloned(); + let context = JobContext { + id, + metadata, + queue_name, + cancellation_token, + }; + + let task = { + let context = context.clone(); + let span = context.span(); + let state = self.state.clone(); + async move { + // We should never crash, but in case we do, we do that in the task and + // don't crash the worker + let job = factory.expect("unknown job factory")(payload); + job.run(&state, context).await + } + .instrument(span) + }; + + let handle = self.running_jobs.spawn(task); + self.job_contexts.insert(handle.id(), context); + } + } + // After this point, we are locking the leader table, so it's important that we // commit as soon as possible to not block the other workers for too long repo.into_inner() @@ -353,6 +538,8 @@ impl QueueWorker { .shutdown_dead_workers(&self.clock, Duration::minutes(2)) .await?; + // TODO: mark tasks those workers had as lost + // Release the leader lock let txn = repo .into_inner() diff --git a/crates/tasks/src/recovery.rs b/crates/tasks/src/recovery.rs index 79f469b06..cd3787d2a 100644 --- a/crates/tasks/src/recovery.rs +++ b/crates/tasks/src/recovery.rs @@ -5,11 +5,10 @@ // Please see LICENSE in the repository root for full details. use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use mas_email::{Address, Mailbox}; use mas_i18n::DataLocale; use mas_storage::{ - job::JobWithSpanContext, queue::SendAccountRecoveryEmailsJob, user::{UserEmailFilter, UserRecoveryRepository}, Pagination, RepositoryAccess, @@ -18,117 +17,108 @@ use mas_templates::{EmailRecoveryContext, TemplateContext}; use rand::distributions::{Alphanumeric, DistString}; use tracing::{error, info}; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; /// Job to send account recovery emails for a given recovery session. -#[tracing::instrument( - name = "job.send_account_recovery_email", - fields( - user_recovery_session.id = %job.user_recovery_session_id(), - user_recovery_session.email, - ), - skip_all, - err(Debug), -)] -async fn send_account_recovery_email_job( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let clock = state.clock(); - let mailer = state.mailer(); - let url_builder = state.url_builder(); - let mut rng = state.rng(); - let mut repo = state.repository().await?; - - let session = repo - .user_recovery() - .lookup_session(job.user_recovery_session_id()) - .await? - .context("User recovery session not found")?; - - tracing::Span::current().record("user_recovery_session.email", &session.email); - - if session.consumed_at.is_some() { - info!("Recovery session already consumed, not sending email"); - return Ok(()); - } +#[async_trait] +impl RunnableJob for SendAccountRecoveryEmailsJob { + #[tracing::instrument( + name = "job.send_account_recovery_email", + fields( + user_recovery_session.id = %self.user_recovery_session_id(), + user_recovery_session.email, + ), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let clock = state.clock(); + let mailer = state.mailer(); + let url_builder = state.url_builder(); + let mut rng = state.rng(); + let mut repo = state.repository().await?; + + let session = repo + .user_recovery() + .lookup_session(self.user_recovery_session_id()) + .await? + .context("User recovery session not found")?; + + tracing::Span::current().record("user_recovery_session.email", &session.email); + + if session.consumed_at.is_some() { + info!("Recovery session already consumed, not sending email"); + return Ok(()); + } - let mut cursor = Pagination::first(50); - - let lang: DataLocale = session - .locale - .parse() - .context("Invalid locale in database on recovery session")?; - - loop { - let page = repo - .user_email() - .list( - UserEmailFilter::new() - .for_email(&session.email) - .verified_only(), - cursor, - ) - .await?; - - for email in page.edges { - let ticket = Alphanumeric.sample_string(&mut rng, 32); - - let ticket = repo - .user_recovery() - .add_ticket(&mut rng, &clock, &session, &email, ticket) - .await?; + let mut cursor = Pagination::first(50); + + let lang: DataLocale = session + .locale + .parse() + .context("Invalid locale in database on recovery session")?; - let user_email = repo + loop { + let page = repo .user_email() - .lookup(email.id) - .await? - .context("User email not found")?; - - let user = repo - .user() - .lookup(user_email.user_id) - .await? - .context("User not found")?; - - let url = url_builder.account_recovery_link(ticket.ticket); - - let address: Address = user_email.email.parse()?; - let mailbox = Mailbox::new(Some(user.username.clone()), address); - - info!("Sending recovery email to {}", mailbox); - let context = - EmailRecoveryContext::new(user, session.clone(), url).with_language(lang.clone()); - - // XXX: we only log if the email fails to send, to avoid stopping the loop - if let Err(e) = mailer.send_recovery_email(mailbox, &context).await { - error!( - error = &e as &dyn std::error::Error, - "Failed to send recovery email" - ); - } + .list( + UserEmailFilter::new() + .for_email(&session.email) + .verified_only(), + cursor, + ) + .await?; - cursor = cursor.after(email.id); - } + for email in page.edges { + let ticket = Alphanumeric.sample_string(&mut rng, 32); - if !page.has_next_page { - break; - } - } + let ticket = repo + .user_recovery() + .add_ticket(&mut rng, &clock, &session, &email, ticket) + .await?; - repo.save().await?; + let user_email = repo + .user_email() + .lookup(email.id) + .await? + .context("User email not found")?; - Ok(()) -} + let user = repo + .user() + .lookup(user_email.user_id) + .await? + .context("User not found")?; -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let send_user_recovery_email_worker = crate::build!(SendAccountRecoveryEmailsJob => send_account_recovery_email_job, suffix, state, storage_factory); + let url = url_builder.account_recovery_link(ticket.ticket); - monitor.register(send_user_recovery_email_worker) + let address: Address = user_email.email.parse()?; + let mailbox = Mailbox::new(Some(user.username.clone()), address); + + info!("Sending recovery email to {}", mailbox); + let context = EmailRecoveryContext::new(user, session.clone(), url) + .with_language(lang.clone()); + + // XXX: we only log if the email fails to send, to avoid stopping the loop + if let Err(e) = mailer.send_recovery_email(mailbox, &context).await { + error!( + error = &e as &dyn std::error::Error, + "Failed to send recovery email" + ); + } + + cursor = cursor.after(email.id); + } + + if !page.has_next_page { + break; + } + } + + repo.save().await?; + + Ok(()) + } } diff --git a/crates/tasks/src/storage/from_row.rs b/crates/tasks/src/storage/from_row.rs deleted file mode 100644 index 5acf6848a..000000000 --- a/crates/tasks/src/storage/from_row.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::str::FromStr; - -use apalis_core::{context::JobContext, job::JobId, request::JobRequest, worker::WorkerId}; -use chrono::{DateTime, Utc}; -use serde_json::Value; -use sqlx::Row; - -/// Wrapper for [`JobRequest`] -pub(crate) struct SqlJobRequest(JobRequest); - -impl From> for JobRequest { - fn from(val: SqlJobRequest) -> Self { - val.0 - } -} - -impl<'r, T: serde::de::DeserializeOwned> sqlx::FromRow<'r, sqlx::postgres::PgRow> - for SqlJobRequest -{ - fn from_row(row: &'r sqlx::postgres::PgRow) -> Result { - let job: Value = row.try_get("job")?; - let id: JobId = - JobId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode { - index: "id".to_owned(), - source: Box::new(e), - })?; - let mut context = JobContext::new(id); - - let run_at = row.try_get("run_at")?; - context.set_run_at(run_at); - - let attempts = row.try_get("attempts").unwrap_or(0); - context.set_attempts(attempts); - - let max_attempts = row.try_get("max_attempts").unwrap_or(25); - context.set_max_attempts(max_attempts); - - let done_at: Option> = row.try_get("done_at").unwrap_or_default(); - context.set_done_at(done_at); - - let lock_at: Option> = row.try_get("lock_at").unwrap_or_default(); - context.set_lock_at(lock_at); - - let last_error = row.try_get("last_error").unwrap_or_default(); - context.set_last_error(last_error); - - let status: String = row.try_get("status")?; - context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode { - index: "job".to_owned(), - source: Box::new(e), - })?); - - let lock_by: Option = row.try_get("lock_by").unwrap_or_default(); - context.set_lock_by(lock_by.map(WorkerId::new)); - - Ok(SqlJobRequest(JobRequest::new_with_context( - serde_json::from_value(job).map_err(|e| sqlx::Error::ColumnDecode { - index: "job".to_owned(), - source: Box::new(e), - })?, - context, - ))) - } -} diff --git a/crates/tasks/src/storage/mod.rs b/crates/tasks/src/storage/mod.rs deleted file mode 100644 index 5f6e77e31..000000000 --- a/crates/tasks/src/storage/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! Reimplementation of the [`apalis_sql::storage::PostgresStorage`] using a -//! shared connection for the [`PgListener`] - -mod from_row; -mod postgres; - -use self::from_row::SqlJobRequest; -pub(crate) use self::postgres::StorageFactory as PostgresStorageFactory; diff --git a/crates/tasks/src/storage/postgres.rs b/crates/tasks/src/storage/postgres.rs deleted file mode 100644 index f709579ed..000000000 --- a/crates/tasks/src/storage/postgres.rs +++ /dev/null @@ -1,391 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::{convert::TryInto, marker::PhantomData, ops::Add, sync::Arc, time::Duration}; - -use apalis_core::{ - error::JobStreamError, - job::{Job, JobId, JobStreamResult}, - request::JobRequest, - storage::{StorageError, StorageResult, StorageWorkerPulse}, - utils::Timer, - worker::WorkerId, -}; -use async_stream::try_stream; -use chrono::{DateTime, Utc}; -use event_listener::Event; -use futures_lite::{Stream, StreamExt}; -use serde::{de::DeserializeOwned, Serialize}; -use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row}; -use tokio::task::JoinHandle; - -use super::SqlJobRequest; - -pub struct StorageFactory { - pool: PgPool, - event: Arc, -} - -impl StorageFactory { - pub fn new(pool: Pool) -> Self { - StorageFactory { - pool, - event: Arc::new(Event::new()), - } - } - - pub async fn listen(self) -> Result, sqlx::Error> { - let mut listener = PgListener::connect_with(&self.pool).await?; - listener.listen("apalis::job").await?; - - let handle = tokio::spawn(async move { - loop { - let notification = listener.recv().await.expect("Failed to poll notification"); - self.event.notify(usize::MAX); - tracing::debug!(?notification, "Broadcast notification"); - } - }); - - Ok(handle) - } - - pub fn build(&self) -> Storage { - Storage { - pool: self.pool.clone(), - event: self.event.clone(), - job_type: PhantomData, - } - } -} - -/// Represents a [`apalis_core::storage::Storage`] that persists to Postgres -#[derive(Debug)] -pub struct Storage { - pool: PgPool, - event: Arc, - job_type: PhantomData, -} - -impl Clone for Storage { - fn clone(&self) -> Self { - Storage { - pool: self.pool.clone(), - event: self.event.clone(), - job_type: PhantomData, - } - } -} - -impl Storage { - fn stream_jobs( - &self, - worker_id: &WorkerId, - interval: Duration, - buffer_size: usize, - ) -> impl Stream, JobStreamError>> { - let pool = self.pool.clone(); - let sleeper = apalis_core::utils::timer::TokioTimer; - let worker_id = worker_id.clone(); - let event = self.event.clone(); - try_stream! { - loop { - // Wait for a notification or a timeout - let listener = event.listen(); - let interval = sleeper.sleep(interval); - futures_lite::future::race(interval, listener).await; - - let tx = pool.clone(); - let job_type = T::NAME; - let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);"; - let jobs: Vec> = sqlx::query_as(fetch_query) - .bind(worker_id.name()) - .bind(job_type) - // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html - .bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?) - .fetch_all(&tx) - .await.map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?; - for job in jobs { - yield job.into() - } - } - } - } - - async fn keep_alive_at( - &mut self, - worker_id: &WorkerId, - last_seen: DateTime, - ) -> StorageResult<()> { - let pool = self.pool.clone(); - - let worker_type = T::NAME; - let storage_name = std::any::type_name::(); - let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (id) DO - UPDATE SET last_seen = EXCLUDED.last_seen"; - sqlx::query(query) - .bind(worker_id.name()) - .bind(worker_type) - .bind(storage_name) - .bind(std::any::type_name::()) - .bind(last_seen) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } -} - -#[async_trait::async_trait] -impl apalis_core::storage::Storage for Storage -where - T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync, -{ - type Output = T; - - /// Push a job to Postgres [Storage] - /// - /// # SQL Example - /// - /// ```sql - /// SELECT apalis.push_job(job_type::text, job::json); - /// ``` - async fn push(&mut self, job: Self::Output) -> StorageResult { - let id = JobId::new(); - let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)"; - let pool = self.pool.clone(); - let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?; - let job_type = T::NAME; - sqlx::query(query) - .bind(job) - .bind(id.to_string()) - .bind(job_type) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(id) - } - - async fn schedule( - &mut self, - job: Self::Output, - on: chrono::DateTime, - ) -> StorageResult { - let query = - "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)"; - - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - - let id = JobId::new(); - let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?; - let job_type = T::NAME; - sqlx::query(query) - .bind(job) - .bind(id.to_string()) - .bind(job_type) - .bind(on) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(id) - } - - async fn fetch_by_id(&self, job_id: &JobId) -> StorageResult>> { - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - - let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1"; - let res: Option> = sqlx::query_as(fetch_query) - .bind(job_id.to_string()) - .fetch_optional(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(res.map(Into::into)) - } - - async fn heartbeat(&mut self, pulse: StorageWorkerPulse) -> StorageResult { - match pulse { - StorageWorkerPulse::EnqueueScheduled { count: _ } => { - // Ideally jobs are queue via run_at. So this is not necessary - Ok(true) - } - - // Worker not seen in 5 minutes yet has running jobs - StorageWorkerPulse::ReenqueueOrphaned { count, .. } => { - let job_type = T::NAME; - let mut conn = self - .pool - .acquire() - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - let query = "UPDATE apalis.jobs - SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned' - WHERE id in - (SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id - WHERE status = 'Running' AND workers.last_seen < NOW() - INTERVAL '5 minutes' - AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);"; - sqlx::query(query) - .bind(job_type) - .bind(count) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(true) - } - - _ => unimplemented!(), - } - } - - async fn kill(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - /// Puts the job instantly back into the queue - /// Another [Worker] may consume - async fn retry(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - fn consume( - &mut self, - worker_id: &WorkerId, - interval: Duration, - buffer_size: usize, - ) -> JobStreamResult { - Box::pin( - self.stream_jobs(worker_id, interval, buffer_size) - .map(|r| r.map(Some)), - ) - } - async fn len(&self) -> StorageResult { - let pool = self.pool.clone(); - let query = "SELECT COUNT(*) AS count FROM apalis.jobs WHERE status = 'Pending'"; - let record = sqlx::query(query) - .fetch_one(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(record - .try_get("count") - .map_err(|e| StorageError::Database(Box::from(e)))?) - } - async fn ack(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> { - let pool = self.pool.clone(); - let query = - "UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(worker_id.name()) - .execute(&pool) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn reschedule(&mut self, job: &JobRequest, wait: Duration) -> StorageResult<()> { - let pool = self.pool.clone(); - let job_id = job.id(); - - let wait: i64 = wait - .as_secs() - .try_into() - .map_err(|e| StorageError::Database(Box::new(e)))?; - let wait = chrono::Duration::microseconds(wait * 1000 * 1000); - // TODO: should we use a clock here? - #[allow(clippy::disallowed_methods)] - let run_at = Utc::now().add(wait); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1"; - sqlx::query(query) - .bind(job_id.to_string()) - .bind(run_at) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn update_by_id( - &self, - job_id: &JobId, - job: &JobRequest, - ) -> StorageResult<()> { - let pool = self.pool.clone(); - let status = job.status().as_ref(); - let attempts = job.attempts(); - let done_at = *job.done_at(); - let lock_by = job.lock_by().clone(); - let lock_at = *job.lock_at(); - let last_error = job.last_error().clone(); - - let mut conn = pool - .acquire() - .await - .map_err(|e| StorageError::Connection(Box::from(e)))?; - let query = - "UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7"; - sqlx::query(query) - .bind(status.to_owned()) - .bind(attempts) - .bind(done_at) - .bind(lock_by.as_ref().map(WorkerId::name)) - .bind(lock_at) - .bind(last_error) - .bind(job_id.to_string()) - .execute(&mut *conn) - .await - .map_err(|e| StorageError::Database(Box::from(e)))?; - Ok(()) - } - - async fn keep_alive(&mut self, worker_id: &WorkerId) -> StorageResult<()> { - #[allow(clippy::disallowed_methods)] - let now = Utc::now(); - - self.keep_alive_at::(worker_id, now).await - } -} diff --git a/crates/tasks/src/user.rs b/crates/tasks/src/user.rs index b3d062bb4..ad4444be5 100644 --- a/crates/tasks/src/user.rs +++ b/crates/tasks/src/user.rs @@ -5,10 +5,9 @@ // Please see LICENSE in the repository root for full details. use anyhow::Context; -use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use async_trait::async_trait; use mas_storage::{ compat::CompatSessionFilter, - job::JobWithSpanContext, oauth2::OAuth2SessionFilter, queue::{DeactivateUserJob, ReactivateUserJob}, user::{BrowserSessionFilter, UserRepository}, @@ -16,122 +15,106 @@ use mas_storage::{ }; use tracing::info; -use crate::{storage::PostgresStorageFactory, JobContextExt, State}; +use crate::{ + new_queue::{JobContext, RunnableJob}, + State, +}; /// Job to deactivate a user, both locally and on the Matrix homeserver. -#[tracing::instrument( +#[async_trait] +impl RunnableJob for DeactivateUserJob { + #[tracing::instrument( name = "job.deactivate_user" - fields(user.id = %job.user_id(), erase = %job.hs_erase()), - skip_all, - err(Debug), -)] -async fn deactivate_user( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let clock = state.clock(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - // Let's first lock the user - let user = repo - .user() - .lock(&clock, user) - .await - .context("Failed to lock user")?; - - // Kill all sessions for the user - let n = repo - .browser_session() - .finish_bulk( - &clock, - BrowserSessionFilter::new().for_user(&user).active_only(), - ) - .await?; - info!(affected = n, "Killed all browser sessions for user"); - - let n = repo - .oauth2_session() - .finish_bulk( - &clock, - OAuth2SessionFilter::new().for_user(&user).active_only(), - ) - .await?; - info!(affected = n, "Killed all OAuth 2.0 sessions for user"); - - let n = repo - .compat_session() - .finish_bulk( - &clock, - CompatSessionFilter::new().for_user(&user).active_only(), - ) - .await?; - info!(affected = n, "Killed all compatibility sessions for user"); - - // Before calling back to the homeserver, commit the changes to the database, as - // we want the user to be locked out as soon as possible - repo.save().await?; - - let mxid = matrix.mxid(&user.username); - info!("Deactivating user {} on homeserver", mxid); - matrix.delete_user(&mxid, job.hs_erase()).await?; - - Ok(()) + fields(user.id = %self.user_id(), erase = %self.hs_erase()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let clock = state.clock(); + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + // Let's first lock the user + let user = repo + .user() + .lock(&clock, user) + .await + .context("Failed to lock user")?; + + // Kill all sessions for the user + let n = repo + .browser_session() + .finish_bulk( + &clock, + BrowserSessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all browser sessions for user"); + + let n = repo + .oauth2_session() + .finish_bulk( + &clock, + OAuth2SessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all OAuth 2.0 sessions for user"); + + let n = repo + .compat_session() + .finish_bulk( + &clock, + CompatSessionFilter::new().for_user(&user).active_only(), + ) + .await?; + info!(affected = n, "Killed all compatibility sessions for user"); + + // Before calling back to the homeserver, commit the changes to the database, as + // we want the user to be locked out as soon as possible + repo.save().await?; + + let mxid = matrix.mxid(&user.username); + info!("Deactivating user {} on homeserver", mxid); + matrix.delete_user(&mxid, self.hs_erase()).await?; + + Ok(()) + } } /// Job to reactivate a user, both locally and on the Matrix homeserver. -#[tracing::instrument( - name = "job.reactivate_user", - fields(user.id = %job.user_id()), - skip_all, - err(Debug), -)] -pub async fn reactivate_user( - job: JobWithSpanContext, - ctx: JobContext, -) -> Result<(), anyhow::Error> { - let state = ctx.state(); - let matrix = state.matrix_connection(); - let mut repo = state.repository().await?; - - let user = repo - .user() - .lookup(job.user_id()) - .await? - .context("User not found")?; - - let mxid = matrix.mxid(&user.username); - info!("Reactivating user {} on homeserver", mxid); - matrix.reactivate_user(&mxid).await?; - - // We want to unlock the user from our side only once it has been reactivated on - // the homeserver - let _user = repo.user().unlock(user).await?; - repo.save().await?; - - Ok(()) -} - -pub(crate) fn register( - suffix: &str, - monitor: Monitor, - state: &State, - storage_factory: &PostgresStorageFactory, -) -> Monitor { - let deactivate_user_worker = - crate::build!(DeactivateUserJob => deactivate_user, suffix, state, storage_factory); - - let reactivate_user_worker = - crate::build!(ReactivateUserJob => reactivate_user, suffix, state, storage_factory); - - monitor - .register(deactivate_user_worker) - .register(reactivate_user_worker) +#[async_trait] +impl RunnableJob for ReactivateUserJob { + #[tracing::instrument( + name = "job.reactivate_user", + fields(user.id = %self.user_id()), + skip_all, + err(Debug), + )] + async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> { + let matrix = state.matrix_connection(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(self.user_id()) + .await? + .context("User not found")?; + + let mxid = matrix.mxid(&user.username); + info!("Reactivating user {} on homeserver", mxid); + matrix.reactivate_user(&mxid).await?; + + // We want to unlock the user from our side only once it has been reactivated on + // the homeserver + let _user = repo.user().unlock(user).await?; + repo.save().await?; + + Ok(()) + } } diff --git a/crates/tasks/src/utils.rs b/crates/tasks/src/utils.rs deleted file mode 100644 index c5862f9cf..000000000 --- a/crates/tasks/src/utils.rs +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use apalis_core::{job::Job, request::JobRequest}; -use mas_storage::job::JobWithSpanContext; -use mas_tower::{ - make_span_fn, DurationRecorderLayer, FnWrapper, IdentityLayer, InFlightCounterLayer, - TraceLayer, KV, -}; -use opentelemetry::{trace::SpanContext, Key, KeyValue}; -use tracing::info_span; -use tracing_opentelemetry::OpenTelemetrySpanExt; - -const JOB_NAME: Key = Key::from_static_str("job.name"); -const JOB_STATUS: Key = Key::from_static_str("job.status"); - -/// Represents a job that can may have a span context attached to it. -pub trait TracedJob: Job { - /// Returns the span context for this job, if any. - /// - /// The default implementation returns `None`. - fn span_context(&self) -> Option { - None - } -} - -/// Implements [`TracedJob`] for any job with the [`JobWithSpanContext`] -/// wrapper. -impl TracedJob for JobWithSpanContext { - fn span_context(&self) -> Option { - JobWithSpanContext::span_context(self) - } -} - -fn make_span_for_job_request(req: &JobRequest) -> tracing::Span { - let span = info_span!( - "job.run", - "otel.kind" = "consumer", - "otel.status_code" = tracing::field::Empty, - "job.id" = %req.id(), - "job.attempts" = req.attempts(), - "job.name" = J::NAME, - ); - - if let Some(context) = req.inner().span_context() { - span.add_link(context); - } - - span -} - -type TraceLayerForJob = - TraceLayer) -> tracing::Span>, KV<&'static str>, KV<&'static str>>; - -pub(crate) fn trace_layer() -> TraceLayerForJob -where - J: TracedJob, -{ - TraceLayer::new(make_span_fn( - make_span_for_job_request:: as fn(&JobRequest) -> tracing::Span, - )) - .on_response(KV("otel.status_code", "OK")) - .on_error(KV("otel.status_code", "ERROR")) -} - -type MetricsLayerForJob = ( - IdentityLayer>, - DurationRecorderLayer, - InFlightCounterLayer, -); - -pub(crate) fn metrics_layer() -> MetricsLayerForJob -where - J: Job, -{ - let duration_recorder = DurationRecorderLayer::new("job.run.duration") - .on_request(JOB_NAME.string(J::NAME)) - .on_response(JOB_STATUS.string("success")) - .on_error(JOB_STATUS.string("error")); - let in_flight_counter = - InFlightCounterLayer::new("job.run.active").on_request(JOB_NAME.string(J::NAME)); - - ( - IdentityLayer::default(), - duration_recorder, - in_flight_counter, - ) -}