Skip to content

Commit

Permalink
Reduce busywaiting in ChannelTx
Browse files Browse the repository at this point in the history
Completely eliminating the busywaiting will require using a different or
additional synchronization primitive for the window size.
`tokio::sync::Notify` will not (yet) suffice because it has no "_owned"
variant of its `Notified` future.
  • Loading branch information
mmirate committed Oct 27, 2023
1 parent 5defc91 commit 336e0c9
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 45 deletions.
2 changes: 1 addition & 1 deletion russh/src/channels/channel_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ where

impl<S> AsyncWrite for ChannelStream<S>
where
S: From<(ChannelId, ChannelMsg)> + 'static,
S: From<(ChannelId, ChannelMsg)> + 'static + Send + Sync,
{
fn poll_write(
mut self: Pin<&mut Self>,
Expand Down
116 changes: 73 additions & 43 deletions russh/src/channels/io/tx.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::task::{Context, Poll, ready};

use futures::FutureExt;
use russh_cryptovec::CryptoVec;
use tokio::io::AsyncWrite;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::Mutex;
use tokio::sync::mpsc::{self, OwnedPermit, error::SendError};
use tokio::sync::{Mutex, OwnedMutexGuard};

use super::ChannelMsg;
use crate::ChannelId;

type BoxedThreadsafeFuture<T> = Pin<Box<dyn Sync + Send + std::future::Future<Output=T>>>;
type OwnedPermitFuture<S> = BoxedThreadsafeFuture<Result<(OwnedPermit<S>, ChannelMsg, usize), SendError<()>>>;

pub struct ChannelTx<S> {
sender: mpsc::Sender<S>,
send_fut: Option<OwnedPermitFuture<S>>,
id: ChannelId,

window_size_fut: Option<BoxedThreadsafeFuture<OwnedMutexGuard<u32>>>,
window_size: Arc<Mutex<u32>>,
max_packet_size: u32,
ext: Option<u32>,
}

impl<S> ChannelTx<S> {
impl<S> ChannelTx<S>
where
S: From<(ChannelId, ChannelMsg)> + 'static + Send,
{
pub fn new(
sender: mpsc::Sender<S>,
id: ChannelId,
Expand All @@ -31,37 +39,31 @@ impl<S> ChannelTx<S> {
) -> Self {
Self {
sender,
send_fut: None,
id,
window_size,
window_size_fut: None,
max_packet_size,
ext,
}
}
}

impl<S> AsyncWrite for ChannelTx<S>
where
S: From<(ChannelId, ChannelMsg)> + 'static,
{
fn poll_write(
self: Pin<&mut Self>,
fn poll_mk_msg(
&mut self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let mut window_size = match self.window_size.try_lock() {
Ok(window_size) => window_size,
Err(_) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
) -> Poll<(ChannelMsg, usize)> {
let window_size = self.window_size.clone();
let window_size_fut = self.window_size_fut.get_or_insert_with(|| Box::pin(window_size.lock_owned()));
let mut window_size = ready!(window_size_fut.poll_unpin(cx));
self.window_size_fut.take();

let writable = self.max_packet_size.min(*window_size).min(buf.len() as u32) as usize;
let writable = (self.max_packet_size).min(*window_size).min(buf.len() as u32) as usize;
if writable == 0 {
// TODO fix this busywait
cx.waker().wake_by_ref();
return Poll::Pending;
}

let mut data = CryptoVec::new_zeroed(writable);
#[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.min`
data.copy_from_slice(&buf[..writable]);
Expand All @@ -75,34 +77,62 @@ where
Some(ext) => ChannelMsg::ExtendedData { data, ext },
};

match self.sender.try_send((self.id, msg).into()) {
Ok(_) => Poll::Ready(Ok(writable)),
Err(err @ TrySendError::Closed(_)) => Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
err.to_string(),
))),
Err(TrySendError::Full(_)) => {
cx.waker().wake_by_ref();
Poll::Pending
Poll::Ready((msg, writable))
}

fn activate(&mut self, msg: ChannelMsg, writable: usize) -> &mut OwnedPermitFuture<S> {
use futures::TryFutureExt;
self.send_fut.insert(Box::pin(self.sender.clone().reserve_owned().map_ok(move |p| (p, msg, writable))))
}

fn handle_write_result(&mut self, r: Result<(OwnedPermit<S>, ChannelMsg, usize), SendError<()>>) -> Result<usize, io::Error> {
self.send_fut = None;
match r {
Ok((permit, msg, writable)) => {
permit.send((self.id, msg).into());
Ok(writable)
}
Err(SendError(())) => {
Err(io::Error::new(io::ErrorKind::BrokenPipe, "channel closed"))
}
}
}
}

impl<S> AsyncWrite for ChannelTx<S>
where
S: From<(ChannelId, ChannelMsg)> + 'static + Send,
{
#[allow(clippy::too_many_lines)]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let send_fut =
if let Some(x) = self.send_fut.as_mut() {
x
} else {
let (msg, writable) = ready!(self.poll_mk_msg(cx, buf));
self.activate(msg, writable)
};
let r = ready!(send_fut.as_mut().poll_unpin(cx));
Poll::Ready(self.handle_write_result(r))
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.sender.try_send((self.id, ChannelMsg::Eof).into()) {
Ok(_) => Poll::Ready(Ok(())),
Err(err @ TrySendError::Closed(_)) => Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
err.to_string(),
))),
Err(TrySendError::Full(_)) => {
cx.waker().wake_by_ref();
Poll::Pending
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let send_fut =
if let Some(x) = self.send_fut.as_mut() {
x
} else {
self.activate(ChannelMsg::Eof, 0)
};
let r = ready!(send_fut.as_mut().poll_unpin(cx))
.map(|(p, _, _)| (p, ChannelMsg::Eof, 0));
Poll::Ready(self.handle_write_result(r).map(drop))
}
}
2 changes: 1 addition & 1 deletion russh/src/channels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl<T: From<(ChannelId, ChannelMsg)>> std::fmt::Debug for Channel<T> {
}
}

impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
pub(crate) fn new(
id: ChannelId,
sender: Sender<S>,
Expand Down

0 comments on commit 336e0c9

Please sign in to comment.