diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index 69903bb4..8bb9e9c9 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -36,6 +36,7 @@ use std::cell::RefCell; use std::collections::{HashMap, VecDeque}; +use std::convert::TryInto; use std::num::Wrapping; use std::pin::Pin; use std::sync::Arc; @@ -49,7 +50,7 @@ use russh_keys::encoding::Reader; #[cfg(feature = "openssl")] use russh_keys::key::SignatureHash; use russh_keys::key::{self, parse_public_key, PublicKey}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::pin; use tokio::sync::mpsc::{ @@ -195,6 +196,19 @@ pub struct Prompt { pub echo: bool, } +#[derive(Debug)] +pub struct RemoteDisconnectInfo { + pub reason_code: crate::Disconnect, + pub message: String, + pub lang_tag: String, +} + +#[derive(Debug)] +pub enum DisconnectReason + Send> { + ReceivedDisconnect(RemoteDisconnectInfo), + Error(E), +} + /// Handle to a session, used to send messages to a client outside of /// the request/response cycle. pub struct Handle { @@ -737,23 +751,59 @@ impl Session { async fn run( mut self, - mut stream: SshRead, + stream: SshRead, mut handler: H, - mut encrypted_signal: Option>, + encrypted_signal: Option>, ) -> Result<(), H::Error> { + let (stream_read, mut stream_write) = stream.split(); + let result = self + .run_inner( + stream_read, + &mut stream_write, + &mut handler, + encrypted_signal, + ) + .await; + trace!("disconnected"); + self.receiver.close(); + self.inbound_channel_receiver.close(); + stream_write.shutdown().await.map_err(crate::Error::from)?; + match result { + Ok(v) => { + handler + .disconnected(DisconnectReason::ReceivedDisconnect(v)) + .await?; + Ok(()) + } + Err(e) => { + handler.disconnected(DisconnectReason::Error(e)).await?; + //Err(e) + Ok(()) + } + } + } + + async fn run_inner( + &mut self, + stream_read: SshRead>, + stream_write: &mut WriteHalf, + handler: &mut H, + mut encrypted_signal: Option>, + ) -> Result { + let mut result: Result = + Err(crate::Error::Disconnect.into()); self.flush()?; if !self.common.write_buffer.buffer.is_empty() { debug!("writing {:?} bytes", self.common.write_buffer.buffer.len()); - stream + stream_write .write_all(&self.common.write_buffer.buffer) .await .map_err(crate::Error::from)?; - stream.flush().await.map_err(crate::Error::from)?; + stream_write.flush().await.map_err(crate::Error::from)?; } self.common.write_buffer.buffer.clear(); let mut decomp = CryptoVec::new(); - let (stream_read, mut stream_write) = stream.split(); let buffer = SSHBuffer::new(); // Allow handing out references to the cipher @@ -805,10 +855,10 @@ impl Session { if !buf.is_empty() { #[allow(clippy::indexing_slicing)] // length checked if buf[0] == crate::msg::DISCONNECT { - break; + result = self.process_disconnect(buf); } else { self.common.received_data = true; - reply( &mut self,&mut handler, &mut encrypted_signal, &mut buffer.seqn, buf).await?; + reply( self,handler, &mut encrypted_signal, &mut buffer.seqn, buf).await?; } } @@ -906,12 +956,30 @@ impl Session { } } } - debug!("disconnected"); - self.receiver.close(); - self.inbound_channel_receiver.close(); - stream_write.shutdown().await.map_err(crate::Error::from)?; - Ok(()) + result + } + + fn process_disconnect + Send>( + &mut self, + buf: &[u8], + ) -> Result { + self.common.disconnected = true; + let mut reader = buf.reader(1); + + let reason_code = reader.read_u32().map_err(crate::Error::from)?.try_into()?; + let message = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?) + .map_err(crate::Error::from)? + .to_owned(); + let lang_tag = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?) + .map_err(crate::Error::from)? + .to_owned(); + + Ok(RemoteDisconnectInfo { + reason_code, + message, + lang_tag, + }) } fn handle_msg(&mut self, msg: Msg) -> Result<(), crate::Error> { @@ -1360,7 +1428,7 @@ impl Default for Config { #[async_trait] pub trait Handler: Sized + Send { - type Error: From + Send; + type Error: From + Send + core::fmt::Debug; /// Called when the server sends us an authentication banner. This /// is usually meant to be shown to the user, see @@ -1620,4 +1688,19 @@ pub trait Handler: Sized + Send { debug!("openssh_ext_hostkeys_announced: {:?}", keys); Ok(()) } + + /// Called when the server sent a disconnect message + /// + /// If reason is an Error, this function should re-return the error so the join can also evaluate it + #[allow(unused_variables)] + async fn disconnected( + &mut self, + reason: DisconnectReason, + ) -> Result<(), Self::Error> { + debug!("disconnected: {:?}", reason); + match reason { + DisconnectReason::ReceivedDisconnect(_) => Ok(()), + DisconnectReason::Error(e) => Err(e), + } + } } diff --git a/russh/src/lib.rs b/russh/src/lib.rs index 54c6d567..ac1a6b0f 100644 --- a/russh/src/lib.rs +++ b/russh/src/lib.rs @@ -94,7 +94,10 @@ //! messages sent through a `server::Handle` are processed when there //! is no incoming packet to read. -use std::fmt::{Debug, Display, Formatter}; +use std::{ + convert::TryFrom, + fmt::{Debug, Display, Formatter}, +}; use log::debug; use parsing::ChannelOpenConfirmation; @@ -379,6 +382,31 @@ pub enum Disconnect { IllegalUserName = 15, } +impl TryFrom for Disconnect { + type Error = crate::Error; + + fn try_from(value: u32) -> Result { + Ok(match value { + 1 => Self::HostNotAllowedToConnect, + 2 => Self::ProtocolError, + 3 => Self::KeyExchangeFailed, + 4 => Self::Reserved, + 5 => Self::MACError, + 6 => Self::CompressionError, + 7 => Self::ServiceNotAvailable, + 8 => Self::ProtocolVersionNotSupported, + 9 => Self::HostKeyNotVerifiable, + 10 => Self::ConnectionLost, + 11 => Self::ByApplication, + 12 => Self::TooManyConnections, + 13 => Self::AuthCancelledByUser, + 14 => Self::NoMoreAuthMethodsAvailable, + 15 => Self::IllegalUserName, + _ => return Err(crate::Error::Inconsistent), + }) + } +} + /// The type of signals that can be sent to a remote process. If you /// plan to use custom signals, read [the /// RFC](https://tools.ietf.org/html/rfc4254#section-6.10) to diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index b19bd6ed..0a333c5e 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -59,6 +59,11 @@ pub enum Msg { address: String, port: u32, }, + Disconnect { + reason: crate::Disconnect, + description: String, + language_tag: String, + }, Channel(ChannelId, ChannelMsg), } @@ -348,6 +353,23 @@ impl Handle { .await .map_err(|_| ()) } + + /// Allows a server to disconnect a client session + pub async fn disconnect( + &self, + reason: Disconnect, + description: String, + language_tag: String, + ) -> Result<(), Error> { + self.sender + .send(Msg::Disconnect { + reason, + description, + language_tag, + }) + .await + .map_err(|_| Error::SendError) + } } impl Session { @@ -511,6 +533,9 @@ impl Session { Some(Msg::CancelTcpIpForward { address, port, reply_channel }) => { self.cancel_tcpip_forward(&address, port, reply_channel); } + Some(Msg::Disconnect {reason, description, language_tag}) => { + self.common.disconnect(reason, &description, &language_tag); + } Some(_) => { // should be unreachable, since the receiver only gets // messages from methods implemented within russh