Skip to content

Commit

Permalink
netplay: more strict checks for error codes from socket operations
Browse files Browse the repository at this point in the history
There isn't any guarantee that `getSockErr()`/`setSockErr()`
won't return `0` in all corner cases where something bad happens
with an underlying socket.

Provide more strict checks for socket write errors and for
checking the result of connection opening routines, allowing to
catch error conditions even in presence of error code == 0.

Signed-off-by: Pavel Solodovnikov <[email protected]>
  • Loading branch information
ManManson committed Nov 8, 2024
1 parent 7f8bb27 commit 9e4c59a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
8 changes: 4 additions & 4 deletions lib/netplay/netsocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct Socket

SOCKET fd[SOCK_COUNT];
bool ready;
std::error_code writeErrorCode = make_network_error_code(0);
nonstd::optional<std::error_code> writeErrorCode;
bool deleteLater;
char textAddress[40] = {};

Expand Down Expand Up @@ -650,9 +650,9 @@ net::result<ssize_t> writeAll(Socket& sock, const void *buf, size_t size, size_t
return tl::make_unexpected(make_network_error_code(EBADF));
}

if (sock.writeErrorCode)
if (sock.writeErrorCode.has_value())
{
return tl::make_unexpected(sock.writeErrorCode);
return tl::make_unexpected(sock.writeErrorCode.value());
}

if (size > 0)
Expand Down Expand Up @@ -731,7 +731,7 @@ void socketFlush(Socket& sock, uint8_t player, size_t *rawByteCount)
return; // Not compressed, so don't mess with zlib.
}

ASSERT(!sock.writeErrorCode, "Socket write error?? (Player: %" PRIu8 "", player);
ASSERT(!sock.writeErrorCode.has_value(), "Socket write error?? (Player: %" PRIu8 "", player);

// Flush data out of zlib compression state.
do
Expand Down
12 changes: 6 additions & 6 deletions lib/netplay/netsocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <system_error>
#include <vector>

#include <nonstd/optional.hpp>
#include <tl/expected.hpp>

namespace net
Expand Down Expand Up @@ -137,7 +138,6 @@ int checkSockets(const SocketSet& set, unsigned int timeout); ///< Checks which
// Higher-level functions for opening a connection / socket
struct OpenConnectionResult
{
public:
OpenConnectionResult(std::error_code ec, std::string errorString)
: errorCode(ec)
, errorString(errorString)
Expand All @@ -146,19 +146,19 @@ struct OpenConnectionResult
OpenConnectionResult(Socket* open_socket)
: open_socket(open_socket)
{ }
public:
bool hasError() const { return static_cast<bool>(errorCode); }
public:

bool hasError() const { return errorCode.has_value(); }

OpenConnectionResult( const OpenConnectionResult& other ) = delete; // non construction-copyable
OpenConnectionResult& operator=( const OpenConnectionResult& ) = delete; // non copyable
OpenConnectionResult(OpenConnectionResult&&) = default;
OpenConnectionResult& operator=(OpenConnectionResult&&) = default;
public:

struct SocketDeleter {
void operator()(Socket* b) { if (b) { socketClose(b); } }
};
std::unique_ptr<Socket, SocketDeleter> open_socket;
std::error_code errorCode;
nonstd::optional<std::error_code> errorCode;
std::string errorString;
};
typedef std::function<void (OpenConnectionResult&& result)> OpenConnectionToHostResultCallback;
Expand Down
5 changes: 3 additions & 2 deletions src/screens/joiningscreen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1149,8 +1149,9 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect
{
debug(LOG_ERROR, "%s", result.errorString.c_str());
// Done trying connections - all failed
const auto sockErrorMsg = result.errorCode.message();
auto localizedError = astringf(_("Failed to open connection: [%d] %s"), result.errorCode.value(), sockErrorMsg.c_str());
const auto errCode = result.errorCode.value();
const auto sockErrorMsg = errCode.message();
auto localizedError = astringf(_("Failed to open connection: [%d] %s"), errCode.value(), sockErrorMsg.c_str());
handleFailure(FailureDetails::makeFromInternalError(WzString::fromUtf8(localizedError)));
}
return;
Expand Down

0 comments on commit 9e4c59a

Please sign in to comment.