Skip to content

Commit

Permalink
Refactor server.rs, use UnboundedChannel
Browse files Browse the repository at this point in the history
  • Loading branch information
LVala committed Dec 16, 2023
1 parent 90a9946 commit c135f5a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 51 deletions.
2 changes: 1 addition & 1 deletion client/src/streams/packets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl Packets {

pub fn id_count(&self) -> usize {
match self.packets.last_key_value() {
Some((id, _)) => id + 1,
Some((id, _)) => *id,
None => 0,
}
}
Expand Down
2 changes: 1 addition & 1 deletion common/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ fn is_rtp(packet: &RtpPacket) -> bool {
if packet.version != 2 {
return false;
}
if let 72..=76 = packet.payload_type.id {
if let 72..=79 = packet.payload_type.id {
return false;
}
if packet.ssrc == 0 {
Expand Down
19 changes: 12 additions & 7 deletions src/cmd/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ impl Run {
let mut file_sniffers = get_sniffers(self.files, Sniffer::from_file);
let mut interface_sniffers = get_sniffers(self.interfaces, Sniffer::from_device);

if file_sniffers.is_empty() && interface_sniffers.is_empty() {
// TODO: use some pretty printing (colors, bold font etc.)
println!("Error: no valid sources were passed");
return;
}

let file_res = apply_filters(&mut file_sniffers, &self.capture);
let interface_res = apply_filters(&mut interface_sniffers, &live_filter);

Expand All @@ -46,8 +40,19 @@ impl Run {
return;
}

let sniffers: HashMap<_, _> = file_sniffers
.into_iter()
.chain(interface_sniffers)
.collect();

if sniffers.is_empty() {
// TODO: use some pretty printing (colors, bold font etc.)
println!("Error: no valid sources were passed");
return;
}

let address = SocketAddr::new(self.address, self.port);
server::run(interface_sniffers, file_sniffers, address).await;
server::run(sniffers, address).await;
}

fn create_capture_filter(&self) -> String {
Expand Down
64 changes: 22 additions & 42 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use tokio::sync::{mpsc, mpsc::Sender, RwLock};
use tokio::sync::{mpsc, mpsc::UnboundedSender, RwLock};
use warp::ws::{Message, WebSocket};
use warp::Filter;

Expand All @@ -22,12 +22,12 @@ const WS_PATH: &str = "ws";
static NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::new(1);

struct Client {
pub sender: mpsc::Sender<Message>,
pub sender: mpsc::UnboundedSender<Message>,
pub source: Option<Source>,
}

impl Client {
pub fn new(sender: mpsc::Sender<Message>) -> Self {
pub fn new(sender: mpsc::UnboundedSender<Message>) -> Self {
Self {
sender,
source: None,
Expand All @@ -39,30 +39,14 @@ type Clients = Arc<RwLock<HashMap<usize, Client>>>;
type Packets = Arc<RwLock<Vec<Response>>>;
type PacketsMap = Arc<HashMap<Source, Packets>>;

pub async fn run(
interface_sniffers: HashMap<String, Sniffer>,
file_sniffers: HashMap<String, Sniffer>,
addr: SocketAddr,
) {
pub async fn run(sniffers: HashMap<String, Sniffer>, addr: SocketAddr) {
let clients = Clients::default();
let mut source_to_packets = HashMap::new();

// a bit of repetition, but Rust bested me this time
for (file, sniffer) in file_sniffers {
for (_file, sniffer) in sniffers {
let packets = Packets::default();
let source = Source::File(file);
source_to_packets.insert(source, packets.clone());

let cloned_clients = clients.clone();
tokio::task::spawn(async move {
sniff(sniffer, packets, cloned_clients).await;
});
}

for (interface, sniffer) in interface_sniffers {
let packets = Packets::default();
let source = Source::Interface(interface);
source_to_packets.insert(source, packets.clone());
source_to_packets.insert(sniffer.source.clone(), packets.clone());

let cloned_clients = clients.clone();
tokio::task::spawn(async move {
Expand Down Expand Up @@ -101,7 +85,8 @@ async fn client_connected(ws: WebSocket, clients: Clients, source_to_packets: Pa
// if buffer size is > 1, rx.recv always waits
// until channel's buffer is full before receiving
// might be because of blocking sniffers
let (tx, mut rx) = mpsc::channel(1);
let (tx, mut rx) = mpsc::unbounded_channel();
// let (tx, mut rx) = mpsc::channel(1);

tokio::task::spawn(async move {
while let Some(message) = rx.recv().await {
Expand Down Expand Up @@ -162,12 +147,9 @@ async fn sniff(mut sniffer: Sniffer, packets: Packets, clients: Clients) {
source: Some(source),
sender,
} if *source == sniffer.source => {
sender
.send(msg.clone())
.unwrap_or_else(|e| {
error!("Sniffer: error while sending packet: {}", e);
})
.await;
sender.send(msg.clone()).unwrap_or_else(|e| {
error!("Sniffer: error while sending packet: {}", e);
});
}
_ => {}
}
Expand All @@ -179,19 +161,20 @@ async fn sniff(mut sniffer: Sniffer, packets: Packets, clients: Clients) {
}
}

async fn send_all_packets(client_id: usize, packets: &Packets, ws_tx: &mut Sender<Message>) {
async fn send_all_packets(
client_id: usize,
packets: &Packets,
ws_tx: &mut UnboundedSender<Message>,
) {
for pack in packets.read().await.iter() {
let Ok(encoded) = pack.encode() else {
error!("Failed to encode packet, client_id: {}", client_id);
continue;
};
let msg = Message::binary(encoded);
ws_tx
.send(msg)
.unwrap_or_else(|e| {
error!("WebSocket `feed` error: {}, client_id: {}", e, client_id);
})
.await;
ws_tx.send(msg).unwrap_or_else(|e| {
error!("WebSocket `feed` error: {}, client_id: {}", e, client_id);
});
}

info!(
Expand Down Expand Up @@ -233,12 +216,9 @@ async fn reparse_packet(
source: Some(source),
sender,
} if *source == *cur_source => {
sender
.send(msg.clone())
.unwrap_or_else(|e| {
error!("Sniffer: error while sending packet: {}", e);
})
.await;
sender.send(msg.clone()).unwrap_or_else(|e| {
error!("Sniffer: error while sending packet: {}", e);
});
}
_ => {}
};
Expand Down

0 comments on commit c135f5a

Please sign in to comment.