Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional Disconnect Handling #255

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 97 additions & 14 deletions russh/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::convert::TryInto;
use std::num::Wrapping;
use std::pin::Pin;
use std::sync::Arc;
Expand All @@ -49,7 +50,7 @@ use russh_keys::encoding::Reader;
#[cfg(feature = "openssl")]
use russh_keys::key::SignatureHash;
use russh_keys::key::{self, parse_public_key, PublicKey};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::{TcpStream, ToSocketAddrs};
use tokio::pin;
use tokio::sync::mpsc::{
Expand Down Expand Up @@ -195,6 +196,19 @@ pub struct Prompt {
pub echo: bool,
}

#[derive(Debug)]
pub struct RemoteDisconnectInfo {
pub reason_code: crate::Disconnect,
pub message: String,
pub lang_tag: String,
}

#[derive(Debug)]
pub enum DisconnectReason<E: From<crate::Error> + Send> {
ReceivedDisconnect(RemoteDisconnectInfo),
Error(E),
}

/// Handle to a session, used to send messages to a client outside of
/// the request/response cycle.
pub struct Handle<H: Handler> {
Expand Down Expand Up @@ -737,23 +751,59 @@ impl Session {

async fn run<H: Handler + Send, R: AsyncRead + AsyncWrite + Unpin + Send>(
mut self,
mut stream: SshRead<R>,
stream: SshRead<R>,
mut handler: H,
mut encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
) -> Result<(), H::Error> {
let (stream_read, mut stream_write) = stream.split();
let result = self
.run_inner(
stream_read,
&mut stream_write,
&mut handler,
encrypted_signal,
)
.await;
trace!("disconnected");
self.receiver.close();
self.inbound_channel_receiver.close();
stream_write.shutdown().await.map_err(crate::Error::from)?;
match result {
Ok(v) => {
handler
.disconnected(DisconnectReason::ReceivedDisconnect(v))
.await?;
Ok(())
}
Err(e) => {
handler.disconnected(DisconnectReason::Error(e)).await?;
//Err(e)
Ok(())
}
}
}

async fn run_inner<H: Handler + Send, R: AsyncRead + AsyncWrite + Unpin + Send>(
&mut self,
stream_read: SshRead<ReadHalf<R>>,
stream_write: &mut WriteHalf<R>,
handler: &mut H,
mut encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
) -> Result<RemoteDisconnectInfo, H::Error> {
let mut result: Result<RemoteDisconnectInfo, H::Error> =
Err(crate::Error::Disconnect.into());
self.flush()?;
if !self.common.write_buffer.buffer.is_empty() {
debug!("writing {:?} bytes", self.common.write_buffer.buffer.len());
stream
stream_write
.write_all(&self.common.write_buffer.buffer)
.await
.map_err(crate::Error::from)?;
stream.flush().await.map_err(crate::Error::from)?;
stream_write.flush().await.map_err(crate::Error::from)?;
}
self.common.write_buffer.buffer.clear();
let mut decomp = CryptoVec::new();

let (stream_read, mut stream_write) = stream.split();
let buffer = SSHBuffer::new();

// Allow handing out references to the cipher
Expand Down Expand Up @@ -805,10 +855,10 @@ impl Session {
if !buf.is_empty() {
#[allow(clippy::indexing_slicing)] // length checked
if buf[0] == crate::msg::DISCONNECT {
break;
result = self.process_disconnect(buf);
} else {
self.common.received_data = true;
reply( &mut self,&mut handler, &mut encrypted_signal, &mut buffer.seqn, buf).await?;
reply( self,handler, &mut encrypted_signal, &mut buffer.seqn, buf).await?;
}
}

Expand Down Expand Up @@ -906,12 +956,30 @@ impl Session {
}
}
}
debug!("disconnected");
self.receiver.close();
self.inbound_channel_receiver.close();
stream_write.shutdown().await.map_err(crate::Error::from)?;

Ok(())
result
}

fn process_disconnect<E: From<crate::Error> + Send>(
&mut self,
buf: &[u8],
) -> Result<RemoteDisconnectInfo, E> {
self.common.disconnected = true;
let mut reader = buf.reader(1);

let reason_code = reader.read_u32().map_err(crate::Error::from)?.try_into()?;
let message = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?
.to_owned();
let lang_tag = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?
.to_owned();

Ok(RemoteDisconnectInfo {
reason_code,
message,
lang_tag,
})
}

fn handle_msg(&mut self, msg: Msg) -> Result<(), crate::Error> {
Expand Down Expand Up @@ -1360,7 +1428,7 @@ impl Default for Config {

#[async_trait]
pub trait Handler: Sized + Send {
type Error: From<crate::Error> + Send;
type Error: From<crate::Error> + Send + core::fmt::Debug;

/// Called when the server sends us an authentication banner. This
/// is usually meant to be shown to the user, see
Expand Down Expand Up @@ -1620,4 +1688,19 @@ pub trait Handler: Sized + Send {
debug!("openssh_ext_hostkeys_announced: {:?}", keys);
Ok(())
}

/// Called when the server sent a disconnect message
///
/// If reason is an Error, this function should re-return the error so the join can also evaluate it
#[allow(unused_variables)]
async fn disconnected(
&mut self,
reason: DisconnectReason<Self::Error>,
) -> Result<(), Self::Error> {
debug!("disconnected: {:?}", reason);
match reason {
DisconnectReason::ReceivedDisconnect(_) => Ok(()),
DisconnectReason::Error(e) => Err(e),
}
}
}
30 changes: 29 additions & 1 deletion russh/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@
//! messages sent through a `server::Handle` are processed when there
//! is no incoming packet to read.

use std::fmt::{Debug, Display, Formatter};
use std::{
convert::TryFrom,
fmt::{Debug, Display, Formatter},
};

use log::debug;
use parsing::ChannelOpenConfirmation;
Expand Down Expand Up @@ -379,6 +382,31 @@ pub enum Disconnect {
IllegalUserName = 15,
}

impl TryFrom<u32> for Disconnect {
type Error = crate::Error;

fn try_from(value: u32) -> Result<Self, Self::Error> {
Ok(match value {
1 => Self::HostNotAllowedToConnect,
2 => Self::ProtocolError,
3 => Self::KeyExchangeFailed,
4 => Self::Reserved,
5 => Self::MACError,
6 => Self::CompressionError,
7 => Self::ServiceNotAvailable,
8 => Self::ProtocolVersionNotSupported,
9 => Self::HostKeyNotVerifiable,
10 => Self::ConnectionLost,
11 => Self::ByApplication,
12 => Self::TooManyConnections,
13 => Self::AuthCancelledByUser,
14 => Self::NoMoreAuthMethodsAvailable,
15 => Self::IllegalUserName,
_ => return Err(crate::Error::Inconsistent),
})
}
}

/// The type of signals that can be sent to a remote process. If you
/// plan to use custom signals, read [the
/// RFC](https://tools.ietf.org/html/rfc4254#section-6.10) to
Expand Down
25 changes: 25 additions & 0 deletions russh/src/server/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ pub enum Msg {
address: String,
port: u32,
},
Disconnect {
reason: crate::Disconnect,
description: String,
language_tag: String,
},
Channel(ChannelId, ChannelMsg),
}

Expand Down Expand Up @@ -348,6 +353,23 @@ impl Handle {
.await
.map_err(|_| ())
}

/// Allows a server to disconnect a client session
pub async fn disconnect(
&self,
reason: Disconnect,
description: String,
language_tag: String,
) -> Result<(), Error> {
self.sender
.send(Msg::Disconnect {
reason,
description,
language_tag,
})
.await
.map_err(|_| Error::SendError)
}
}

impl Session {
Expand Down Expand Up @@ -511,6 +533,9 @@ impl Session {
Some(Msg::CancelTcpIpForward { address, port, reply_channel }) => {
self.cancel_tcpip_forward(&address, port, reply_channel);
}
Some(Msg::Disconnect {reason, description, language_tag}) => {
self.common.disconnect(reason, &description, &language_tag);
}
Some(_) => {
// should be unreachable, since the receiver only gets
// messages from methods implemented within russh
Expand Down
Loading