diff --git a/lib/netplay/CMakeLists.txt b/lib/netplay/CMakeLists.txt index 12cb14cb648..65962e17e91 100644 --- a/lib/netplay/CMakeLists.txt +++ b/lib/netplay/CMakeLists.txt @@ -26,8 +26,16 @@ add_dependencies(autorevision_netcodeversion autorevision) # Ensure ordering and ############################ # netplay library -file(GLOB HEADERS "*.h") -file(GLOB SRC "*.cpp") +file(GLOB_RECURSE HEADERS "*.h") +file(GLOB_RECURSE SRC "*.cpp") + +if(MSVC AND CMAKE_VERSION VERSION_GREATER 3.7) + # Automatic detection of source groups via `source_group(TREE )` syntax + # has been introduced in CMake 3.8. + # Please consult https://cmake.org/cmake/help/latest/command/source_group.html for additional info. + source_group(TREE "${CMAKE_CURRENT_LIST_DIR}" PREFIX "Sources" FILES ${SRC}) + source_group(TREE "${CMAKE_CURRENT_LIST_DIR}" PREFIX "Headers" FILES ${HEADERS}) +endif() find_package (Threads REQUIRED) find_package (ZLIB REQUIRED) diff --git a/lib/netplay/byteorder_funcs_wrapper.cpp b/lib/netplay/byteorder_funcs_wrapper.cpp new file mode 100644 index 00000000000..79366a8949b --- /dev/null +++ b/lib/netplay/byteorder_funcs_wrapper.cpp @@ -0,0 +1,50 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 1999-2004 Eidos Interactive + Copyright (C) 2005-2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/byteorder_funcs_wrapper.h" + +#include "lib/framework/wzglobal.h" + +// bring in the original `htonl`/`htons`/`ntohs`/`htohl` functions +#if defined WZ_OS_WIN +# include +#else // *NIX / *BSD variants +# include +#endif + +uint32_t wz_htonl(uint32_t hostlong) +{ + return htonl(hostlong); +} + +uint16_t wz_htons(uint16_t hostshort) +{ + return htons(hostshort); +} + +uint32_t wz_ntohl(uint32_t netlong) +{ + return ntohl(netlong); +} + +uint16_t wz_ntohs(uint16_t netshort) +{ + return ntohs(netshort); +} diff --git a/lib/netplay/byteorder_funcs_wrapper.h b/lib/netplay/byteorder_funcs_wrapper.h new file mode 100644 index 00000000000..c1348fd6505 --- /dev/null +++ b/lib/netplay/byteorder_funcs_wrapper.h @@ -0,0 +1,34 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 1999-2004 Eidos Interactive + Copyright (C) 2005-2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +/// +/// byteorder functions wrappers for WZ just to avoid polluting all places, +/// where these functions are needed, with conditional includes of +/// and winsock headers. +/// + +uint32_t wz_htonl(uint32_t hostlong); +uint16_t wz_htons(uint16_t hostshort); +uint32_t wz_ntohl(uint32_t netlong); +uint16_t wz_ntohs(uint16_t netshort); diff --git a/lib/netplay/client_connection.h b/lib/netplay/client_connection.h new file mode 100644 index 00000000000..986fdc04b4a --- /dev/null +++ b/lib/netplay/client_connection.h @@ -0,0 +1,120 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include + +#include "lib/framework/types.h" // bring in `ssize_t` for MSVC +#include "lib/netplay/net_result.h" + +/// +/// Basic abstraction over client connection sockets. +/// +/// These are capable of reading (`readAll` and `readNoInt`) and +/// writing data (via `writeAll()` + `flush()` combination). +/// +/// The internal implementation may also implement advanced compression mechanisms +/// on top of these connections by providing non-trivial `enableCompression()` overload. +/// +/// In this case, `writeAll()` should somehow accumulate the data into a write queue, +/// compressing the outcoming data on-the-fly; and `flush()` should empty the write queue +/// and actually post a message to the transmission queue, which, in turn, will be emptied +/// by the internal connection interface in a timely manner, when there are enough messages +/// to be sent over the network. +/// +class IClientConnection +{ +public: + + virtual ~IClientConnection() = default; + + /// + /// Read exactly `size` bytes into `buf` buffer. + /// Supports setting a timeout value in milliseconds. + /// + /// Destination buffer to read the data into. + /// The size of data to be read in bytes. + /// Timeout value in milliseconds. + /// On success, returns the number of bytes read; + /// On failure, returns an `std::error_code` (having `GenericSystemErrorCategory` error category) + /// describing the actual error. + virtual net::result readAll(void* buf, size_t size, unsigned timeout) = 0; + /// + /// Reads at most `max_size` bytes into `buf` buffer. + /// Raw count of bytes (after compression) is returned in `rawByteCount`. + /// + /// Destination buffer to read the data into. + /// The maximum number of bytes to read from the client socket. + /// Output parameter: Raw count of bytes (after compression). + /// On success, returns the number of bytes read; + /// On failure, returns an `std::error_code` (having `GenericSystemErrorCategory` error category) + /// describing the actual error. + virtual net::result readNoInt(void* buf, size_t max_size, size_t* rawByteCount) = 0; + /// + /// Nonblocking write of `size` bytes to the socket. The data will be written to a + /// separate write queue in asynchronous manner, possibly by a separate thread. + /// Raw count of bytes (after compression) will be returned in `rawByteCount`, which + /// will often be 0 until the socket is flushed. + /// + /// The reason for this method to be async is that in some cases we want + /// client connections to have compression mechanism enabled. This naturally + /// introduces the 2-phase write process, which involves a write queue (accumulating + /// the data for compression on-the-fly) and a submission (transmission) + /// queue (for transmitting of compressed and assembled messages), + /// which is managed by the network backend implementation. + /// + /// Source buffer to read the data from. + /// The number of bytes to write to the socket. + /// Output parameter: raw count of bytes (after compression) written. + /// The total number of bytes written. + virtual net::result writeAll(const void* buf, size_t size, size_t* rawByteCount) = 0; + /// + /// This method indicates whether the socket has some data ready to be read (i.e. + /// whether the next `readAll/readNoInt` operation will execute without blocking or not). + /// + virtual bool readReady() const = 0; + /// + /// Actually sends the data written with `writeAll()`. Only useful with sockets + /// which have compression enabled. + /// Note that flushing too often makes compression less effective. + /// Raw count of bytes (after compression) is returned in `rawByteCount`. + /// + /// Raw count of bytes (after compression) as written + /// to the submission queue by the flush operation. + virtual void flush(size_t* rawByteCount) = 0; + /// + /// Enables compression for the current socket. + /// + /// This makes all subsequent write operations asynchronous, plus + /// the written data will need to be flushed explicitly at some point. + /// + virtual void enableCompression() = 0; + /// + /// Enables or disables the use of Nagle algorithm for the socket. + /// + /// For direct TCP connections this is equivalent to setting `TCP_NODELAY` to the + /// appropriate value (i.e.: + /// `enable == true` <=> `TCP_NODELAY == false`; + /// `enable == false` <=> `TCP_NODELAY == true`). + /// + virtual void useNagleAlgorithm(bool enable) = 0; + /// + /// Returns textual representation of the socket's connection address. + /// + virtual std::string textAddress() const = 0; +}; diff --git a/lib/netplay/connection_address.h b/lib/netplay/connection_address.h new file mode 100644 index 00000000000..876a46ea745 --- /dev/null +++ b/lib/netplay/connection_address.h @@ -0,0 +1,50 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include +#include + +#include "lib/netplay/net_result.h" + + + +/// +/// Opaque class representing abstract connection address to use with various +/// network backend implementations. The internal representation is made +/// hidden on purpose since we don't want to actually leak internal data layout +/// to clients. +/// +/// Instead, we would like to introduce "conversion routines" yielding +/// various representations for convenient consumption with various network +/// backends. +/// +/// NOTE: this class may or may not represent a chain of resolved network addresses +/// instead of just a single one, much like a `addrinfo` structure. +/// +/// Currently, only knows how to convert itself to `addrinfo` struct, +/// which is used with the `TCP_DIRECT` network backend. +/// +/// New conversion routines should be introduced for other network backends, +/// if deemed necessary. +/// +struct IConnectionAddress +{ + virtual ~IConnectionAddress() = default; +}; diff --git a/lib/netplay/connection_poll_group.h b/lib/netplay/connection_poll_group.h new file mode 100644 index 00000000000..a2b66bcb801 --- /dev/null +++ b/lib/netplay/connection_poll_group.h @@ -0,0 +1,42 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +class IClientConnection; + +/// +/// Abstract representation of a poll group comprised of several client connections. +/// +class IConnectionPollGroup +{ +public: + + virtual ~IConnectionPollGroup() = default; + + /// + /// Polls the sockets in the poll group for updates. + /// + /// Timeout value after which the internal implementation should abandon + /// polling the client connections and return. + /// On success, returns the number of connection descriptors in the poll group. + /// On failure, `0` can returned if the timeout expired before any connection descriptors + /// became ready, or `-1` if there was an error during the internal poll operation. + virtual int checkSockets(unsigned timeout) = 0; + + virtual void add(IClientConnection* conn) = 0; + virtual void remove(IClientConnection* conn) = 0; +}; diff --git a/lib/netplay/connection_provider_registry.cpp b/lib/netplay/connection_provider_registry.cpp new file mode 100644 index 00000000000..b68e7c8c331 --- /dev/null +++ b/lib/netplay/connection_provider_registry.cpp @@ -0,0 +1,57 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include + +#include "lib/netplay/connection_provider_registry.h" +#include "lib/netplay/tcp/tcp_connection_provider.h" + +ConnectionProviderRegistry& ConnectionProviderRegistry::Instance() +{ + static ConnectionProviderRegistry instance; + return instance; +} + +WzConnectionProvider& ConnectionProviderRegistry::Get(ConnectionProviderType pt) +{ + const auto it = registeredProviders_.find(pt); + if (it == registeredProviders_.end()) + { + throw std::runtime_error("Attempt to get nonexistent connection provider"); + } + return *it->second; +} + +void ConnectionProviderRegistry::Register(ConnectionProviderType pt) +{ + // No-op in case this provider has been already registered. + switch (pt) + { + case ConnectionProviderType::TCP_DIRECT: + registeredProviders_.emplace(pt, std::make_unique()); + break; + default: + throw std::runtime_error("Unknown connection provider type"); + } +} + +void ConnectionProviderRegistry::Deregister(ConnectionProviderType pt) +{ + registeredProviders_.erase(pt); +} diff --git a/lib/netplay/connection_provider_registry.h b/lib/netplay/connection_provider_registry.h new file mode 100644 index 00000000000..2a760b47536 --- /dev/null +++ b/lib/netplay/connection_provider_registry.h @@ -0,0 +1,54 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include + +#include "lib/netplay/wz_connection_provider.h" + +/// +/// Available types of connection providers (i.e. network backend implementations). +/// +enum class ConnectionProviderType +{ + TCP_DIRECT +}; + +/// +/// Global singleton registry containing available network connection providers. +/// +class ConnectionProviderRegistry +{ +public: + + static ConnectionProviderRegistry& Instance(); + + WzConnectionProvider& Get(ConnectionProviderType pt); + + void Register(ConnectionProviderType pt); + void Deregister(ConnectionProviderType pt); + +private: + + ConnectionProviderRegistry() = default; + + std::unordered_map> registeredProviders_; +}; diff --git a/lib/netplay/listen_socket.h b/lib/netplay/listen_socket.h new file mode 100644 index 00000000000..04cfdd9fc9e --- /dev/null +++ b/lib/netplay/listen_socket.h @@ -0,0 +1,45 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include + +class IClientConnection; + +/// +/// Server-side listen socket abstraction. +/// +class IListenSocket +{ +public: + + virtual ~IListenSocket() = default; + + enum class IPVersions : uint8_t + { + IPV4 = 0b00000001, + IPV6 = 0b00000010 + }; + using IPVersionsMask = std::underlying_type_t; + + /// + /// Accept an incoming client connection on the current server-side listen socket. + /// + virtual IClientConnection* accept() = 0; + virtual IPVersionsMask supportedIpVersions() const = 0; +}; diff --git a/lib/netplay/net_result.h b/lib/netplay/net_result.h new file mode 100644 index 00000000000..be558bcc517 --- /dev/null +++ b/lib/netplay/net_result.h @@ -0,0 +1,32 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include + +namespace net +{ + +template +using result = ::tl::expected; + +} // namespace net diff --git a/lib/netplay/netplay.cpp b/lib/netplay/netplay.cpp index 2a74a231660..467f6c1ba26 100644 --- a/lib/netplay/netplay.cpp +++ b/lib/netplay/netplay.cpp @@ -48,7 +48,11 @@ #include "netplay.h" #include "netlog.h" #include "netreplay.h" -#include "netsocket.h" +#include "lib/netplay/byteorder_funcs_wrapper.h" +#include "lib/netplay/client_connection.h" +#include "lib/netplay/listen_socket.h" +#include "lib/netplay/connection_poll_group.h" +#include "lib/netplay/connection_provider_registry.h" #include "netpermissions.h" #include "sync_debug.h" #include "port_mapping_manager.h" @@ -71,6 +75,12 @@ # include "lib/framework/cocoa_wrapper.h" #endif +#ifndef WZ_OS_WIN +static const int SOCKET_ERROR = -1; +#else +# include // SOCKET_ERROR +#endif + // WARNING !!! This is initialised via configuration.c !!! char masterserver_name[255] = {'\0'}; static unsigned int masterserver_port = 0, gameserver_port = 0; @@ -161,8 +171,8 @@ class LobbyServerConnectionHandler Connected }; LobbyConnectionState currentState = LobbyConnectionState::Disconnected; - Socket *rs_socket = nullptr; - SocketSet* waitingForConnectionFinalize = nullptr; + IClientConnection* rs_socket = nullptr; + IConnectionPollGroup* waitingForConnectionFinalize = nullptr; uint32_t lastConnectionTime = 0; uint32_t lastServerUpdate = 0; bool queuedServerUpdate = false; @@ -232,15 +242,15 @@ bool netPlayersUpdated; // Server-side socket (host-only) which is used to listen for client connections. // There's also `rs_socket` held by `LobbyServerConnectionHandler`, which is used to communicate with the lobby server. -static Socket* server_listen_socket = nullptr; +static IListenSocket* server_listen_socket = nullptr; -static Socket *bsocket = nullptr; ///< Socket used to talk to the host (clients only). If bsocket != NULL, then client_transient_socket == NULL. -static Socket *connected_bsocket[MAX_CONNECTED_PLAYERS] = { nullptr }; ///< Sockets used to talk to clients (host only). +static IClientConnection* bsocket = nullptr; ///< Socket used to talk to the host (clients only). If bsocket != NULL, then client_transient_socket == NULL. +static IClientConnection* connected_bsocket[MAX_CONNECTED_PLAYERS] = { nullptr }; ///< Sockets used to talk to clients (host only). // Client-side socket set. Contains of only 1 socket at most: `bsocket` (which is a stable client connection to the host). -static SocketSet* client_socket_set = nullptr; +static IConnectionPollGroup* client_socket_set = nullptr; // Server-side socket set. Contains up to `MAX_CONNECTED_PLAYERS` sockets: // `connected_bsocket[i]` - sockets used to communicate with clients during a game session. -static SocketSet* server_socket_set = nullptr; +static IConnectionPollGroup* server_socket_set = nullptr; /** * Used for connections with clients. @@ -302,13 +312,13 @@ struct TmpSocketInfo } }; -static Socket *tmp_socket[MAX_TMP_SOCKETS] = { nullptr }; ///< Sockets used to talk to clients which have not yet been assigned a player number (host only). +static IClientConnection* tmp_socket[MAX_TMP_SOCKETS] = { nullptr }; ///< Sockets used to talk to clients which have not yet been assigned a player number (host only). static std::array tmp_connectState; static bool bAsyncJoinApprovalEnabled = false; static std::unordered_map tmp_pendingIPs; static lru11::Cache tmp_badIPs(512, 64); -static SocketSet *tmp_socket_set = nullptr; +static IConnectionPollGroup* tmp_socket_set = nullptr; static int32_t NetGameFlags[4] = { 0, 0, 0, 0 }; char iptoconnect[PATH_MAX] = "\0"; // holds IP/hostname from command line bool cliConnectToIpAsSpectator = false; // for cli option @@ -540,17 +550,17 @@ bool NETsetAsyncJoinApprovalResult(const std::string& uniqueJoinID, AsyncJoinApp // *********** Socket with buffer that read NETMSGs ****************** -static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *bufstart, int bufsize) +static size_t NET_fillBuffer(IClientConnection** pSocket, IConnectionPollGroup* pSocketSet, uint8_t *bufstart, int bufsize) { - Socket *socket = *pSocket; + IClientConnection* socket = *pSocket; - if (!socketReadReady(*socket)) + if (!socket->readReady()) { return 0; } size_t rawBytes; - const auto readResult = readNoInt(*socket, bufstart, bufsize, &rawBytes); + const auto readResult = socket->readNoInt(bufstart, bufsize, &rawBytes); if (readResult.has_value()) { @@ -578,7 +588,7 @@ static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *b // an error occurred, or the remote host has closed the connection. if (pSocketSet != nullptr) { - SocketSet_DelSocket(*pSocketSet, socket); + pSocketSet->remove(socket); } if (bsocket == socket) { @@ -592,7 +602,7 @@ static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *b NETclose(); return 0; } - socketClose(socket); + delete socket; *pSocket = nullptr; } @@ -974,8 +984,8 @@ static void NETplayerCloseSocket(UDWORD index, bool quietSocketClose) NETlogEntry("Player has left nicely.", SYNC_FLAG, index); // Although we can get a error result from DelSocket, it don't really matter here. - SocketSet_DelSocket(*server_socket_set, connected_bsocket[index]); - socketClose(connected_bsocket[index]); + server_socket_set->remove(connected_bsocket[index]); + delete connected_bsocket[index]; connected_bsocket[index] = nullptr; } else @@ -1221,7 +1231,7 @@ static constexpr size_t GAMESTRUCTmessageBufSize() * * @see GAMESTRUCT,NETrecvGAMESTRUCT */ -static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourgamestruct) +static net::result NETsendGAMESTRUCT(IClientConnection* sock, const GAMESTRUCT *ourgamestruct) { // A buffer that's guaranteed to have the correct size (i.e. it // circumvents struct padding, which could pose a problem). Initialise @@ -1232,13 +1242,13 @@ static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourga unsigned int i; auto push32 = [&](uint32_t value) { - uint32_t swapped = htonl(value); + uint32_t swapped = wz_htonl(value); memcpy(buffer, &swapped, sizeof(swapped)); buffer += sizeof(swapped); }; auto push16 = [&](uint16_t value) { - uint16_t swapped = htons(value); + uint16_t swapped = wz_htons(value); memcpy(buffer, &swapped, sizeof(swapped)); buffer += sizeof(swapped); }; @@ -1327,7 +1337,7 @@ static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourga debug(LOG_NET, "sending GAMESTRUCT, size: %u", (unsigned int)sizeof(buf)); // Send over the GAMESTRUCT - const auto writeResult = writeAll(*sock, buf, sizeof(buf)); + const auto writeResult = sock->writeAll(buf, sizeof(buf), nullptr); if (!writeResult.has_value()) { const auto writeErrMsg = writeResult.error().message(); @@ -1347,7 +1357,7 @@ static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourga * * @see GAMESTRUCT,NETsendGAMESTRUCT */ -static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) +static bool NETrecvGAMESTRUCT(IClientConnection& sock, GAMESTRUCT *ourgamestruct) { // A buffer that's guaranteed to have the correct size (i.e. it // circumvents struct padding, which could pose a problem). @@ -1358,7 +1368,7 @@ static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) auto pop32 = [&]() -> uint32_t { uint32_t value = 0; memcpy(&value, buffer, sizeof(value)); - value = ntohl(value); + value = wz_ntohl(value); buffer += sizeof(value); return value; }; @@ -1366,13 +1376,13 @@ static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) auto pop16 = [&]() -> uint16_t { uint16_t value = 0; memcpy(&value, buffer, sizeof(value)); - value = ntohs(value); + value = wz_ntohs(value); buffer += sizeof(value); return value; }; // Read a GAMESTRUCT from the connection - auto readResult = readAll(sock, buf, sizeof(buf), NET_TIMEOUT_DELAY); + auto readResult = sock.readAll(buf, sizeof(buf), NET_TIMEOUT_DELAY); if (!readResult.has_value()) { if (readResult.error() == std::errc::timed_out || readResult.error() == std::errc::connection_reset) @@ -1532,7 +1542,8 @@ int NETinit(bool bFirstCall) NETlogEntry("NETinit!", SYNC_FLAG, selectedPlayer); NET_InitPlayers(true, true); - SOCKETinit(); + ConnectionProviderRegistry::Instance().Register(ConnectionProviderType::TCP_DIRECT); + ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT).initialize(); if (bFirstCall) { @@ -1580,7 +1591,8 @@ int NETshutdown() } NetPlay.MOTD = nullptr; NETdeleteQueue(); - SOCKETshutdown(); + ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT).shutdown(); + ConnectionProviderRegistry::Instance().Deregister(ConnectionProviderType::TCP_DIRECT); // Reset net usage statistics. nStats = nZeroStats; @@ -1611,7 +1623,7 @@ int NETclose() if (connected_bsocket[i]) { debug(LOG_NET, "Closing connected_bsocket[%u], %p", i, static_cast(connected_bsocket[i])); - socketClose(connected_bsocket[i]); + delete connected_bsocket[i]; connected_bsocket[i] = nullptr; } NET_DestroyPlayer(i, true); @@ -1620,7 +1632,7 @@ int NETclose() if (tmp_socket_set) { debug(LOG_NET, "Freeing tmp_socket_set %p", static_cast(tmp_socket_set)); - deleteSocketSet(tmp_socket_set); + delete tmp_socket_set; tmp_socket_set = nullptr; } @@ -1630,7 +1642,7 @@ int NETclose() { // FIXME: need SocketSet_DelSocket() as well, socket_set or tmp_socket_set? debug(LOG_NET, "Closing tmp_socket[%d] %p", i, static_cast(tmp_socket[i])); - socketClose(tmp_socket[i]); + delete tmp_socket[i]; tmp_socket[i] = nullptr; } } @@ -1639,28 +1651,28 @@ int NETclose() { if (bsocket) { - SocketSet_DelSocket(*client_socket_set, bsocket); + client_socket_set->remove(bsocket); } debug(LOG_NET, "Freeing socket_set %p", static_cast(client_socket_set)); - deleteSocketSet(client_socket_set); + delete client_socket_set; client_socket_set = nullptr; } else if (server_socket_set) { debug(LOG_NET, "Freeing socket_set %p", static_cast(server_socket_set)); - deleteSocketSet(server_socket_set); + delete server_socket_set; server_socket_set = nullptr; } if (server_listen_socket) { debug(LOG_NET, "Closing server_listen_socket %p", static_cast(server_listen_socket)); - socketClose(server_listen_socket); + delete server_listen_socket; server_listen_socket = nullptr; } if (bsocket) { debug(LOG_NET, "Closing bsocket %p", static_cast(bsocket)); - socketClose(bsocket); + delete bsocket; bsocket = nullptr; } @@ -1740,7 +1752,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) return true; } - Socket **sockets = connected_bsocket; + IClientConnection** sockets = connected_bsocket; bool isTmpQueue = false; switch (queue.queueType) { @@ -1777,7 +1789,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) } ssize_t rawLen = message->rawLen(); size_t compressedRawLen; - const auto writeResult = writeAll(*sockets[player], rawData, rawLen, &compressedRawLen); + const auto writeResult = sockets[player]->writeAll(rawData, rawLen, &compressedRawLen); const auto res = writeResult.value_or(SOCKET_ERROR); delete[] rawData; // Done with the data. @@ -1810,7 +1822,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) uint8_t *rawData = message->rawDataDup(); ssize_t rawLen = message->rawLen(); size_t compressedRawLen; - const auto writeResult = writeAll(*bsocket, rawData, rawLen, &compressedRawLen); + const auto writeResult = bsocket->writeAll(rawData, rawLen, &compressedRawLen); const auto res = writeResult.value_or(SOCKET_ERROR); delete[] rawData; // Done with the data. @@ -1827,8 +1839,8 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) debug(LOG_ERROR, "Failed to send message: %s", writeErrMsg.c_str()); debug(LOG_ERROR, "Host connection was broken, socket %p.", static_cast(bsocket)); NETlogEntry("write error--client disconnect.", SYNC_FLAG, player); - SocketSet_DelSocket(*client_socket_set, bsocket); // mark it invalid - socketClose(bsocket); + client_socket_set->remove(bsocket); // mark it invalid + delete bsocket; bsocket = nullptr; NetPlay.players[NetPlay.hostPlayer].heartbeat = false; // mark host as dead //Game is pretty much over --should just end everything when HOST dies. @@ -1869,7 +1881,7 @@ void NETflush() // We are the host, send directly to player. if (connected_bsocket[player] != nullptr) { - socketFlush(*connected_bsocket[player], player, &compressedRawLen); + connected_bsocket[player]->flush(&compressedRawLen); nStats.rawBytes.sent += compressedRawLen; } } @@ -1878,7 +1890,7 @@ void NETflush() // We are the host, send directly to player. if (tmp_socket[player] != nullptr) { - socketFlush(*tmp_socket[player], std::numeric_limits::max(), &compressedRawLen); + tmp_socket[player]->flush(&compressedRawLen); nStats.rawBytes.sent += compressedRawLen; } } @@ -1887,7 +1899,7 @@ void NETflush() { if (bsocket != nullptr) { - socketFlush(*bsocket, NetPlay.hostPlayer, &compressedRawLen); + bsocket->flush(&compressedRawLen); nStats.rawBytes.sent += compressedRawLen; } } @@ -2852,15 +2864,15 @@ bool NETrecvNet(NETQUEUE *queue, uint8_t *type) NETcheckPlayers(); // make sure players are still alive & well } - SocketSet* sset = NetPlay.isHost ? server_socket_set : client_socket_set; - if (sset == nullptr || checkSockets(*sset, NET_READ_TIMEOUT) <= 0) + IConnectionPollGroup* pollGroup = NetPlay.isHost ? server_socket_set : client_socket_set; + if (pollGroup == nullptr || pollGroup->checkSockets(NET_READ_TIMEOUT) <= 0) { goto checkMessages; } for (current = 0; current < MAX_CONNECTED_PLAYERS; ++current) { - Socket **pSocket = NetPlay.isHost ? &connected_bsocket[current] : &bsocket; + IClientConnection** pSocket = NetPlay.isHost ? &connected_bsocket[current] : &bsocket; uint8_t buffer[NET_BUFFER_SIZE]; size_t dataLen; @@ -2874,7 +2886,7 @@ bool NETrecvNet(NETQUEUE *queue, uint8_t *type) continue; } - dataLen = NET_fillBuffer(pSocket, sset, buffer, sizeof(buffer)); + dataLen = NET_fillBuffer(pSocket, pollGroup, buffer, sizeof(buffer)); if (dataLen > 0) { // we received some data, add to buffer @@ -3254,7 +3266,7 @@ unsigned NETgetDownloadProgress(unsigned player) return static_cast(progress); } -static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) +static ssize_t readLobbyResponse(IClientConnection& sock, unsigned int timeout) { uint32_t lobbyStatusCode; uint32_t MOTDLength; @@ -3262,14 +3274,14 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) ssize_t received = 0; // Get status and message length - auto readResult = readAll(sock, &buffer, sizeof(buffer), timeout); + auto readResult = sock.readAll(&buffer, sizeof(buffer), timeout); if (!readResult.has_value()) { goto error; } received += readResult.value(); - lobbyStatusCode = ntohl(buffer[0]); - MOTDLength = ntohl(buffer[1]); + lobbyStatusCode = wz_ntohl(buffer[0]); + MOTDLength = wz_ntohl(buffer[1]); // Get status message if (NetPlay.MOTD) @@ -3277,7 +3289,7 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) free(NetPlay.MOTD); } NetPlay.MOTD = (char *)malloc(MOTDLength + 1); - readResult = readAll(sock, NetPlay.MOTD, MOTDLength, timeout); + readResult = sock.readAll(NetPlay.MOTD, MOTDLength, timeout); if (!readResult.has_value()) { goto error; @@ -3332,15 +3344,15 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) return SOCKET_ERROR; } -bool readGameStructsList(Socket& sock, unsigned int timeout, const std::function& handleEnumerateGameFunc) +bool readGameStructsList(IClientConnection& sock, unsigned int timeout, const std::function& handleEnumerateGameFunc) { unsigned int gamecount = 0; uint32_t gamesavailable = 0; - const auto readResult = readAll(sock, &gamesavailable, sizeof(gamesavailable), NET_TIMEOUT_DELAY); + const auto readResult = sock.readAll(&gamesavailable, sizeof(gamesavailable), NET_TIMEOUT_DELAY); if (readResult.has_value()) { - gamesavailable = ntohl(gamesavailable); + gamesavailable = wz_ntohl(gamesavailable); } else { @@ -3372,7 +3384,8 @@ bool readGameStructsList(Socket& sock, unsigned int timeout, const std::function if (tmpGame.desc.host[0] == '\0') { memset(tmpGame.desc.host, 0, sizeof(tmpGame.desc.host)); - strncpy(tmpGame.desc.host, getSocketTextAddress(sock), sizeof(tmpGame.desc.host) - 1); + const auto textAddr = sock.textAddress(); + strncpy(tmpGame.desc.host, textAddr.data(), sizeof(tmpGame.desc.host) - 1); } uint32_t Vmgr = (tmpGame.future4 & 0xFFFF0000) >> 16; @@ -3424,12 +3437,13 @@ bool LobbyServerConnectionHandler::connect() return false; // already connecting or connected } + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + bool bProcessingConnectOrDisconnectThisCall = true; uint32_t gameId = 0; - const auto hostsResult = resolveHost(masterserver_name, masterserver_port); - const auto hosts = hostsResult.value_or(nullptr); + const auto hostsResult = connProvider.resolveHost(masterserver_name, masterserver_port); - if (hosts == nullptr) + if (!hostsResult.has_value()) { const auto hostsErrMsg = hostsResult.error().message(); debug(LOG_ERROR, "Cannot resolve masterserver \"%s\": %s", masterserver_name, hostsErrMsg.c_str()); @@ -3443,16 +3457,17 @@ bool LobbyServerConnectionHandler::connect() return bProcessingConnectOrDisconnectThisCall; } + const auto& hosts = hostsResult.value(); + // Close an existing socket. if (rs_socket != nullptr) { - socketClose(rs_socket); + delete rs_socket; rs_socket = nullptr; } // try each address from resolveHost until we successfully connect. - auto sockResult = socketOpenAny(hosts, 1500); - deleteSocketAddress(hosts); + auto sockResult = connProvider.openClientConnectionAny(*hosts, 1500); rs_socket = sockResult.value_or(nullptr); @@ -3472,10 +3487,10 @@ bool LobbyServerConnectionHandler::connect() } // Get a game ID - auto gameIdResult = writeAll(*rs_socket, "gaId", sizeof("gaId")); + auto gameIdResult = rs_socket->writeAll("gaId", sizeof("gaId"), nullptr); if (gameIdResult.has_value()) { - gameIdResult = readAll(*rs_socket, &gameId, sizeof(gameId), 10000); + gameIdResult = rs_socket->readAll(&gameId, sizeof(gameId), 10000); } if (!gameIdResult.has_value()) { @@ -3495,13 +3510,13 @@ bool LobbyServerConnectionHandler::connect() return bProcessingConnectOrDisconnectThisCall; } - gamestruct.gameId = ntohl(gameId); + gamestruct.gameId = wz_ntohl(gameId); debug(LOG_NET, "Using game ID: %u", (unsigned int)gamestruct.gameId); wz_command_interface_output("WZEVENT: lobbyid: %" PRIu32 "\n", gamestruct.gameId); // Register our game with the server - const auto writeAddGameRes = writeAll(*rs_socket, "addg", sizeof("addg")); + const auto writeAddGameRes = rs_socket->writeAll("addg", sizeof("addg"), nullptr); auto sendGamestructRes = ignoreExpectedResultValue(writeAddGameRes); if (sendGamestructRes.has_value()) @@ -3521,8 +3536,8 @@ bool LobbyServerConnectionHandler::connect() queuedServerUpdate = false; lastConnectionTime = realTime; - waitingForConnectionFinalize = allocSocketSet(); - SocketSet_AddSocket(*waitingForConnectionFinalize, rs_socket); + waitingForConnectionFinalize = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT).newConnectionPollGroup(); + waitingForConnectionFinalize->add(rs_socket); currentState = LobbyConnectionState::Connecting_WaitingForResponse; return bProcessingConnectOrDisconnectThisCall; @@ -3538,7 +3553,7 @@ bool LobbyServerConnectionHandler::disconnect() if (rs_socket != nullptr) { // we don't need this anymore, so clean up - socketClose(rs_socket); + delete rs_socket; rs_socket = nullptr; server_not_there = true; } @@ -3602,7 +3617,7 @@ void LobbyServerConnectionHandler::sendUpdateNow() void LobbyServerConnectionHandler::sendKeepAlive() { ASSERT_OR_RETURN(, rs_socket != nullptr, "Null socket"); - if (!writeAll(*rs_socket, "keep", sizeof("keep")).has_value()) + if (!rs_socket->writeAll("keep", sizeof("keep"), nullptr).has_value()) { // The socket has been invalidated, so get rid of it. (using them now may cause SIGPIPE). disconnect(); @@ -3625,21 +3640,21 @@ void LobbyServerConnectionHandler::run() bool exceededTimeout = (realTime - lastConnectionTime >= 10000); // We use readLobbyResponse to display error messages and handle state changes if there's no response // So if exceededTimeout, just call it with a low timeout - int checkSocketRet = checkSockets(*waitingForConnectionFinalize, NET_READ_TIMEOUT); + int checkSocketRet = waitingForConnectionFinalize->checkSockets(NET_READ_TIMEOUT); if (checkSocketRet == SOCKET_ERROR) { debug(LOG_ERROR, "Lost connection to lobby server"); disconnect(); break; } - if (exceededTimeout || (checkSocketRet > 0 && socketReadReady(*rs_socket))) + if (exceededTimeout || (checkSocketRet > 0 && rs_socket->readReady())) { if (readLobbyResponse(*rs_socket, NET_TIMEOUT_DELAY) == SOCKET_ERROR) { disconnect(); break; } - deleteSocketSet(waitingForConnectionFinalize); + delete waitingForConnectionFinalize; waitingForConnectionFinalize = nullptr; currentState = LobbyConnectionState::Connected; } @@ -3774,9 +3789,9 @@ static bool quickRejectConnection(const std::string& ip) static void NETcloseTempSocket(unsigned int i) { - std::string rIP = getSocketTextAddress(*tmp_socket[i]); - SocketSet_DelSocket(*tmp_socket_set, tmp_socket[i]); - socketClose(tmp_socket[i]); + std::string rIP = tmp_socket[i]->textAddress(); + tmp_socket_set->remove(tmp_socket[i]); + delete tmp_socket[i]; tmp_socket[i] = nullptr; tmp_connectState[i].reset(); auto it = tmp_pendingIPs.find(rIP); @@ -3795,14 +3810,14 @@ static void NETcloseTempSocket(unsigned int i) static void NEThostPromoteTempSocketToPermanentPlayerConnection(unsigned int tempSocketIdx, uint8_t index) { - std::string rIP = getSocketTextAddress(*tmp_socket[tempSocketIdx]); + std::string rIP = tmp_socket[tempSocketIdx]->textAddress(); debug(LOG_NET, "freeing temp socket %p (%d), creating permanent socket.", static_cast(tmp_socket[tempSocketIdx]), __LINE__); - SocketSet_DelSocket(*tmp_socket_set, tmp_socket[tempSocketIdx]); + tmp_socket_set->remove(tmp_socket[tempSocketIdx]); connected_bsocket[index] = tmp_socket[tempSocketIdx]; tmp_socket[tempSocketIdx] = nullptr; NET_waitingForIndexChangeAckSince[index] = nullopt; - SocketSet_AddSocket(*server_socket_set, connected_bsocket[index]); + server_socket_set->add(connected_bsocket[index]); NETmoveQueue(NETnetTmpQueue(tempSocketIdx), NETnetQueue(index)); // Copy player's IP address @@ -3846,12 +3861,13 @@ static void NETallowJoining() ActivitySink::ListeningInterfaces listeningInterfaces; if (server_listen_socket != nullptr) { - listeningInterfaces.IPv4 = socketHasIPv4(*server_listen_socket); + const auto supportedProtocols = server_listen_socket->supportedIpVersions(); + listeningInterfaces.IPv4 = supportedProtocols & static_cast(IListenSocket::IPVersions::IPV4); if (listeningInterfaces.IPv4) { listeningInterfaces.ipv4_port = NETgetGameserverPort(); } - listeningInterfaces.IPv6 = socketHasIPv6(*server_listen_socket); + listeningInterfaces.IPv6 = supportedProtocols & static_cast(IListenSocket::IPVersions::IPV6); if (listeningInterfaces.IPv6) { listeningInterfaces.ipv6_port = NETgetGameserverPort(); @@ -3875,7 +3891,7 @@ static void NETallowJoining() } ASSERT(tmp_socket_set != nullptr, "Null tmp_socket_set"); - if (checkSockets(*tmp_socket_set, NET_READ_TIMEOUT) > 0) + if (tmp_socket_set->checkSockets(NET_READ_TIMEOUT) > 0) { for (i = 0; i < MAX_TMP_SOCKETS; ++i) { @@ -3884,7 +3900,7 @@ static void NETallowJoining() continue; } - if (!socketReadReady(*tmp_socket[i])) + if (!tmp_socket[i]->readReady()) { continue; } @@ -3893,7 +3909,7 @@ static void NETallowJoining() { char *p_buffer = tmp_connectState[i].buffer; - const auto sizeReadResult = readNoInt(*tmp_socket[i], p_buffer + tmp_connectState[i].usedBuffer, 8 - tmp_connectState[i].usedBuffer); + const auto sizeReadResult = tmp_socket[i]->readNoInt(p_buffer + tmp_connectState[i].usedBuffer, 8 - tmp_connectState[i].usedBuffer, nullptr); if (sizeReadResult.has_value()) { tmp_connectState[i].usedBuffer += sizeReadResult.value(); @@ -3914,10 +3930,10 @@ static void NETallowJoining() // Check these numbers with our own. memcpy(&major, p_buffer, sizeof(uint32_t)); - major = ntohl(major); + major = wz_ntohl(major); p_buffer += sizeof(int32_t); memcpy(&minor, p_buffer, sizeof(uint32_t)); - minor = ntohl(minor); + minor = wz_ntohl(minor); if (major == 0 && minor == 0) { @@ -3927,7 +3943,7 @@ static void NETallowJoining() char buf[(sizeof(char) * 4) + sizeof(uint32_t) + sizeof(uint32_t)] = { 0 }; char *pLobbyRespBuffer = buf; auto push32 = [&pLobbyRespBuffer](uint32_t value) { - uint32_t swapped = htonl(value); + uint32_t swapped = wz_htonl(value); memcpy(pLobbyRespBuffer, &swapped, sizeof(swapped)); pLobbyRespBuffer += sizeof(swapped); }; @@ -3943,15 +3959,15 @@ static void NETallowJoining() // Copy gameId (as 32bit large big endian number) push32(gamestruct.gameId); - writeAll(*tmp_socket[i], buf, sizeof(buf)); + tmp_socket[i]->writeAll(buf, sizeof(buf), nullptr); connectFailed = true; } else if (NETisCorrectVersion(major, minor)) { - result = htonl(ERROR_NOERROR); + result = wz_htonl(ERROR_NOERROR); memcpy(&tmp_connectState[i].buffer, &result, sizeof(result)); - writeAll(*tmp_socket[i], &tmp_connectState[i].buffer, sizeof(result)); - socketBeginCompression(*tmp_socket[i]); + tmp_socket[i]->writeAll(&tmp_connectState[i].buffer, sizeof(result), nullptr); + tmp_socket[i]->enableCompression(); // Connection is successful. connectFailed = false; @@ -3966,9 +3982,9 @@ static void NETallowJoining() else { debug(LOG_INFO, "Received an invalid version \"%" PRIu32 ".%" PRIu32 "\".", major, minor); - result = htonl(ERROR_WRONGVERSION); + result = wz_htonl(ERROR_WRONGVERSION); memcpy(&tmp_connectState[i].buffer, &result, sizeof(result)); - writeAll(*tmp_socket[i], &tmp_connectState[i].buffer, sizeof(result)); + tmp_socket[i]->writeAll(&tmp_connectState[i].buffer, sizeof(result), nullptr); NETlogEntry("Invalid game version", SYNC_FLAG, i); NETaddSessionBanBadIP(tmp_connectState[i].ip); connectFailed = true; @@ -4008,7 +4024,7 @@ static void NETallowJoining() else if (tmp_connectState[i].connectState == TmpSocketInfo::TmpConnectState::PendingJoinRequest) { uint8_t buffer[NET_BUFFER_SIZE]; - const auto readResult = readNoInt(*tmp_socket[i], buffer, sizeof(buffer)); + const auto readResult = tmp_socket[i]->readNoInt(buffer, sizeof(buffer), nullptr); uint8_t rejected = 0; if (!readResult.has_value()) @@ -4406,7 +4422,7 @@ static void NETallowJoining() NETpop(tmpQueue); } - std::string rIP = getSocketTextAddress(*tmp_socket[i]); + std::string rIP = tmp_socket[i]->textAddress(); NETaddSessionBanBadIP(rIP); NETcloseTempSocket(i); @@ -4511,14 +4527,16 @@ bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectator return true; } + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + // Start listening for client connections on `gameserver_port`. // These will initially be assigned to `tmp_socket[i]` until accepted in the game session, // in which case `tmp_socket[i]` will be assigned to `connected_bsocket[i]` and `tmp_socket[i]` // will become nullptr. - net::result serverListenResult = {}; + net::result serverListenResult = {}; if (!server_listen_socket) { - serverListenResult = socketListen(gameserver_port); + serverListenResult = connProvider.openListenSocket(gameserver_port); server_listen_socket = serverListenResult.value_or(nullptr); } if (server_listen_socket == nullptr) @@ -4531,7 +4549,7 @@ bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectator // Host needs to create a socket set for MAX_PLAYERS if (!server_socket_set) { - server_socket_set = allocSocketSet(); + server_socket_set = connProvider.newConnectionPollGroup(); } // allocate socket storage for all possible players for (unsigned i = 0; i < MAX_CONNECTED_PLAYERS; ++i) @@ -4636,19 +4654,17 @@ bool NETenumerateGames(const std::function& handl debug(LOG_ERROR, "Likely missing NETinit(true) - this won't return any results"); return false; } - const auto hostsResult = resolveHost(masterserver_name, masterserver_port); - SocketAddress* hosts = hostsResult.value_or(nullptr); - if (!hosts) + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + const auto hostsResult = connProvider.resolveHost(masterserver_name, masterserver_port); + if (!hostsResult.has_value()) { const auto hostsErrMsg = hostsResult.error().message(); debug(LOG_ERROR, "Cannot resolve hostname \"%s\": %s", masterserver_name, hostsErrMsg.c_str()); setLobbyError(ERROR_CONNECTION); return false; } - - auto sockResult = socketOpenAny(hosts, 15000); - deleteSocketAddress(hosts); - hosts = nullptr; + const auto& hosts = hostsResult.value(); + auto sockResult = connProvider.openClientConnectionAny(*hosts, 15000); if (!sockResult.has_value()) { const auto sockErrMsg = sockResult.error().message(); @@ -4656,18 +4672,18 @@ bool NETenumerateGames(const std::function& handl setLobbyError(ERROR_CONNECTION); return false; } - Socket* sock = sockResult.value(); + IClientConnection* sock = sockResult.value(); debug(LOG_NET, "New socket = %p", static_cast(sock)); debug(LOG_NET, "Sending list cmd"); - const auto writeResult = writeAll(*sock, "list", sizeof("list")); + const auto writeResult = sock->writeAll("list", sizeof("list"), nullptr); if (!writeResult.has_value()) { const auto writeErrMsg = writeResult.error().message(); debug(LOG_NET, "Server socket encountered error: %s", writeErrMsg.c_str()); // mark it invalid - socketClose(sock); + delete sock; // when we fail to receive a game count, bail out setLobbyError(ERROR_CONNECTION); @@ -4684,7 +4700,7 @@ bool NETenumerateGames(const std::function& handl })) { // mark it invalid - socketClose(sock); + delete sock; setLobbyError(ERROR_CONNECTION); return false; @@ -4694,7 +4710,7 @@ bool NETenumerateGames(const std::function& handl if (readLobbyResponse(*sock, NET_TIMEOUT_DELAY) == SOCKET_ERROR) { // mark it invalid - socketClose(sock); + delete sock; addConsoleMessage(_("Failed to get a lobby response!"), DEFAULT_JUSTIFY, NOTIFY_MESSAGE); // treat as fatal error @@ -4708,10 +4724,10 @@ bool NETenumerateGames(const std::function& handl // Hence as long as we don't treat "0" as signifying any change in behavior, this should be safe + backwards-compatible #define IGNORE_FIRST_BATCH 1 uint32_t responseParameters = 0; - const auto readResult = readAll(*sock, &responseParameters, sizeof(responseParameters), NET_TIMEOUT_DELAY); + const auto readResult = sock->readAll(&responseParameters, sizeof(responseParameters), NET_TIMEOUT_DELAY); if (readResult.has_value()) { - responseParameters = ntohl(responseParameters); + responseParameters = wz_ntohl(responseParameters); bool requestSecondBatch = true; bool ignoreFirstBatch = ((responseParameters & IGNORE_FIRST_BATCH) == IGNORE_FIRST_BATCH); @@ -4746,7 +4762,7 @@ bool NETenumerateGames(const std::function& handl debug(LOG_NET, "Second readGameStructsList call failed"); // mark it invalid - socketClose(sock); + delete sock; // when we fail to receive a game count, bail out setLobbyError(ERROR_CONNECTION); @@ -4769,7 +4785,7 @@ bool NETenumerateGames(const std::function& handl } // mark it invalid (we are done with it) - socketClose(sock); + delete sock; return true; } @@ -4825,7 +4841,7 @@ bool NETfindGame(uint32_t gameId, GAMESTRUCT& output) } // "consumes" the sockets and related info -bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char *playername, NETQUEUE joiningQUEUEInfo, Socket **client_joining_socket, SocketSet **client_joining_socket_set) +bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char *playername, NETQUEUE joiningQUEUEInfo, IClientConnection** client_joining_socket, IConnectionPollGroup** client_joining_socket_set) { if (hostPlayer >= MAX_CONNECTED_PLAYERS) { @@ -5130,7 +5146,8 @@ void NETacceptIncomingConnections() { // initialize temporary server socket set // FIXME: why is this not done in NETinit()?? - Per - tmp_socket_set = allocSocketSet(); + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + tmp_socket_set = connProvider.newConnectionPollGroup(); // FIXME: I guess initialization of allowjoining is here now... - FlexCoral for (auto& tmpState : tmp_connectState) { @@ -5156,12 +5173,12 @@ void NETacceptIncomingConnections() } // See if there's an incoming connection - tmp_socket[i] = socketAccept(server_listen_socket); + tmp_socket[i] = server_listen_socket->accept(); if (!tmp_socket[i]) { return; } - const std::string rIP = getSocketTextAddress(*tmp_socket[i]); + const std::string rIP = tmp_socket[i]->textAddress(); if (quickRejectConnection(rIP)) { debug(LOG_NET, "freeing temp socket %p (%d)", static_cast(tmp_socket[i]), __LINE__); @@ -5170,7 +5187,7 @@ void NETacceptIncomingConnections() } NETinitQueue(NETnetTmpQueue(i)); - SocketSet_AddSocket(*tmp_socket_set, tmp_socket[i]); + tmp_socket_set->add(tmp_socket[i]); tmp_pendingIPs[rIP]++; @@ -5184,7 +5201,7 @@ void NETacceptIncomingConnections() if (bEnableTCPNoDelay) { - // Enable TCP_NODELAY - socketSetTCPNoDelay(*tmp_socket[i], true); + // Disable use of Nagle Algorithm for the TCP socket (i.e. enable TCP_NODELAY option in case of TCP connection) + tmp_socket[i]->useNagleAlgorithm(false); } } diff --git a/lib/netplay/netplay.h b/lib/netplay/netplay.h index 6cfaf64846c..3147f0b074e 100644 --- a/lib/netplay/netplay.h +++ b/lib/netplay/netplay.h @@ -445,9 +445,11 @@ bool NEThaltJoining(); // stop new players joining this game bool NETenumerateGames(const std::function& handleEnumerateGameFunc); bool NETfindGames(std::vector& results, size_t startingIndex, size_t resultsLimit, bool onlyMatchingLocalVersion = false); bool NETfindGame(uint32_t gameId, GAMESTRUCT& output); -struct Socket; -struct SocketSet; -bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char *playername, NETQUEUE joiningQUEUEInfo, Socket **client_joining_socket, SocketSet **client_joining_socket_set); + +class IClientConnection; +class IConnectionPollGroup; + +bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char* playername, NETQUEUE joiningQUEUEInfo, IClientConnection** client_joining_socket, IConnectionPollGroup** client_joining_socket_set); bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectatorHost, // host a game uint32_t gameType, uint32_t two, uint32_t three, uint32_t four, UDWORD plyrs); bool NETchangePlayerName(UDWORD player, char *newName);// change a players name. diff --git a/lib/netplay/open_connection_result.h b/lib/netplay/open_connection_result.h new file mode 100644 index 00000000000..ae45ddb9ff5 --- /dev/null +++ b/lib/netplay/open_connection_result.h @@ -0,0 +1,60 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include +#include +#include + +#include "lib/netplay/client_connection.h" + +#include +using nonstd::optional; +using nonstd::nullopt; + +// Higher-level functions for opening a connection / socket +struct OpenConnectionResult +{ +public: + OpenConnectionResult(std::error_code ec, std::string errorString) + : errorCode(ec) + , errorString(std::move(errorString)) + { } + + OpenConnectionResult(IClientConnection* open_socket) + : open_socket(open_socket) + { } + +public: + bool hasError() const { return errorCode.has_value(); } +public: + OpenConnectionResult(const OpenConnectionResult& other) = delete; // non construction-copyable + OpenConnectionResult& operator=(const OpenConnectionResult&) = delete; // non copyable + OpenConnectionResult(OpenConnectionResult&&) = default; + OpenConnectionResult& operator=(OpenConnectionResult&&) = default; +public: + + std::unique_ptr open_socket; + optional errorCode = nullopt; + std::string errorString; +}; + +typedef std::function OpenConnectionToHostResultCallback; diff --git a/lib/netplay/sync_debug.cpp b/lib/netplay/sync_debug.cpp index 536db78c080..e23499a86db 100644 --- a/lib/netplay/sync_debug.cpp +++ b/lib/netplay/sync_debug.cpp @@ -27,9 +27,9 @@ #include "lib/framework/debug.h" #include "lib/framework/physfs_ext.h" #include "lib/gamelib/gtime.h" +#include "lib/netplay/byteorder_funcs_wrapper.h" #include "nettypes.h" #include "netplay.h" -#include "netsocket.h" // solely to bring in `htonl` function #include @@ -76,7 +76,7 @@ struct SyncDebugValueChange : public SyncDebugEntry variableName = vn; newValue = nv; id = i; - uint32_t valueBytes = htonl(newValue); + uint32_t valueBytes = wz_htonl(newValue); crc = wz::crc_update(crc, function, strlen(function) + 1); crc = wz::crc_update(crc, variableName, strlen(variableName) + 1); crc = wz::crc_update(crc, &valueBytes, 4); @@ -105,7 +105,7 @@ struct SyncDebugIntList : public SyncDebugEntry numInts = std::min(num, ARRAY_SIZE(valueBytes)); for (unsigned n = 0; n < numInts; ++n) { - valueBytes[n] = htonl(ints[n]); + valueBytes[n] = wz_htonl(ints[n]); } crc = wz::crc_update(crc, valueBytes, 4 * numInts); } diff --git a/lib/netplay/netsocket.cpp b/lib/netplay/tcp/netsocket.cpp similarity index 96% rename from lib/netplay/netsocket.cpp rename to lib/netplay/tcp/netsocket.cpp index f03be08e7e1..9f844793da8 100644 --- a/lib/netplay/netsocket.cpp +++ b/lib/netplay/tcp/netsocket.cpp @@ -26,7 +26,7 @@ #include "lib/framework/frame.h" #include "lib/framework/wzapp.h" #include "netsocket.h" -#include "error_categories.h" +#include "lib/netplay/error_categories.h" #include #include @@ -47,6 +47,9 @@ // Already included Winsock2.h which defines TCP_NODELAY #endif +namespace tcp +{ + enum { SOCK_CONNECTION, @@ -127,6 +130,8 @@ static void setSockErr(int error) #endif } +} // namespace tcp + #if defined(WZ_OS_WIN) typedef int (WINAPI *GETADDRINFO_DLL_FUNC)(const char *node, const char *service, const struct addrinfo *hints, @@ -216,6 +221,9 @@ static void freeaddrinfo(struct addrinfo *res) } #endif +namespace tcp +{ + static int addressToText(const struct sockaddr *addr, char *buf, size_t size) { auto handleIpv4 = [&](uint32_t addr) { @@ -1749,71 +1757,4 @@ void SOCKETshutdown() #endif } -OpenConnectionResult socketOpenTCPConnectionSync(const char *host, uint32_t port) -{ - const auto hostsResult = resolveHost(host, port); - SocketAddress* hosts = hostsResult.value_or(nullptr); - if (hosts == nullptr) - { - const auto hostsErr = hostsResult.error(); - const auto hostsErrMsg = hostsErr.message(); - return OpenConnectionResult(hostsErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, hostsErr.value(), hostsErrMsg.c_str())); - } - - auto sockResult = socketOpenAny(hosts, 15000); - Socket* client_transient_socket = sockResult.value_or(nullptr); - deleteSocketAddress(hosts); - hosts = nullptr; - - if (client_transient_socket == nullptr) - { - const auto errValue = sockResult.error(); - const auto errMsg = errValue.message(); - return OpenConnectionResult(errValue, astringf("Cannot connect to [%s]:%d, [%d]:%s", host, port, errValue.value(), errMsg.c_str())); - } - - return OpenConnectionResult(client_transient_socket); -} - -struct OpenConnectionRequest -{ - std::string host; - uint32_t port = 0; - OpenConnectionToHostResultCallback callback; -}; - -static int openDirectTCPConnectionAsyncImpl(void* data) -{ - OpenConnectionRequest* pRequestInfo = (OpenConnectionRequest*)data; - if (!pRequestInfo) - { - return 1; - } - - pRequestInfo->callback(socketOpenTCPConnectionSync(pRequestInfo->host.c_str(), pRequestInfo->port)); - delete pRequestInfo; - return 0; -} - -bool socketOpenTCPConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) -{ - // spawn background thread to handle this - auto pRequest = new OpenConnectionRequest(); - pRequest->host = host; - pRequest->port = port; - pRequest->callback = callback; - - WZ_THREAD * pOpenConnectionThread = wzThreadCreate(openDirectTCPConnectionAsyncImpl, pRequest); - if (pOpenConnectionThread == nullptr) - { - debug(LOG_ERROR, "Failed to create thread for opening connection"); - delete pRequest; - return false; - } - - wzThreadDetach(pOpenConnectionThread); - // the thread handles deleting pRequest - pOpenConnectionThread = nullptr; - - return true; -} +} // namespace tcp diff --git a/lib/netplay/netsocket.h b/lib/netplay/tcp/netsocket.h similarity index 81% rename from lib/netplay/netsocket.h rename to lib/netplay/tcp/netsocket.h index 01413011cd3..b46a9ce89ce 100644 --- a/lib/netplay/netsocket.h +++ b/lib/netplay/tcp/netsocket.h @@ -21,21 +21,14 @@ #ifndef _net_socket_h #define _net_socket_h +#include "lib/framework/wzglobal.h" #include "lib/framework/types.h" #include #include #include +#include -#include -using nonstd::optional; -using nonstd::nullopt; -#include - -namespace net -{ - template - using result = ::tl::expected; -} // namespace net +#include "lib/netplay/net_result.h" #if defined(WZ_OS_UNIX) # include @@ -84,14 +77,22 @@ static const SOCKET INVALID_SOCKET = -1; # define MSG_NOSIGNAL 0 #endif +namespace tcp +{ + struct Socket; struct SocketSet; + +} // namespace tcp + typedef struct addrinfo SocketAddress; #ifndef WZ_OS_WIN static const int SOCKET_ERROR = -1; #endif +namespace tcp +{ // Init/shutdown. void SOCKETinit(); @@ -137,34 +138,6 @@ WZ_DECL_NONNULL(2) void SocketSet_AddSocket(SocketSet& set, Socket *socket); // WZ_DECL_NONNULL(2) void SocketSet_DelSocket(SocketSet& set, Socket *socket); ///< Removes a Socket from a SocketSet. int checkSockets(const SocketSet& set, unsigned int timeout); ///< Checks which Sockets are ready for reading. Returns the number of ready Sockets, or returns SOCKET_ERROR on error. -// Higher-level functions for opening a connection / socket -struct OpenConnectionResult -{ - OpenConnectionResult(std::error_code ec, std::string errorString) - : errorCode(ec) - , errorString(errorString) - { } - - OpenConnectionResult(Socket* open_socket) - : open_socket(open_socket) - { } - - bool hasError() const { return errorCode.has_value(); } - - OpenConnectionResult( const OpenConnectionResult& other ) = delete; // non construction-copyable - OpenConnectionResult& operator=( const OpenConnectionResult& ) = delete; // non copyable - OpenConnectionResult(OpenConnectionResult&&) = default; - OpenConnectionResult& operator=(OpenConnectionResult&&) = default; - - struct SocketDeleter { - void operator()(Socket* b) { if (b) { socketClose(b); } } - }; - std::unique_ptr open_socket; - optional errorCode = nullopt; - std::string errorString; -}; -typedef std::function OpenConnectionToHostResultCallback; -bool socketOpenTCPConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback); -OpenConnectionResult socketOpenTCPConnectionSync(const char *host, uint32_t port); +} // namespace tcp #endif //_net_socket_h diff --git a/lib/netplay/tcp/tcp_client_connection.cpp b/lib/netplay/tcp/tcp_client_connection.cpp new file mode 100644 index 00000000000..73d9d78c768 --- /dev/null +++ b/lib/netplay/tcp/tcp_client_connection.cpp @@ -0,0 +1,83 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/netsocket.h" +#include "lib/framework/wzapp.h" +#include "lib/framework/debug.h" +#include "lib/framework/string_ext.h" + +namespace tcp +{ + +TCPClientConnection::TCPClientConnection(Socket* rawSocket) + : socket_(rawSocket) +{ + ASSERT(socket_ != nullptr, "Null socket passed to TCPClientConnection ctor"); +} + +TCPClientConnection::~TCPClientConnection() +{ + if (socket_) + { + socketClose(socket_); + } +} + +net::result TCPClientConnection::readAll(void* buf, size_t size, unsigned timeout) +{ + return tcp::readAll(*socket_, buf, size, timeout); +} + +net::result TCPClientConnection::readNoInt(void* buf, size_t maxSize, size_t* rawByteCount) +{ + return tcp::readNoInt(*socket_, buf, maxSize, rawByteCount); +} + +net::result TCPClientConnection::writeAll(const void* buf, size_t size, size_t* rawByteCount) +{ + return tcp::writeAll(*socket_, buf, size, rawByteCount); +} + +bool TCPClientConnection::readReady() const +{ + return socketReadReady(*socket_); +} + +void TCPClientConnection::flush(size_t* rawByteCount) +{ + socketFlush(*socket_, std::numeric_limits::max()/*unused*/, rawByteCount); +} + +void TCPClientConnection::enableCompression() +{ + socketBeginCompression(*socket_); +} + +void TCPClientConnection::useNagleAlgorithm(bool enable) +{ + socketSetTCPNoDelay(*socket_, !enable); +} + +std::string TCPClientConnection::textAddress() const +{ + return getSocketTextAddress(*socket_); +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_client_connection.h b/lib/netplay/tcp/tcp_client_connection.h new file mode 100644 index 00000000000..c2d00959d7f --- /dev/null +++ b/lib/netplay/tcp/tcp_client_connection.h @@ -0,0 +1,52 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include "lib/netplay/client_connection.h" + +namespace tcp +{ + +struct Socket; + +class TCPClientConnection : public IClientConnection +{ +public: + + explicit TCPClientConnection(Socket* rawSocket); + virtual ~TCPClientConnection() override; + + virtual net::result readAll(void* buf, size_t size, unsigned timeout) override; + virtual net::result readNoInt(void* buf, size_t maxSize, size_t* rawByteCount) override; + virtual net::result writeAll(const void* buf, size_t size, size_t* rawByteCount) override; + virtual bool readReady() const override; + virtual void flush(size_t* rawByteCount) override; + virtual void enableCompression() override; + virtual void useNagleAlgorithm(bool enable) override; + virtual std::string textAddress() const override; + +private: + + Socket* socket_; + + friend class TCPConnectionPollGroup; +}; + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_address.cpp b/lib/netplay/tcp/tcp_connection_address.cpp new file mode 100644 index 00000000000..66f7d1981e7 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_address.cpp @@ -0,0 +1,30 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_connection_address.h" + +#include "lib/netplay/tcp/netsocket.h" // for `freeaddrinfo` +#include "lib/framework/frame.h" // for `ASSERT` + +TCPConnectionAddress::TCPConnectionAddress(SocketAddress* addr) + : addr_(addr) +{} + +TCPConnectionAddress::~TCPConnectionAddress() +{ + ASSERT(addr_ != nullptr, "Invalid addrinfo stored in the connection address"); + freeaddrinfo(addr_); +} diff --git a/lib/netplay/tcp/tcp_connection_address.h b/lib/netplay/tcp/tcp_connection_address.h new file mode 100644 index 00000000000..fdc7b0b6d76 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_address.h @@ -0,0 +1,43 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include "lib/netplay/connection_address.h" + +#if defined WZ_OS_UNIX +# include +#elif defined WZ_OS_WIN +# include +#endif + +typedef struct addrinfo SocketAddress; + +class TCPConnectionAddress : public IConnectionAddress +{ +public: + + /// Assumes ownership of `addr` + explicit TCPConnectionAddress(SocketAddress* addr); + virtual ~TCPConnectionAddress() override; + + // NOTE: The lifetime of the returned `addrinfo` struct is bounded by the parent object's lifetime! + const SocketAddress* asRawSocketAddress() const { return addr_; } + +private: + + SocketAddress* addr_; +}; diff --git a/lib/netplay/tcp/tcp_connection_poll_group.cpp b/lib/netplay/tcp/tcp_connection_poll_group.cpp new file mode 100644 index 00000000000..7aab0fcbe71 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_poll_group.cpp @@ -0,0 +1,60 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_connection_poll_group.h" +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/netsocket.h" +#include "lib/framework/wzapp.h" +#include "lib/framework/debug.h" + +namespace tcp +{ + +TCPConnectionPollGroup::TCPConnectionPollGroup(SocketSet* sset) + : sset_(sset) +{} + +TCPConnectionPollGroup::~TCPConnectionPollGroup() +{ + if (sset_) + { + deleteSocketSet(sset_); + } +} + +int TCPConnectionPollGroup::checkSockets(unsigned timeout) +{ + return tcp::checkSockets(*sset_, timeout); +} + +void TCPConnectionPollGroup::add(IClientConnection* conn) +{ + auto* tcpConn = dynamic_cast(conn); + ASSERT_OR_RETURN(, tcpConn != nullptr, "Expected to have TCPClientConnection instance"); + SocketSet_AddSocket(*sset_, tcpConn->socket_); +} + +void TCPConnectionPollGroup::remove(IClientConnection* conn) +{ + auto tcpConn = dynamic_cast(conn); + ASSERT_OR_RETURN(, tcpConn != nullptr, "Expected to have TCPClientConnection instance"); + SocketSet_DelSocket(*sset_, tcpConn->socket_); +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_poll_group.h b/lib/netplay/tcp/tcp_connection_poll_group.h new file mode 100644 index 00000000000..d9eec5bb909 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_poll_group.h @@ -0,0 +1,45 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include "lib/netplay/connection_poll_group.h" + +namespace tcp +{ + +struct SocketSet; + +class TCPConnectionPollGroup : public IConnectionPollGroup +{ +public: + + explicit TCPConnectionPollGroup(SocketSet* sset); + virtual ~TCPConnectionPollGroup() override; + + virtual int checkSockets(unsigned timeout) override; + virtual void add(IClientConnection* conn) override; + virtual void remove(IClientConnection* conn) override; + +private: + + SocketSet* sset_; +}; + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_provider.cpp b/lib/netplay/tcp/tcp_connection_provider.cpp new file mode 100644 index 00000000000..1bf1edddf4b --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_provider.cpp @@ -0,0 +1,91 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "tcp_connection_provider.h" + +#include "lib/netplay/tcp/netsocket.h" +#include "lib/netplay/tcp/tcp_connection_address.h" +#include "lib/netplay/tcp/tcp_connection_poll_group.h" +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/tcp_listen_socket.h" + +#include "lib/netplay/open_connection_result.h" +#include "lib/framework/wzapp.h" + +namespace tcp +{ + +void TCPConnectionProvider::initialize() +{ + SOCKETinit(); +} + +void TCPConnectionProvider::shutdown() +{ + SOCKETshutdown(); +} + +net::result> TCPConnectionProvider::resolveHost(const char* host, uint16_t port) +{ + auto resolved = tcp::resolveHost(host, port); + if (!resolved.has_value()) + { + return tl::make_unexpected(resolved.error()); + } + return std::make_unique(resolved.value()); +} + +net::result TCPConnectionProvider::openListenSocket(uint16_t port) +{ + auto res = tcp::socketListen(port); + if (!res.has_value()) + { + return tl::make_unexpected(res.error()); + } + return new TCPListenSocket(res.value()); +} + +net::result TCPConnectionProvider::openClientConnectionAny(const IConnectionAddress& addr, unsigned timeout) +{ + const auto* tcpAddr = dynamic_cast(&addr); + ASSERT(tcpAddr != nullptr, "Expected TCPConnectionAddress instance"); + if (!tcpAddr) + { + throw std::runtime_error("Expected TCPConnectionAddress instance"); + } + const auto* rawAddr = tcpAddr->asRawSocketAddress(); + auto res = socketOpenAny(rawAddr, timeout); + if (!res.has_value()) + { + return tl::make_unexpected(res.error()); + } + return new TCPClientConnection(res.value()); +} + +IConnectionPollGroup* TCPConnectionProvider::newConnectionPollGroup() +{ + auto* sset = allocSocketSet(); + if (!sset) + { + return nullptr; + } + return new TCPConnectionPollGroup(sset); +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_provider.h b/lib/netplay/tcp/tcp_connection_provider.h new file mode 100644 index 00000000000..891bf469f9c --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_provider.h @@ -0,0 +1,47 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include +#include + +#include "lib/netplay/wz_connection_provider.h" + +namespace tcp +{ + +class TCPConnectionProvider final : public WzConnectionProvider +{ +public: + + virtual void initialize() override; + virtual void shutdown() override; + + virtual net::result> resolveHost(const char* host, uint16_t port) override; + + virtual net::result openListenSocket(uint16_t port) override; + + virtual net::result openClientConnectionAny(const IConnectionAddress& addr, unsigned timeout) override; + + virtual IConnectionPollGroup* newConnectionPollGroup() override; +}; + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_listen_socket.cpp b/lib/netplay/tcp/tcp_listen_socket.cpp new file mode 100644 index 00000000000..d7ea294b4fa --- /dev/null +++ b/lib/netplay/tcp/tcp_listen_socket.cpp @@ -0,0 +1,70 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_listen_socket.h" +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/netsocket.h" +#include "lib/framework/wzapp.h" +#include "lib/framework/debug.h" + +namespace tcp +{ + +TCPListenSocket::TCPListenSocket(Socket* rawSocket) + : listenSocket_(rawSocket) +{} + +TCPListenSocket::~TCPListenSocket() +{ + if (listenSocket_) + { + socketClose(listenSocket_); + } +} + +IClientConnection* TCPListenSocket::accept() +{ + ASSERT(listenSocket_ != nullptr, "Internal socket handle shouldn't be null!"); + if (!listenSocket_) + { + return nullptr; + } + auto* s = socketAccept(listenSocket_); + if (!s) + { + return nullptr; + } + return new TCPClientConnection(s); +} + +IListenSocket::IPVersionsMask TCPListenSocket::supportedIpVersions() const +{ + IPVersionsMask resMask = 0; + if (socketHasIPv4(*listenSocket_)) + { + resMask |= static_cast(IPVersions::IPV4); + } + if (socketHasIPv6(*listenSocket_)) + { + resMask |= static_cast(IPVersions::IPV6); + } + return resMask; +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_listen_socket.h b/lib/netplay/tcp/tcp_listen_socket.h new file mode 100644 index 00000000000..db78b25a0df --- /dev/null +++ b/lib/netplay/tcp/tcp_listen_socket.h @@ -0,0 +1,46 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include "lib/netplay/listen_socket.h" + +namespace tcp +{ + +struct Socket; + +class TCPListenSocket : public IListenSocket +{ +public: + + explicit TCPListenSocket(tcp::Socket* rawSocket); + virtual ~TCPListenSocket() override; + + virtual IClientConnection* accept() override; + virtual IPVersionsMask supportedIpVersions() const override; + +private: + + tcp::Socket* listenSocket_; +}; + +} // namespace tcp diff --git a/lib/netplay/wz_connection_provider.cpp b/lib/netplay/wz_connection_provider.cpp new file mode 100644 index 00000000000..89f4b8556b2 --- /dev/null +++ b/lib/netplay/wz_connection_provider.cpp @@ -0,0 +1,93 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "wz_connection_provider.h" + +#include "lib/framework/wzapp.h" + +namespace +{ + +OpenConnectionResult openClientConnectionSyncImpl(const char* host, uint32_t port, std::chrono::milliseconds timeout, WzConnectionProvider* connProvider) +{ + auto addrResult = connProvider->resolveHost(host, port); + if (!addrResult.has_value()) + { + const auto hostsErr = addrResult.error(); + const auto hostsErrMsg = hostsErr.message(); + return OpenConnectionResult(hostsErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, hostsErr.value(), hostsErrMsg.c_str())); + } + auto connRes = connProvider->openClientConnectionAny(*addrResult.value(), timeout.count()); + if (!connRes.has_value()) + { + const auto connErr = connRes.error(); + const auto connErrMsg = connErr.message(); + return OpenConnectionResult(connErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, connErr.value(), connErrMsg.c_str())); + } + return OpenConnectionResult(connRes.value()); +} + +struct OpenConnectionRequest +{ + std::string host; + uint32_t port = 0; + std::chrono::milliseconds timeout{ 15000 }; + OpenConnectionToHostResultCallback callback; + WzConnectionProvider* connProvider; +}; + +int openDirectConnectionAsyncImpl(void* data) +{ + OpenConnectionRequest* pRequestInfo = (OpenConnectionRequest*)data; + if (!pRequestInfo) + { + return 1; + } + pRequestInfo->callback(openClientConnectionSyncImpl( + pRequestInfo->host.c_str(), + pRequestInfo->port, + pRequestInfo->timeout, + pRequestInfo->connProvider)); + delete pRequestInfo; + return 0; +} + +} // anonymous namespace + +bool WzConnectionProvider::openClientConnectionAsync(const std::string& host, uint32_t port, std::chrono::milliseconds timeout, OpenConnectionToHostResultCallback callback) +{ + // spawn background thread to handle this + auto pRequest = new OpenConnectionRequest(); + pRequest->host = host; + pRequest->port = port; + pRequest->timeout = timeout; + pRequest->callback = callback; + pRequest->connProvider = this; + WZ_THREAD* pOpenConnectionThread = wzThreadCreate(openDirectConnectionAsyncImpl, pRequest); + if (pOpenConnectionThread == nullptr) + { + debug(LOG_ERROR, "Failed to create thread for opening connection"); + delete pRequest; + return false; + } + wzThreadDetach(pOpenConnectionThread); + // the thread handles deleting pRequest + pOpenConnectionThread = nullptr; + return true; +} diff --git a/lib/netplay/wz_connection_provider.h b/lib/netplay/wz_connection_provider.h new file mode 100644 index 00000000000..91683e3081b --- /dev/null +++ b/lib/netplay/wz_connection_provider.h @@ -0,0 +1,85 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include +#include +#include + +#include "lib/netplay/connection_address.h" +#include "lib/netplay/net_result.h" +#include "lib/netplay/open_connection_result.h" + +class IListenSocket; +class IClientConnection; +class IConnectionPollGroup; +struct IConnectionAddress; + +/// +/// Abstraction layer to facilitate creating client/server connections and +/// provide host resolution routines for a given network backend. +/// +/// A typical implementation of this interface should at least provide the following +/// things: +/// +/// 1. Initialization/teardown routines (setup some common state, like write/submission +/// queues or service threads, plus initialization of low-level backend code, e.g. +/// calls to init/deinit functions from a 3rd-party library). +/// 2. Host resolution. +/// 3. Opening server-side listen sockets. +/// 4. Opening client-side connections (sync and async). +/// 5. Creating connection poll groups. +/// +class WzConnectionProvider +{ +public: + + virtual ~WzConnectionProvider() = default; + + virtual void initialize() = 0; + virtual void shutdown() = 0; + + /// + /// Resolve host + port combination and return an opaque `ConnectionAddress` handle + /// representing the resolved network address. + /// + virtual net::result> resolveHost(const char* host, uint16_t port) = 0; + /// + /// Open a listening socket bound to a specified local port. + /// + virtual net::result openListenSocket(uint16_t port) = 0; + /// + /// Synchronously open a client connection bound to one of the addresses + /// represented by `addr` (the first one that succeeds). + /// + /// Connection address to bind the client connection to. + /// Timeout in milliseconds. + virtual net::result openClientConnectionAny(const IConnectionAddress& addr, unsigned timeout) = 0; + /// + /// Async variant of `openClientConnectionAny()` with the default implementation, which + /// spawns a new thread and piggybacks on the `resolveHost()` and `openClientConnectionAny()` combination. + /// + virtual bool openClientConnectionAsync(const std::string& host, uint32_t port, std::chrono::milliseconds timeout, OpenConnectionToHostResultCallback callback); + /// + /// Create a group for polling client connections. + /// + virtual IConnectionPollGroup* newConnectionPollGroup() = 0; +}; diff --git a/po/POTFILES.in b/po/POTFILES.in index 9f8c0cc58b6..78dd4afab84 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -302,6 +302,7 @@ lib/ivis_opengl/png_util_spng.cpp lib/ivis_opengl/screen.cpp lib/ivis_opengl/tex.cpp lib/ivis_opengl/textdraw.cpp +lib/netplay/connection_provider_registry.cpp lib/netplay/error_categories.cpp lib/netplay/netjoin_stub.cpp lib/netplay/netlog.cpp @@ -309,8 +310,12 @@ lib/netplay/netpermissions.cpp lib/netplay/netplay.cpp lib/netplay/netqueue.cpp lib/netplay/netreplay.cpp -lib/netplay/netsocket.cpp lib/netplay/nettypes.cpp +lib/netplay/tcp/netsocket.cpp +lib/netplay/tcp/tcp_client_connection.cpp +lib/netplay/tcp/tcp_connection_poll_group.cpp +lib/netplay/tcp/tcp_connection_provider.cpp +lib/netplay/tcp/tcp_listen_socket.cpp lib/netplay/port_mapping_manager.cpp lib/netplay/port_mapping_manager_impl_libplum.cpp lib/netplay/port_mapping_manager_impl_miniupnpc.cpp diff --git a/src/integrations/wzdiscordrpc.cpp b/src/integrations/wzdiscordrpc.cpp index 8cfd245eb18..eaa897d5b16 100644 --- a/src/integrations/wzdiscordrpc.cpp +++ b/src/integrations/wzdiscordrpc.cpp @@ -22,7 +22,7 @@ #include "lib/framework/crc.h" #include "lib/gamelib/gtime.h" #include "lib/netplay/netplay.h" -#include "lib/netplay/netsocket.h" +#include "lib/netplay/tcp/netsocket.h" #include "../activity.h" #include "../frontend.h" #include "../multiint.h" @@ -395,10 +395,10 @@ static void joinGameFromSecret_v1(const std::string joinSecretStr) return; } std::vector ipNetworkBinaryFormat = EmbeddedJSONSignature::b64Decode(b64UrlSafeTob64(connectionDetails[0].toUtf8())); - std::string ipAddressStr = ipv6_NetBinary_To_AddressString(ipNetworkBinaryFormat); + std::string ipAddressStr = tcp::ipv6_NetBinary_To_AddressString(ipNetworkBinaryFormat); if (ipAddressStr.empty()) { - ipAddressStr = ipv4_NetBinary_To_AddressString(ipNetworkBinaryFormat); + ipAddressStr = tcp::ipv4_NetBinary_To_AddressString(ipNetworkBinaryFormat); } if (ipAddressStr.empty()) { @@ -766,7 +766,7 @@ void DiscordRPCActivitySink::setJoinInformation(const ActivitySink::MultiplayerG std::string joinSecretDetails; if (!pExternalIPv4Address->empty()) { - auto ipv4AddressBinaryForm = ipv4_AddressString_To_NetBinary(*pExternalIPv4Address); + auto ipv4AddressBinaryForm = tcp::ipv4_AddressString_To_NetBinary(*pExternalIPv4Address); if (!ipv4AddressBinaryForm.empty()) { joinSecretDetails += b64Tob64UrlSafe(EmbeddedJSONSignature::b64Encode(ipv4AddressBinaryForm)); @@ -775,7 +775,7 @@ void DiscordRPCActivitySink::setJoinInformation(const ActivitySink::MultiplayerG } if (!ipv6Address.empty()) { - auto ipv6AddressBinaryForm = ipv6_AddressString_To_NetBinary(ipv6Address); + auto ipv6AddressBinaryForm = tcp::ipv6_AddressString_To_NetBinary(ipv6Address); if (!ipv6AddressBinaryForm.empty()) { if (!joinSecretDetails.empty()) diff --git a/src/screens/joiningscreen.cpp b/src/screens/joiningscreen.cpp index 1f18c884aff..bf77aafe2fb 100644 --- a/src/screens/joiningscreen.cpp +++ b/src/screens/joiningscreen.cpp @@ -28,8 +28,12 @@ #include "lib/widget/scrollablelist.h" #include "lib/ivis_opengl/pieblitfunc.h" #include "lib/ivis_opengl/piepalette.h" +#include "lib/netplay/byteorder_funcs_wrapper.h" #include "lib/netplay/netplay.h" -#include "lib/netplay/netsocket.h" +#include "lib/netplay/client_connection.h" +#include "lib/netplay/connection_poll_group.h" +#include "lib/netplay/open_connection_result.h" +#include "lib/netplay/connection_provider_registry.h" #include "../hci.h" #include "../activity.h" @@ -786,8 +790,8 @@ class WzJoiningGameScreen_HandlerRoot : public W_CLICKFORM // state when handling initial connection join uint32_t startTime = 0; - Socket* client_transient_socket = nullptr; - SocketSet* tmp_joining_socket_set = nullptr; + IClientConnection* client_transient_socket = nullptr; + IConnectionPollGroup* tmp_joining_socket_set = nullptr; NETQUEUE tmpJoiningQUEUE = {}; NetQueuePair *tmpJoiningQueuePair = nullptr; char initialAckBuffer[10] = {'\0'}; @@ -1162,22 +1166,22 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect if (NETgetEnableTCPNoDelay()) { - // Enable TCP_NODELAY - socketSetTCPNoDelay(*client_transient_socket, true); + // Disable use of Nagle Algorithm for the TCP socket (i.e. enable TCP_NODELAY option in case of TCP transport) + client_transient_socket->useNagleAlgorithm(false); } // Send initial connection data: NETCODE_VERSION_MAJOR and NETCODE_VERSION_MINOR char buffer[sizeof(int32_t) * 2] = { 0 }; char *p_buffer = buffer; auto pushu32 = [&](uint32_t value) { - uint32_t swapped = htonl(value); + uint32_t swapped = wz_htonl(value); memcpy(p_buffer, &swapped, sizeof(swapped)); p_buffer += sizeof(swapped); }; pushu32(NETGetMajorVersion()); pushu32(NETGetMinorVersion()); - const auto writeResult = writeAll(*client_transient_socket, buffer, sizeof(buffer)); + const auto writeResult = client_transient_socket->writeAll(buffer, sizeof(buffer), nullptr); if (!writeResult.has_value()) { const auto writeErrMsg = writeResult.error().message(); @@ -1186,7 +1190,8 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect return; } - tmp_joining_socket_set = allocSocketSet(); + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + tmp_joining_socket_set = connProvider.newConnectionPollGroup(); if (tmp_joining_socket_set == nullptr) { debug(LOG_ERROR, "Cannot create socket set - out of memory?"); @@ -1196,7 +1201,7 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect debug(LOG_NET, "Created socket_set %p", static_cast(tmp_joining_socket_set)); // `client_transient_socket` is used to talk to host machine - SocketSet_AddSocket(*tmp_joining_socket_set, client_transient_socket); + tmp_joining_socket_set->add(client_transient_socket); // Create temporary NETQUEUE auto NETnetJoinTmpQueue = [&]() @@ -1233,15 +1238,20 @@ void WzJoiningGameScreen_HandlerRoot::attemptToOpenConnection(size_t connectionI description.port = NETgetGameserverPort(); // use default configured port } auto weakSelf = std::weak_ptr(std::dynamic_pointer_cast(shared_from_this())); - socketOpenTCPConnectionAsync(description.host, description.port, [weakSelf, connectionIdx](OpenConnectionResult&& result) { - auto strongSelf = weakSelf.lock(); - if (!strongSelf) - { - // background thread ultimately returned after the requester has gone away (join was cancelled?) - just return - return; - } - strongSelf->processOpenConnectionResultOnMainThread(connectionIdx, std::move(result)); - }); + + constexpr std::chrono::milliseconds CLIENT_OPEN_ASYNC_TIMEOUT{ 15000 }; // Default timeout of 15s + + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + connProvider.openClientConnectionAsync(description.host, description.port, CLIENT_OPEN_ASYNC_TIMEOUT, + [weakSelf, connectionIdx](OpenConnectionResult&& result) { + auto strongSelf = weakSelf.lock(); + if (!strongSelf) + { + // background thread ultimately returned after the requester has gone away (join was cancelled?) - just return + return; + } + strongSelf->processOpenConnectionResultOnMainThread(connectionIdx, std::move(result)); + }); break; } updateJoiningStatus(_("Establishing connection with host")); @@ -1254,7 +1264,7 @@ bool WzJoiningGameScreen_HandlerRoot::joiningSocketNETsend() uint8_t *rawData = message->rawDataDup(); ssize_t rawLen = message->rawLen(); size_t compressedRawLen = 0; - const auto writeResult = writeAll(*client_transient_socket, rawData, rawLen, &compressedRawLen); + const auto writeResult = client_transient_socket->writeAll(rawData, rawLen, &compressedRawLen); delete[] rawData; // Done with the data. queue->popMessageForNet(); if (writeResult.has_value()) @@ -1269,7 +1279,7 @@ bool WzJoiningGameScreen_HandlerRoot::joiningSocketNETsend() debug(LOG_ERROR, "Failed to send message (type: %" PRIu8 ", rawLen: %zu, compressedRawLen: %zu) to host: %s", message->type, message->rawLen(), compressedRawLen, writeErrMsg.c_str()); return false; } - socketFlush(*client_transient_socket, NET_HOST_ONLY); // Make sure the message was completely sent. + client_transient_socket->flush(nullptr); // Make sure the message was completely sent. ASSERT(queue->numMessagesForNet() == 0, "Queue not empty (%u messages remaining).", queue->numMessagesForNet()); return true; } @@ -1280,14 +1290,14 @@ void WzJoiningGameScreen_HandlerRoot::closeConnectionAttempt() { if (tmp_joining_socket_set) { - SocketSet_DelSocket(*tmp_joining_socket_set, client_transient_socket); + tmp_joining_socket_set->remove(client_transient_socket); } - socketClose(client_transient_socket); + delete client_transient_socket; client_transient_socket = nullptr; } if (tmp_joining_socket_set) { - deleteSocketSet(tmp_joining_socket_set); + delete tmp_joining_socket_set; tmp_joining_socket_set = nullptr; } if (tmpJoiningQueuePair) @@ -1347,15 +1357,17 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() if (currentJoiningState == JoiningState::AwaitingInitialNetcodeHandshakeAck) { // read in data, if we have it - if (checkSockets(*tmp_joining_socket_set, NET_READ_TIMEOUT) > 0) + if (tmp_joining_socket_set->checkSockets(NET_READ_TIMEOUT) > 0) { - if (!socketReadReady(*client_transient_socket)) + if (!client_transient_socket->readReady()) { return; // wait for next check } char *p_buffer = initialAckBuffer; - const auto readResult = readNoInt(*client_transient_socket, p_buffer + usedInitialAckBuffer, expectedInitialAckSize - usedInitialAckBuffer); + const auto readResult = client_transient_socket->readNoInt(p_buffer + usedInitialAckBuffer, + expectedInitialAckSize - usedInitialAckBuffer, + nullptr); if (readResult.has_value()) { usedInitialAckBuffer += static_cast(readResult.value()); @@ -1365,7 +1377,7 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() { uint32_t result = ERROR_CONNECTION; memcpy(&result, initialAckBuffer, sizeof(result)); - result = ntohl(result); + result = wz_ntohl(result); if (result != ERROR_NOERROR) { debug(LOG_ERROR, "Received error %d", result); @@ -1387,7 +1399,7 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() } // transition to net message mode (enable compression, wait for messages) - socketBeginCompression(*client_transient_socket); + client_transient_socket->enableCompression(); currentJoiningState = JoiningState::ProcessingJoinMessages; // permit fall-through to currentJoiningState == JoiningState::ProcessingJoinMessage case below } @@ -1398,15 +1410,15 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() if (currentJoiningState == JoiningState::ProcessingJoinMessages) { // read in data, if we have it - if (checkSockets(*tmp_joining_socket_set, NET_READ_TIMEOUT) > 0) + if (tmp_joining_socket_set->checkSockets(NET_READ_TIMEOUT) > 0) { - if (!socketReadReady(*client_transient_socket)) + if (!client_transient_socket->readReady()) { return; // wait for next check } uint8_t readBuffer[NET_BUFFER_SIZE]; - const auto readResult = readNoInt(*client_transient_socket, readBuffer, sizeof(readBuffer)); + const auto readResult = client_transient_socket->readNoInt(readBuffer, sizeof(readBuffer), nullptr); if (!readResult.has_value()) { // disconnect or programmer error