From c135f5aec9ded04626d8686735208cc600f1e83f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Wala?= Date: Sat, 16 Dec 2023 15:15:24 +0100 Subject: [PATCH] Refactor `server.rs`, use UnboundedChannel --- client/src/streams/packets.rs | 2 +- common/src/packet.rs | 2 +- src/cmd/run.rs | 19 +++++++---- src/server.rs | 64 ++++++++++++----------------------- 4 files changed, 36 insertions(+), 51 deletions(-) diff --git a/client/src/streams/packets.rs b/client/src/streams/packets.rs index 2c4543e..f31aa70 100644 --- a/client/src/streams/packets.rs +++ b/client/src/streams/packets.rs @@ -47,7 +47,7 @@ impl Packets { pub fn id_count(&self) -> usize { match self.packets.last_key_value() { - Some((id, _)) => id + 1, + Some((id, _)) => *id, None => 0, } } diff --git a/common/src/packet.rs b/common/src/packet.rs index b5a409d..a029d27 100644 --- a/common/src/packet.rs +++ b/common/src/packet.rs @@ -185,7 +185,7 @@ fn is_rtp(packet: &RtpPacket) -> bool { if packet.version != 2 { return false; } - if let 72..=76 = packet.payload_type.id { + if let 72..=79 = packet.payload_type.id { return false; } if packet.ssrc == 0 { diff --git a/src/cmd/run.rs b/src/cmd/run.rs index f2f937a..6f53993 100644 --- a/src/cmd/run.rs +++ b/src/cmd/run.rs @@ -32,12 +32,6 @@ impl Run { let mut file_sniffers = get_sniffers(self.files, Sniffer::from_file); let mut interface_sniffers = get_sniffers(self.interfaces, Sniffer::from_device); - if file_sniffers.is_empty() && interface_sniffers.is_empty() { - // TODO: use some pretty printing (colors, bold font etc.) - println!("Error: no valid sources were passed"); - return; - } - let file_res = apply_filters(&mut file_sniffers, &self.capture); let interface_res = apply_filters(&mut interface_sniffers, &live_filter); @@ -46,8 +40,19 @@ impl Run { return; } + let sniffers: HashMap<_, _> = file_sniffers + .into_iter() + .chain(interface_sniffers) + .collect(); + + if sniffers.is_empty() { + // TODO: use some pretty printing (colors, bold font etc.) + println!("Error: no valid sources were passed"); + return; + } + let address = SocketAddr::new(self.address, self.port); - server::run(interface_sniffers, file_sniffers, address).await; + server::run(sniffers, address).await; } fn create_capture_filter(&self) -> String { diff --git a/src/server.rs b/src/server.rs index 2ac123d..7787869 100644 --- a/src/server.rs +++ b/src/server.rs @@ -13,7 +13,7 @@ use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; -use tokio::sync::{mpsc, mpsc::Sender, RwLock}; +use tokio::sync::{mpsc, mpsc::UnboundedSender, RwLock}; use warp::ws::{Message, WebSocket}; use warp::Filter; @@ -22,12 +22,12 @@ const WS_PATH: &str = "ws"; static NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::new(1); struct Client { - pub sender: mpsc::Sender, + pub sender: mpsc::UnboundedSender, pub source: Option, } impl Client { - pub fn new(sender: mpsc::Sender) -> Self { + pub fn new(sender: mpsc::UnboundedSender) -> Self { Self { sender, source: None, @@ -39,30 +39,14 @@ type Clients = Arc>>; type Packets = Arc>>; type PacketsMap = Arc>; -pub async fn run( - interface_sniffers: HashMap, - file_sniffers: HashMap, - addr: SocketAddr, -) { +pub async fn run(sniffers: HashMap, addr: SocketAddr) { let clients = Clients::default(); let mut source_to_packets = HashMap::new(); // a bit of repetition, but Rust bested me this time - for (file, sniffer) in file_sniffers { + for (_file, sniffer) in sniffers { let packets = Packets::default(); - let source = Source::File(file); - source_to_packets.insert(source, packets.clone()); - - let cloned_clients = clients.clone(); - tokio::task::spawn(async move { - sniff(sniffer, packets, cloned_clients).await; - }); - } - - for (interface, sniffer) in interface_sniffers { - let packets = Packets::default(); - let source = Source::Interface(interface); - source_to_packets.insert(source, packets.clone()); + source_to_packets.insert(sniffer.source.clone(), packets.clone()); let cloned_clients = clients.clone(); tokio::task::spawn(async move { @@ -101,7 +85,8 @@ async fn client_connected(ws: WebSocket, clients: Clients, source_to_packets: Pa // if buffer size is > 1, rx.recv always waits // until channel's buffer is full before receiving // might be because of blocking sniffers - let (tx, mut rx) = mpsc::channel(1); + let (tx, mut rx) = mpsc::unbounded_channel(); + // let (tx, mut rx) = mpsc::channel(1); tokio::task::spawn(async move { while let Some(message) = rx.recv().await { @@ -162,12 +147,9 @@ async fn sniff(mut sniffer: Sniffer, packets: Packets, clients: Clients) { source: Some(source), sender, } if *source == sniffer.source => { - sender - .send(msg.clone()) - .unwrap_or_else(|e| { - error!("Sniffer: error while sending packet: {}", e); - }) - .await; + sender.send(msg.clone()).unwrap_or_else(|e| { + error!("Sniffer: error while sending packet: {}", e); + }); } _ => {} } @@ -179,19 +161,20 @@ async fn sniff(mut sniffer: Sniffer, packets: Packets, clients: Clients) { } } -async fn send_all_packets(client_id: usize, packets: &Packets, ws_tx: &mut Sender) { +async fn send_all_packets( + client_id: usize, + packets: &Packets, + ws_tx: &mut UnboundedSender, +) { for pack in packets.read().await.iter() { let Ok(encoded) = pack.encode() else { error!("Failed to encode packet, client_id: {}", client_id); continue; }; let msg = Message::binary(encoded); - ws_tx - .send(msg) - .unwrap_or_else(|e| { - error!("WebSocket `feed` error: {}, client_id: {}", e, client_id); - }) - .await; + ws_tx.send(msg).unwrap_or_else(|e| { + error!("WebSocket `feed` error: {}, client_id: {}", e, client_id); + }); } info!( @@ -233,12 +216,9 @@ async fn reparse_packet( source: Some(source), sender, } if *source == *cur_source => { - sender - .send(msg.clone()) - .unwrap_or_else(|e| { - error!("Sniffer: error while sending packet: {}", e); - }) - .await; + sender.send(msg.clone()).unwrap_or_else(|e| { + error!("Sniffer: error while sending packet: {}", e); + }); } _ => {} };