From 7230fd84df86f85627330ecaf223a2ae0d247713 Mon Sep 17 00:00:00 2001 From: Max Kellermann Date: Wed, 3 Jul 2024 20:59:23 +0200 Subject: [PATCH] ssh/Connection: postpone Destroy() call until DISCONNECT is flushed If the connection is already encrypted, we need to wait for the worker thread to finish encrypting the DISCONNECT packet, and only after that can we send it to the socket. This finishes the second part of commit 62f4ca00924e8; it needs some complicated code to ensure that the connection does no other I/O while it's waiting for the flush. --- src/Connection.cxx | 8 ++++++++ src/OutgoingConnection.cxx | 2 ++ src/ssh/CConnection.cxx | 13 +++++++++++++ src/ssh/CConnection.hxx | 2 ++ src/ssh/Connection.cxx | 27 +++++++++++++++++++++++++++ src/ssh/Connection.hxx | 10 ++++++++++ src/ssh/GConnection.cxx | 11 +++++++++++ src/ssh/GConnection.hxx | 2 ++ src/ssh/Output.hxx | 7 +++++++ 9 files changed, 82 insertions(+) diff --git a/src/Connection.cxx b/src/Connection.cxx index 55c2d16..b24b58f 100644 --- a/src/Connection.cxx +++ b/src/Connection.cxx @@ -988,6 +988,14 @@ void Connection::OnDisconnecting(SSH::DisconnectReasonCode reason_code, std::string_view msg) noexcept { + CConnection::OnDisconnecting(reason_code, msg); + + /* some manual shutdown just in case the Destroy() is + postponed */ + auth_timeout.Cancel(); + socket_forward_listeners.clear_and_dispose(DeleteDisposer{}); + occupied_task = {}; + if (log_disconnect) { log_disconnect = false; LogFmt("Disconnecting: {}", msg); diff --git a/src/OutgoingConnection.cxx b/src/OutgoingConnection.cxx index 6f33c00..b06ea5b 100644 --- a/src/OutgoingConnection.cxx +++ b/src/OutgoingConnection.cxx @@ -162,6 +162,8 @@ void OutgoingConnection::OnDisconnecting(SSH::DisconnectReasonCode reason_code, std::string_view msg) noexcept { + SSH::Connection::OnDisconnecting(reason_code, msg); + handler.OnOutgoingDisconnecting(reason_code, msg); } diff --git a/src/ssh/CConnection.cxx b/src/ssh/CConnection.cxx index a4d351e..148ea4c 100644 --- a/src/ssh/CConnection.cxx +++ b/src/ssh/CConnection.cxx @@ -526,5 +526,18 @@ CConnection::OnWriteUnblocked() noexcept if (i != nullptr) i->OnWriteUnblocked(); } +void +CConnection::OnDisconnecting(DisconnectReasonCode reason_code, + std::string_view msg) noexcept +{ + GConnection::OnDisconnecting(reason_code, msg); + + /* delete all channels so they don't try to do any I/O while + we're waiting for the DISCONNECT to be flushed */ + for (auto &i : channels) { + delete i; + i = nullptr; + } +} } // namespace SSH diff --git a/src/ssh/CConnection.hxx b/src/ssh/CConnection.hxx index 1ca1f86..5555cbd 100644 --- a/src/ssh/CConnection.hxx +++ b/src/ssh/CConnection.hxx @@ -144,6 +144,8 @@ protected: std::span payload) override; void OnWriteBlocked() noexcept override; void OnWriteUnblocked() noexcept override; + void OnDisconnecting(DisconnectReasonCode reason_code, + std::string_view msg) noexcept; }; } // namespace SSH diff --git a/src/ssh/Connection.cxx b/src/ssh/Connection.cxx index 678c8b4..e064af2 100644 --- a/src/ssh/Connection.cxx +++ b/src/ssh/Connection.cxx @@ -106,10 +106,29 @@ Connection::SendPacket(MessageNumber msg, std::span payload) void Connection::DoDisconnect(DisconnectReasonCode reason_code, std::string_view msg) noexcept { + if (IsDead()) + return; + OnDisconnecting(reason_code, msg); SendPacket(MakeDisconnect(reason_code, msg)); + if (output.IsEncrypted()) { + /* we have to wait for the worker thread to encrypt + the the DISCONNECT packet before we can actually + send it to the socket; therefore postpone the + Destroy() call */ + dead = true; + + /* we now have very little patience with this + client */ + socket.SetWriteTimeout(std::chrono::seconds{1}); + + /* we don't want to receive anything from it */ + socket.UnscheduleRead(); + return; + } + try { /* attempt to flush the DISCONNECT packet immediately before we close the socket */ @@ -598,6 +617,11 @@ Connection::OnBufferedWrite() switch (output.Flush()) { case Output::FlushResult::DONE: + if (IsDead() && output.IsEmpty()) { + Destroy(); + return false; + } + socket.UnscheduleWrite(); break; @@ -629,6 +653,9 @@ Connection::OnBufferedError([[maybe_unused]] std::exception_ptr e) noexcept bool Connection::OnInputReady() noexcept try { + if (IsDead()) + return false; + while (true) { const auto payload = input.ReadPacket(); if (payload.data() == nullptr) diff --git a/src/ssh/Connection.hxx b/src/ssh/Connection.hxx index 06cbad3..2e3ff48 100644 --- a/src/ssh/Connection.hxx +++ b/src/ssh/Connection.hxx @@ -53,6 +53,12 @@ class Connection : BufferedSocketHandler, InputHandler const Role role; + /** + * If true, then the connection is about to be closed, only + * waiting for the DISCONNECT to be encrypted and sent. + */ + bool dead = false; + bool version_exchanged = false; bool authenticated = false; @@ -103,6 +109,10 @@ public: metrics = &_metrics; } + bool IsDead() const noexcept { + return dead; + } + [[gnu::pure]] bool IsEncrypted() const noexcept; diff --git a/src/ssh/GConnection.cxx b/src/ssh/GConnection.cxx index 68b9367..f027f9c 100644 --- a/src/ssh/GConnection.cxx +++ b/src/ssh/GConnection.cxx @@ -166,4 +166,15 @@ GConnection::HandlePacket(MessageNumber msg, } } +void +GConnection::OnDisconnecting(DisconnectReasonCode reason_code, + std::string_view msg) noexcept +{ + Connection::OnDisconnecting(reason_code, msg); + + /* cancel all pending requests so they don't try to do any I/O + while we're waiting for the DISCONNECT to be flushed */ + pending_global_requests.clear_and_dispose(DeleteDisposer{}); +} + } // namespace SSH diff --git a/src/ssh/GConnection.hxx b/src/ssh/GConnection.hxx index c22ae40..18556d1 100644 --- a/src/ssh/GConnection.hxx +++ b/src/ssh/GConnection.hxx @@ -55,6 +55,8 @@ protected: /* virtual methods from class SSH::Connection */ void HandlePacket(MessageNumber msg, std::span payload) override; + void OnDisconnecting(DisconnectReasonCode reason_code, + std::string_view msg) noexcept; }; } // namespace SSH diff --git a/src/ssh/Output.hxx b/src/ssh/Output.hxx index 8357a0f..76f278c 100644 --- a/src/ssh/Output.hxx +++ b/src/ssh/Output.hxx @@ -95,6 +95,13 @@ public: return push_cipher != nullptr; } + [[gnu::pure]] + bool IsEmpty() noexcept { + const std::scoped_lock lock{mutex}; + return pending_queue.empty() && plain_queue.empty() && + next_plain_queue.empty() && encrypted_queue.empty(); + } + const Cipher *GetCipher() const noexcept { return push_cipher; }