Skip to content

Commit

Permalink
fix(enrich-members): Abort ongoing tasks when the request is canceled
Browse files Browse the repository at this point in the history
  • Loading branch information
RemiBardon committed Nov 21, 2024
1 parent 13b83f4 commit 86e3c7f
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 23 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ thiserror = "1"
# https://github.com/time-rs/time/issues/681
time = "=0.3.36"
tokio = { version = "1", features = ["rt"] }
tokio-util = "0.7"
tracing = { version = "0.1" }
tracing-subscriber = "0.3"
url_serde = "0.2"
Expand Down
60 changes: 50 additions & 10 deletions crates/rest-api/src/features/members/enrich_members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
use rocket::{
form::Strict,
get,
http::Status,
response::stream::{Event, EventStream},
serde::json::Json,
};
Expand All @@ -17,9 +18,18 @@ use service::{
members::{member_controller, MemberController},
models::BareJid,
};
use tokio::task::JoinHandle;
use tokio::{
sync::mpsc,
task::JoinHandle,
time::{sleep, Duration},
};
use tracing::{debug, error};

use crate::{error::Error, forms::JID as JIDUriParam, guards::LazyGuard};
use crate::{
error::{self, Error},
forms::JID as JIDUriParam,
guards::LazyGuard,
};

#[derive(Debug, Serialize, Deserialize)]
pub struct EnrichedMember {
Expand All @@ -35,30 +45,60 @@ pub struct JIDs {
}

#[get("/v1/enrich-members?<jids..>", format = "application/json")]
pub async fn enrich_members_route<'r>(
pub async fn enrich_members_route(
member_controller: LazyGuard<MemberController>,
jids: Strict<JIDs>,
) -> Result<Json<HashMap<BareJid, EnrichedMember>>, Error> {
let member_controller = member_controller.inner?;
let jids = jids.into_inner();

let mut tasks: FuturesUnordered<JoinHandle<EnrichedMember>> = FuturesUnordered::new();
let (tx, mut rx) = mpsc::channel::<EnrichedMember>(jids.len());
let tasks: FuturesUnordered<JoinHandle<()>> = FuturesUnordered::new();
for jid in jids.iter() {
let jid = jid.clone();
let member_controller = member_controller.clone();
let tx = tx.clone();
tasks.push(tokio::spawn(async move {
member_controller
let member = member_controller
.enrich_member(&jid)
.map(EnrichedMember::from)
.await
.await;
if let Err(err) = tx.send(member).await {
// TODO: Investigate why these messages are not received by
// the tracing subscriber (i.e. not logged).
if tx.is_closed() {
debug!("Cannot send enriched member: Task aborted.");
} else {
error!("Cannot send enriched member: {err}");
}
}
}));
}

let mut res = HashMap::with_capacity(jids.len());
while let Some(Ok(member)) = tasks.next().await {
res.insert(member.jid.clone(), member.into());
let res = tokio::select! {
res = async {
let mut res = HashMap::with_capacity(jids.len());
while let Some(member) = rx.recv().await {
res.insert(member.jid.clone(), member.into());
}
Ok(res.into())
} => {
res
}
_ = sleep(Duration::from_secs(3)) => {
debug!("Timed out.");
Err(error::HTTPStatus(Status::new(499)).into())
}
};

debug!("Cancelling all task…");
rx.close();
for task in tasks {
task.abort();
}
Ok(res.into())
member_controller.cancel_tasks();

res
}

#[get("/v1/enrich-members?<jids..>", format = "text/event-stream", rank = 2)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ impl<'r> LazyFromRequest<'r> for MemberController {
let db = try_outcome!(database_connection(req).await);
let xmpp_service = try_outcome!(XmppService::from_request(req).await);

Outcome::Success(MemberController {
db: Arc::new(db.clone()),
xmpp_service: Arc::new(xmpp_service),
})
Outcome::Success(MemberController::new(
Arc::new(db.clone()),
Arc::new(xmpp_service),
))
}
}
1 change: 1 addition & 0 deletions crates/service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ strum = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true }
url_serde = { workspace = true }
urlencoding = { workspace = true }
Expand Down
46 changes: 37 additions & 9 deletions crates/service/src/features/members/member_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::Arc;

use chrono::{DateTime, Utc};
use sea_orm::{DatabaseConnection, DbErr, ItemsAndPagesNumber};
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn};

use crate::xmpp::{BareJid, XmppService};
Expand All @@ -17,6 +18,22 @@ use super::{Member, MemberRepository};
pub struct MemberController {
pub db: Arc<DatabaseConnection>,
pub xmpp_service: Arc<XmppService>,
/// token used When
cancellation_token: CancellationToken,
}

impl MemberController {
pub fn new(db: Arc<DatabaseConnection>, xmpp_service: Arc<XmppService>) -> Self {
Self {
db,
xmpp_service,
cancellation_token: CancellationToken::new(),
}
}

pub fn cancel_tasks(&self) {
self.cancellation_token.cancel();
}
}

impl MemberController {
Expand All @@ -34,6 +51,14 @@ impl MemberController {
pub async fn enrich_member(&self, jid: &BareJid) -> EnrichedMember {
trace!("Enriching `{jid}`…");

let mut member = EnrichedMember {
jid: jid.to_owned(),
nickname: None,
avatar: None,
online: None,
};

trace!("-> Getting `{jid}`'s vCard…");
let vcard = match self.xmpp_service.get_vcard(jid).await {
Ok(Some(vcard)) => Some(vcard),
Ok(None) => {
Expand All @@ -47,11 +72,15 @@ impl MemberController {
None
}
};
let nickname = vcard
member.nickname = vcard
.and_then(|vcard| vcard.nickname.first().cloned())
.map(|p| p.value);

let avatar = match self.xmpp_service.get_avatar(jid).await {
if self.cancellation_token.is_cancelled() {
return member;
}
trace!("-> Getting `{jid}`'s avatar…");
member.avatar = match self.xmpp_service.get_avatar(jid).await {
Ok(Some(avatar)) => Some(avatar.base64().to_string()),
Ok(None) => {
debug!("`{jid}` has no avatar.");
Expand All @@ -65,7 +94,11 @@ impl MemberController {
}
};

let online = self
if self.cancellation_token.is_cancelled() {
return member;
}
trace!("-> Checking if `{jid}` is connected…");
member.online = self
.xmpp_service
.is_connected(jid)
.await
Expand All @@ -74,12 +107,7 @@ impl MemberController {
// But dismiss it
.ok();

EnrichedMember {
jid: jid.to_owned(),
nickname,
avatar,
online,
}
member
}
}

Expand Down

0 comments on commit 86e3c7f

Please sign in to comment.