diff --git a/russh/src/channels/channel_stream.rs b/russh/src/channels/channel_stream.rs index 853a2c4d..9afb51fc 100644 --- a/russh/src/channels/channel_stream.rs +++ b/russh/src/channels/channel_stream.rs @@ -42,7 +42,7 @@ where impl AsyncWrite for ChannelStream where - S: From<(ChannelId, ChannelMsg)> + 'static, + S: From<(ChannelId, ChannelMsg)> + 'static + Send + Sync, { fn poll_write( mut self: Pin<&mut Self>, diff --git a/russh/src/channels/io/tx.rs b/russh/src/channels/io/tx.rs index 15d86a6f..65aec1cc 100644 --- a/russh/src/channels/io/tx.rs +++ b/russh/src/channels/io/tx.rs @@ -1,27 +1,35 @@ use std::io; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, ready}; +use futures::FutureExt; use russh_cryptovec::CryptoVec; use tokio::io::AsyncWrite; -use tokio::sync::mpsc; -use tokio::sync::mpsc::error::TrySendError; -use tokio::sync::Mutex; +use tokio::sync::mpsc::{self, OwnedPermit, error::SendError}; +use tokio::sync::{Mutex, OwnedMutexGuard}; use super::ChannelMsg; use crate::ChannelId; +type BoxedThreadsafeFuture = Pin>>; +type OwnedPermitFuture = BoxedThreadsafeFuture, ChannelMsg, usize), SendError<()>>>; + pub struct ChannelTx { sender: mpsc::Sender, + send_fut: Option>, id: ChannelId, + window_size_fut: Option>>, window_size: Arc>, max_packet_size: u32, ext: Option, } -impl ChannelTx { +impl ChannelTx +where + S: From<(ChannelId, ChannelMsg)> + 'static + Send, +{ pub fn new( sender: mpsc::Sender, id: ChannelId, @@ -31,37 +39,31 @@ impl ChannelTx { ) -> Self { Self { sender, + send_fut: None, id, window_size, + window_size_fut: None, max_packet_size, ext, } } -} -impl AsyncWrite for ChannelTx -where - S: From<(ChannelId, ChannelMsg)> + 'static, -{ - fn poll_write( - self: Pin<&mut Self>, + fn poll_mk_msg( + &mut self, cx: &mut Context<'_>, buf: &[u8], - ) -> Poll> { - let mut window_size = match self.window_size.try_lock() { - Ok(window_size) => window_size, - Err(_) => { - cx.waker().wake_by_ref(); - return Poll::Pending; - } - }; + ) -> Poll<(ChannelMsg, usize)> { + let window_size = self.window_size.clone(); + let window_size_fut = self.window_size_fut.get_or_insert_with(|| Box::pin(window_size.lock_owned())); + let mut window_size = ready!(window_size_fut.poll_unpin(cx)); + self.window_size_fut.take(); - let writable = self.max_packet_size.min(*window_size).min(buf.len() as u32) as usize; + let writable = (self.max_packet_size).min(*window_size).min(buf.len() as u32) as usize; if writable == 0 { + // TODO fix this busywait cx.waker().wake_by_ref(); return Poll::Pending; } - let mut data = CryptoVec::new_zeroed(writable); #[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.min` data.copy_from_slice(&buf[..writable]); @@ -75,34 +77,62 @@ where Some(ext) => ChannelMsg::ExtendedData { data, ext }, }; - match self.sender.try_send((self.id, msg).into()) { - Ok(_) => Poll::Ready(Ok(writable)), - Err(err @ TrySendError::Closed(_)) => Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - err.to_string(), - ))), - Err(TrySendError::Full(_)) => { - cx.waker().wake_by_ref(); - Poll::Pending + Poll::Ready((msg, writable)) + } + + fn activate(&mut self, msg: ChannelMsg, writable: usize) -> &mut OwnedPermitFuture { + use futures::TryFutureExt; + self.send_fut.insert(Box::pin(self.sender.clone().reserve_owned().map_ok(move |p| (p, msg, writable)))) + } + + fn handle_write_result(&mut self, r: Result<(OwnedPermit, ChannelMsg, usize), SendError<()>>) -> Result { + self.send_fut = None; + match r { + Ok((permit, msg, writable)) => { + permit.send((self.id, msg).into()); + Ok(writable) + } + Err(SendError(())) => { + Err(io::Error::new(io::ErrorKind::BrokenPipe, "channel closed")) } } } +} + +impl AsyncWrite for ChannelTx +where + S: From<(ChannelId, ChannelMsg)> + 'static + Send, +{ + #[allow(clippy::too_many_lines)] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let send_fut = + if let Some(x) = self.send_fut.as_mut() { + x + } else { + let (msg, writable) = ready!(self.poll_mk_msg(cx, buf)); + self.activate(msg, writable) + }; + let r = ready!(send_fut.as_mut().poll_unpin(cx)); + Poll::Ready(self.handle_write_result(r)) + } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.sender.try_send((self.id, ChannelMsg::Eof).into()) { - Ok(_) => Poll::Ready(Ok(())), - Err(err @ TrySendError::Closed(_)) => Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - err.to_string(), - ))), - Err(TrySendError::Full(_)) => { - cx.waker().wake_by_ref(); - Poll::Pending - } - } + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let send_fut = + if let Some(x) = self.send_fut.as_mut() { + x + } else { + self.activate(ChannelMsg::Eof, 0) + }; + let r = ready!(send_fut.as_mut().poll_unpin(cx)) + .map(|(p, _, _)| (p, ChannelMsg::Eof, 0)); + Poll::Ready(self.handle_write_result(r).map(drop)) } } diff --git a/russh/src/channels/mod.rs b/russh/src/channels/mod.rs index 8677885c..8e9b4085 100644 --- a/russh/src/channels/mod.rs +++ b/russh/src/channels/mod.rs @@ -133,7 +133,7 @@ impl> std::fmt::Debug for Channel { } } -impl + Send + 'static> Channel { +impl + Send + Sync + 'static> Channel { pub(crate) fn new( id: ChannelId, sender: Sender,