Skip to content

Commit

Permalink
Merge pull request #132 from 4t145/broadcast-arc-in-websocket
Browse files Browse the repository at this point in the history
use arc to share between broadcast channels
  • Loading branch information
4t145 authored Jun 19, 2024
2 parents 8957194 + 932a1b0 commit d15fb38
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 44 deletions.
3 changes: 2 additions & 1 deletion examples/websocket/src/processor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;

use serde::{Deserialize, Serialize};
use tardis::basic::result::TardisResult;
Expand Down Expand Up @@ -127,7 +128,7 @@ impl Page {
}

#[oai(path = "/ws/broadcast/:name", method = "get")]
async fn ws_broadcast(&self, name: Path<String>, websocket: WebSocket, sender: Data<&Sender<TardisWebsocketMgrMessage>>) -> BoxWebSocketUpgraded {
async fn ws_broadcast(&self, name: Path<String>, websocket: WebSocket, sender: Data<&Sender<Arc<TardisWebsocketMgrMessage>>>) -> BoxWebSocketUpgraded {
pub struct Hooks {
ext: HashMap<String, String>,
}
Expand Down
8 changes: 4 additions & 4 deletions tardis/src/cluster/cluster_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ where
T: Send + Sync + 'static + Clone + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug,
{
pub ident: String,
pub local_broadcast_channel: broadcast::Sender<T>,
pub local_broadcast_channel: broadcast::Sender<Arc<T>>,
}

impl<T> ClusterBroadcastChannel<T>
Expand All @@ -26,7 +26,7 @@ where
format!("tardis/broadcast/{}", self.ident)
}
pub async fn send(&self, message: T) -> TardisResult<()> {
match self.local_broadcast_channel.send(message.clone()) {
match self.local_broadcast_channel.send(message.clone().into()) {
Ok(size) => {
tracing::trace!("[Tardis.Cluster] broadcast channel send to {size} local subscribers");
}
Expand Down Expand Up @@ -73,7 +73,7 @@ impl<T> std::ops::Deref for ClusterBroadcastChannel<T>
where
T: Send + Sync + 'static + Clone + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug,
{
type Target = broadcast::Sender<T>;
type Target = broadcast::Sender<Arc<T>>;

fn deref(&self) -> &Self::Target {
&self.local_broadcast_channel
Expand All @@ -98,7 +98,7 @@ where
async fn handle(self: Arc<Self>, message_req: TardisClusterMessageReq) -> TardisResult<Option<Value>> {
if let Ok(message) = serde_json::from_value(message_req.msg) {
if let Some(chan) = self.channel.upgrade() {
let _ = chan.local_broadcast_channel.send(message);
let _ = chan.local_broadcast_channel.send(Arc::new(message));
} else {
unsubscribe(&self.event_name()).await;
}
Expand Down
66 changes: 36 additions & 30 deletions tardis/src/web/ws_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,17 @@ where
}

pub trait WsBroadcastSender: Send + Sync + 'static {
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<TardisWebsocketMgrMessage>;
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Arc<TardisWebsocketMgrMessage>>;
fn send(&self, msg: TardisWebsocketMgrMessage) -> impl Future<Output = TardisResult<()>> + Send;
}

impl WsBroadcastSender for tokio::sync::broadcast::Sender<TardisWebsocketMgrMessage> {
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<TardisWebsocketMgrMessage> {
impl WsBroadcastSender for tokio::sync::broadcast::Sender<Arc<TardisWebsocketMgrMessage>> {
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Arc<TardisWebsocketMgrMessage>> {
self.subscribe()
}

async fn send(&self, msg: TardisWebsocketMgrMessage) -> TardisResult<()> {
let _ = self.send(msg).map_err(|_| TardisError::internal_error("tokio channel send error", ""))?;
let _ = self.send(msg.into()).map_err(|_| TardisError::internal_error("tokio channel send error", ""))?;
Ok(())
}
}
Expand Down Expand Up @@ -227,7 +227,7 @@ where
return Err("from_avatar is illegal".to_string());
}
// System process
if req_msg.event == Some(WS_SYSTEM_EVENT_INFO.to_string()) {
if req_msg.event.as_deref() == Some(WS_SYSTEM_EVENT_INFO) {
let msg = TardisFuns::json
.obj_to_json(&TardisWebsocketInstInfo {
inst_id: self.context.inst_id.clone(),
Expand Down Expand Up @@ -260,33 +260,39 @@ where
self.send_to_channel(send_msg);
return Ok(());
// For security reasons, adding an avatar needs to be handled by the management node
} else if self.context.mgr_node && req_msg.event == Some(WS_SYSTEM_EVENT_AVATAR_ADD.to_string()) {
let Some(new_avatar) = req_msg.msg.as_str() else {
return Err("msg is not a string".to_string());
};
let Some(ref spec_inst_id) = req_msg.spec_inst_id else {
return Err("spec_inst_id is not specified".to_string());
};
#[cfg(feature = "cluster")]
{
let Ok(Some(_)) = insts_in_send.get(spec_inst_id.clone()).await else {
return Err("spec_inst_id not found".to_string());
} else if req_msg.event.as_deref() == Some(WS_SYSTEM_EVENT_AVATAR_ADD) {
if self.context.mgr_node {
let Some(new_avatar) = req_msg.msg.as_str() else {
return Err("msg is not a string".to_string());
};
trace!("[Tardis.WebServer] WS message add avatar {}:{} to {}", &msg_id, &new_avatar, &spec_inst_id);
let _ = insts_in_send.modify(spec_inst_id.clone(), "add_avatar", json!(new_avatar)).await;
return Ok(());
}
#[cfg(not(feature = "cluster"))]
{
let mut write_locked = insts_in_send.write().await;
let Some(inst) = write_locked.get_mut(spec_inst_id) else {
return Err("spec_inst_id not found".to_string());
let Some(ref spec_inst_id) = req_msg.spec_inst_id else {
return Err("spec_inst_id is not specified".to_string());
};
inst.push(new_avatar.to_string());
drop(write_locked);
trace!("[Tardis.WebServer] WS message add avatar {}:{} to {}", msg_id, new_avatar, spec_inst_id);
#[cfg(feature = "cluster")]
{
let Ok(Some(_)) = insts_in_send.get(spec_inst_id.clone()).await else {
return Err("spec_inst_id not found".to_string());
};
trace!("[Tardis.WebServer] WS message add avatar {}:{} to {}", &msg_id, &new_avatar, &spec_inst_id);
let _ = insts_in_send.modify(spec_inst_id.clone(), "add_avatar", json!(new_avatar)).await;
// return Ok(());
}
#[cfg(not(feature = "cluster"))]
{
let mut write_locked = insts_in_send.write().await;
let Some(inst) = write_locked.get_mut(spec_inst_id) else {
return Err("spec_inst_id not found".to_string());
};
inst.push(new_avatar.to_string());
drop(write_locked);
trace!("[Tardis.WebServer] WS message add avatar {}:{} to {}", msg_id, new_avatar, spec_inst_id);
// return Ok(());
}
} else {
// ignore this message
// return Ok(())
}
} else if req_msg.event == Some(WS_SYSTEM_EVENT_AVATAR_DEL.to_string()) {
} else if req_msg.event.as_deref() == Some(WS_SYSTEM_EVENT_AVATAR_DEL) {
#[cfg(feature = "cluster")]
{
let Ok(Some(_)) = insts_in_send.get(self.context.inst_id.clone()).await else {
Expand Down Expand Up @@ -471,7 +477,7 @@ where
|| !mgr_message.ignore_avatars.is_empty() && mgr_message.ignore_avatars.iter().all(|avatar| current_avatars.contains(avatar))
{
let Ok(resp_msg) = (if context.mgr_node {
TardisFuns::json.obj_to_string(&mgr_message)
TardisFuns::json.obj_to_string(mgr_message.as_ref())
} else {
TardisFuns::json.obj_to_string(&TardisWebsocketMessage {
msg_id: mgr_message.msg_id.clone(),
Expand Down
4 changes: 2 additions & 2 deletions tardis/src/web/ws_processor/cluster_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{basic::result::TardisResult, cluster::cluster_broadcast::ClusterBroa
use super::{TardisWebsocketMgrMessage, WsBroadcastSender};

impl WsBroadcastSender for ClusterBroadcastChannel<TardisWebsocketMgrMessage> {
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<TardisWebsocketMgrMessage> {
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Arc<TardisWebsocketMgrMessage>> {
self.local_broadcast_channel.subscribe()
}

Expand All @@ -15,7 +15,7 @@ impl WsBroadcastSender for ClusterBroadcastChannel<TardisWebsocketMgrMessage> {
}

impl WsBroadcastSender for Arc<ClusterBroadcastChannel<TardisWebsocketMgrMessage>> {
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<TardisWebsocketMgrMessage> {
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Arc<TardisWebsocketMgrMessage>> {
self.as_ref().subscribe()
}

Expand Down
50 changes: 43 additions & 7 deletions tardis/tests/test_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use tokio::time::sleep;
use tokio_tungstenite::tungstenite::Message;

lazy_static! {
static ref SENDERS: Arc<RwLock<HashMap<String, Sender<TardisWebsocketMgrMessage>>>> = Arc::new(RwLock::new(HashMap::new()));
static ref SENDERS: Arc<RwLock<HashMap<String, Sender<Arc<TardisWebsocketMgrMessage>>>>> = Arc::new(RwLock::new(HashMap::new()));
}

#[tokio::test(flavor = "multi_thread")]
Expand Down Expand Up @@ -54,6 +54,9 @@ async fn test_normal() -> TardisResult<()> {

// message illegal test
let error_client_a = TardisFuns::ws_client("ws://127.0.0.1:8080/ws/broadcast/gerror/a", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
if let Message::Text(msg) = msg {
println!("client_not_found recv:{}", msg);
assert!(msg.contains("message illegal"));
Expand All @@ -65,6 +68,9 @@ async fn test_normal() -> TardisResult<()> {
error_client_a.send_text("hi".to_string()).await?;
// not found test
let error_client_b = TardisFuns::ws_client("ws://127.0.0.1:8080/ws/broadcast/gxxx/a", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
if let Message::Text(msg) = msg {
println!("client_not_found recv:{}", msg);
assert_eq!(msg, "Websocket connection error: group not found");
Expand All @@ -83,6 +89,9 @@ async fn test_normal() -> TardisResult<()> {

// subscribe mode test
let sub_client_a = TardisFuns::ws_client("ws://127.0.0.1:8080/ws/broadcast/g1/a", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
if let Message::Text(msg) = msg {
println!("client_a recv:{}", msg);
assert!(msg.contains(r#"service send:\"hi\""#));
Expand All @@ -92,6 +101,9 @@ async fn test_normal() -> TardisResult<()> {
})
.await?;
let sub_client_b1 = TardisFuns::ws_client("ws://127.0.0.1:8080/ws/broadcast/g1/b", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
if let Message::Text(msg) = msg {
println!("client_b1 recv:{}", msg);
assert!(msg.contains(r#"service send:\"hi\""#));
Expand All @@ -111,6 +123,9 @@ async fn test_normal() -> TardisResult<()> {
})
.await?;
let sub_client_b2 = TardisFuns::ws_client("ws://127.0.0.1:8080/ws/broadcast/g1/b", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
if let Message::Text(msg) = msg {
println!("client_b2 recv:{}", msg);
assert!(msg.contains(r#"service send:\"hi\""#));
Expand Down Expand Up @@ -153,6 +168,9 @@ async fn test_normal() -> TardisResult<()> {

// non-subscribe mode test
let non_sub_client_a = TardisFuns::ws_client("ws://127.0.0.1:8080/ws/broadcast/g2/a", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
if let Message::Text(msg) = msg {
println!("client_a recv:{}", msg);
assert!(msg.contains(r#"service send:\"hi\""#));
Expand All @@ -162,6 +180,9 @@ async fn test_normal() -> TardisResult<()> {
})
.await?;
let non_sub_client_b1 = TardisFuns::ws_client("ws://127.0.0.1:8080/ws/broadcast/g2/b", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
if let Message::Text(msg) = msg {
println!("client_b1 recv:{}", msg);
assert!(msg.contains(r#"service send:\"hi\""#));
Expand All @@ -181,6 +202,9 @@ async fn test_normal() -> TardisResult<()> {
})
.await?;
let non_sub_client_b2 = TardisFuns::ws_client("ws://127.0.0.1:8080/ws/broadcast/g2/b", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
if let Message::Text(msg) = msg {
println!("client_b2 recv:{}", msg);
assert!(msg.contains(r#"service send:\"hi\""#));
Expand Down Expand Up @@ -236,6 +260,9 @@ async fn test_dyn_avatar() -> TardisResult<()> {
static DEL_COUNTER: AtomicUsize = AtomicUsize::new(0);

TardisFuns::ws_client("ws://127.0.0.1:8080/ws/dyn/_/true", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
let receive_msg = msg.str_to_obj::<TardisWebsocketMgrMessage>().unwrap();
if receive_msg.event == Some(WS_SYSTEM_EVENT_AVATAR_ADD.to_string()) && receive_msg.msg.as_str().unwrap() == "c" {
ADD_COUNTER.fetch_add(1, Ordering::SeqCst);
Expand All @@ -253,10 +280,13 @@ async fn test_dyn_avatar() -> TardisResult<()> {
.await?;

TardisFuns::ws_client("ws://127.0.0.1:8080/ws/dyn/a/false", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
let receive_msg = TardisFuns::json.str_to_obj::<TardisWebsocketMessage>(msg.to_text().unwrap()).unwrap();
if receive_msg.event == Some(WS_SYSTEM_EVENT_AVATAR_ADD.to_string()) && receive_msg.msg.as_str().unwrap() == "c" {
panic!();
ADD_COUNTER.fetch_add(1, Ordering::SeqCst);
// panic!();
// ADD_COUNTER.fetch_add(1, Ordering::SeqCst);
}
if receive_msg.event == Some(WS_SYSTEM_EVENT_AVATAR_DEL.to_string()) && receive_msg.msg.as_str().unwrap() == "c" {
panic!();
Expand All @@ -267,10 +297,13 @@ async fn test_dyn_avatar() -> TardisResult<()> {
.await?;

let a_client = TardisFuns::ws_client("ws://127.0.0.1:8080/ws/dyn/a/false", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
let receive_msg = TardisFuns::json.str_to_obj::<TardisWebsocketMessage>(msg.to_text().unwrap()).unwrap();
if receive_msg.event == Some(WS_SYSTEM_EVENT_AVATAR_ADD.to_string()) && receive_msg.msg.as_str().unwrap() == "c" {
panic!();
ADD_COUNTER.fetch_add(1, Ordering::SeqCst);
// panic!();
// ADD_COUNTER.fetch_add(1, Ordering::SeqCst);
}
if receive_msg.event == Some(WS_SYSTEM_EVENT_AVATAR_DEL.to_string()) && receive_msg.msg.as_str().unwrap() == "c" {
panic!();
Expand All @@ -286,6 +319,9 @@ async fn test_dyn_avatar() -> TardisResult<()> {
.await?;

TardisFuns::ws_client("ws://127.0.0.1:8080/ws/dyn/a/false", move |msg| async move {
if msg.is_ping() || msg.is_pong() {
return None;
}
let receive_msg = TardisFuns::json.str_to_obj::<TardisWebsocketMessage>(msg.to_text().unwrap()).unwrap();
if receive_msg.msg.as_str().unwrap() == "a" {
ADD_COUNTER.fetch_add(1, Ordering::SeqCst);
Expand Down Expand Up @@ -381,7 +417,7 @@ impl Api {
#[oai(path = "/ws/broadcast/:group/:name", method = "get")]
async fn ws_broadcast(&self, group: Path<String>, name: Path<String>, websocket: WebSocket) -> BoxWebSocketUpgraded {
if !SENDERS.read().await.contains_key(&group.0) {
SENDERS.write().await.insert(group.0.clone(), tokio::sync::broadcast::channel::<TardisWebsocketMgrMessage>(100).0);
SENDERS.write().await.insert(group.0.clone(), tokio::sync::broadcast::channel::<_>(100).0);
}
let sender = SENDERS.read().await.get(&group.0).unwrap().clone();
if group.0 == "g1" {
Expand Down Expand Up @@ -422,7 +458,7 @@ impl Api {
#[oai(path = "/ws/dyn/:name/:mgr", method = "get")]
async fn ws_dyn_broadcast(&self, name: Path<String>, mgr: Path<bool>, websocket: WebSocket) -> BoxWebSocketUpgraded {
if !SENDERS.read().await.contains_key("dyn") {
SENDERS.write().await.insert("dyn".to_string(), tokio::sync::broadcast::channel::<TardisWebsocketMgrMessage>(100).0);
SENDERS.write().await.insert("dyn".to_string(), tokio::sync::broadcast::channel::<_>(100).0);
}
let sender = SENDERS.read().await.get("dyn").unwrap().clone();
WsBroadcast::new(sender, PassThroughHook, WsBroadcastContext::new(mgr.0, true)).run(vec![name.0], websocket).await
Expand Down

0 comments on commit d15fb38

Please sign in to comment.