From 20b6070aecdcba53a9893ee0435425e24258a2c6 Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Sun, 22 Sep 2024 16:34:44 +0300 Subject: [PATCH 1/8] netplay: remove redundant handling of `nullptr` coming from `allocSocketSet()` This instead should always throw `std::bad_alloc`. Signed-off-by: Pavel Solodovnikov --- lib/netplay/netplay.cpp | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/lib/netplay/netplay.cpp b/lib/netplay/netplay.cpp index cd80b7ff7bf..cc306e98db2 100644 --- a/lib/netplay/netplay.cpp +++ b/lib/netplay/netplay.cpp @@ -4504,11 +4504,6 @@ bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectator { server_socket_set = allocSocketSet(); } - if (server_socket_set == nullptr) - { - debug(LOG_ERROR, "Cannot create socket set: %s", strSockError(getSockErr())); - return false; - } // allocate socket storage for all possible players for (unsigned i = 0; i < MAX_CONNECTED_PLAYERS; ++i) { @@ -5102,11 +5097,6 @@ void NETacceptIncomingConnections() // initialize temporary server socket set // FIXME: why is this not done in NETinit()?? - Per tmp_socket_set = allocSocketSet(); - if (tmp_socket_set == nullptr) - { - debug(LOG_ERROR, "Cannot create socket set: %s", strSockError(getSockErr())); - return; - } // FIXME: I guess initialization of allowjoining is here now... - FlexCoral for (auto& tmpState : tmp_connectState) { From fd543ee04522a7313449a16cd595c9e0b49b8def Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Sun, 29 Sep 2024 13:43:46 +0300 Subject: [PATCH 2/8] netsocket: remove redundant check for nullptr when allocating `Socket` `operator new` will throw a `std::bad_alloc` exception instead of returning `nullptr`, so there's no need to check for `nullptr` after trying to allocate a new `Socket` instance. Signed-off-by: Pavel Solodovnikov --- lib/netplay/netsocket.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lib/netplay/netsocket.cpp b/lib/netplay/netsocket.cpp index 9d42171dd33..1109efc9ebe 100644 --- a/lib/netplay/netsocket.cpp +++ b/lib/netplay/netsocket.cpp @@ -1391,12 +1391,6 @@ Socket *socketListen(unsigned int port) unsigned int i; Socket *const conn = new Socket; - if (conn == nullptr) - { - debug(LOG_ERROR, "Out of memory!"); - abort(); - return nullptr; - } // Mark all unused socket handles as invalid for (i = 0; i < ARRAY_SIZE(conn->fd); ++i) From 3081b66ad02f1660289205e50127e5fae89d58b6 Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Sun, 22 Sep 2024 16:30:04 +0300 Subject: [PATCH 3/8] 3rdparty: add `expected` submodule Signed-off-by: Pavel Solodovnikov --- .gitmodules | 3 +++ 3rdparty/CMakeLists.txt | 13 +++++++++++++ 3rdparty/expected | 1 + COPYING.NONGPL | 2 ++ 4 files changed, 19 insertions(+) create mode 160000 3rdparty/expected diff --git a/.gitmodules b/.gitmodules index 41ba63f817d..69910ef72cf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -74,3 +74,6 @@ [submodule "lib/netplay/3rdparty/miniupnp"] path = lib/netplay/3rdparty/miniupnp url = https://github.com/miniupnp/miniupnp.git +[submodule "3rdparty/expected"] + path = 3rdparty/expected + url = https://github.com/TartanLlama/expected.git diff --git a/3rdparty/CMakeLists.txt b/3rdparty/CMakeLists.txt index 18071178538..dad82b3add9 100644 --- a/3rdparty/CMakeLists.txt +++ b/3rdparty/CMakeLists.txt @@ -277,3 +277,16 @@ if(NOT MSVC) target_compile_options(plum-static PRIVATE ${_supported_libplum_c_compiler_flags}) endif() endif() + +set(EXPECTED_BUILD_TESTS OFF) +add_subdirectory(expected EXCLUDE_FROM_ALL) +# There isn't any release note or established CMake policy about this behavior, +# but looks like prior to CMake 3.19 only a handful of `INTERFACE_*` properties +# were allowed for INTERFACE libraries (which is our case). This restriction +# seems to be lifted in CMake 3.19 and later. +# +# See https://discourse.cmake.org/t/how-to-find-current-interface-library-property-whitelist/4784/2 +# for some clarification about this. +if (CMAKE_VERSION VERSION_GREATER "3.18") + set_target_properties(expected PROPERTIES FOLDER "3rdparty") +endif() diff --git a/3rdparty/expected b/3rdparty/expected new file mode 160000 index 00000000000..3f0ca7b1925 --- /dev/null +++ b/3rdparty/expected @@ -0,0 +1 @@ +Subproject commit 3f0ca7b19253129700a073abfa6d8638d9f7c80c diff --git a/COPYING.NONGPL b/COPYING.NONGPL index 0722ac2a893..db182c4c955 100644 --- a/COPYING.NONGPL +++ b/COPYING.NONGPL @@ -13,6 +13,8 @@ data/base/texpages/page-25-sky-urban.png - MIT, Various Authors, See: https://github.com/HowardHinnant/date/blob/master/include/date/date.h 3rdparty/discord-rpc/* - MIT, Copyright 2017 Discord, Inc. +3rdparty/expected/* + - CC0 1.0 Universal, Written in 2017 by Sy Brand (tartanllama@gmail.com, @TartanLlama) 3rdparty/json/* - MIT, Copyright (c) 2013-2018 Niels Lohmann 3rdparty/LRUCache11/* From 3943cd4f33cb35c6706797ec14d3be29c1a1f012 Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Sat, 12 Oct 2024 21:00:19 +0300 Subject: [PATCH 4/8] netplay: introduce custom error category for locale-indepdendent translation of network errors Signed-off-by: Pavel Solodovnikov --- lib/netplay/error_categories.cpp | 110 +++++++++++++++++++++++++++++++ lib/netplay/error_categories.h | 52 +++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 lib/netplay/error_categories.cpp create mode 100644 lib/netplay/error_categories.h diff --git a/lib/netplay/error_categories.cpp b/lib/netplay/error_categories.cpp new file mode 100644 index 00000000000..0b0683425b8 --- /dev/null +++ b/lib/netplay/error_categories.cpp @@ -0,0 +1,110 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "error_categories.h" + +#include "lib/framework/wzglobal.h" + +#ifdef WZ_OS_WIN +# include +# include +#elif defined(WZ_OS_UNIX) +# include +# include +#endif + +std::string GenericSystemErrorCategory::message(int ev) const +{ +#if defined(WZ_OS_WIN) + switch (ev) + { + case 0: return "No error"; + case WSAEINTR: return "Interrupted system call"; + case WSAEBADF: return "Bad file number"; + case WSAEACCES: return "Permission denied"; + case WSAEFAULT: return "Bad address"; + case WSAEINVAL: return "Invalid argument"; + case WSAEMFILE: return "Too many open sockets"; + case WSAEWOULDBLOCK: return "Operation would block"; + case WSAEINPROGRESS: return "Operation now in progress"; + case WSAEALREADY: return "Operation already in progress"; + case WSAENOTSOCK: return "Socket operation on non-socket"; + case WSAEDESTADDRREQ: return "Destination address required"; + case WSAEMSGSIZE: return "Message too long"; + case WSAEPROTOTYPE: return "Protocol wrong type for socket"; + case WSAENOPROTOOPT: return "Bad protocol option"; + case WSAEPROTONOSUPPORT: return "Protocol not supported"; + case WSAESOCKTNOSUPPORT: return "Socket type not supported"; + case WSAEOPNOTSUPP: return "Operation not supported on socket"; + case WSAEPFNOSUPPORT: return "Protocol family not supported"; + case WSAEAFNOSUPPORT: return "Address family not supported"; + case WSAEADDRINUSE: return "Address already in use"; + case WSAEADDRNOTAVAIL: return "Can't assign requested address"; + case WSAENETDOWN: return "Network is down"; + case WSAENETUNREACH: return "Network is unreachable"; + case WSAENETRESET: return "Net connection reset"; + case WSAECONNABORTED: return "Software caused connection abort"; + case WSAECONNRESET: return "Connection reset by peer"; + case WSAENOBUFS: return "No buffer space available"; + case WSAEISCONN: return "Socket is already connected"; + case WSAENOTCONN: return "Socket is not connected"; + case WSAESHUTDOWN: return "Can't send after socket shutdown"; + case WSAETOOMANYREFS: return "Too many references, can't splice"; + case WSAETIMEDOUT: return "Connection timed out"; + case WSAECONNREFUSED: return "Connection refused"; + case WSAELOOP: return "Too many levels of symbolic links"; + case WSAENAMETOOLONG: return "File name too long"; + case WSAEHOSTDOWN: return "Host is down"; + case WSAEHOSTUNREACH: return "No route to host"; + case WSAENOTEMPTY: return "Directory not empty"; + case WSAEPROCLIM: return "Too many processes"; + case WSAEUSERS: return "Too many users"; + case WSAEDQUOT: return "Disc quota exceeded"; + case WSAESTALE: return "Stale NFS file handle"; + case WSAEREMOTE: return "Too many levels of remote in path"; + case WSASYSNOTREADY: return "Network system is unavailable"; + case WSAVERNOTSUPPORTED: return "Winsock version out of range"; + case WSANOTINITIALISED: return "WSAStartup not yet called"; + case WSAEDISCON: return "Graceful shutdown in progress"; + case WSAHOST_NOT_FOUND: return "Host not found"; + case WSANO_DATA: return "No host data of that type was found"; + default: return "Unknown error"; + } +#elif defined(WZ_OS_UNIX) + return strerror(ev); +#endif +} + +std::error_condition GenericSystemErrorCategory::default_error_condition(int ev) const noexcept +{ + // Try to map the raw error values either to POSIX or Windows error codes (depending on the OS). + // The default system category should capture them all well. + return std::system_category().default_error_condition(ev); +} + +const std::error_category& generic_system_error_category() +{ + static GenericSystemErrorCategory instance; + return instance; +} + +std::error_code make_network_error_code(int ev) +{ + return { ev, generic_system_error_category() }; +} diff --git a/lib/netplay/error_categories.h b/lib/netplay/error_categories.h new file mode 100644 index 00000000000..a266fcfd5e0 --- /dev/null +++ b/lib/netplay/error_categories.h @@ -0,0 +1,52 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include + +/// +/// Custom error category class, which acts exactly as `std::system_category()` except +/// that the error messages are always translated to english representation (locale-agnostic behavior). +/// +/// Please see the bug https://github.com/microsoft/STL/issues/3254 for the explanation +/// as to why we would need to use a custom error category (at least on Windows). +/// Date: Sat, 12 Oct 2024 21:05:15 +0300 Subject: [PATCH 5/8] netplay: introduce custom error category for `getaddrinfo()` error codes Signed-off-by: Pavel Solodovnikov --- lib/netplay/error_categories.cpp | 16 ++++++++++++++++ lib/netplay/error_categories.h | 20 ++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/lib/netplay/error_categories.cpp b/lib/netplay/error_categories.cpp index 0b0683425b8..eb990de3c2f 100644 --- a/lib/netplay/error_categories.cpp +++ b/lib/netplay/error_categories.cpp @@ -98,13 +98,29 @@ std::error_condition GenericSystemErrorCategory::default_error_condition(int ev) return std::system_category().default_error_condition(ev); } +std::string GetaddrinfoErrorCategory::message(int ev) const +{ + return gai_strerror(ev); +} + const std::error_category& generic_system_error_category() { static GenericSystemErrorCategory instance; return instance; } +const std::error_category& getaddrinfo_error_category() +{ + static GetaddrinfoErrorCategory instance; + return instance; +} + std::error_code make_network_error_code(int ev) { return { ev, generic_system_error_category() }; } + +std::error_code make_getaddrinfo_error_code(int ev) +{ + return { ev, getaddrinfo_error_category() }; +} diff --git a/lib/netplay/error_categories.h b/lib/netplay/error_categories.h index a266fcfd5e0..87a3fd20079 100644 --- a/lib/netplay/error_categories.h +++ b/lib/netplay/error_categories.h @@ -47,6 +47,26 @@ class GenericSystemErrorCategory : public std::error_category std::error_condition default_error_condition(int ev) const noexcept override; }; +/// +/// Custom error category which maps error codes from `getaddrinfo()` function to +/// the appropriate error messages. +/// +class GetaddrinfoErrorCategory : public std::error_category +{ +public: + + constexpr GetaddrinfoErrorCategory() = default; + + const char* name() const noexcept override + { + return "getaddrinfo"; + } + + std::string message(int ev) const override; +}; + const std::error_category& generic_system_error_category(); +const std::error_category& getaddrinfo_error_category(); std::error_code make_network_error_code(int ev); +std::error_code make_getaddrinfo_error_code(int ev); From e565f3606cb4d288ce15855299159bced42cde8f Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Sat, 12 Oct 2024 21:09:00 +0300 Subject: [PATCH 6/8] netplay: introduce custom error category for zlib errors Signed-off-by: Pavel Solodovnikov --- lib/netplay/error_categories.cpp | 30 ++++++++++++++++++++++++++++++ lib/netplay/error_categories.h | 20 ++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/lib/netplay/error_categories.cpp b/lib/netplay/error_categories.cpp index eb990de3c2f..5be0196bc98 100644 --- a/lib/netplay/error_categories.cpp +++ b/lib/netplay/error_categories.cpp @@ -29,6 +29,8 @@ # include #endif +#include + std::string GenericSystemErrorCategory::message(int ev) const { #if defined(WZ_OS_WIN) @@ -103,6 +105,23 @@ std::string GetaddrinfoErrorCategory::message(int ev) const return gai_strerror(ev); } +std::string ZlibErrorCategory::message(int ev) const +{ + switch (ev) + { + case Z_STREAM_ERROR: + return "Z_STREAM_ERROR"; + case Z_NEED_DICT: + return "Z_NEED_DICT"; + case Z_DATA_ERROR: + return "Z_DATA_ERROR"; + case Z_MEM_ERROR: + return "Z_MEM_ERROR"; + default: + return "Unknown zlib error"; + } +} + const std::error_category& generic_system_error_category() { static GenericSystemErrorCategory instance; @@ -115,6 +134,12 @@ const std::error_category& getaddrinfo_error_category() return instance; } +const std::error_category& zlib_error_category() +{ + static ZlibErrorCategory instance; + return instance; +} + std::error_code make_network_error_code(int ev) { return { ev, generic_system_error_category() }; @@ -124,3 +149,8 @@ std::error_code make_getaddrinfo_error_code(int ev) { return { ev, getaddrinfo_error_category() }; } + +std::error_code make_zlib_error_code(int ev) +{ + return { ev, zlib_error_category() }; +} diff --git a/lib/netplay/error_categories.h b/lib/netplay/error_categories.h index 87a3fd20079..ddb98ed8815 100644 --- a/lib/netplay/error_categories.h +++ b/lib/netplay/error_categories.h @@ -65,8 +65,28 @@ class GetaddrinfoErrorCategory : public std::error_category std::string message(int ev) const override; }; +/// +/// Custom error category which maps some of the error codes from zlib to +/// the appropriate error messages. +/// +class ZlibErrorCategory : public std::error_category +{ +public: + + constexpr ZlibErrorCategory() = default; + + const char* name() const noexcept override + { + return "zlib"; + } + + std::string message(int ev) const override; +}; + const std::error_category& generic_system_error_category(); const std::error_category& getaddrinfo_error_category(); +const std::error_category& zlib_error_category(); std::error_code make_network_error_code(int ev); std::error_code make_getaddrinfo_error_code(int ev); +std::error_code make_zlib_error_code(int ev); From 30eeaf6c46057de72d6c2ee3cbee3521cab870fd Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Sat, 12 Oct 2024 22:03:25 +0300 Subject: [PATCH 7/8] netsocket: provide the shorthand `net::result` for `tl::expected` Plus, add the necessary dependency for the `netplay` library in order for it to find `tl::expected` header files. Signed-off-by: Pavel Solodovnikov --- lib/netplay/CMakeLists.txt | 4 +++- lib/netplay/netsocket.h | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/lib/netplay/CMakeLists.txt b/lib/netplay/CMakeLists.txt index 963f4852394..12cb14cb648 100644 --- a/lib/netplay/CMakeLists.txt +++ b/lib/netplay/CMakeLists.txt @@ -57,7 +57,9 @@ add_dependencies(netplay autorevision_netcodeversion) set_property(TARGET netplay PROPERTY FOLDER "lib") include(WZTargetConfiguration) WZ_TARGET_CONFIGURATION(netplay) -target_link_libraries(netplay PRIVATE framework re2::re2 nlohmann_json plum-static Threads::Threads ZLIB::ZLIB) +target_link_libraries(netplay + PRIVATE framework re2::re2 nlohmann_json plum-static Threads::Threads ZLIB::ZLIB + PUBLIC tl::expected) if(WZ_USE_IMPORTED_MINIUPNPC) target_link_libraries(netplay PRIVATE imported-miniupnpc) diff --git a/lib/netplay/netsocket.h b/lib/netplay/netsocket.h index a36c3fcbfb5..1aad4412af6 100644 --- a/lib/netplay/netsocket.h +++ b/lib/netplay/netsocket.h @@ -23,8 +23,17 @@ #include "lib/framework/types.h" #include +#include #include +#include + +namespace net +{ + template + using result = ::tl::expected; +} // namespace net + #if defined(WZ_OS_UNIX) # include # include From 6337be008ccbee3081d5dd1166a6669ed1aea6f2 Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Sun, 13 Oct 2024 12:16:33 +0300 Subject: [PATCH 8/8] netsocket: socket operations return error codes via return values Switch to returning socket error codes explicitly via return values instead of using legacy functions `getSockErr()`, `setSockErr`, `strSockError()` to get extended error context. To facilitate that, use `tl::expected` as the return value type for all main socket operations: read, write, open, listen. Moving to this new approach has numerous benefits over the old one: 1. Instead of using POSIX constants directly, we wrap them into `std::error_code` instances, which allows to attach custom error categories to them. This can be used to customize error messages and error codes mapping to platform-independent `std::error_conditions`. 2. Calling separate `get/setSockErr()` functions is error-prone: one can easily forget to check the error condition from `getSockErr()` and the value will be overwritten by the next socket function without the ability to recover the former error. Conversely, one can forget to call `setSockErr()` to set the proper error code for the caller to check upon. 3. As mentioned above, `std::error_code:s` can be implicitly mapped to platform-independent `std::error_conditions`, allowing for this code to compile successfuly: if (errCode == std::errc::connection_reset) { ... } This allows for very convenient and portable error checking code, which completely hides implementation details of how a particular error code is implemented (but, if one really needs to, they still can extract the platform-dependent error code value to get the extended error context). The `getSockErr()`, `setSockErr` and `strSockError()` functions are still used in the `netsocket.cpp` implementation, but now they are strictly confined to this particular translation unit, meaning they have now become an implementation detail, rather than a part of public API contract of `netplay` library. Signed-off-by: Pavel Solodovnikov --- lib/netplay/netplay.cpp | 259 +++++++++++++++++----------------- lib/netplay/netsocket.cpp | 153 +++++++++++--------- lib/netplay/netsocket.h | 27 ++-- src/screens/joiningscreen.cpp | 37 ++--- 4 files changed, 249 insertions(+), 227 deletions(-) diff --git a/lib/netplay/netplay.cpp b/lib/netplay/netplay.cpp index cc306e98db2..fa2cd4a5432 100644 --- a/lib/netplay/netplay.cpp +++ b/lib/netplay/netplay.cpp @@ -64,6 +64,7 @@ #include "src/stdinreader.h" #include +#include #if defined (WZ_OS_MAC) # include "lib/framework/cocoa_wrapper.h" @@ -536,7 +537,6 @@ bool NETsetAsyncJoinApprovalResult(const std::string& uniqueJoinID, AsyncJoinApp static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *bufstart, int bufsize) { Socket *socket = *pSocket; - ssize_t size; if (!socketReadReady(*socket)) { @@ -544,26 +544,29 @@ static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *b } size_t rawBytes; - size = readNoInt(*socket, bufstart, bufsize, &rawBytes); + const auto readResult = readNoInt(*socket, bufstart, bufsize, &rawBytes); - if ((size != 0 || !socketReadDisconnected(*socket)) && size != SOCKET_ERROR) + if (readResult.has_value()) { + const auto size = readResult.value(); + nStats.rawBytes.received += rawBytes; - nStats.uncompressedBytes.received += size; + nStats.uncompressedBytes.received += static_cast(size); nStats.packets.received += 1; return size; } else { - if (size == 0) + if (readResult.error() == std::errc::timed_out || readResult.error() == std::errc::connection_reset) { debug(LOG_NET, "Connection closed from the other side"); NETlogEntry("Connection closed from the other side..", SYNC_FLAG, selectedPlayer); } else { - debug(LOG_NET, "%s socket %p is now invalid", strSockError(getSockErr()), static_cast(socket)); + const auto readErrMsg = readResult.error().message(); + debug(LOG_NET, "%s socket %p is now invalid", readErrMsg.c_str(), static_cast(socket)); } // an error occurred, or the remote host has closed the connection. @@ -571,14 +574,6 @@ static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *b { SocketSet_DelSocket(*pSocketSet, socket); } - - ASSERT(size <= bufsize, "Socket buffer is too small!"); - - if (size > bufsize) - { - debug(LOG_ERROR, "Fatal connection error: buffer size of (%d) was too small, current byte count was %ld", bufsize, (long)size); - NETlogEntry("Fatal connection error: buffer size was too small!", SYNC_FLAG, selectedPlayer); - } if (bsocket == socket) { debug(LOG_NET, "Host connection was lost!"); @@ -1204,7 +1199,7 @@ static constexpr size_t GAMESTRUCTmessageBufSize() * * @see GAMESTRUCT,NETrecvGAMESTRUCT */ -static bool NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourgamestruct) +static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourgamestruct) { // A buffer that's guaranteed to have the correct size (i.e. it // circumvents struct padding, which could pose a problem). Initialise @@ -1213,7 +1208,6 @@ static bool NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourgamestruct) char buf[GAMESTRUCTmessageBufSize()] = { 0 }; char *buffer = buf; unsigned int i; - ssize_t result; auto push32 = [&](uint32_t value) { uint32_t swapped = htonl(value); @@ -1311,20 +1305,17 @@ static bool NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourgamestruct) debug(LOG_NET, "sending GAMESTRUCT, size: %u", (unsigned int)sizeof(buf)); // Send over the GAMESTRUCT - result = writeAll(*sock, buf, sizeof(buf)); - if (result == SOCKET_ERROR) + const auto writeResult = writeAll(*sock, buf, sizeof(buf)); + if (!writeResult.has_value()) { - const int err = getSockErr(); - + const auto writeErrMsg = writeResult.error().message(); // If packet could not be sent, we should inform user of the error. - debug(LOG_ERROR, "Failed to send GAMESTRUCT. Reason: %s", strSockError(err)); + debug(LOG_ERROR, "Failed to send GAMESTRUCT. Reason: %s", writeErrMsg.c_str()); debug(LOG_ERROR, "Please make sure TCP ports %u & %u are open!", masterserver_port, gameserver_port); - setSockErr(err); - return false; + return tl::make_unexpected(writeResult.error()); } - - return true; + return {}; } /** @@ -1341,7 +1332,6 @@ static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) char buf[GAMESTRUCTmessageBufSize()] = { 0 }; char *buffer = buf; unsigned int i; - ssize_t result = 0; auto pop32 = [&]() -> uint32_t { uint32_t value = 0; @@ -1360,20 +1350,18 @@ static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) }; // Read a GAMESTRUCT from the connection - result = readAll(sock, buf, sizeof(buf), NET_TIMEOUT_DELAY); - bool failed = false; - if (result == SOCKET_ERROR) - { - debug(LOG_ERROR, "Lobby server connection error: %s", strSockError(getSockErr())); - failed = true; - } - else if ((unsigned)result != sizeof(buf)) - { - debug(LOG_ERROR, "GAMESTRUCT recv timed out; received %d bytes; expecting %d", (int)result, (int)sizeof(buf)); - failed = true; - } - if (failed) + auto readResult = readAll(sock, buf, sizeof(buf), NET_TIMEOUT_DELAY); + if (!readResult.has_value()) { + if (readResult.error() == std::errc::timed_out || readResult.error() == std::errc::connection_reset) + { + debug(LOG_ERROR, "GAMESTRUCT recv failed: timed out"); + } + else + { + const auto readErrMsg = readResult.error().message(); + debug(LOG_ERROR, "Lobby server connection error: %s", readErrMsg.c_str()); + } // caller handles invalidating and closing tcp_socket return false; } @@ -1724,7 +1712,6 @@ void NETsendProcessDelayedActions() bool NETsend(NETQUEUE queue, NetMessage const *message) { uint8_t player = queue.index; - ssize_t result = 0; if (!NetPlay.bComms) { @@ -1768,19 +1755,22 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) } ssize_t rawLen = message->rawLen(); size_t compressedRawLen; - result = writeAll(*sockets[player], rawData, rawLen, &compressedRawLen); + const auto writeResult = writeAll(*sockets[player], rawData, rawLen, &compressedRawLen); + const auto res = writeResult.value_or(SOCKET_ERROR); delete[] rawData; // Done with the data. - if (result == rawLen) + if (res == rawLen) { nStats.rawBytes.sent += compressedRawLen; nStats.uncompressedBytes.sent += rawLen; nStats.packets.sent += 1; } - else if (result == SOCKET_ERROR) + else if (res == SOCKET_ERROR) { + const auto writeErrMsg = writeResult.error().message(); // Write error, most likely client disconnect. - debug(LOG_ERROR, "Failed to send message (type: %" PRIu8 ", rawLen: %zu, compressedRawLen: %zu) to %" PRIu8 ": %s", message->type, message->rawLen(), compressedRawLen, player, strSockError(getSockErr())); + debug(LOG_ERROR, "Failed to send message (type: %" PRIu8 ", rawLen: %zu, compressedRawLen: %zu) to %" PRIu8 ": %s", + message->type, message->rawLen(), compressedRawLen, player, writeErrMsg.c_str()); if (!isTmpQueue) { netSendPendingDisconnectPlayerIndexes.insert(player); @@ -1798,19 +1788,21 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) uint8_t *rawData = message->rawDataDup(); ssize_t rawLen = message->rawLen(); size_t compressedRawLen; - result = writeAll(*bsocket, rawData, rawLen, &compressedRawLen); + const auto writeResult = writeAll(*bsocket, rawData, rawLen, &compressedRawLen); + const auto res = writeResult.value_or(SOCKET_ERROR); delete[] rawData; // Done with the data. - if (result == rawLen) + if (res == rawLen) { nStats.rawBytes.sent += compressedRawLen; nStats.uncompressedBytes.sent += rawLen; nStats.packets.sent += 1; } - else if (result == SOCKET_ERROR) + else if (res == SOCKET_ERROR) { + const auto writeErrMsg = writeResult.error().message(); // Write error, most likely host disconnect. - debug(LOG_ERROR, "Failed to send message: %s", strSockError(getSockErr())); + debug(LOG_ERROR, "Failed to send message: %s", writeErrMsg.c_str()); debug(LOG_ERROR, "Host connection was broken, socket %p.", static_cast(bsocket)); NETlogEntry("write error--client disconnect.", SYNC_FLAG, player); SocketSet_DelSocket(*client_socket_set, bsocket); // mark it invalid @@ -1821,7 +1813,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) NetPlay.isHostAlive = false; } - return result == rawLen; + return res == rawLen; } } else @@ -3241,15 +3233,15 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) uint32_t lobbyStatusCode; uint32_t MOTDLength; uint32_t buffer[2]; - ssize_t result, received = 0; + ssize_t received = 0; // Get status and message length - result = readAll(sock, &buffer, sizeof(buffer), timeout); - if (result != sizeof(buffer)) + auto readResult = readAll(sock, &buffer, sizeof(buffer), timeout); + if (!readResult.has_value()) { goto error; } - received += result; + received += readResult.value(); lobbyStatusCode = ntohl(buffer[0]); MOTDLength = ntohl(buffer[1]); @@ -3259,12 +3251,12 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) free(NetPlay.MOTD); } NetPlay.MOTD = (char *)malloc(MOTDLength + 1); - result = readAll(sock, NetPlay.MOTD, MOTDLength, timeout); - if (result != MOTDLength) + readResult = readAll(sock, NetPlay.MOTD, MOTDLength, timeout); + if (!readResult.has_value()) { goto error; } - received += result; + received += readResult.value(); // NUL terminate string NetPlay.MOTD[MOTDLength] = '\0'; @@ -3293,37 +3285,19 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) return received; error: - if (result == SOCKET_ERROR) + if (NetPlay.MOTD) { - if (NetPlay.MOTD) - { - free(NetPlay.MOTD); - } - if (asprintf(&NetPlay.MOTD, "Error while connecting to the lobby server: %s\nMake sure port %d can receive incoming connections.", strSockError(getSockErr()), gameserver_port) == -1) - { - NetPlay.MOTD = nullptr; - } - else - { - NetPlay.ShowedMOTD = false; - debug(LOG_ERROR, "%s", NetPlay.MOTD); - } + free(NetPlay.MOTD); + } + const auto readErrMsg = readResult.error().message(); + if (asprintf(&NetPlay.MOTD, "Error while connecting to the lobby server: %s\nMake sure port %d can receive incoming connections.", readErrMsg.c_str(), gameserver_port) == -1) + { + NetPlay.MOTD = nullptr; } else { - if (NetPlay.MOTD) - { - free(NetPlay.MOTD); - } - if (asprintf(&NetPlay.MOTD, "Disconnected from lobby server. Failed to register game.") == -1) - { - NetPlay.MOTD = nullptr; - } - else - { - NetPlay.ShowedMOTD = false; - debug(LOG_ERROR, "%s", NetPlay.MOTD); - } + NetPlay.ShowedMOTD = false; + debug(LOG_ERROR, "%s", NetPlay.MOTD); } std::string strmotd = (NetPlay.MOTD) ? std::string(NetPlay.MOTD) : std::string(); @@ -3336,21 +3310,22 @@ bool readGameStructsList(Socket& sock, unsigned int timeout, const std::function { unsigned int gamecount = 0; uint32_t gamesavailable = 0; - int result = 0; + const auto readResult = readAll(sock, &gamesavailable, sizeof(gamesavailable), NET_TIMEOUT_DELAY); - if ((result = readAll(sock, &gamesavailable, sizeof(gamesavailable), NET_TIMEOUT_DELAY)) == sizeof(gamesavailable)) + if (readResult.has_value()) { gamesavailable = ntohl(gamesavailable); } else { - if (result == SOCKET_ERROR) + if (readResult.error() == std::errc::timed_out || readResult.error() == std::errc::connection_reset) { - debug(LOG_NET, "Server socket encountered error: %s", strSockError(getSockErr())); + debug(LOG_NET, "Server didn't respond (timeout)"); } else { - debug(LOG_NET, "Server didn't respond (timeout)"); + const auto readErrMsg = readResult.error().message(); + debug(LOG_NET, "Server socket encountered error: %s", readErrMsg.c_str()); } return false; } @@ -3398,6 +3373,12 @@ bool readGameStructsList(Socket& sock, unsigned int timeout, const std::function return true; } +template +static net::result ignoreExpectedResultValue(const net::result& res) +{ + return res.has_value() ? net::result{} : tl::make_unexpected(res.error()); +} + bool LobbyServerConnectionHandler::connect() { if (server_not_there) @@ -3419,18 +3400,19 @@ bool LobbyServerConnectionHandler::connect() bool bProcessingConnectOrDisconnectThisCall = true; uint32_t gameId = 0; - SocketAddress *const hosts = resolveHost(masterserver_name, masterserver_port); + const auto hostsResult = resolveHost(masterserver_name, masterserver_port); + const auto hosts = hostsResult.value_or(nullptr); if (hosts == nullptr) { - int sockErrInt = getSockErr(); - debug(LOG_ERROR, "Cannot resolve masterserver \"%s\": %s", masterserver_name, strSockError(sockErrInt)); + const auto hostsErrMsg = hostsResult.error().message(); + debug(LOG_ERROR, "Cannot resolve masterserver \"%s\": %s", masterserver_name, hostsErrMsg.c_str()); free(NetPlay.MOTD); if (asprintf(&NetPlay.MOTD, _("Could not resolve masterserver name (%s)!"), masterserver_name) == -1) { NetPlay.MOTD = nullptr; } - wz_command_interface_output("WZEVENT: lobbyerror (%u): Cannot resolve lobby server: %s\n", 0, strSockError(sockErrInt)); + wz_command_interface_output("WZEVENT: lobbyerror (%u): Cannot resolve lobby server: %s\n", 0, hostsErrMsg.c_str()); server_not_there = true; return bProcessingConnectOrDisconnectThisCall; } @@ -3443,17 +3425,19 @@ bool LobbyServerConnectionHandler::connect() } // try each address from resolveHost until we successfully connect. - rs_socket = socketOpenAny(hosts, 1500); - int sockOpenErr = getSockErr(); + auto sockResult = socketOpenAny(hosts, 1500); deleteSocketAddress(hosts); + rs_socket = sockResult.value_or(nullptr); + // No address succeeded. if (rs_socket == nullptr) { - debug(LOG_ERROR, "Cannot connect to masterserver \"%s:%d\": %s", masterserver_name, masterserver_port, strSockError(sockOpenErr)); + const auto errMsg = sockResult.error().message(); + debug(LOG_ERROR, "Cannot connect to masterserver \"%s:%d\": %s", masterserver_name, masterserver_port, errMsg.c_str()); free(NetPlay.MOTD); if (asprintf(&NetPlay.MOTD, _("Error connecting to the lobby server: %s.\nMake sure port %d can receive incoming connections.\nIf you're using a router configure it to enable UPnP/NAT-PMP/PCP\n or to forward the port to your system."), - strSockError(getSockErr()), masterserver_port) == -1) + errMsg.c_str(), masterserver_port) == -1) { NetPlay.MOTD = nullptr; } @@ -3462,11 +3446,16 @@ bool LobbyServerConnectionHandler::connect() } // Get a game ID - if (writeAll(*rs_socket, "gaId", sizeof("gaId")) == SOCKET_ERROR - || readAll(*rs_socket, &gameId, sizeof(gameId), 10000) != sizeof(gameId)) + auto gameIdResult = writeAll(*rs_socket, "gaId", sizeof("gaId")); + if (gameIdResult.has_value()) + { + gameIdResult = readAll(*rs_socket, &gameId, sizeof(gameId), 10000); + } + if (!gameIdResult.has_value()) { + const auto gameIdErrMsg = gameIdResult.error().message(); free(NetPlay.MOTD); - if (asprintf(&NetPlay.MOTD, "Failed to retrieve a game ID: %s", strSockError(getSockErr())) == -1) + if (asprintf(&NetPlay.MOTD, "Failed to retrieve a game ID: %s", gameIdErrMsg.c_str()) == -1) { NetPlay.MOTD = nullptr; } @@ -3486,11 +3475,18 @@ bool LobbyServerConnectionHandler::connect() wz_command_interface_output("WZEVENT: lobbyid: %" PRIu32 "\n", gamestruct.gameId); // Register our game with the server - if (writeAll(*rs_socket, "addg", sizeof("addg")) == SOCKET_ERROR + const auto writeAddGameRes = writeAll(*rs_socket, "addg", sizeof("addg")); + + auto sendGamestructRes = ignoreExpectedResultValue(writeAddGameRes); + if (sendGamestructRes.has_value()) + { // and now send what the server wants - || !NETsendGAMESTRUCT(rs_socket, &gamestruct)) + sendGamestructRes = NETsendGAMESTRUCT(rs_socket, &gamestruct); + } + if (!sendGamestructRes.has_value()) { - debug(LOG_ERROR, "Failed to register game with server: %s", strSockError(getSockErr())); + const auto sendGameErrMsg = sendGamestructRes.error().message(); + debug(LOG_ERROR, "Failed to register game with server: %s", sendGameErrMsg.c_str()); disconnect(); return bProcessingConnectOrDisconnectThisCall; } @@ -3564,7 +3560,7 @@ void LobbyServerConnectionHandler::sendUpdateNow() return; } - if (!NETsendGAMESTRUCT(rs_socket, &gamestruct)) + if (!NETsendGAMESTRUCT(rs_socket, &gamestruct).has_value()) { disconnect(); } @@ -3580,7 +3576,7 @@ void LobbyServerConnectionHandler::sendUpdateNow() void LobbyServerConnectionHandler::sendKeepAlive() { ASSERT_OR_RETURN(, rs_socket != nullptr, "Null socket"); - if (writeAll(*rs_socket, "keep", sizeof("keep")) == SOCKET_ERROR) + if (!writeAll(*rs_socket, "keep", sizeof("keep")).has_value()) { // The socket has been invalidated, so get rid of it. (using them now may cause SIGPIPE). disconnect(); @@ -3871,10 +3867,10 @@ static void NETallowJoining() { char *p_buffer = tmp_connectState[i].buffer; - ssize_t sizeRead = readNoInt(*tmp_socket[i], p_buffer + tmp_connectState[i].usedBuffer, 8 - tmp_connectState[i].usedBuffer); - if (sizeRead != SOCKET_ERROR) + const auto sizeReadResult = readNoInt(*tmp_socket[i], p_buffer + tmp_connectState[i].usedBuffer, 8 - tmp_connectState[i].usedBuffer); + if (sizeReadResult.has_value()) { - tmp_connectState[i].usedBuffer += sizeRead; + tmp_connectState[i].usedBuffer += sizeReadResult.value(); } // A 2.3.7 client sends a "list" command first, just drop the connection. @@ -3986,27 +3982,28 @@ static void NETallowJoining() else if (tmp_connectState[i].connectState == TmpSocketInfo::TmpConnectState::PendingJoinRequest) { uint8_t buffer[NET_BUFFER_SIZE]; - ssize_t size = readNoInt(*tmp_socket[i], buffer, sizeof(buffer)); + const auto readResult = readNoInt(*tmp_socket[i], buffer, sizeof(buffer)); uint8_t rejected = 0; - if ((size == 0 && socketReadDisconnected(*tmp_socket[i])) || size == SOCKET_ERROR) + if (!readResult.has_value()) { // disconnect or programmer error - if (size == 0) + if (readResult.error() == std::errc::timed_out || readResult.error() == std::errc::connection_reset) { debug(LOG_NET, "Client socket disconnected."); } else { - debug(LOG_NET, "Client socket encountered error: %s", strSockError(getSockErr())); + const auto readErrMsg = readResult.error().message(); + debug(LOG_NET, "Client socket encountered error: %s", readErrMsg.c_str()); } NETlogEntry("Client socket disconnected (allowJoining)", SYNC_FLAG, i); debug(LOG_NET, "freeing temp socket %p (%d)", static_cast(tmp_socket[i]), __LINE__); NETcloseTempSocket(i); continue; } - - NETinsertRawData(NETnetTmpQueue(i), buffer, size); + const auto size = readResult.value(); + NETinsertRawData(NETnetTmpQueue(i), buffer, static_cast(size)); if (!NETisMessageReady(NETnetTmpQueue(i))) { @@ -4489,13 +4486,16 @@ bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectator // These will initially be assigned to `tmp_socket[i]` until accepted in the game session, // in which case `tmp_socket[i]` will be assigned to `connected_bsocket[i]` and `tmp_socket[i]` // will become nullptr. + net::result serverListenResult = {}; if (!server_listen_socket) { - server_listen_socket = socketListen(gameserver_port); + serverListenResult = socketListen(gameserver_port); + server_listen_socket = serverListenResult.value_or(nullptr); } if (server_listen_socket == nullptr) { - debug(LOG_ERROR, "Cannot connect to master self: %s", strSockError(getSockErr())); + const auto sockErrMsg = serverListenResult.error().message(); + debug(LOG_ERROR, "Cannot connect to master self: %s", sockErrMsg.c_str()); return false; } debug(LOG_NET, "New server_listen_socket = %p", static_cast(server_listen_socket)); @@ -4594,8 +4594,6 @@ bool NEThaltJoining() // find games on open connection bool NETenumerateGames(const std::function& handleEnumerateGameFunc) { - SocketAddress *hosts; - int result = 0; debug(LOG_NET, "Looking for games..."); if (getLobbyError() == ERROR_INVALID || getLobbyError() == ERROR_KICKED || getLobbyError() == ERROR_HOSTDROPPED) @@ -4609,30 +4607,36 @@ bool NETenumerateGames(const std::function& handl debug(LOG_ERROR, "Likely missing NETinit(true) - this won't return any results"); return false; } - if ((hosts = resolveHost(masterserver_name, masterserver_port)) == nullptr) + const auto hostsResult = resolveHost(masterserver_name, masterserver_port); + SocketAddress* hosts = hostsResult.value_or(nullptr); + if (!hosts) { - debug(LOG_ERROR, "Cannot resolve hostname \"%s\": %s", masterserver_name, strSockError(getSockErr())); + const auto hostsErrMsg = hostsResult.error().message(); + debug(LOG_ERROR, "Cannot resolve hostname \"%s\": %s", masterserver_name, hostsErrMsg.c_str()); setLobbyError(ERROR_CONNECTION); return false; } - Socket* sock = socketOpenAny(hosts, 15000); - + auto sockResult = socketOpenAny(hosts, 15000); deleteSocketAddress(hosts); hosts = nullptr; - if (sock == nullptr) - { - debug(LOG_ERROR, "Cannot connect to \"%s:%d\": %s", masterserver_name, masterserver_port, strSockError(getSockErr())); + if (!sockResult.has_value()) { + const auto sockErrMsg = sockResult.error().message(); + debug(LOG_ERROR, "Cannot connect to \"%s:%d\": %s", masterserver_name, masterserver_port, sockErrMsg.c_str()); setLobbyError(ERROR_CONNECTION); return false; } + Socket* sock = sockResult.value(); + debug(LOG_NET, "New socket = %p", static_cast(sock)); debug(LOG_NET, "Sending list cmd"); - if (writeAll(*sock, "list", sizeof("list")) == SOCKET_ERROR) + const auto writeResult = writeAll(*sock, "list", sizeof("list")); + if (!writeResult.has_value()) { - debug(LOG_NET, "Server socket encountered error: %s", strSockError(getSockErr())); + const auto writeErrMsg = writeResult.error().message(); + debug(LOG_NET, "Server socket encountered error: %s", writeErrMsg.c_str()); // mark it invalid socketClose(sock); @@ -4675,7 +4679,8 @@ bool NETenumerateGames(const std::function& handl // Hence as long as we don't treat "0" as signifying any change in behavior, this should be safe + backwards-compatible #define IGNORE_FIRST_BATCH 1 uint32_t responseParameters = 0; - if ((result = readAll(*sock, &responseParameters, sizeof(responseParameters), NET_TIMEOUT_DELAY)) == sizeof(responseParameters)) + const auto readResult = readAll(*sock, &responseParameters, sizeof(responseParameters), NET_TIMEOUT_DELAY); + if (readResult.has_value()) { responseParameters = ntohl(responseParameters); diff --git a/lib/netplay/netsocket.cpp b/lib/netplay/netsocket.cpp index 1109efc9ebe..b76a50bc420 100644 --- a/lib/netplay/netsocket.cpp +++ b/lib/netplay/netsocket.cpp @@ -26,6 +26,7 @@ #include "lib/framework/frame.h" #include "lib/framework/wzapp.h" #include "netsocket.h" +#include "error_categories.h" #include #include @@ -62,7 +63,7 @@ struct Socket * * All non-listening sockets will only use the first socket handle. */ - Socket() : ready(false), writeError(false), deleteLater(false), isCompressed(false), readDisconnected(false), zDeflateInSize(0) + Socket() : ready(false), deleteLater(false), isCompressed(false), readDisconnected(false), zDeflateInSize(0) { memset(&zDeflate, 0, sizeof(zDeflate)); memset(&zInflate, 0, sizeof(zInflate)); @@ -71,7 +72,7 @@ struct Socket SOCKET fd[SOCK_COUNT]; bool ready; - bool writeError; + std::error_code writeErrorCode = make_network_error_code(0); bool deleteLater; char textAddress[40] = {}; @@ -108,7 +109,7 @@ bool socketReadReady(const Socket& sock) } // Returns the last error for the calling thread -int getSockErr() +static int getSockErr() { #if defined(WZ_OS_UNIX) return errno; @@ -117,7 +118,7 @@ int getSockErr() #endif } -void setSockErr(int error) +static void setSockErr(int error) { #if defined(WZ_OS_UNIX) errno = error; @@ -287,7 +288,7 @@ static int addressToText(const struct sockaddr *addr, char *buf, size_t size) } } -const char *strSockError(int error) +static const char *strSockError(int error) { #if defined(WZ_OS_WIN) switch (error) @@ -489,7 +490,7 @@ static int socketThreadFunction(void *) if (!connectionIsOpen(sock)) { debug(LOG_NET, "Socket error"); - sock->writeError = true; + sock->writeErrorCode = make_network_error_code(getSockErr()); socketThreadWrites.erase(w); // Socket broken, don't try writing to it again. if (sock->deleteLater) { @@ -503,7 +504,7 @@ static int socketThreadFunction(void *) case EPIPE: #endif default: - sock->writeError = true; + sock->writeErrorCode = make_network_error_code(getSockErr()); socketThreadWrites.erase(w); // Socket broken, don't try writing to it again. if (sock->deleteLater) { @@ -532,7 +533,7 @@ static int socketThreadFunction(void *) * Similar to read(2) with the exception that this function won't be * interrupted by signals (EINTR). */ -ssize_t readNoInt(Socket& sock, void *buf, size_t max_size, size_t *rawByteCount) +net::result readNoInt(Socket& sock, void *buf, size_t max_size, size_t *rawByteCount) { size_t ignored; size_t &rawBytes = rawByteCount != nullptr ? *rawByteCount : ignored; @@ -541,8 +542,7 @@ ssize_t readNoInt(Socket& sock, void *buf, size_t max_size, size_t *rawByteCount if (sock.fd[SOCK_CONNECTION] == INVALID_SOCKET) { debug(LOG_ERROR, "Invalid socket"); - setSockErr(EBADF); - return SOCKET_ERROR; + return tl::make_unexpected(make_network_error_code(EBADF)); } if (sock.isCompressed) @@ -562,7 +562,7 @@ ssize_t readNoInt(Socket& sock, void *buf, size_t max_size, size_t *rawByteCount while (received == SOCKET_ERROR && getSockErr() == EINTR); if (received < 0) { - return received; + return tl::make_unexpected(make_network_error_code(getSockErr())); } sock.zInflate.next_in = &sock.zInflateInBuf[0]; @@ -593,7 +593,8 @@ ssize_t readNoInt(Socket& sock, void *buf, size_t max_size, size_t *rawByteCount if (err != nullptr) { debug(LOG_ERROR, "Couldn't decompress data from socket. zlib error %s", err); - return -1; // Bad data! + // Bad data! + return tl::make_unexpected(make_zlib_error_code(ret)); } if (sock.zInflate.avail_out != 0) @@ -602,6 +603,11 @@ ssize_t readNoInt(Socket& sock, void *buf, size_t max_size, size_t *rawByteCount ASSERT(sock.zInflate.avail_in == 0, "zlib not consuming all input!"); } + if (sock.readDisconnected) + { + return tl::make_unexpected(make_network_error_code(ECONNRESET)); + } + return max_size - sock.zInflate.avail_out; // Got some data, return how much. } @@ -617,6 +623,10 @@ ssize_t readNoInt(Socket& sock, void *buf, size_t max_size, size_t *rawByteCount while (received == SOCKET_ERROR && getSockErr() == EINTR); sock.ready = false; + if (sock.readDisconnected) + { + return tl::make_unexpected(make_network_error_code(ECONNRESET)); + } rawBytes = received; return received; @@ -633,7 +643,7 @@ bool socketReadDisconnected(const Socket& sock) * * @return @c size when successful or @c SOCKET_ERROR if an error occurred. */ -ssize_t writeAll(Socket& sock, const void *buf, size_t size, size_t *rawByteCount) +net::result writeAll(Socket& sock, const void *buf, size_t size, size_t *rawByteCount) { size_t ignored; size_t &rawBytes = rawByteCount != nullptr ? *rawByteCount : ignored; @@ -642,13 +652,12 @@ ssize_t writeAll(Socket& sock, const void *buf, size_t size, size_t *rawByteCoun if (sock.fd[SOCK_CONNECTION] == INVALID_SOCKET) { debug(LOG_ERROR, "Invalid socket (EBADF)"); - setSockErr(EBADF); - return SOCKET_ERROR; + return tl::make_unexpected(make_network_error_code(EBADF)); } - if (sock.writeError) + if (sock.writeErrorCode) { - return SOCKET_ERROR; + return tl::make_unexpected(sock.writeErrorCode); } if (size > 0) @@ -727,7 +736,7 @@ void socketFlush(Socket& sock, uint8_t player, size_t *rawByteCount) return; // Not compressed, so don't mess with zlib. } - ASSERT(!sock.writeError, "Socket write error?? (Player: %" PRIu8 "", player); + ASSERT(!sock.writeErrorCode, "Socket write error?? (Player: %" PRIu8 "", player); // Flush data out of zlib compression state. do @@ -1047,7 +1056,7 @@ int checkSockets(const SocketSet& set, unsigned int timeout) * when the other end disconnected or a timeout occurred. Or @c SOCKET_ERROR if * an error occurred. */ -ssize_t readAll(Socket& sock, void *buf, size_t size, unsigned int timeout) +net::result readAll(Socket& sock, void *buf, size_t size, unsigned int timeout) { ASSERT(!sock.isCompressed, "readAll on compressed sockets not implemented."); @@ -1058,8 +1067,7 @@ ssize_t readAll(Socket& sock, void *buf, size_t size, unsigned int timeout) if (sock.fd[SOCK_CONNECTION] == INVALID_SOCKET) { debug(LOG_ERROR, "Invalid socket (%p), sock->fd[SOCK_CONNECTION]=%" PRIuPTR"x (error: EBADF)", static_cast(&sock), static_cast(sock.fd[SOCK_CONNECTION])); - setSockErr(EBADF); - return SOCKET_ERROR; + return tl::make_unexpected(make_network_error_code(EBADF)); } while (received < size) @@ -1076,10 +1084,10 @@ ssize_t readAll(Socket& sock, void *buf, size_t size, unsigned int timeout) if (ret == 0) { debug(LOG_NET, "socket (%p) has timed out.", static_cast(&sock)); - setSockErr(ETIMEDOUT); + return tl::make_unexpected(make_network_error_code(ETIMEDOUT)); } debug(LOG_NET, "socket (%p) error.", static_cast(&sock)); - return SOCKET_ERROR; + return tl::make_unexpected(make_network_error_code(getSockErr())); } } @@ -1089,13 +1097,13 @@ ssize_t readAll(Socket& sock, void *buf, size_t size, unsigned int timeout) { debug(LOG_NET, "Socket %" PRIuPTR"x disconnected.", static_cast(sock.fd[SOCK_CONNECTION])); sock.readDisconnected = true; - setSockErr(ECONNRESET); - return received; + return tl::make_unexpected(make_network_error_code(ECONNRESET)); } if (ret == SOCKET_ERROR) { - switch (getSockErr()) + const auto sockErr = getSockErr(); + switch (sockErr) { case EAGAIN: #if defined(EWOULDBLOCK) && EAGAIN != EWOULDBLOCK @@ -1105,7 +1113,7 @@ ssize_t readAll(Socket& sock, void *buf, size_t size, unsigned int timeout) continue; default: - return SOCKET_ERROR; + return tl::make_unexpected(make_network_error_code(sockErr)); } } @@ -1237,18 +1245,12 @@ Socket *socketAccept(Socket *sock) return nullptr; } -Socket *socketOpen(const SocketAddress *addr, unsigned timeout) +net::result socketOpen(const SocketAddress *addr, unsigned timeout) { unsigned int i; int ret; Socket *const conn = new Socket; - if (conn == nullptr) - { - debug(LOG_ERROR, "Out of memory!"); - abort(); - return nullptr; - } ASSERT(addr != nullptr, "NULL Socket provided"); @@ -1276,9 +1278,10 @@ Socket *socketOpen(const SocketAddress *addr, unsigned timeout) if (conn->fd[SOCK_CONNECTION] == INVALID_SOCKET) { - debug(LOG_ERROR, "Failed to create a socket (%p): %s", static_cast(conn), strSockError(getSockErr())); + const auto sockErr = getSockErr(); + debug(LOG_ERROR, "Failed to create a socket (%p): %s", static_cast(conn), strSockError(sockErr)); socketClose(conn); - return nullptr; + return tl::make_unexpected(make_network_error_code(sockErr)); } #if !defined(SOCK_CLOEXEC) @@ -1292,9 +1295,10 @@ Socket *socketOpen(const SocketAddress *addr, unsigned timeout) debug(LOG_NET, "setting socket (%p) blocking status (false).", static_cast(conn)); if (!setSocketBlocking(conn->fd[SOCK_CONNECTION], false)) { + const auto sockErr = getSockErr(); debug(LOG_NET, "Couldn't set socket (%p) blocking status (false). Closing.", static_cast(conn)); socketClose(conn); - return nullptr; + return tl::make_unexpected(make_network_error_code(sockErr)); } socketBlockSIGPIPE(conn->fd[SOCK_CONNECTION], true); @@ -1315,9 +1319,10 @@ Socket *socketOpen(const SocketAddress *addr, unsigned timeout) #endif || timeout == 0) { - debug(LOG_NET, "Failed to start connecting: %s, using socket %p", strSockError(getSockErr()), static_cast(conn)); + const auto sockErr = getSockErr(); + debug(LOG_NET, "Failed to start connecting: %s, using socket %p", strSockError(sockErr), static_cast(conn)); socketClose(conn); - return nullptr; + return tl::make_unexpected(make_network_error_code(sockErr)); } do @@ -1341,17 +1346,18 @@ Socket *socketOpen(const SocketAddress *addr, unsigned timeout) if (ret == SOCKET_ERROR) { - debug(LOG_NET, "Failed to wait for connection: %s, socket %p. Closing.", strSockError(getSockErr()), static_cast(conn)); + const auto sockErr = getSockErr(); + debug(LOG_NET, "Failed to wait for connection: %s, socket %p. Closing.", strSockError(sockErr), static_cast(conn)); socketClose(conn); - return nullptr; + return tl::make_unexpected(make_network_error_code(sockErr)); } if (ret == 0) { - setSockErr(ETIMEDOUT); - debug(LOG_NET, "Timed out while waiting for connection to be established: %s, using socket %p. Closing.", strSockError(getSockErr()), static_cast(conn)); + const auto sockErr = ETIMEDOUT; + debug(LOG_NET, "Timed out while waiting for connection to be established: %s, using socket %p. Closing.", strSockError(sockErr), static_cast(conn)); socketClose(conn); - return nullptr; + return tl::make_unexpected(make_network_error_code(sockErr)); } #if defined(WZ_OS_WIN) @@ -1367,16 +1373,17 @@ Socket *socketOpen(const SocketAddress *addr, unsigned timeout) && getSockErr() != EISCONN) #endif { - debug(LOG_NET, "Failed to connect: %s, with socket %p. Closing.", strSockError(getSockErr()), static_cast(conn)); + const auto sockErr = getSockErr(); + debug(LOG_NET, "Failed to connect: %s, with socket %p. Closing.", strSockError(sockErr), static_cast(conn)); socketClose(conn); - return nullptr; + return tl::make_unexpected(make_network_error_code(sockErr)); } } return conn; } -Socket *socketListen(unsigned int port) +net::result socketListen(unsigned int port) { /* Enable the V4 to V6 mapping, but only when available, because it * isn't available on all platforms. @@ -1421,9 +1428,10 @@ Socket *socketListen(unsigned int port) if (conn->fd[SOCK_IPV4_LISTEN] == INVALID_SOCKET && conn->fd[SOCK_IPV6_LISTEN] == INVALID_SOCKET) { - debug(LOG_ERROR, "Failed to create an IPv4 and IPv6 (only supported address families) socket (%p): %s. Closing.", static_cast(conn), strSockError(getSockErr())); + const auto errorCode = getSockErr(); + debug(LOG_ERROR, "Failed to create an IPv4 and IPv6 (only supported address families) socket (%p): %s. Closing.", static_cast(conn), strSockError(errorCode)); socketClose(conn); - return nullptr; + return tl::make_unexpected(make_network_error_code(errorCode)); } if (conn->fd[SOCK_IPV4_LISTEN] != INVALID_SOCKET) @@ -1520,25 +1528,29 @@ Socket *socketListen(unsigned int port) if (conn->fd[SOCK_IPV4_LISTEN] == INVALID_SOCKET && conn->fd[SOCK_IPV6_LISTEN] == INVALID_SOCKET) { + const auto errorCode = getSockErr(); debug(LOG_NET, "No IPv4 or IPv6 sockets created."); socketClose(conn); - return nullptr; + return tl::make_unexpected(make_network_error_code(errorCode)); } return conn; } -Socket *socketOpenAny(const SocketAddress *addr, unsigned timeout) +net::result socketOpenAny(const SocketAddress *addr, unsigned timeout) { - Socket *ret = nullptr; - while (addr != nullptr && ret == nullptr) + net::result res; + while (addr != nullptr) { - ret = socketOpen(addr, timeout); - + res = socketOpen(addr, timeout); + if (res) + { + return res; + } addr = addr->ai_next; } - return ret; + return res; } bool socketHasIPv4(const Socket& sock) @@ -1636,12 +1648,12 @@ std::string ipv6_NetBinary_To_AddressString(const std::vector& ip return ipv6Address; } -SocketAddress *resolveHost(const char *host, unsigned int port) +net::result resolveHost(const char *host, unsigned int port) { struct addrinfo *results; std::string service; struct addrinfo hint; - int error, flags = 0; + int flags = 0; hint.ai_family = AF_UNSPEC; hint.ai_socktype = SOCK_STREAM; @@ -1660,11 +1672,13 @@ SocketAddress *resolveHost(const char *host, unsigned int port) service = astringf("%u", port); - error = getaddrinfo(host, service.c_str(), &hint, &results); + auto error = getaddrinfo(host, service.c_str(), &hint, &results); if (error != 0) { - debug(LOG_NET, "getaddrinfo failed for %s:%s: %s", host, service.c_str(), gai_strerror(error)); - return nullptr; + const auto ec = make_getaddrinfo_error_code(error); + const auto errMsg = ec.message(); + debug(LOG_NET, "getaddrinfo failed for %s:%s: %s", host, service.c_str(), errMsg.c_str()); + return tl::make_unexpected(ec); } return results; @@ -1742,20 +1756,25 @@ void SOCKETshutdown() OpenConnectionResult socketOpenTCPConnectionSync(const char *host, uint32_t port) { - SocketAddress *hosts = resolveHost(host, port); + const auto hostsResult = resolveHost(host, port); + SocketAddress* hosts = hostsResult.value_or(nullptr); if (hosts == nullptr) { - int sErr = getSockErr(); - return OpenConnectionResult((sErr != 0) ? sErr : -1, astringf("Cannot resolve host \"%s\": [%d]: %s", host, sErr, strSockError(sErr))); + const auto hostsErr = hostsResult.error(); + const auto hostsErrMsg = hostsErr.message(); + return OpenConnectionResult(hostsErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, hostsErr.value(), hostsErrMsg.c_str())); } - Socket* client_transient_socket = socketOpenAny(hosts, 15000); - int sockOpenErr = getSockErr(); + auto sockResult = socketOpenAny(hosts, 15000); + Socket* client_transient_socket = sockResult.value_or(nullptr); deleteSocketAddress(hosts); + hosts = nullptr; if (client_transient_socket == nullptr) { - return OpenConnectionResult((sockOpenErr != 0) ? sockOpenErr : -1, astringf("Cannot connect to [%s]:%d, [%d]:%s", host, port, sockOpenErr, strSockError(sockOpenErr))); + const auto errValue = sockResult.error(); + const auto errMsg = errValue.message(); + return OpenConnectionResult(errValue, astringf("Cannot connect to [%s]:%d, [%d]:%s", host, port, errValue.value(), errMsg.c_str())); } return OpenConnectionResult(client_transient_socket); diff --git a/lib/netplay/netsocket.h b/lib/netplay/netsocket.h index 1aad4412af6..8bca6a7a1c1 100644 --- a/lib/netplay/netsocket.h +++ b/lib/netplay/netsocket.h @@ -94,21 +94,16 @@ static const int SOCKET_ERROR = -1; void SOCKETinit(); void SOCKETshutdown(); -// General. -int getSockErr(); ///< Gets last socket error. (May be overwritten by functions that set errno.) -void setSockErr(int error); ///< Sets last socket error. -const char *strSockError(int error); ///< Converts a socket error to a string. - // Socket addresses. -SocketAddress *resolveHost(const char *host, unsigned port); ///< Looks up a socket address. +net::result resolveHost(const char *host, unsigned port); ///< Looks up a socket address. WZ_DECL_NONNULL(1) void deleteSocketAddress(SocketAddress *addr); ///< Destroys the socket address. // Sockets. -Socket *socketOpen(const SocketAddress *addr, unsigned timeout); ///< Opens a Socket, using the first address in addr. -Socket *socketListen(unsigned int port); ///< Creates a listen-only Socket, which listens for incoming connections. +net::result socketOpen(const SocketAddress *addr, unsigned timeout); ///< Opens a Socket, using the first address in addr. +net::result socketListen(unsigned int port); ///< Creates a listen-only Socket, which listens for incoming connections. WZ_DECL_NONNULL(1) Socket *socketAccept(Socket *sock); ///< Accepts an incoming Socket connection from a listening Socket. WZ_DECL_NONNULL(1) void socketClose(Socket *sock); ///< Destroys the Socket. -Socket *socketOpenAny(const SocketAddress *addr, unsigned timeout); ///< Opens a Socket, using the first address that works in addr. +net::result socketOpenAny(const SocketAddress *addr, unsigned timeout); ///< Opens a Socket, using the first address that works in addr. bool socketHasIPv4(const Socket& sock); bool socketHasIPv6(const Socket& sock); @@ -119,11 +114,11 @@ std::string ipv4_NetBinary_To_AddressString(const std::vector& ip std::string ipv6_NetBinary_To_AddressString(const std::vector& ip6NetBinaryForm); bool socketReadReady(const Socket& sock); ///< Returns if checkSockets found data to read from this Socket. WZ_DECL_NONNULL(2) -ssize_t readNoInt(Socket& sock, void *buf, size_t max_size, size_t *rawByteCount = nullptr); ///< Reads up to max_size bytes from the Socket. Raw count of bytes (after compression) returned in rawByteCount. +net::result readNoInt(Socket& sock, void *buf, size_t max_size, size_t *rawByteCount = nullptr); ///< Reads up to max_size bytes from the Socket. Raw count of bytes (after compression) returned in rawByteCount. WZ_DECL_NONNULL(2) -ssize_t readAll(Socket& sock, void *buf, size_t size, unsigned timeout);///< Reads exactly size bytes from the Socket, or blocks until the timeout expires. +net::result readAll(Socket& sock, void *buf, size_t size, unsigned timeout);///< Reads exactly size bytes from the Socket, or blocks until the timeout expires. WZ_DECL_NONNULL(2) -ssize_t writeAll(Socket& sock, const void *buf, size_t size, size_t *rawByteCount = nullptr); ///< Nonblocking write of size bytes to the Socket. All bytes will be written asynchronously, by a separate thread. Raw count of bytes (after compression) returned in rawByteCount, which will often be 0 until the socket is flushed. +net::result writeAll(Socket& sock, const void *buf, size_t size, size_t *rawByteCount = nullptr); ///< Nonblocking write of size bytes to the Socket. All bytes will be written asynchronously, by a separate thread. Raw count of bytes (after compression) returned in rawByteCount, which will often be 0 until the socket is flushed. bool socketSetTCPNoDelay(Socket& sock, bool nodelay); ///< nodelay = true disables the Nagle algorithm for TCP socket @@ -144,8 +139,8 @@ int checkSockets(const SocketSet& set, unsigned int timeout); ///< Checks which struct OpenConnectionResult { public: - OpenConnectionResult(int error, std::string errorString) - : error(error) + OpenConnectionResult(std::error_code ec, std::string errorString) + : errorCode(ec) , errorString(errorString) { } @@ -153,7 +148,7 @@ struct OpenConnectionResult : open_socket(open_socket) { } public: - bool hasError() const { return error != 0; } + bool hasError() const { return static_cast(errorCode); } public: OpenConnectionResult( const OpenConnectionResult& other ) = delete; // non construction-copyable OpenConnectionResult& operator=( const OpenConnectionResult& ) = delete; // non copyable @@ -164,7 +159,7 @@ struct OpenConnectionResult void operator()(Socket* b) { if (b) { socketClose(b); } } }; std::unique_ptr open_socket; - int error = 0; + std::error_code errorCode; std::string errorString; }; typedef std::function OpenConnectionToHostResultCallback; diff --git a/src/screens/joiningscreen.cpp b/src/screens/joiningscreen.cpp index 796285f4bda..d249fe76c90 100644 --- a/src/screens/joiningscreen.cpp +++ b/src/screens/joiningscreen.cpp @@ -1084,8 +1084,8 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect { debug(LOG_ERROR, "%s", result.errorString.c_str()); // Done trying connections - all failed - const char* pSocketErrorStr = strSockError(result.error); - auto localizedError = astringf(_("Failed to open connection: [%d] %s"), result.error, (pSocketErrorStr) ? pSocketErrorStr : ""); + const auto sockErrorMsg = result.errorCode.message(); + auto localizedError = astringf(_("Failed to open connection: [%d] %s"), result.errorCode.value(), sockErrorMsg.c_str()); handleFailure(FailureDetails::makeFromInternalError(WzString::fromUtf8(localizedError))); } return; @@ -1111,9 +1111,11 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect pushu32(NETGetMajorVersion()); pushu32(NETGetMinorVersion()); - if (writeAll(*client_transient_socket, buffer, sizeof(buffer)) == SOCKET_ERROR) + const auto writeResult = writeAll(*client_transient_socket, buffer, sizeof(buffer)); + if (!writeResult.has_value()) { - debug(LOG_ERROR, "Couldn't send my version."); + const auto writeErrMsg = writeResult.error().message(); + debug(LOG_ERROR, "Couldn't send my version: %s", writeErrMsg.c_str()); closeConnectionAttempt(); return; } @@ -1186,18 +1188,19 @@ bool WzJoiningGameScreen_HandlerRoot::joiningSocketNETsend() uint8_t *rawData = message->rawDataDup(); ssize_t rawLen = message->rawLen(); size_t compressedRawLen = 0; - ssize_t result = writeAll(*client_transient_socket, rawData, rawLen, &compressedRawLen); + const auto writeResult = writeAll(*client_transient_socket, rawData, rawLen, &compressedRawLen); delete[] rawData; // Done with the data. queue->popMessageForNet(); - if (result == rawLen) + if (writeResult.has_value()) { // success writing to socket debug(LOG_NET, "Wrote initial message to socket to host"); } - else if (result == SOCKET_ERROR) + else { + const auto writeErrMsg = writeResult.error().message(); // Write error, most likely host disconnect. - debug(LOG_ERROR, "Failed to send message (type: %" PRIu8 ", rawLen: %zu, compressedRawLen: %zu) to host", message->type, message->rawLen(), compressedRawLen); + debug(LOG_ERROR, "Failed to send message (type: %" PRIu8 ", rawLen: %zu, compressedRawLen: %zu) to host: %s", message->type, message->rawLen(), compressedRawLen, writeErrMsg.c_str()); return false; } socketFlush(*client_transient_socket, NET_HOST_ONLY); // Make sure the message was completely sent. @@ -1286,10 +1289,10 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() } char *p_buffer = initialAckBuffer; - ssize_t sizeRead = readNoInt(*client_transient_socket, p_buffer + usedInitialAckBuffer, expectedInitialAckSize - usedInitialAckBuffer); - if (sizeRead != SOCKET_ERROR) + const auto readResult = readNoInt(*client_transient_socket, p_buffer + usedInitialAckBuffer, expectedInitialAckSize - usedInitialAckBuffer); + if (readResult.has_value()) { - usedInitialAckBuffer += sizeRead; + usedInitialAckBuffer += static_cast(readResult.value()); } if (usedInitialAckBuffer >= expectedInitialAckSize) @@ -1337,18 +1340,18 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() } uint8_t readBuffer[NET_BUFFER_SIZE]; - ssize_t size = readNoInt(*client_transient_socket, readBuffer, sizeof(readBuffer)); - - if ((size == 0 && socketReadDisconnected(*client_transient_socket)) || size == SOCKET_ERROR) + const auto readResult = readNoInt(*client_transient_socket, readBuffer, sizeof(readBuffer)); + if (!readResult.has_value()) { // disconnect or programmer error - if (size == 0) + if (readResult.error() == std::errc::timed_out || readResult.error() == std::errc::connection_reset) { debug(LOG_NET, "Client socket disconnected."); } else { - debug(LOG_NET, "Client socket encountered error: %s", strSockError(getSockErr())); + const auto readErrMsg = readResult.error().message(); + debug(LOG_NET, "Client socket encountered error: %s", readErrMsg.c_str()); } NETlogEntry("Client socket disconnected (allowJoining)", SYNC_FLAG, startTime); debug(LOG_NET, "freeing temp socket %p (%d)", static_cast(client_transient_socket), __LINE__); @@ -1358,7 +1361,7 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() } else { - NETinsertRawData(tmpJoiningQUEUE, readBuffer, size); + NETinsertRawData(tmpJoiningQUEUE, readBuffer, static_cast(readResult.value())); } }