From a784526cb1b13f44d3f8f37133f301dda06d9c3d Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Fri, 25 Oct 2024 14:55:33 +0300 Subject: [PATCH 1/3] netplay: introduce abstractions for client/server-side sockets and connection providers This patch introduces multiple high-level abstractions over raw low-level sockets, which are necessary for supporting network backends other than default legacy `TCP_DIRECT` implementation: 1. `WzConnectionProvider` - abstracts the way WZ establishes server-side and client-side connections. This thing effectively provides usable listen sockets and client connections to work with, hence the name. 2. `IListenSocket` - abstraction over listen sockets. 3. `IClientConnection` - abstraction over client-side sockets (and also server-side connections to the game clients). 4. `IConnectionPollGroup` - generalization of socket sets for polling multiple connections in one go. 5. `ConnectionProviderRegistry` - trivial singleton class providing storage for connection providers. 6. `ConnectionAddress` - opaque connection address object, aimed to replace direct uses of `addrinfo` and provide a bit more abstract way to represent connection credentials. Still looks like a crutch right now, but it's better than nothing, nonetheless. The existing implementation in `netplay/netsocket.h(.cpp)` has been moved to the `tcp` subfolder and wrapped entirely into the `tcp` namespace. The patch provides `TCP*`-prefixed implementations of the base interfaces mentioned above, which are implemented in terms of the old `netsocket` code. There's now a `ConnectionProviderType::TCP_DIRECT` enumeration descriptor for accessing the default connection provider. All uses in the high-level code (`netplay.cpp`, `joiningscreen.cpp`) are amended appropriately to use the all-new high-level abstractions instead of old low-level tcp-specific `Socket` and `SocketSet`. NOTE: there are still a few functions from the `tcp::` namespace used directly in the Discord RPC integration code, but these shouldn't pose any problem to either extract these into a more generic abstraction layer or to be rewritten not to use these functions at all, because they don't actually use any low-level stuff that's hard to refactor. Signed-off-by: Pavel Solodovnikov --- lib/netplay/CMakeLists.txt | 12 +- lib/netplay/byteorder_funcs_wrapper.cpp | 50 ++++ lib/netplay/byteorder_funcs_wrapper.h | 34 +++ lib/netplay/client_connection.h | 120 ++++++++ lib/netplay/connection_address.cpp | 65 +++++ lib/netplay/connection_address.h | 72 +++++ lib/netplay/connection_poll_group.h | 42 +++ lib/netplay/connection_provider_registry.cpp | 57 ++++ lib/netplay/connection_provider_registry.h | 54 ++++ lib/netplay/listen_socket.h | 45 +++ lib/netplay/net_result.h | 32 +++ lib/netplay/netplay.cpp | 261 ++++++++++-------- lib/netplay/netplay.h | 8 +- lib/netplay/open_connection_result.h | 60 ++++ lib/netplay/sync_debug.cpp | 6 +- lib/netplay/{ => tcp}/netsocket.cpp | 79 +----- lib/netplay/{ => tcp}/netsocket.h | 51 +--- lib/netplay/tcp/tcp_client_connection.cpp | 83 ++++++ lib/netplay/tcp/tcp_client_connection.h | 52 ++++ lib/netplay/tcp/tcp_connection_poll_group.cpp | 60 ++++ lib/netplay/tcp/tcp_connection_poll_group.h | 45 +++ lib/netplay/tcp/tcp_connection_provider.cpp | 153 ++++++++++ lib/netplay/tcp/tcp_connection_provider.h | 48 ++++ lib/netplay/tcp/tcp_listen_socket.cpp | 70 +++++ lib/netplay/tcp/tcp_listen_socket.h | 46 +++ lib/netplay/wz_connection_provider.h | 80 ++++++ po/POTFILES.in | 8 +- src/integrations/wzdiscordrpc.cpp | 10 +- src/screens/joiningscreen.cpp | 55 ++-- 29 files changed, 1491 insertions(+), 267 deletions(-) create mode 100644 lib/netplay/byteorder_funcs_wrapper.cpp create mode 100644 lib/netplay/byteorder_funcs_wrapper.h create mode 100644 lib/netplay/client_connection.h create mode 100644 lib/netplay/connection_address.cpp create mode 100644 lib/netplay/connection_address.h create mode 100644 lib/netplay/connection_poll_group.h create mode 100644 lib/netplay/connection_provider_registry.cpp create mode 100644 lib/netplay/connection_provider_registry.h create mode 100644 lib/netplay/listen_socket.h create mode 100644 lib/netplay/net_result.h create mode 100644 lib/netplay/open_connection_result.h rename lib/netplay/{ => tcp}/netsocket.cpp (96%) rename lib/netplay/{ => tcp}/netsocket.h (81%) create mode 100644 lib/netplay/tcp/tcp_client_connection.cpp create mode 100644 lib/netplay/tcp/tcp_client_connection.h create mode 100644 lib/netplay/tcp/tcp_connection_poll_group.cpp create mode 100644 lib/netplay/tcp/tcp_connection_poll_group.h create mode 100644 lib/netplay/tcp/tcp_connection_provider.cpp create mode 100644 lib/netplay/tcp/tcp_connection_provider.h create mode 100644 lib/netplay/tcp/tcp_listen_socket.cpp create mode 100644 lib/netplay/tcp/tcp_listen_socket.h create mode 100644 lib/netplay/wz_connection_provider.h diff --git a/lib/netplay/CMakeLists.txt b/lib/netplay/CMakeLists.txt index 12cb14cb648..65962e17e91 100644 --- a/lib/netplay/CMakeLists.txt +++ b/lib/netplay/CMakeLists.txt @@ -26,8 +26,16 @@ add_dependencies(autorevision_netcodeversion autorevision) # Ensure ordering and ############################ # netplay library -file(GLOB HEADERS "*.h") -file(GLOB SRC "*.cpp") +file(GLOB_RECURSE HEADERS "*.h") +file(GLOB_RECURSE SRC "*.cpp") + +if(MSVC AND CMAKE_VERSION VERSION_GREATER 3.7) + # Automatic detection of source groups via `source_group(TREE )` syntax + # has been introduced in CMake 3.8. + # Please consult https://cmake.org/cmake/help/latest/command/source_group.html for additional info. + source_group(TREE "${CMAKE_CURRENT_LIST_DIR}" PREFIX "Sources" FILES ${SRC}) + source_group(TREE "${CMAKE_CURRENT_LIST_DIR}" PREFIX "Headers" FILES ${HEADERS}) +endif() find_package (Threads REQUIRED) find_package (ZLIB REQUIRED) diff --git a/lib/netplay/byteorder_funcs_wrapper.cpp b/lib/netplay/byteorder_funcs_wrapper.cpp new file mode 100644 index 00000000000..79366a8949b --- /dev/null +++ b/lib/netplay/byteorder_funcs_wrapper.cpp @@ -0,0 +1,50 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 1999-2004 Eidos Interactive + Copyright (C) 2005-2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/byteorder_funcs_wrapper.h" + +#include "lib/framework/wzglobal.h" + +// bring in the original `htonl`/`htons`/`ntohs`/`htohl` functions +#if defined WZ_OS_WIN +# include +#else // *NIX / *BSD variants +# include +#endif + +uint32_t wz_htonl(uint32_t hostlong) +{ + return htonl(hostlong); +} + +uint16_t wz_htons(uint16_t hostshort) +{ + return htons(hostshort); +} + +uint32_t wz_ntohl(uint32_t netlong) +{ + return ntohl(netlong); +} + +uint16_t wz_ntohs(uint16_t netshort) +{ + return ntohs(netshort); +} diff --git a/lib/netplay/byteorder_funcs_wrapper.h b/lib/netplay/byteorder_funcs_wrapper.h new file mode 100644 index 00000000000..c1348fd6505 --- /dev/null +++ b/lib/netplay/byteorder_funcs_wrapper.h @@ -0,0 +1,34 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 1999-2004 Eidos Interactive + Copyright (C) 2005-2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +/// +/// byteorder functions wrappers for WZ just to avoid polluting all places, +/// where these functions are needed, with conditional includes of +/// and winsock headers. +/// + +uint32_t wz_htonl(uint32_t hostlong); +uint16_t wz_htons(uint16_t hostshort); +uint32_t wz_ntohl(uint32_t netlong); +uint16_t wz_ntohs(uint16_t netshort); diff --git a/lib/netplay/client_connection.h b/lib/netplay/client_connection.h new file mode 100644 index 00000000000..986fdc04b4a --- /dev/null +++ b/lib/netplay/client_connection.h @@ -0,0 +1,120 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include + +#include "lib/framework/types.h" // bring in `ssize_t` for MSVC +#include "lib/netplay/net_result.h" + +/// +/// Basic abstraction over client connection sockets. +/// +/// These are capable of reading (`readAll` and `readNoInt`) and +/// writing data (via `writeAll()` + `flush()` combination). +/// +/// The internal implementation may also implement advanced compression mechanisms +/// on top of these connections by providing non-trivial `enableCompression()` overload. +/// +/// In this case, `writeAll()` should somehow accumulate the data into a write queue, +/// compressing the outcoming data on-the-fly; and `flush()` should empty the write queue +/// and actually post a message to the transmission queue, which, in turn, will be emptied +/// by the internal connection interface in a timely manner, when there are enough messages +/// to be sent over the network. +/// +class IClientConnection +{ +public: + + virtual ~IClientConnection() = default; + + /// + /// Read exactly `size` bytes into `buf` buffer. + /// Supports setting a timeout value in milliseconds. + /// + /// Destination buffer to read the data into. + /// The size of data to be read in bytes. + /// Timeout value in milliseconds. + /// On success, returns the number of bytes read; + /// On failure, returns an `std::error_code` (having `GenericSystemErrorCategory` error category) + /// describing the actual error. + virtual net::result readAll(void* buf, size_t size, unsigned timeout) = 0; + /// + /// Reads at most `max_size` bytes into `buf` buffer. + /// Raw count of bytes (after compression) is returned in `rawByteCount`. + /// + /// Destination buffer to read the data into. + /// The maximum number of bytes to read from the client socket. + /// Output parameter: Raw count of bytes (after compression). + /// On success, returns the number of bytes read; + /// On failure, returns an `std::error_code` (having `GenericSystemErrorCategory` error category) + /// describing the actual error. + virtual net::result readNoInt(void* buf, size_t max_size, size_t* rawByteCount) = 0; + /// + /// Nonblocking write of `size` bytes to the socket. The data will be written to a + /// separate write queue in asynchronous manner, possibly by a separate thread. + /// Raw count of bytes (after compression) will be returned in `rawByteCount`, which + /// will often be 0 until the socket is flushed. + /// + /// The reason for this method to be async is that in some cases we want + /// client connections to have compression mechanism enabled. This naturally + /// introduces the 2-phase write process, which involves a write queue (accumulating + /// the data for compression on-the-fly) and a submission (transmission) + /// queue (for transmitting of compressed and assembled messages), + /// which is managed by the network backend implementation. + /// + /// Source buffer to read the data from. + /// The number of bytes to write to the socket. + /// Output parameter: raw count of bytes (after compression) written. + /// The total number of bytes written. + virtual net::result writeAll(const void* buf, size_t size, size_t* rawByteCount) = 0; + /// + /// This method indicates whether the socket has some data ready to be read (i.e. + /// whether the next `readAll/readNoInt` operation will execute without blocking or not). + /// + virtual bool readReady() const = 0; + /// + /// Actually sends the data written with `writeAll()`. Only useful with sockets + /// which have compression enabled. + /// Note that flushing too often makes compression less effective. + /// Raw count of bytes (after compression) is returned in `rawByteCount`. + /// + /// Raw count of bytes (after compression) as written + /// to the submission queue by the flush operation. + virtual void flush(size_t* rawByteCount) = 0; + /// + /// Enables compression for the current socket. + /// + /// This makes all subsequent write operations asynchronous, plus + /// the written data will need to be flushed explicitly at some point. + /// + virtual void enableCompression() = 0; + /// + /// Enables or disables the use of Nagle algorithm for the socket. + /// + /// For direct TCP connections this is equivalent to setting `TCP_NODELAY` to the + /// appropriate value (i.e.: + /// `enable == true` <=> `TCP_NODELAY == false`; + /// `enable == false` <=> `TCP_NODELAY == true`). + /// + virtual void useNagleAlgorithm(bool enable) = 0; + /// + /// Returns textual representation of the socket's connection address. + /// + virtual std::string textAddress() const = 0; +}; diff --git a/lib/netplay/connection_address.cpp b/lib/netplay/connection_address.cpp new file mode 100644 index 00000000000..e4c277b8c50 --- /dev/null +++ b/lib/netplay/connection_address.cpp @@ -0,0 +1,65 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/connection_address.h" +#include "lib/netplay/tcp/netsocket.h" // for `resolveHost` + +#include "lib/framework/frame.h" // for `ASSERT` + +struct ConnectionAddress::Impl final +{ + explicit Impl(SocketAddress* addr) + : mAddr_(addr) + {} + + ~Impl() + { + ASSERT(mAddr_ != nullptr, "Invalid addrinfo stored in the connection address"); + freeaddrinfo(mAddr_); + } + + SocketAddress* mAddr_; +}; + +ConnectionAddress::ConnectionAddress() = default; +ConnectionAddress::ConnectionAddress(ConnectionAddress&&) = default; +ConnectionAddress::~ConnectionAddress() = default; + +const SocketAddress* ConnectionAddress::asRawSocketAddress() const +{ + return mPimpl_->mAddr_; +} + + +net::result ConnectionAddress::parse(const char* hostname, uint16_t port) +{ + ConnectionAddress res; + const auto addr = tcp::resolveHost(hostname, port); + if (!addr.has_value()) + { + return tl::make_unexpected(addr.error()); + } + res.mPimpl_ = std::make_unique(addr.value()); + return net::result{std::move(res)}; +} + +net::result ConnectionAddress::parse(const std::string& hostname, uint16_t port) +{ + return parse(hostname.c_str(), port); +} diff --git a/lib/netplay/connection_address.h b/lib/netplay/connection_address.h new file mode 100644 index 00000000000..604d9dc12da --- /dev/null +++ b/lib/netplay/connection_address.h @@ -0,0 +1,72 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include +#include + +#include "lib/netplay/net_result.h" + +#if defined WZ_OS_UNIX +# include +#elif defined WZ_OS_WIN +# include +#endif + +typedef struct addrinfo SocketAddress; + +/// +/// Opaque class representing abstract connection address to use with various +/// network backend implementations. The internal representation is made +/// hidden on purpose since we don't want to actually leak internal data layout +/// to clients. +/// +/// Instead, we would like to introduce "conversion routines" yielding +/// various representations for convenient consumption with various network +/// backends. +/// +/// NOTE: this class may or may not represent a chain of resolved network addresses +/// instead of just a single one, much like a `addrinfo` structure. +/// +/// Currently, only knows how to convert itself to `addrinfo` struct, +/// which is used with the `TCP_DIRECT` network backend. +/// +/// New conversion routines should be introduced for other network backends, +/// if deemed necessary. +/// +class ConnectionAddress +{ +public: + + ConnectionAddress(); + ConnectionAddress(ConnectionAddress&&); + ConnectionAddress(const ConnectionAddress&) = delete; + ~ConnectionAddress(); + + static net::result parse(const char* hostname, uint16_t port); + static net::result parse(const std::string& hostname, uint16_t port); + + // NOTE: The lifetime of the returned `addrinfo` struct is bounded by the parent object's lifetime! + const SocketAddress* asRawSocketAddress() const; + +private: + + struct Impl; + std::unique_ptr mPimpl_; +}; diff --git a/lib/netplay/connection_poll_group.h b/lib/netplay/connection_poll_group.h new file mode 100644 index 00000000000..a2b66bcb801 --- /dev/null +++ b/lib/netplay/connection_poll_group.h @@ -0,0 +1,42 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +class IClientConnection; + +/// +/// Abstract representation of a poll group comprised of several client connections. +/// +class IConnectionPollGroup +{ +public: + + virtual ~IConnectionPollGroup() = default; + + /// + /// Polls the sockets in the poll group for updates. + /// + /// Timeout value after which the internal implementation should abandon + /// polling the client connections and return. + /// On success, returns the number of connection descriptors in the poll group. + /// On failure, `0` can returned if the timeout expired before any connection descriptors + /// became ready, or `-1` if there was an error during the internal poll operation. + virtual int checkSockets(unsigned timeout) = 0; + + virtual void add(IClientConnection* conn) = 0; + virtual void remove(IClientConnection* conn) = 0; +}; diff --git a/lib/netplay/connection_provider_registry.cpp b/lib/netplay/connection_provider_registry.cpp new file mode 100644 index 00000000000..b68e7c8c331 --- /dev/null +++ b/lib/netplay/connection_provider_registry.cpp @@ -0,0 +1,57 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include + +#include "lib/netplay/connection_provider_registry.h" +#include "lib/netplay/tcp/tcp_connection_provider.h" + +ConnectionProviderRegistry& ConnectionProviderRegistry::Instance() +{ + static ConnectionProviderRegistry instance; + return instance; +} + +WzConnectionProvider& ConnectionProviderRegistry::Get(ConnectionProviderType pt) +{ + const auto it = registeredProviders_.find(pt); + if (it == registeredProviders_.end()) + { + throw std::runtime_error("Attempt to get nonexistent connection provider"); + } + return *it->second; +} + +void ConnectionProviderRegistry::Register(ConnectionProviderType pt) +{ + // No-op in case this provider has been already registered. + switch (pt) + { + case ConnectionProviderType::TCP_DIRECT: + registeredProviders_.emplace(pt, std::make_unique()); + break; + default: + throw std::runtime_error("Unknown connection provider type"); + } +} + +void ConnectionProviderRegistry::Deregister(ConnectionProviderType pt) +{ + registeredProviders_.erase(pt); +} diff --git a/lib/netplay/connection_provider_registry.h b/lib/netplay/connection_provider_registry.h new file mode 100644 index 00000000000..2a760b47536 --- /dev/null +++ b/lib/netplay/connection_provider_registry.h @@ -0,0 +1,54 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include + +#include "lib/netplay/wz_connection_provider.h" + +/// +/// Available types of connection providers (i.e. network backend implementations). +/// +enum class ConnectionProviderType +{ + TCP_DIRECT +}; + +/// +/// Global singleton registry containing available network connection providers. +/// +class ConnectionProviderRegistry +{ +public: + + static ConnectionProviderRegistry& Instance(); + + WzConnectionProvider& Get(ConnectionProviderType pt); + + void Register(ConnectionProviderType pt); + void Deregister(ConnectionProviderType pt); + +private: + + ConnectionProviderRegistry() = default; + + std::unordered_map> registeredProviders_; +}; diff --git a/lib/netplay/listen_socket.h b/lib/netplay/listen_socket.h new file mode 100644 index 00000000000..04cfdd9fc9e --- /dev/null +++ b/lib/netplay/listen_socket.h @@ -0,0 +1,45 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include + +class IClientConnection; + +/// +/// Server-side listen socket abstraction. +/// +class IListenSocket +{ +public: + + virtual ~IListenSocket() = default; + + enum class IPVersions : uint8_t + { + IPV4 = 0b00000001, + IPV6 = 0b00000010 + }; + using IPVersionsMask = std::underlying_type_t; + + /// + /// Accept an incoming client connection on the current server-side listen socket. + /// + virtual IClientConnection* accept() = 0; + virtual IPVersionsMask supportedIpVersions() const = 0; +}; diff --git a/lib/netplay/net_result.h b/lib/netplay/net_result.h new file mode 100644 index 00000000000..be558bcc517 --- /dev/null +++ b/lib/netplay/net_result.h @@ -0,0 +1,32 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include + +namespace net +{ + +template +using result = ::tl::expected; + +} // namespace net diff --git a/lib/netplay/netplay.cpp b/lib/netplay/netplay.cpp index 2a74a231660..294ac0cfb26 100644 --- a/lib/netplay/netplay.cpp +++ b/lib/netplay/netplay.cpp @@ -48,7 +48,11 @@ #include "netplay.h" #include "netlog.h" #include "netreplay.h" -#include "netsocket.h" +#include "lib/netplay/byteorder_funcs_wrapper.h" +#include "lib/netplay/client_connection.h" +#include "lib/netplay/listen_socket.h" +#include "lib/netplay/connection_poll_group.h" +#include "lib/netplay/connection_provider_registry.h" #include "netpermissions.h" #include "sync_debug.h" #include "port_mapping_manager.h" @@ -71,6 +75,12 @@ # include "lib/framework/cocoa_wrapper.h" #endif +#ifndef WZ_OS_WIN +static const int SOCKET_ERROR = -1; +#else +# include // SOCKET_ERROR +#endif + // WARNING !!! This is initialised via configuration.c !!! char masterserver_name[255] = {'\0'}; static unsigned int masterserver_port = 0, gameserver_port = 0; @@ -161,8 +171,8 @@ class LobbyServerConnectionHandler Connected }; LobbyConnectionState currentState = LobbyConnectionState::Disconnected; - Socket *rs_socket = nullptr; - SocketSet* waitingForConnectionFinalize = nullptr; + IClientConnection* rs_socket = nullptr; + IConnectionPollGroup* waitingForConnectionFinalize = nullptr; uint32_t lastConnectionTime = 0; uint32_t lastServerUpdate = 0; bool queuedServerUpdate = false; @@ -232,15 +242,15 @@ bool netPlayersUpdated; // Server-side socket (host-only) which is used to listen for client connections. // There's also `rs_socket` held by `LobbyServerConnectionHandler`, which is used to communicate with the lobby server. -static Socket* server_listen_socket = nullptr; +static IListenSocket* server_listen_socket = nullptr; -static Socket *bsocket = nullptr; ///< Socket used to talk to the host (clients only). If bsocket != NULL, then client_transient_socket == NULL. -static Socket *connected_bsocket[MAX_CONNECTED_PLAYERS] = { nullptr }; ///< Sockets used to talk to clients (host only). +static IClientConnection* bsocket = nullptr; ///< Socket used to talk to the host (clients only). If bsocket != NULL, then client_transient_socket == NULL. +static IClientConnection* connected_bsocket[MAX_CONNECTED_PLAYERS] = { nullptr }; ///< Sockets used to talk to clients (host only). // Client-side socket set. Contains of only 1 socket at most: `bsocket` (which is a stable client connection to the host). -static SocketSet* client_socket_set = nullptr; +static IConnectionPollGroup* client_socket_set = nullptr; // Server-side socket set. Contains up to `MAX_CONNECTED_PLAYERS` sockets: // `connected_bsocket[i]` - sockets used to communicate with clients during a game session. -static SocketSet* server_socket_set = nullptr; +static IConnectionPollGroup* server_socket_set = nullptr; /** * Used for connections with clients. @@ -302,13 +312,13 @@ struct TmpSocketInfo } }; -static Socket *tmp_socket[MAX_TMP_SOCKETS] = { nullptr }; ///< Sockets used to talk to clients which have not yet been assigned a player number (host only). +static IClientConnection* tmp_socket[MAX_TMP_SOCKETS] = { nullptr }; ///< Sockets used to talk to clients which have not yet been assigned a player number (host only). static std::array tmp_connectState; static bool bAsyncJoinApprovalEnabled = false; static std::unordered_map tmp_pendingIPs; static lru11::Cache tmp_badIPs(512, 64); -static SocketSet *tmp_socket_set = nullptr; +static IConnectionPollGroup* tmp_socket_set = nullptr; static int32_t NetGameFlags[4] = { 0, 0, 0, 0 }; char iptoconnect[PATH_MAX] = "\0"; // holds IP/hostname from command line bool cliConnectToIpAsSpectator = false; // for cli option @@ -540,17 +550,17 @@ bool NETsetAsyncJoinApprovalResult(const std::string& uniqueJoinID, AsyncJoinApp // *********** Socket with buffer that read NETMSGs ****************** -static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *bufstart, int bufsize) +static size_t NET_fillBuffer(IClientConnection** pSocket, IConnectionPollGroup* pSocketSet, uint8_t *bufstart, int bufsize) { - Socket *socket = *pSocket; + IClientConnection* socket = *pSocket; - if (!socketReadReady(*socket)) + if (!socket->readReady()) { return 0; } size_t rawBytes; - const auto readResult = readNoInt(*socket, bufstart, bufsize, &rawBytes); + const auto readResult = socket->readNoInt(bufstart, bufsize, &rawBytes); if (readResult.has_value()) { @@ -578,7 +588,7 @@ static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *b // an error occurred, or the remote host has closed the connection. if (pSocketSet != nullptr) { - SocketSet_DelSocket(*pSocketSet, socket); + pSocketSet->remove(socket); } if (bsocket == socket) { @@ -592,7 +602,7 @@ static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *b NETclose(); return 0; } - socketClose(socket); + delete socket; *pSocket = nullptr; } @@ -974,8 +984,8 @@ static void NETplayerCloseSocket(UDWORD index, bool quietSocketClose) NETlogEntry("Player has left nicely.", SYNC_FLAG, index); // Although we can get a error result from DelSocket, it don't really matter here. - SocketSet_DelSocket(*server_socket_set, connected_bsocket[index]); - socketClose(connected_bsocket[index]); + server_socket_set->remove(connected_bsocket[index]); + delete connected_bsocket[index]; connected_bsocket[index] = nullptr; } else @@ -1221,7 +1231,7 @@ static constexpr size_t GAMESTRUCTmessageBufSize() * * @see GAMESTRUCT,NETrecvGAMESTRUCT */ -static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourgamestruct) +static net::result NETsendGAMESTRUCT(IClientConnection* sock, const GAMESTRUCT *ourgamestruct) { // A buffer that's guaranteed to have the correct size (i.e. it // circumvents struct padding, which could pose a problem). Initialise @@ -1232,13 +1242,13 @@ static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourga unsigned int i; auto push32 = [&](uint32_t value) { - uint32_t swapped = htonl(value); + uint32_t swapped = wz_htonl(value); memcpy(buffer, &swapped, sizeof(swapped)); buffer += sizeof(swapped); }; auto push16 = [&](uint16_t value) { - uint16_t swapped = htons(value); + uint16_t swapped = wz_htons(value); memcpy(buffer, &swapped, sizeof(swapped)); buffer += sizeof(swapped); }; @@ -1327,7 +1337,7 @@ static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourga debug(LOG_NET, "sending GAMESTRUCT, size: %u", (unsigned int)sizeof(buf)); // Send over the GAMESTRUCT - const auto writeResult = writeAll(*sock, buf, sizeof(buf)); + const auto writeResult = sock->writeAll(buf, sizeof(buf), nullptr); if (!writeResult.has_value()) { const auto writeErrMsg = writeResult.error().message(); @@ -1347,7 +1357,7 @@ static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourga * * @see GAMESTRUCT,NETsendGAMESTRUCT */ -static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) +static bool NETrecvGAMESTRUCT(IClientConnection& sock, GAMESTRUCT *ourgamestruct) { // A buffer that's guaranteed to have the correct size (i.e. it // circumvents struct padding, which could pose a problem). @@ -1358,7 +1368,7 @@ static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) auto pop32 = [&]() -> uint32_t { uint32_t value = 0; memcpy(&value, buffer, sizeof(value)); - value = ntohl(value); + value = wz_ntohl(value); buffer += sizeof(value); return value; }; @@ -1366,13 +1376,13 @@ static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) auto pop16 = [&]() -> uint16_t { uint16_t value = 0; memcpy(&value, buffer, sizeof(value)); - value = ntohs(value); + value = wz_ntohs(value); buffer += sizeof(value); return value; }; // Read a GAMESTRUCT from the connection - auto readResult = readAll(sock, buf, sizeof(buf), NET_TIMEOUT_DELAY); + auto readResult = sock.readAll(buf, sizeof(buf), NET_TIMEOUT_DELAY); if (!readResult.has_value()) { if (readResult.error() == std::errc::timed_out || readResult.error() == std::errc::connection_reset) @@ -1532,7 +1542,8 @@ int NETinit(bool bFirstCall) NETlogEntry("NETinit!", SYNC_FLAG, selectedPlayer); NET_InitPlayers(true, true); - SOCKETinit(); + ConnectionProviderRegistry::Instance().Register(ConnectionProviderType::TCP_DIRECT); + ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT).initialize(); if (bFirstCall) { @@ -1580,7 +1591,8 @@ int NETshutdown() } NetPlay.MOTD = nullptr; NETdeleteQueue(); - SOCKETshutdown(); + ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT).shutdown(); + ConnectionProviderRegistry::Instance().Deregister(ConnectionProviderType::TCP_DIRECT); // Reset net usage statistics. nStats = nZeroStats; @@ -1611,7 +1623,7 @@ int NETclose() if (connected_bsocket[i]) { debug(LOG_NET, "Closing connected_bsocket[%u], %p", i, static_cast(connected_bsocket[i])); - socketClose(connected_bsocket[i]); + delete connected_bsocket[i]; connected_bsocket[i] = nullptr; } NET_DestroyPlayer(i, true); @@ -1620,7 +1632,7 @@ int NETclose() if (tmp_socket_set) { debug(LOG_NET, "Freeing tmp_socket_set %p", static_cast(tmp_socket_set)); - deleteSocketSet(tmp_socket_set); + delete tmp_socket_set; tmp_socket_set = nullptr; } @@ -1630,7 +1642,7 @@ int NETclose() { // FIXME: need SocketSet_DelSocket() as well, socket_set or tmp_socket_set? debug(LOG_NET, "Closing tmp_socket[%d] %p", i, static_cast(tmp_socket[i])); - socketClose(tmp_socket[i]); + delete tmp_socket[i]; tmp_socket[i] = nullptr; } } @@ -1639,28 +1651,28 @@ int NETclose() { if (bsocket) { - SocketSet_DelSocket(*client_socket_set, bsocket); + client_socket_set->remove(bsocket); } debug(LOG_NET, "Freeing socket_set %p", static_cast(client_socket_set)); - deleteSocketSet(client_socket_set); + delete client_socket_set; client_socket_set = nullptr; } else if (server_socket_set) { debug(LOG_NET, "Freeing socket_set %p", static_cast(server_socket_set)); - deleteSocketSet(server_socket_set); + delete server_socket_set; server_socket_set = nullptr; } if (server_listen_socket) { debug(LOG_NET, "Closing server_listen_socket %p", static_cast(server_listen_socket)); - socketClose(server_listen_socket); + delete server_listen_socket; server_listen_socket = nullptr; } if (bsocket) { debug(LOG_NET, "Closing bsocket %p", static_cast(bsocket)); - socketClose(bsocket); + delete bsocket; bsocket = nullptr; } @@ -1740,7 +1752,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) return true; } - Socket **sockets = connected_bsocket; + IClientConnection** sockets = connected_bsocket; bool isTmpQueue = false; switch (queue.queueType) { @@ -1777,7 +1789,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) } ssize_t rawLen = message->rawLen(); size_t compressedRawLen; - const auto writeResult = writeAll(*sockets[player], rawData, rawLen, &compressedRawLen); + const auto writeResult = sockets[player]->writeAll(rawData, rawLen, &compressedRawLen); const auto res = writeResult.value_or(SOCKET_ERROR); delete[] rawData; // Done with the data. @@ -1810,7 +1822,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) uint8_t *rawData = message->rawDataDup(); ssize_t rawLen = message->rawLen(); size_t compressedRawLen; - const auto writeResult = writeAll(*bsocket, rawData, rawLen, &compressedRawLen); + const auto writeResult = bsocket->writeAll(rawData, rawLen, &compressedRawLen); const auto res = writeResult.value_or(SOCKET_ERROR); delete[] rawData; // Done with the data. @@ -1827,8 +1839,8 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) debug(LOG_ERROR, "Failed to send message: %s", writeErrMsg.c_str()); debug(LOG_ERROR, "Host connection was broken, socket %p.", static_cast(bsocket)); NETlogEntry("write error--client disconnect.", SYNC_FLAG, player); - SocketSet_DelSocket(*client_socket_set, bsocket); // mark it invalid - socketClose(bsocket); + client_socket_set->remove(bsocket); // mark it invalid + delete bsocket; bsocket = nullptr; NetPlay.players[NetPlay.hostPlayer].heartbeat = false; // mark host as dead //Game is pretty much over --should just end everything when HOST dies. @@ -1869,7 +1881,7 @@ void NETflush() // We are the host, send directly to player. if (connected_bsocket[player] != nullptr) { - socketFlush(*connected_bsocket[player], player, &compressedRawLen); + connected_bsocket[player]->flush(&compressedRawLen); nStats.rawBytes.sent += compressedRawLen; } } @@ -1878,7 +1890,7 @@ void NETflush() // We are the host, send directly to player. if (tmp_socket[player] != nullptr) { - socketFlush(*tmp_socket[player], std::numeric_limits::max(), &compressedRawLen); + tmp_socket[player]->flush(&compressedRawLen); nStats.rawBytes.sent += compressedRawLen; } } @@ -1887,7 +1899,7 @@ void NETflush() { if (bsocket != nullptr) { - socketFlush(*bsocket, NetPlay.hostPlayer, &compressedRawLen); + bsocket->flush(&compressedRawLen); nStats.rawBytes.sent += compressedRawLen; } } @@ -2852,15 +2864,15 @@ bool NETrecvNet(NETQUEUE *queue, uint8_t *type) NETcheckPlayers(); // make sure players are still alive & well } - SocketSet* sset = NetPlay.isHost ? server_socket_set : client_socket_set; - if (sset == nullptr || checkSockets(*sset, NET_READ_TIMEOUT) <= 0) + IConnectionPollGroup* pollGroup = NetPlay.isHost ? server_socket_set : client_socket_set; + if (pollGroup == nullptr || pollGroup->checkSockets(NET_READ_TIMEOUT) <= 0) { goto checkMessages; } for (current = 0; current < MAX_CONNECTED_PLAYERS; ++current) { - Socket **pSocket = NetPlay.isHost ? &connected_bsocket[current] : &bsocket; + IClientConnection** pSocket = NetPlay.isHost ? &connected_bsocket[current] : &bsocket; uint8_t buffer[NET_BUFFER_SIZE]; size_t dataLen; @@ -2874,7 +2886,7 @@ bool NETrecvNet(NETQUEUE *queue, uint8_t *type) continue; } - dataLen = NET_fillBuffer(pSocket, sset, buffer, sizeof(buffer)); + dataLen = NET_fillBuffer(pSocket, pollGroup, buffer, sizeof(buffer)); if (dataLen > 0) { // we received some data, add to buffer @@ -3254,7 +3266,7 @@ unsigned NETgetDownloadProgress(unsigned player) return static_cast(progress); } -static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) +static ssize_t readLobbyResponse(IClientConnection& sock, unsigned int timeout) { uint32_t lobbyStatusCode; uint32_t MOTDLength; @@ -3262,14 +3274,14 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) ssize_t received = 0; // Get status and message length - auto readResult = readAll(sock, &buffer, sizeof(buffer), timeout); + auto readResult = sock.readAll(&buffer, sizeof(buffer), timeout); if (!readResult.has_value()) { goto error; } received += readResult.value(); - lobbyStatusCode = ntohl(buffer[0]); - MOTDLength = ntohl(buffer[1]); + lobbyStatusCode = wz_ntohl(buffer[0]); + MOTDLength = wz_ntohl(buffer[1]); // Get status message if (NetPlay.MOTD) @@ -3277,7 +3289,7 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) free(NetPlay.MOTD); } NetPlay.MOTD = (char *)malloc(MOTDLength + 1); - readResult = readAll(sock, NetPlay.MOTD, MOTDLength, timeout); + readResult = sock.readAll(NetPlay.MOTD, MOTDLength, timeout); if (!readResult.has_value()) { goto error; @@ -3332,15 +3344,15 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) return SOCKET_ERROR; } -bool readGameStructsList(Socket& sock, unsigned int timeout, const std::function& handleEnumerateGameFunc) +bool readGameStructsList(IClientConnection& sock, unsigned int timeout, const std::function& handleEnumerateGameFunc) { unsigned int gamecount = 0; uint32_t gamesavailable = 0; - const auto readResult = readAll(sock, &gamesavailable, sizeof(gamesavailable), NET_TIMEOUT_DELAY); + const auto readResult = sock.readAll(&gamesavailable, sizeof(gamesavailable), NET_TIMEOUT_DELAY); if (readResult.has_value()) { - gamesavailable = ntohl(gamesavailable); + gamesavailable = wz_ntohl(gamesavailable); } else { @@ -3372,7 +3384,8 @@ bool readGameStructsList(Socket& sock, unsigned int timeout, const std::function if (tmpGame.desc.host[0] == '\0') { memset(tmpGame.desc.host, 0, sizeof(tmpGame.desc.host)); - strncpy(tmpGame.desc.host, getSocketTextAddress(sock), sizeof(tmpGame.desc.host) - 1); + const auto textAddr = sock.textAddress(); + strncpy(tmpGame.desc.host, textAddr.data(), sizeof(tmpGame.desc.host) - 1); } uint32_t Vmgr = (tmpGame.future4 & 0xFFFF0000) >> 16; @@ -3424,12 +3437,13 @@ bool LobbyServerConnectionHandler::connect() return false; // already connecting or connected } + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + bool bProcessingConnectOrDisconnectThisCall = true; uint32_t gameId = 0; - const auto hostsResult = resolveHost(masterserver_name, masterserver_port); - const auto hosts = hostsResult.value_or(nullptr); + const auto hostsResult = connProvider.resolveHost(masterserver_name, masterserver_port); - if (hosts == nullptr) + if (!hostsResult.has_value()) { const auto hostsErrMsg = hostsResult.error().message(); debug(LOG_ERROR, "Cannot resolve masterserver \"%s\": %s", masterserver_name, hostsErrMsg.c_str()); @@ -3443,16 +3457,17 @@ bool LobbyServerConnectionHandler::connect() return bProcessingConnectOrDisconnectThisCall; } + const auto& hosts = hostsResult.value(); + // Close an existing socket. if (rs_socket != nullptr) { - socketClose(rs_socket); + delete rs_socket; rs_socket = nullptr; } // try each address from resolveHost until we successfully connect. - auto sockResult = socketOpenAny(hosts, 1500); - deleteSocketAddress(hosts); + auto sockResult = connProvider.openClientConnectionAny(hosts, 1500); rs_socket = sockResult.value_or(nullptr); @@ -3472,10 +3487,10 @@ bool LobbyServerConnectionHandler::connect() } // Get a game ID - auto gameIdResult = writeAll(*rs_socket, "gaId", sizeof("gaId")); + auto gameIdResult = rs_socket->writeAll("gaId", sizeof("gaId"), nullptr); if (gameIdResult.has_value()) { - gameIdResult = readAll(*rs_socket, &gameId, sizeof(gameId), 10000); + gameIdResult = rs_socket->readAll(&gameId, sizeof(gameId), 10000); } if (!gameIdResult.has_value()) { @@ -3495,13 +3510,13 @@ bool LobbyServerConnectionHandler::connect() return bProcessingConnectOrDisconnectThisCall; } - gamestruct.gameId = ntohl(gameId); + gamestruct.gameId = wz_ntohl(gameId); debug(LOG_NET, "Using game ID: %u", (unsigned int)gamestruct.gameId); wz_command_interface_output("WZEVENT: lobbyid: %" PRIu32 "\n", gamestruct.gameId); // Register our game with the server - const auto writeAddGameRes = writeAll(*rs_socket, "addg", sizeof("addg")); + const auto writeAddGameRes = rs_socket->writeAll("addg", sizeof("addg"), nullptr); auto sendGamestructRes = ignoreExpectedResultValue(writeAddGameRes); if (sendGamestructRes.has_value()) @@ -3521,8 +3536,8 @@ bool LobbyServerConnectionHandler::connect() queuedServerUpdate = false; lastConnectionTime = realTime; - waitingForConnectionFinalize = allocSocketSet(); - SocketSet_AddSocket(*waitingForConnectionFinalize, rs_socket); + waitingForConnectionFinalize = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT).newConnectionPollGroup(); + waitingForConnectionFinalize->add(rs_socket); currentState = LobbyConnectionState::Connecting_WaitingForResponse; return bProcessingConnectOrDisconnectThisCall; @@ -3538,7 +3553,7 @@ bool LobbyServerConnectionHandler::disconnect() if (rs_socket != nullptr) { // we don't need this anymore, so clean up - socketClose(rs_socket); + delete rs_socket; rs_socket = nullptr; server_not_there = true; } @@ -3602,7 +3617,7 @@ void LobbyServerConnectionHandler::sendUpdateNow() void LobbyServerConnectionHandler::sendKeepAlive() { ASSERT_OR_RETURN(, rs_socket != nullptr, "Null socket"); - if (!writeAll(*rs_socket, "keep", sizeof("keep")).has_value()) + if (!rs_socket->writeAll("keep", sizeof("keep"), nullptr).has_value()) { // The socket has been invalidated, so get rid of it. (using them now may cause SIGPIPE). disconnect(); @@ -3625,21 +3640,21 @@ void LobbyServerConnectionHandler::run() bool exceededTimeout = (realTime - lastConnectionTime >= 10000); // We use readLobbyResponse to display error messages and handle state changes if there's no response // So if exceededTimeout, just call it with a low timeout - int checkSocketRet = checkSockets(*waitingForConnectionFinalize, NET_READ_TIMEOUT); + int checkSocketRet = waitingForConnectionFinalize->checkSockets(NET_READ_TIMEOUT); if (checkSocketRet == SOCKET_ERROR) { debug(LOG_ERROR, "Lost connection to lobby server"); disconnect(); break; } - if (exceededTimeout || (checkSocketRet > 0 && socketReadReady(*rs_socket))) + if (exceededTimeout || (checkSocketRet > 0 && rs_socket->readReady())) { if (readLobbyResponse(*rs_socket, NET_TIMEOUT_DELAY) == SOCKET_ERROR) { disconnect(); break; } - deleteSocketSet(waitingForConnectionFinalize); + delete waitingForConnectionFinalize; waitingForConnectionFinalize = nullptr; currentState = LobbyConnectionState::Connected; } @@ -3774,9 +3789,9 @@ static bool quickRejectConnection(const std::string& ip) static void NETcloseTempSocket(unsigned int i) { - std::string rIP = getSocketTextAddress(*tmp_socket[i]); - SocketSet_DelSocket(*tmp_socket_set, tmp_socket[i]); - socketClose(tmp_socket[i]); + std::string rIP = tmp_socket[i]->textAddress(); + tmp_socket_set->remove(tmp_socket[i]); + delete tmp_socket[i]; tmp_socket[i] = nullptr; tmp_connectState[i].reset(); auto it = tmp_pendingIPs.find(rIP); @@ -3795,14 +3810,14 @@ static void NETcloseTempSocket(unsigned int i) static void NEThostPromoteTempSocketToPermanentPlayerConnection(unsigned int tempSocketIdx, uint8_t index) { - std::string rIP = getSocketTextAddress(*tmp_socket[tempSocketIdx]); + std::string rIP = tmp_socket[tempSocketIdx]->textAddress(); debug(LOG_NET, "freeing temp socket %p (%d), creating permanent socket.", static_cast(tmp_socket[tempSocketIdx]), __LINE__); - SocketSet_DelSocket(*tmp_socket_set, tmp_socket[tempSocketIdx]); + tmp_socket_set->remove(tmp_socket[tempSocketIdx]); connected_bsocket[index] = tmp_socket[tempSocketIdx]; tmp_socket[tempSocketIdx] = nullptr; NET_waitingForIndexChangeAckSince[index] = nullopt; - SocketSet_AddSocket(*server_socket_set, connected_bsocket[index]); + server_socket_set->add(connected_bsocket[index]); NETmoveQueue(NETnetTmpQueue(tempSocketIdx), NETnetQueue(index)); // Copy player's IP address @@ -3846,12 +3861,13 @@ static void NETallowJoining() ActivitySink::ListeningInterfaces listeningInterfaces; if (server_listen_socket != nullptr) { - listeningInterfaces.IPv4 = socketHasIPv4(*server_listen_socket); + const auto supportedProtocols = server_listen_socket->supportedIpVersions(); + listeningInterfaces.IPv4 = supportedProtocols & static_cast(IListenSocket::IPVersions::IPV4); if (listeningInterfaces.IPv4) { listeningInterfaces.ipv4_port = NETgetGameserverPort(); } - listeningInterfaces.IPv6 = socketHasIPv6(*server_listen_socket); + listeningInterfaces.IPv6 = supportedProtocols & static_cast(IListenSocket::IPVersions::IPV6); if (listeningInterfaces.IPv6) { listeningInterfaces.ipv6_port = NETgetGameserverPort(); @@ -3875,7 +3891,7 @@ static void NETallowJoining() } ASSERT(tmp_socket_set != nullptr, "Null tmp_socket_set"); - if (checkSockets(*tmp_socket_set, NET_READ_TIMEOUT) > 0) + if (tmp_socket_set->checkSockets(NET_READ_TIMEOUT) > 0) { for (i = 0; i < MAX_TMP_SOCKETS; ++i) { @@ -3884,7 +3900,7 @@ static void NETallowJoining() continue; } - if (!socketReadReady(*tmp_socket[i])) + if (!tmp_socket[i]->readReady()) { continue; } @@ -3893,7 +3909,7 @@ static void NETallowJoining() { char *p_buffer = tmp_connectState[i].buffer; - const auto sizeReadResult = readNoInt(*tmp_socket[i], p_buffer + tmp_connectState[i].usedBuffer, 8 - tmp_connectState[i].usedBuffer); + const auto sizeReadResult = tmp_socket[i]->readNoInt(p_buffer + tmp_connectState[i].usedBuffer, 8 - tmp_connectState[i].usedBuffer, nullptr); if (sizeReadResult.has_value()) { tmp_connectState[i].usedBuffer += sizeReadResult.value(); @@ -3914,10 +3930,10 @@ static void NETallowJoining() // Check these numbers with our own. memcpy(&major, p_buffer, sizeof(uint32_t)); - major = ntohl(major); + major = wz_ntohl(major); p_buffer += sizeof(int32_t); memcpy(&minor, p_buffer, sizeof(uint32_t)); - minor = ntohl(minor); + minor = wz_ntohl(minor); if (major == 0 && minor == 0) { @@ -3927,7 +3943,7 @@ static void NETallowJoining() char buf[(sizeof(char) * 4) + sizeof(uint32_t) + sizeof(uint32_t)] = { 0 }; char *pLobbyRespBuffer = buf; auto push32 = [&pLobbyRespBuffer](uint32_t value) { - uint32_t swapped = htonl(value); + uint32_t swapped = wz_htonl(value); memcpy(pLobbyRespBuffer, &swapped, sizeof(swapped)); pLobbyRespBuffer += sizeof(swapped); }; @@ -3943,15 +3959,15 @@ static void NETallowJoining() // Copy gameId (as 32bit large big endian number) push32(gamestruct.gameId); - writeAll(*tmp_socket[i], buf, sizeof(buf)); + tmp_socket[i]->writeAll(buf, sizeof(buf), nullptr); connectFailed = true; } else if (NETisCorrectVersion(major, minor)) { - result = htonl(ERROR_NOERROR); + result = wz_htonl(ERROR_NOERROR); memcpy(&tmp_connectState[i].buffer, &result, sizeof(result)); - writeAll(*tmp_socket[i], &tmp_connectState[i].buffer, sizeof(result)); - socketBeginCompression(*tmp_socket[i]); + tmp_socket[i]->writeAll(&tmp_connectState[i].buffer, sizeof(result), nullptr); + tmp_socket[i]->enableCompression(); // Connection is successful. connectFailed = false; @@ -3966,9 +3982,9 @@ static void NETallowJoining() else { debug(LOG_INFO, "Received an invalid version \"%" PRIu32 ".%" PRIu32 "\".", major, minor); - result = htonl(ERROR_WRONGVERSION); + result = wz_htonl(ERROR_WRONGVERSION); memcpy(&tmp_connectState[i].buffer, &result, sizeof(result)); - writeAll(*tmp_socket[i], &tmp_connectState[i].buffer, sizeof(result)); + tmp_socket[i]->writeAll(&tmp_connectState[i].buffer, sizeof(result), nullptr); NETlogEntry("Invalid game version", SYNC_FLAG, i); NETaddSessionBanBadIP(tmp_connectState[i].ip); connectFailed = true; @@ -4008,7 +4024,7 @@ static void NETallowJoining() else if (tmp_connectState[i].connectState == TmpSocketInfo::TmpConnectState::PendingJoinRequest) { uint8_t buffer[NET_BUFFER_SIZE]; - const auto readResult = readNoInt(*tmp_socket[i], buffer, sizeof(buffer)); + const auto readResult = tmp_socket[i]->readNoInt(buffer, sizeof(buffer), nullptr); uint8_t rejected = 0; if (!readResult.has_value()) @@ -4406,7 +4422,7 @@ static void NETallowJoining() NETpop(tmpQueue); } - std::string rIP = getSocketTextAddress(*tmp_socket[i]); + std::string rIP = tmp_socket[i]->textAddress(); NETaddSessionBanBadIP(rIP); NETcloseTempSocket(i); @@ -4511,14 +4527,16 @@ bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectator return true; } + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + // Start listening for client connections on `gameserver_port`. // These will initially be assigned to `tmp_socket[i]` until accepted in the game session, // in which case `tmp_socket[i]` will be assigned to `connected_bsocket[i]` and `tmp_socket[i]` // will become nullptr. - net::result serverListenResult = {}; + net::result serverListenResult = {}; if (!server_listen_socket) { - serverListenResult = socketListen(gameserver_port); + serverListenResult = connProvider.openListenSocket(gameserver_port); server_listen_socket = serverListenResult.value_or(nullptr); } if (server_listen_socket == nullptr) @@ -4531,7 +4549,7 @@ bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectator // Host needs to create a socket set for MAX_PLAYERS if (!server_socket_set) { - server_socket_set = allocSocketSet(); + server_socket_set = connProvider.newConnectionPollGroup(); } // allocate socket storage for all possible players for (unsigned i = 0; i < MAX_CONNECTED_PLAYERS; ++i) @@ -4636,19 +4654,17 @@ bool NETenumerateGames(const std::function& handl debug(LOG_ERROR, "Likely missing NETinit(true) - this won't return any results"); return false; } - const auto hostsResult = resolveHost(masterserver_name, masterserver_port); - SocketAddress* hosts = hostsResult.value_or(nullptr); - if (!hosts) + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + const auto hostsResult = connProvider.resolveHost(masterserver_name, masterserver_port); + if (!hostsResult.has_value()) { const auto hostsErrMsg = hostsResult.error().message(); debug(LOG_ERROR, "Cannot resolve hostname \"%s\": %s", masterserver_name, hostsErrMsg.c_str()); setLobbyError(ERROR_CONNECTION); return false; } - - auto sockResult = socketOpenAny(hosts, 15000); - deleteSocketAddress(hosts); - hosts = nullptr; + const auto& hosts = hostsResult.value(); + auto sockResult = connProvider.openClientConnectionAny(hosts, 15000); if (!sockResult.has_value()) { const auto sockErrMsg = sockResult.error().message(); @@ -4656,18 +4672,18 @@ bool NETenumerateGames(const std::function& handl setLobbyError(ERROR_CONNECTION); return false; } - Socket* sock = sockResult.value(); + IClientConnection* sock = sockResult.value(); debug(LOG_NET, "New socket = %p", static_cast(sock)); debug(LOG_NET, "Sending list cmd"); - const auto writeResult = writeAll(*sock, "list", sizeof("list")); + const auto writeResult = sock->writeAll("list", sizeof("list"), nullptr); if (!writeResult.has_value()) { const auto writeErrMsg = writeResult.error().message(); debug(LOG_NET, "Server socket encountered error: %s", writeErrMsg.c_str()); // mark it invalid - socketClose(sock); + delete sock; // when we fail to receive a game count, bail out setLobbyError(ERROR_CONNECTION); @@ -4684,7 +4700,7 @@ bool NETenumerateGames(const std::function& handl })) { // mark it invalid - socketClose(sock); + delete sock; setLobbyError(ERROR_CONNECTION); return false; @@ -4694,7 +4710,7 @@ bool NETenumerateGames(const std::function& handl if (readLobbyResponse(*sock, NET_TIMEOUT_DELAY) == SOCKET_ERROR) { // mark it invalid - socketClose(sock); + delete sock; addConsoleMessage(_("Failed to get a lobby response!"), DEFAULT_JUSTIFY, NOTIFY_MESSAGE); // treat as fatal error @@ -4708,10 +4724,10 @@ bool NETenumerateGames(const std::function& handl // Hence as long as we don't treat "0" as signifying any change in behavior, this should be safe + backwards-compatible #define IGNORE_FIRST_BATCH 1 uint32_t responseParameters = 0; - const auto readResult = readAll(*sock, &responseParameters, sizeof(responseParameters), NET_TIMEOUT_DELAY); + const auto readResult = sock->readAll(&responseParameters, sizeof(responseParameters), NET_TIMEOUT_DELAY); if (readResult.has_value()) { - responseParameters = ntohl(responseParameters); + responseParameters = wz_ntohl(responseParameters); bool requestSecondBatch = true; bool ignoreFirstBatch = ((responseParameters & IGNORE_FIRST_BATCH) == IGNORE_FIRST_BATCH); @@ -4746,7 +4762,7 @@ bool NETenumerateGames(const std::function& handl debug(LOG_NET, "Second readGameStructsList call failed"); // mark it invalid - socketClose(sock); + delete sock; // when we fail to receive a game count, bail out setLobbyError(ERROR_CONNECTION); @@ -4769,7 +4785,7 @@ bool NETenumerateGames(const std::function& handl } // mark it invalid (we are done with it) - socketClose(sock); + delete sock; return true; } @@ -4825,7 +4841,7 @@ bool NETfindGame(uint32_t gameId, GAMESTRUCT& output) } // "consumes" the sockets and related info -bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char *playername, NETQUEUE joiningQUEUEInfo, Socket **client_joining_socket, SocketSet **client_joining_socket_set) +bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char *playername, NETQUEUE joiningQUEUEInfo, IClientConnection** client_joining_socket, IConnectionPollGroup** client_joining_socket_set) { if (hostPlayer >= MAX_CONNECTED_PLAYERS) { @@ -5130,7 +5146,8 @@ void NETacceptIncomingConnections() { // initialize temporary server socket set // FIXME: why is this not done in NETinit()?? - Per - tmp_socket_set = allocSocketSet(); + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + tmp_socket_set = connProvider.newConnectionPollGroup(); // FIXME: I guess initialization of allowjoining is here now... - FlexCoral for (auto& tmpState : tmp_connectState) { @@ -5156,12 +5173,12 @@ void NETacceptIncomingConnections() } // See if there's an incoming connection - tmp_socket[i] = socketAccept(server_listen_socket); + tmp_socket[i] = server_listen_socket->accept(); if (!tmp_socket[i]) { return; } - const std::string rIP = getSocketTextAddress(*tmp_socket[i]); + const std::string rIP = tmp_socket[i]->textAddress(); if (quickRejectConnection(rIP)) { debug(LOG_NET, "freeing temp socket %p (%d)", static_cast(tmp_socket[i]), __LINE__); @@ -5170,7 +5187,7 @@ void NETacceptIncomingConnections() } NETinitQueue(NETnetTmpQueue(i)); - SocketSet_AddSocket(*tmp_socket_set, tmp_socket[i]); + tmp_socket_set->add(tmp_socket[i]); tmp_pendingIPs[rIP]++; @@ -5184,7 +5201,7 @@ void NETacceptIncomingConnections() if (bEnableTCPNoDelay) { - // Enable TCP_NODELAY - socketSetTCPNoDelay(*tmp_socket[i], true); + // Disable use of Nagle Algorithm for the TCP socket (i.e. enable TCP_NODELAY option in case of TCP connection) + tmp_socket[i]->useNagleAlgorithm(false); } } diff --git a/lib/netplay/netplay.h b/lib/netplay/netplay.h index 6cfaf64846c..3147f0b074e 100644 --- a/lib/netplay/netplay.h +++ b/lib/netplay/netplay.h @@ -445,9 +445,11 @@ bool NEThaltJoining(); // stop new players joining this game bool NETenumerateGames(const std::function& handleEnumerateGameFunc); bool NETfindGames(std::vector& results, size_t startingIndex, size_t resultsLimit, bool onlyMatchingLocalVersion = false); bool NETfindGame(uint32_t gameId, GAMESTRUCT& output); -struct Socket; -struct SocketSet; -bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char *playername, NETQUEUE joiningQUEUEInfo, Socket **client_joining_socket, SocketSet **client_joining_socket_set); + +class IClientConnection; +class IConnectionPollGroup; + +bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char* playername, NETQUEUE joiningQUEUEInfo, IClientConnection** client_joining_socket, IConnectionPollGroup** client_joining_socket_set); bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectatorHost, // host a game uint32_t gameType, uint32_t two, uint32_t three, uint32_t four, UDWORD plyrs); bool NETchangePlayerName(UDWORD player, char *newName);// change a players name. diff --git a/lib/netplay/open_connection_result.h b/lib/netplay/open_connection_result.h new file mode 100644 index 00000000000..ae45ddb9ff5 --- /dev/null +++ b/lib/netplay/open_connection_result.h @@ -0,0 +1,60 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include +#include +#include + +#include "lib/netplay/client_connection.h" + +#include +using nonstd::optional; +using nonstd::nullopt; + +// Higher-level functions for opening a connection / socket +struct OpenConnectionResult +{ +public: + OpenConnectionResult(std::error_code ec, std::string errorString) + : errorCode(ec) + , errorString(std::move(errorString)) + { } + + OpenConnectionResult(IClientConnection* open_socket) + : open_socket(open_socket) + { } + +public: + bool hasError() const { return errorCode.has_value(); } +public: + OpenConnectionResult(const OpenConnectionResult& other) = delete; // non construction-copyable + OpenConnectionResult& operator=(const OpenConnectionResult&) = delete; // non copyable + OpenConnectionResult(OpenConnectionResult&&) = default; + OpenConnectionResult& operator=(OpenConnectionResult&&) = default; +public: + + std::unique_ptr open_socket; + optional errorCode = nullopt; + std::string errorString; +}; + +typedef std::function OpenConnectionToHostResultCallback; diff --git a/lib/netplay/sync_debug.cpp b/lib/netplay/sync_debug.cpp index 536db78c080..e23499a86db 100644 --- a/lib/netplay/sync_debug.cpp +++ b/lib/netplay/sync_debug.cpp @@ -27,9 +27,9 @@ #include "lib/framework/debug.h" #include "lib/framework/physfs_ext.h" #include "lib/gamelib/gtime.h" +#include "lib/netplay/byteorder_funcs_wrapper.h" #include "nettypes.h" #include "netplay.h" -#include "netsocket.h" // solely to bring in `htonl` function #include @@ -76,7 +76,7 @@ struct SyncDebugValueChange : public SyncDebugEntry variableName = vn; newValue = nv; id = i; - uint32_t valueBytes = htonl(newValue); + uint32_t valueBytes = wz_htonl(newValue); crc = wz::crc_update(crc, function, strlen(function) + 1); crc = wz::crc_update(crc, variableName, strlen(variableName) + 1); crc = wz::crc_update(crc, &valueBytes, 4); @@ -105,7 +105,7 @@ struct SyncDebugIntList : public SyncDebugEntry numInts = std::min(num, ARRAY_SIZE(valueBytes)); for (unsigned n = 0; n < numInts; ++n) { - valueBytes[n] = htonl(ints[n]); + valueBytes[n] = wz_htonl(ints[n]); } crc = wz::crc_update(crc, valueBytes, 4 * numInts); } diff --git a/lib/netplay/netsocket.cpp b/lib/netplay/tcp/netsocket.cpp similarity index 96% rename from lib/netplay/netsocket.cpp rename to lib/netplay/tcp/netsocket.cpp index f03be08e7e1..9f844793da8 100644 --- a/lib/netplay/netsocket.cpp +++ b/lib/netplay/tcp/netsocket.cpp @@ -26,7 +26,7 @@ #include "lib/framework/frame.h" #include "lib/framework/wzapp.h" #include "netsocket.h" -#include "error_categories.h" +#include "lib/netplay/error_categories.h" #include #include @@ -47,6 +47,9 @@ // Already included Winsock2.h which defines TCP_NODELAY #endif +namespace tcp +{ + enum { SOCK_CONNECTION, @@ -127,6 +130,8 @@ static void setSockErr(int error) #endif } +} // namespace tcp + #if defined(WZ_OS_WIN) typedef int (WINAPI *GETADDRINFO_DLL_FUNC)(const char *node, const char *service, const struct addrinfo *hints, @@ -216,6 +221,9 @@ static void freeaddrinfo(struct addrinfo *res) } #endif +namespace tcp +{ + static int addressToText(const struct sockaddr *addr, char *buf, size_t size) { auto handleIpv4 = [&](uint32_t addr) { @@ -1749,71 +1757,4 @@ void SOCKETshutdown() #endif } -OpenConnectionResult socketOpenTCPConnectionSync(const char *host, uint32_t port) -{ - const auto hostsResult = resolveHost(host, port); - SocketAddress* hosts = hostsResult.value_or(nullptr); - if (hosts == nullptr) - { - const auto hostsErr = hostsResult.error(); - const auto hostsErrMsg = hostsErr.message(); - return OpenConnectionResult(hostsErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, hostsErr.value(), hostsErrMsg.c_str())); - } - - auto sockResult = socketOpenAny(hosts, 15000); - Socket* client_transient_socket = sockResult.value_or(nullptr); - deleteSocketAddress(hosts); - hosts = nullptr; - - if (client_transient_socket == nullptr) - { - const auto errValue = sockResult.error(); - const auto errMsg = errValue.message(); - return OpenConnectionResult(errValue, astringf("Cannot connect to [%s]:%d, [%d]:%s", host, port, errValue.value(), errMsg.c_str())); - } - - return OpenConnectionResult(client_transient_socket); -} - -struct OpenConnectionRequest -{ - std::string host; - uint32_t port = 0; - OpenConnectionToHostResultCallback callback; -}; - -static int openDirectTCPConnectionAsyncImpl(void* data) -{ - OpenConnectionRequest* pRequestInfo = (OpenConnectionRequest*)data; - if (!pRequestInfo) - { - return 1; - } - - pRequestInfo->callback(socketOpenTCPConnectionSync(pRequestInfo->host.c_str(), pRequestInfo->port)); - delete pRequestInfo; - return 0; -} - -bool socketOpenTCPConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) -{ - // spawn background thread to handle this - auto pRequest = new OpenConnectionRequest(); - pRequest->host = host; - pRequest->port = port; - pRequest->callback = callback; - - WZ_THREAD * pOpenConnectionThread = wzThreadCreate(openDirectTCPConnectionAsyncImpl, pRequest); - if (pOpenConnectionThread == nullptr) - { - debug(LOG_ERROR, "Failed to create thread for opening connection"); - delete pRequest; - return false; - } - - wzThreadDetach(pOpenConnectionThread); - // the thread handles deleting pRequest - pOpenConnectionThread = nullptr; - - return true; -} +} // namespace tcp diff --git a/lib/netplay/netsocket.h b/lib/netplay/tcp/netsocket.h similarity index 81% rename from lib/netplay/netsocket.h rename to lib/netplay/tcp/netsocket.h index 01413011cd3..b46a9ce89ce 100644 --- a/lib/netplay/netsocket.h +++ b/lib/netplay/tcp/netsocket.h @@ -21,21 +21,14 @@ #ifndef _net_socket_h #define _net_socket_h +#include "lib/framework/wzglobal.h" #include "lib/framework/types.h" #include #include #include +#include -#include -using nonstd::optional; -using nonstd::nullopt; -#include - -namespace net -{ - template - using result = ::tl::expected; -} // namespace net +#include "lib/netplay/net_result.h" #if defined(WZ_OS_UNIX) # include @@ -84,14 +77,22 @@ static const SOCKET INVALID_SOCKET = -1; # define MSG_NOSIGNAL 0 #endif +namespace tcp +{ + struct Socket; struct SocketSet; + +} // namespace tcp + typedef struct addrinfo SocketAddress; #ifndef WZ_OS_WIN static const int SOCKET_ERROR = -1; #endif +namespace tcp +{ // Init/shutdown. void SOCKETinit(); @@ -137,34 +138,6 @@ WZ_DECL_NONNULL(2) void SocketSet_AddSocket(SocketSet& set, Socket *socket); // WZ_DECL_NONNULL(2) void SocketSet_DelSocket(SocketSet& set, Socket *socket); ///< Removes a Socket from a SocketSet. int checkSockets(const SocketSet& set, unsigned int timeout); ///< Checks which Sockets are ready for reading. Returns the number of ready Sockets, or returns SOCKET_ERROR on error. -// Higher-level functions for opening a connection / socket -struct OpenConnectionResult -{ - OpenConnectionResult(std::error_code ec, std::string errorString) - : errorCode(ec) - , errorString(errorString) - { } - - OpenConnectionResult(Socket* open_socket) - : open_socket(open_socket) - { } - - bool hasError() const { return errorCode.has_value(); } - - OpenConnectionResult( const OpenConnectionResult& other ) = delete; // non construction-copyable - OpenConnectionResult& operator=( const OpenConnectionResult& ) = delete; // non copyable - OpenConnectionResult(OpenConnectionResult&&) = default; - OpenConnectionResult& operator=(OpenConnectionResult&&) = default; - - struct SocketDeleter { - void operator()(Socket* b) { if (b) { socketClose(b); } } - }; - std::unique_ptr open_socket; - optional errorCode = nullopt; - std::string errorString; -}; -typedef std::function OpenConnectionToHostResultCallback; -bool socketOpenTCPConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback); -OpenConnectionResult socketOpenTCPConnectionSync(const char *host, uint32_t port); +} // namespace tcp #endif //_net_socket_h diff --git a/lib/netplay/tcp/tcp_client_connection.cpp b/lib/netplay/tcp/tcp_client_connection.cpp new file mode 100644 index 00000000000..73d9d78c768 --- /dev/null +++ b/lib/netplay/tcp/tcp_client_connection.cpp @@ -0,0 +1,83 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/netsocket.h" +#include "lib/framework/wzapp.h" +#include "lib/framework/debug.h" +#include "lib/framework/string_ext.h" + +namespace tcp +{ + +TCPClientConnection::TCPClientConnection(Socket* rawSocket) + : socket_(rawSocket) +{ + ASSERT(socket_ != nullptr, "Null socket passed to TCPClientConnection ctor"); +} + +TCPClientConnection::~TCPClientConnection() +{ + if (socket_) + { + socketClose(socket_); + } +} + +net::result TCPClientConnection::readAll(void* buf, size_t size, unsigned timeout) +{ + return tcp::readAll(*socket_, buf, size, timeout); +} + +net::result TCPClientConnection::readNoInt(void* buf, size_t maxSize, size_t* rawByteCount) +{ + return tcp::readNoInt(*socket_, buf, maxSize, rawByteCount); +} + +net::result TCPClientConnection::writeAll(const void* buf, size_t size, size_t* rawByteCount) +{ + return tcp::writeAll(*socket_, buf, size, rawByteCount); +} + +bool TCPClientConnection::readReady() const +{ + return socketReadReady(*socket_); +} + +void TCPClientConnection::flush(size_t* rawByteCount) +{ + socketFlush(*socket_, std::numeric_limits::max()/*unused*/, rawByteCount); +} + +void TCPClientConnection::enableCompression() +{ + socketBeginCompression(*socket_); +} + +void TCPClientConnection::useNagleAlgorithm(bool enable) +{ + socketSetTCPNoDelay(*socket_, !enable); +} + +std::string TCPClientConnection::textAddress() const +{ + return getSocketTextAddress(*socket_); +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_client_connection.h b/lib/netplay/tcp/tcp_client_connection.h new file mode 100644 index 00000000000..c2d00959d7f --- /dev/null +++ b/lib/netplay/tcp/tcp_client_connection.h @@ -0,0 +1,52 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include "lib/netplay/client_connection.h" + +namespace tcp +{ + +struct Socket; + +class TCPClientConnection : public IClientConnection +{ +public: + + explicit TCPClientConnection(Socket* rawSocket); + virtual ~TCPClientConnection() override; + + virtual net::result readAll(void* buf, size_t size, unsigned timeout) override; + virtual net::result readNoInt(void* buf, size_t maxSize, size_t* rawByteCount) override; + virtual net::result writeAll(const void* buf, size_t size, size_t* rawByteCount) override; + virtual bool readReady() const override; + virtual void flush(size_t* rawByteCount) override; + virtual void enableCompression() override; + virtual void useNagleAlgorithm(bool enable) override; + virtual std::string textAddress() const override; + +private: + + Socket* socket_; + + friend class TCPConnectionPollGroup; +}; + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_poll_group.cpp b/lib/netplay/tcp/tcp_connection_poll_group.cpp new file mode 100644 index 00000000000..7aab0fcbe71 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_poll_group.cpp @@ -0,0 +1,60 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_connection_poll_group.h" +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/netsocket.h" +#include "lib/framework/wzapp.h" +#include "lib/framework/debug.h" + +namespace tcp +{ + +TCPConnectionPollGroup::TCPConnectionPollGroup(SocketSet* sset) + : sset_(sset) +{} + +TCPConnectionPollGroup::~TCPConnectionPollGroup() +{ + if (sset_) + { + deleteSocketSet(sset_); + } +} + +int TCPConnectionPollGroup::checkSockets(unsigned timeout) +{ + return tcp::checkSockets(*sset_, timeout); +} + +void TCPConnectionPollGroup::add(IClientConnection* conn) +{ + auto* tcpConn = dynamic_cast(conn); + ASSERT_OR_RETURN(, tcpConn != nullptr, "Expected to have TCPClientConnection instance"); + SocketSet_AddSocket(*sset_, tcpConn->socket_); +} + +void TCPConnectionPollGroup::remove(IClientConnection* conn) +{ + auto tcpConn = dynamic_cast(conn); + ASSERT_OR_RETURN(, tcpConn != nullptr, "Expected to have TCPClientConnection instance"); + SocketSet_DelSocket(*sset_, tcpConn->socket_); +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_poll_group.h b/lib/netplay/tcp/tcp_connection_poll_group.h new file mode 100644 index 00000000000..d9eec5bb909 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_poll_group.h @@ -0,0 +1,45 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include "lib/netplay/connection_poll_group.h" + +namespace tcp +{ + +struct SocketSet; + +class TCPConnectionPollGroup : public IConnectionPollGroup +{ +public: + + explicit TCPConnectionPollGroup(SocketSet* sset); + virtual ~TCPConnectionPollGroup() override; + + virtual int checkSockets(unsigned timeout) override; + virtual void add(IClientConnection* conn) override; + virtual void remove(IClientConnection* conn) override; + +private: + + SocketSet* sset_; +}; + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_provider.cpp b/lib/netplay/tcp/tcp_connection_provider.cpp new file mode 100644 index 00000000000..8a3182a02e7 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_provider.cpp @@ -0,0 +1,153 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "tcp_connection_provider.h" + +#include "lib/netplay/tcp/netsocket.h" +#include "lib/netplay/tcp/tcp_connection_poll_group.h" +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/tcp_listen_socket.h" + +#include "lib/netplay/open_connection_result.h" +#include "lib/framework/wzapp.h" + +namespace tcp +{ + +void TCPConnectionProvider::initialize() +{ + SOCKETinit(); +} + +void TCPConnectionProvider::shutdown() +{ + SOCKETshutdown(); +} + +net::result TCPConnectionProvider::resolveHost(const char* host, uint16_t port) +{ + return ConnectionAddress::parse(host, port); +} + +net::result TCPConnectionProvider::openListenSocket(uint16_t port) +{ + auto res = tcp::socketListen(port); + if (!res.has_value()) + { + return tl::make_unexpected(res.error()); + } + return new TCPListenSocket(res.value()); +} + +net::result TCPConnectionProvider::openClientConnectionAny(const ConnectionAddress& addr, unsigned timeout) +{ + const auto* rawAddr = addr.asRawSocketAddress(); + auto res = socketOpenAny(rawAddr, timeout); + if (!res.has_value()) + { + return tl::make_unexpected(res.error()); + } + return new TCPClientConnection(res.value()); +} + +namespace +{ + +OpenConnectionResult socketOpenTCPConnectionSync(const char* host, uint32_t port) +{ + const auto hostsResult = resolveHost(host, port); + SocketAddress* hosts = hostsResult.value_or(nullptr); + if (hosts == nullptr) + { + const auto hostsErr = hostsResult.error(); + const auto hostsErrMsg = hostsErr.message(); + return OpenConnectionResult(hostsErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, hostsErr.value(), hostsErrMsg.c_str())); + } + + auto sockResult = socketOpenAny(hosts, 15000); + Socket* client_transient_socket = sockResult.value_or(nullptr); + deleteSocketAddress(hosts); + hosts = nullptr; + + if (client_transient_socket == nullptr) + { + const auto errValue = sockResult.error(); + const auto errMsg = errValue.message(); + return OpenConnectionResult(errValue, astringf("Cannot connect to [%s]:%d, [%d]:%s", host, port, errValue.value(), errMsg.c_str())); + } + + return OpenConnectionResult(new TCPClientConnection(client_transient_socket)); +} + +struct OpenConnectionRequest +{ + std::string host; + uint32_t port = 0; + OpenConnectionToHostResultCallback callback; +}; + +static int openDirectTCPConnectionAsyncImpl(void* data) +{ + OpenConnectionRequest* pRequestInfo = (OpenConnectionRequest*)data; + if (!pRequestInfo) + { + return 1; + } + + pRequestInfo->callback(socketOpenTCPConnectionSync(pRequestInfo->host.c_str(), pRequestInfo->port)); + delete pRequestInfo; + return 0; +} + +} // anonymous namespace + +bool TCPConnectionProvider::openClientConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) +{ + // spawn background thread to handle this + auto pRequest = new OpenConnectionRequest(); + pRequest->host = host; + pRequest->port = port; + pRequest->callback = callback; + + WZ_THREAD* pOpenConnectionThread = wzThreadCreate(openDirectTCPConnectionAsyncImpl, pRequest); + if (pOpenConnectionThread == nullptr) + { + debug(LOG_ERROR, "Failed to create thread for opening connection"); + delete pRequest; + return false; + } + + wzThreadDetach(pOpenConnectionThread); + // the thread handles deleting pRequest + pOpenConnectionThread = nullptr; + + return true; +} + +IConnectionPollGroup* TCPConnectionProvider::newConnectionPollGroup() +{ + auto* sset = allocSocketSet(); + if (!sset) + { + return nullptr; + } + return new TCPConnectionPollGroup(sset); +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_provider.h b/lib/netplay/tcp/tcp_connection_provider.h new file mode 100644 index 00000000000..9a1336558b3 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_provider.h @@ -0,0 +1,48 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include +#include + +#include "lib/netplay/wz_connection_provider.h" + +namespace tcp +{ + +class TCPConnectionProvider final : public WzConnectionProvider +{ +public: + + virtual void initialize() override; + virtual void shutdown() override; + + virtual net::result resolveHost(const char* host, uint16_t port) override; + + virtual net::result openListenSocket(uint16_t port) override; + + virtual net::result openClientConnectionAny(const ConnectionAddress& addr, unsigned timeout) override; + virtual bool openClientConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) override; + + virtual IConnectionPollGroup* newConnectionPollGroup() override; +}; + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_listen_socket.cpp b/lib/netplay/tcp/tcp_listen_socket.cpp new file mode 100644 index 00000000000..d7ea294b4fa --- /dev/null +++ b/lib/netplay/tcp/tcp_listen_socket.cpp @@ -0,0 +1,70 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_listen_socket.h" +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/netsocket.h" +#include "lib/framework/wzapp.h" +#include "lib/framework/debug.h" + +namespace tcp +{ + +TCPListenSocket::TCPListenSocket(Socket* rawSocket) + : listenSocket_(rawSocket) +{} + +TCPListenSocket::~TCPListenSocket() +{ + if (listenSocket_) + { + socketClose(listenSocket_); + } +} + +IClientConnection* TCPListenSocket::accept() +{ + ASSERT(listenSocket_ != nullptr, "Internal socket handle shouldn't be null!"); + if (!listenSocket_) + { + return nullptr; + } + auto* s = socketAccept(listenSocket_); + if (!s) + { + return nullptr; + } + return new TCPClientConnection(s); +} + +IListenSocket::IPVersionsMask TCPListenSocket::supportedIpVersions() const +{ + IPVersionsMask resMask = 0; + if (socketHasIPv4(*listenSocket_)) + { + resMask |= static_cast(IPVersions::IPV4); + } + if (socketHasIPv6(*listenSocket_)) + { + resMask |= static_cast(IPVersions::IPV6); + } + return resMask; +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_listen_socket.h b/lib/netplay/tcp/tcp_listen_socket.h new file mode 100644 index 00000000000..db78b25a0df --- /dev/null +++ b/lib/netplay/tcp/tcp_listen_socket.h @@ -0,0 +1,46 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include "lib/netplay/listen_socket.h" + +namespace tcp +{ + +struct Socket; + +class TCPListenSocket : public IListenSocket +{ +public: + + explicit TCPListenSocket(tcp::Socket* rawSocket); + virtual ~TCPListenSocket() override; + + virtual IClientConnection* accept() override; + virtual IPVersionsMask supportedIpVersions() const override; + +private: + + tcp::Socket* listenSocket_; +}; + +} // namespace tcp diff --git a/lib/netplay/wz_connection_provider.h b/lib/netplay/wz_connection_provider.h new file mode 100644 index 00000000000..1c9d51dbd8e --- /dev/null +++ b/lib/netplay/wz_connection_provider.h @@ -0,0 +1,80 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include "lib/netplay/connection_address.h" +#include "lib/netplay/net_result.h" +#include "lib/netplay/open_connection_result.h" + +class IListenSocket; +class IClientConnection; +class IConnectionPollGroup; + +/// +/// Abstraction layer to facilitate creating client/server connections and +/// provide host resolution routines for a given network backend. +/// +/// A typical implementation of this interface should at least provide the following +/// things: +/// +/// 1. Initialization/teardown routines (setup some common state, like write/submission +/// queues or service threads, plus initialization of low-level backend code, e.g. +/// calls to init/deinit functions from a 3rd-party library). +/// 2. Host resolution. +/// 3. Opening server-side listen sockets. +/// 4. Opening client-side connections (sync and async). +/// 5. Creating connection poll groups. +/// +class WzConnectionProvider +{ +public: + + virtual ~WzConnectionProvider() = default; + + virtual void initialize() = 0; + virtual void shutdown() = 0; + + /// + /// Resolve host + port combination and return an opaque `ConnectionAddress` handle + /// representing the resolved network address. + /// + virtual net::result resolveHost(const char* host, uint16_t port) = 0; + /// + /// Open a listening socket bound to a specified local port. + /// + virtual net::result openListenSocket(uint16_t port) = 0; + /// + /// Synchronously open a client connection bound to one of the addresses + /// represented by `addr` (the first one that succeeds). + /// + /// Connection address to bind the client connection to. + /// Timeout in milliseconds. + virtual net::result openClientConnectionAny(const ConnectionAddress& addr, unsigned timeout) = 0; + /// + /// Async variant of `openClientConnectionAny()`. + /// + virtual bool openClientConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) = 0; + /// + /// Create a group for polling client connections. + /// + virtual IConnectionPollGroup* newConnectionPollGroup() = 0; +}; diff --git a/po/POTFILES.in b/po/POTFILES.in index 9f8c0cc58b6..1365f738c66 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -303,14 +303,20 @@ lib/ivis_opengl/screen.cpp lib/ivis_opengl/tex.cpp lib/ivis_opengl/textdraw.cpp lib/netplay/error_categories.cpp +lib/netplay/connection_address.cpp +lib/netplay/connection_provider_registry.cpp lib/netplay/netjoin_stub.cpp lib/netplay/netlog.cpp lib/netplay/netpermissions.cpp lib/netplay/netplay.cpp lib/netplay/netqueue.cpp lib/netplay/netreplay.cpp -lib/netplay/netsocket.cpp lib/netplay/nettypes.cpp +lib/netplay/tcp/netsocket.cpp +lib/netplay/tcp/tcp_client_connection.cpp +lib/netplay/tcp/tcp_connection_poll_group.cpp +lib/netplay/tcp/tcp_connection_provider.cpp +lib/netplay/tcp/tcp_listen_socket.cpp lib/netplay/port_mapping_manager.cpp lib/netplay/port_mapping_manager_impl_libplum.cpp lib/netplay/port_mapping_manager_impl_miniupnpc.cpp diff --git a/src/integrations/wzdiscordrpc.cpp b/src/integrations/wzdiscordrpc.cpp index 8cfd245eb18..eaa897d5b16 100644 --- a/src/integrations/wzdiscordrpc.cpp +++ b/src/integrations/wzdiscordrpc.cpp @@ -22,7 +22,7 @@ #include "lib/framework/crc.h" #include "lib/gamelib/gtime.h" #include "lib/netplay/netplay.h" -#include "lib/netplay/netsocket.h" +#include "lib/netplay/tcp/netsocket.h" #include "../activity.h" #include "../frontend.h" #include "../multiint.h" @@ -395,10 +395,10 @@ static void joinGameFromSecret_v1(const std::string joinSecretStr) return; } std::vector ipNetworkBinaryFormat = EmbeddedJSONSignature::b64Decode(b64UrlSafeTob64(connectionDetails[0].toUtf8())); - std::string ipAddressStr = ipv6_NetBinary_To_AddressString(ipNetworkBinaryFormat); + std::string ipAddressStr = tcp::ipv6_NetBinary_To_AddressString(ipNetworkBinaryFormat); if (ipAddressStr.empty()) { - ipAddressStr = ipv4_NetBinary_To_AddressString(ipNetworkBinaryFormat); + ipAddressStr = tcp::ipv4_NetBinary_To_AddressString(ipNetworkBinaryFormat); } if (ipAddressStr.empty()) { @@ -766,7 +766,7 @@ void DiscordRPCActivitySink::setJoinInformation(const ActivitySink::MultiplayerG std::string joinSecretDetails; if (!pExternalIPv4Address->empty()) { - auto ipv4AddressBinaryForm = ipv4_AddressString_To_NetBinary(*pExternalIPv4Address); + auto ipv4AddressBinaryForm = tcp::ipv4_AddressString_To_NetBinary(*pExternalIPv4Address); if (!ipv4AddressBinaryForm.empty()) { joinSecretDetails += b64Tob64UrlSafe(EmbeddedJSONSignature::b64Encode(ipv4AddressBinaryForm)); @@ -775,7 +775,7 @@ void DiscordRPCActivitySink::setJoinInformation(const ActivitySink::MultiplayerG } if (!ipv6Address.empty()) { - auto ipv6AddressBinaryForm = ipv6_AddressString_To_NetBinary(ipv6Address); + auto ipv6AddressBinaryForm = tcp::ipv6_AddressString_To_NetBinary(ipv6Address); if (!ipv6AddressBinaryForm.empty()) { if (!joinSecretDetails.empty()) diff --git a/src/screens/joiningscreen.cpp b/src/screens/joiningscreen.cpp index 1f18c884aff..13f5ee5b7ce 100644 --- a/src/screens/joiningscreen.cpp +++ b/src/screens/joiningscreen.cpp @@ -28,8 +28,12 @@ #include "lib/widget/scrollablelist.h" #include "lib/ivis_opengl/pieblitfunc.h" #include "lib/ivis_opengl/piepalette.h" +#include "lib/netplay/byteorder_funcs_wrapper.h" #include "lib/netplay/netplay.h" -#include "lib/netplay/netsocket.h" +#include "lib/netplay/client_connection.h" +#include "lib/netplay/connection_poll_group.h" +#include "lib/netplay/open_connection_result.h" +#include "lib/netplay/connection_provider_registry.h" #include "../hci.h" #include "../activity.h" @@ -786,8 +790,8 @@ class WzJoiningGameScreen_HandlerRoot : public W_CLICKFORM // state when handling initial connection join uint32_t startTime = 0; - Socket* client_transient_socket = nullptr; - SocketSet* tmp_joining_socket_set = nullptr; + IClientConnection* client_transient_socket = nullptr; + IConnectionPollGroup* tmp_joining_socket_set = nullptr; NETQUEUE tmpJoiningQUEUE = {}; NetQueuePair *tmpJoiningQueuePair = nullptr; char initialAckBuffer[10] = {'\0'}; @@ -1162,22 +1166,22 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect if (NETgetEnableTCPNoDelay()) { - // Enable TCP_NODELAY - socketSetTCPNoDelay(*client_transient_socket, true); + // Disable use of Nagle Algorithm for the TCP socket (i.e. enable TCP_NODELAY option in case of TCP transport) + client_transient_socket->useNagleAlgorithm(false); } // Send initial connection data: NETCODE_VERSION_MAJOR and NETCODE_VERSION_MINOR char buffer[sizeof(int32_t) * 2] = { 0 }; char *p_buffer = buffer; auto pushu32 = [&](uint32_t value) { - uint32_t swapped = htonl(value); + uint32_t swapped = wz_htonl(value); memcpy(p_buffer, &swapped, sizeof(swapped)); p_buffer += sizeof(swapped); }; pushu32(NETGetMajorVersion()); pushu32(NETGetMinorVersion()); - const auto writeResult = writeAll(*client_transient_socket, buffer, sizeof(buffer)); + const auto writeResult = client_transient_socket->writeAll(buffer, sizeof(buffer), nullptr); if (!writeResult.has_value()) { const auto writeErrMsg = writeResult.error().message(); @@ -1186,7 +1190,8 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect return; } - tmp_joining_socket_set = allocSocketSet(); + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + tmp_joining_socket_set = connProvider.newConnectionPollGroup(); if (tmp_joining_socket_set == nullptr) { debug(LOG_ERROR, "Cannot create socket set - out of memory?"); @@ -1196,7 +1201,7 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect debug(LOG_NET, "Created socket_set %p", static_cast(tmp_joining_socket_set)); // `client_transient_socket` is used to talk to host machine - SocketSet_AddSocket(*tmp_joining_socket_set, client_transient_socket); + tmp_joining_socket_set->add(client_transient_socket); // Create temporary NETQUEUE auto NETnetJoinTmpQueue = [&]() @@ -1233,7 +1238,9 @@ void WzJoiningGameScreen_HandlerRoot::attemptToOpenConnection(size_t connectionI description.port = NETgetGameserverPort(); // use default configured port } auto weakSelf = std::weak_ptr(std::dynamic_pointer_cast(shared_from_this())); - socketOpenTCPConnectionAsync(description.host, description.port, [weakSelf, connectionIdx](OpenConnectionResult&& result) { + + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + connProvider.openClientConnectionAsync(description.host, description.port, [weakSelf, connectionIdx](OpenConnectionResult&& result) { auto strongSelf = weakSelf.lock(); if (!strongSelf) { @@ -1254,7 +1261,7 @@ bool WzJoiningGameScreen_HandlerRoot::joiningSocketNETsend() uint8_t *rawData = message->rawDataDup(); ssize_t rawLen = message->rawLen(); size_t compressedRawLen = 0; - const auto writeResult = writeAll(*client_transient_socket, rawData, rawLen, &compressedRawLen); + const auto writeResult = client_transient_socket->writeAll(rawData, rawLen, &compressedRawLen); delete[] rawData; // Done with the data. queue->popMessageForNet(); if (writeResult.has_value()) @@ -1269,7 +1276,7 @@ bool WzJoiningGameScreen_HandlerRoot::joiningSocketNETsend() debug(LOG_ERROR, "Failed to send message (type: %" PRIu8 ", rawLen: %zu, compressedRawLen: %zu) to host: %s", message->type, message->rawLen(), compressedRawLen, writeErrMsg.c_str()); return false; } - socketFlush(*client_transient_socket, NET_HOST_ONLY); // Make sure the message was completely sent. + client_transient_socket->flush(nullptr); // Make sure the message was completely sent. ASSERT(queue->numMessagesForNet() == 0, "Queue not empty (%u messages remaining).", queue->numMessagesForNet()); return true; } @@ -1280,14 +1287,14 @@ void WzJoiningGameScreen_HandlerRoot::closeConnectionAttempt() { if (tmp_joining_socket_set) { - SocketSet_DelSocket(*tmp_joining_socket_set, client_transient_socket); + tmp_joining_socket_set->remove(client_transient_socket); } - socketClose(client_transient_socket); + delete client_transient_socket; client_transient_socket = nullptr; } if (tmp_joining_socket_set) { - deleteSocketSet(tmp_joining_socket_set); + delete tmp_joining_socket_set; tmp_joining_socket_set = nullptr; } if (tmpJoiningQueuePair) @@ -1347,15 +1354,17 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() if (currentJoiningState == JoiningState::AwaitingInitialNetcodeHandshakeAck) { // read in data, if we have it - if (checkSockets(*tmp_joining_socket_set, NET_READ_TIMEOUT) > 0) + if (tmp_joining_socket_set->checkSockets(NET_READ_TIMEOUT) > 0) { - if (!socketReadReady(*client_transient_socket)) + if (!client_transient_socket->readReady()) { return; // wait for next check } char *p_buffer = initialAckBuffer; - const auto readResult = readNoInt(*client_transient_socket, p_buffer + usedInitialAckBuffer, expectedInitialAckSize - usedInitialAckBuffer); + const auto readResult = client_transient_socket->readNoInt(p_buffer + usedInitialAckBuffer, + expectedInitialAckSize - usedInitialAckBuffer, + nullptr); if (readResult.has_value()) { usedInitialAckBuffer += static_cast(readResult.value()); @@ -1365,7 +1374,7 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() { uint32_t result = ERROR_CONNECTION; memcpy(&result, initialAckBuffer, sizeof(result)); - result = ntohl(result); + result = wz_ntohl(result); if (result != ERROR_NOERROR) { debug(LOG_ERROR, "Received error %d", result); @@ -1387,7 +1396,7 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() } // transition to net message mode (enable compression, wait for messages) - socketBeginCompression(*client_transient_socket); + client_transient_socket->enableCompression(); currentJoiningState = JoiningState::ProcessingJoinMessages; // permit fall-through to currentJoiningState == JoiningState::ProcessingJoinMessage case below } @@ -1398,15 +1407,15 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() if (currentJoiningState == JoiningState::ProcessingJoinMessages) { // read in data, if we have it - if (checkSockets(*tmp_joining_socket_set, NET_READ_TIMEOUT) > 0) + if (tmp_joining_socket_set->checkSockets(NET_READ_TIMEOUT) > 0) { - if (!socketReadReady(*client_transient_socket)) + if (!client_transient_socket->readReady()) { return; // wait for next check } uint8_t readBuffer[NET_BUFFER_SIZE]; - const auto readResult = readNoInt(*client_transient_socket, readBuffer, sizeof(readBuffer)); + const auto readResult = client_transient_socket->readNoInt(readBuffer, sizeof(readBuffer), nullptr); if (!readResult.has_value()) { // disconnect or programmer error From b0c39f5adb8cf0c61bc79d57283aa2802f5c99a6 Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Sat, 26 Oct 2024 16:37:40 +0300 Subject: [PATCH 2/3] netplay: introduce inheritance chain for connection address objects This way it looks much cleaner and naturally doesn't leak any implementation details to clients. Also, now it's much easier to provide native connection address implementations for each backend implementation. Signed-off-by: Pavel Solodovnikov --- lib/netplay/connection_address.cpp | 65 --------------------- lib/netplay/connection_address.h | 26 +-------- lib/netplay/netplay.cpp | 4 +- lib/netplay/tcp/tcp_connection_address.cpp | 30 ++++++++++ lib/netplay/tcp/tcp_connection_address.h | 43 ++++++++++++++ lib/netplay/tcp/tcp_connection_provider.cpp | 20 +++++-- lib/netplay/tcp/tcp_connection_provider.h | 4 +- lib/netplay/wz_connection_provider.h | 6 +- po/POTFILES.in | 3 +- 9 files changed, 100 insertions(+), 101 deletions(-) delete mode 100644 lib/netplay/connection_address.cpp create mode 100644 lib/netplay/tcp/tcp_connection_address.cpp create mode 100644 lib/netplay/tcp/tcp_connection_address.h diff --git a/lib/netplay/connection_address.cpp b/lib/netplay/connection_address.cpp deleted file mode 100644 index e4c277b8c50..00000000000 --- a/lib/netplay/connection_address.cpp +++ /dev/null @@ -1,65 +0,0 @@ -/* - This file is part of Warzone 2100. - Copyright (C) 2024 Warzone 2100 Project - - Warzone 2100 is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - Warzone 2100 is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with Warzone 2100; if not, write to the Free Software - Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA -*/ - -#include "lib/netplay/connection_address.h" -#include "lib/netplay/tcp/netsocket.h" // for `resolveHost` - -#include "lib/framework/frame.h" // for `ASSERT` - -struct ConnectionAddress::Impl final -{ - explicit Impl(SocketAddress* addr) - : mAddr_(addr) - {} - - ~Impl() - { - ASSERT(mAddr_ != nullptr, "Invalid addrinfo stored in the connection address"); - freeaddrinfo(mAddr_); - } - - SocketAddress* mAddr_; -}; - -ConnectionAddress::ConnectionAddress() = default; -ConnectionAddress::ConnectionAddress(ConnectionAddress&&) = default; -ConnectionAddress::~ConnectionAddress() = default; - -const SocketAddress* ConnectionAddress::asRawSocketAddress() const -{ - return mPimpl_->mAddr_; -} - - -net::result ConnectionAddress::parse(const char* hostname, uint16_t port) -{ - ConnectionAddress res; - const auto addr = tcp::resolveHost(hostname, port); - if (!addr.has_value()) - { - return tl::make_unexpected(addr.error()); - } - res.mPimpl_ = std::make_unique(addr.value()); - return net::result{std::move(res)}; -} - -net::result ConnectionAddress::parse(const std::string& hostname, uint16_t port) -{ - return parse(hostname.c_str(), port); -} diff --git a/lib/netplay/connection_address.h b/lib/netplay/connection_address.h index 604d9dc12da..876a46ea745 100644 --- a/lib/netplay/connection_address.h +++ b/lib/netplay/connection_address.h @@ -23,13 +23,7 @@ #include "lib/netplay/net_result.h" -#if defined WZ_OS_UNIX -# include -#elif defined WZ_OS_WIN -# include -#endif -typedef struct addrinfo SocketAddress; /// /// Opaque class representing abstract connection address to use with various @@ -50,23 +44,7 @@ typedef struct addrinfo SocketAddress; /// New conversion routines should be introduced for other network backends, /// if deemed necessary. /// -class ConnectionAddress +struct IConnectionAddress { -public: - - ConnectionAddress(); - ConnectionAddress(ConnectionAddress&&); - ConnectionAddress(const ConnectionAddress&) = delete; - ~ConnectionAddress(); - - static net::result parse(const char* hostname, uint16_t port); - static net::result parse(const std::string& hostname, uint16_t port); - - // NOTE: The lifetime of the returned `addrinfo` struct is bounded by the parent object's lifetime! - const SocketAddress* asRawSocketAddress() const; - -private: - - struct Impl; - std::unique_ptr mPimpl_; + virtual ~IConnectionAddress() = default; }; diff --git a/lib/netplay/netplay.cpp b/lib/netplay/netplay.cpp index 294ac0cfb26..467f6c1ba26 100644 --- a/lib/netplay/netplay.cpp +++ b/lib/netplay/netplay.cpp @@ -3467,7 +3467,7 @@ bool LobbyServerConnectionHandler::connect() } // try each address from resolveHost until we successfully connect. - auto sockResult = connProvider.openClientConnectionAny(hosts, 1500); + auto sockResult = connProvider.openClientConnectionAny(*hosts, 1500); rs_socket = sockResult.value_or(nullptr); @@ -4664,7 +4664,7 @@ bool NETenumerateGames(const std::function& handl return false; } const auto& hosts = hostsResult.value(); - auto sockResult = connProvider.openClientConnectionAny(hosts, 15000); + auto sockResult = connProvider.openClientConnectionAny(*hosts, 15000); if (!sockResult.has_value()) { const auto sockErrMsg = sockResult.error().message(); diff --git a/lib/netplay/tcp/tcp_connection_address.cpp b/lib/netplay/tcp/tcp_connection_address.cpp new file mode 100644 index 00000000000..66f7d1981e7 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_address.cpp @@ -0,0 +1,30 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_connection_address.h" + +#include "lib/netplay/tcp/netsocket.h" // for `freeaddrinfo` +#include "lib/framework/frame.h" // for `ASSERT` + +TCPConnectionAddress::TCPConnectionAddress(SocketAddress* addr) + : addr_(addr) +{} + +TCPConnectionAddress::~TCPConnectionAddress() +{ + ASSERT(addr_ != nullptr, "Invalid addrinfo stored in the connection address"); + freeaddrinfo(addr_); +} diff --git a/lib/netplay/tcp/tcp_connection_address.h b/lib/netplay/tcp/tcp_connection_address.h new file mode 100644 index 00000000000..fdc7b0b6d76 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_address.h @@ -0,0 +1,43 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include "lib/netplay/connection_address.h" + +#if defined WZ_OS_UNIX +# include +#elif defined WZ_OS_WIN +# include +#endif + +typedef struct addrinfo SocketAddress; + +class TCPConnectionAddress : public IConnectionAddress +{ +public: + + /// Assumes ownership of `addr` + explicit TCPConnectionAddress(SocketAddress* addr); + virtual ~TCPConnectionAddress() override; + + // NOTE: The lifetime of the returned `addrinfo` struct is bounded by the parent object's lifetime! + const SocketAddress* asRawSocketAddress() const { return addr_; } + +private: + + SocketAddress* addr_; +}; diff --git a/lib/netplay/tcp/tcp_connection_provider.cpp b/lib/netplay/tcp/tcp_connection_provider.cpp index 8a3182a02e7..d801fc05a14 100644 --- a/lib/netplay/tcp/tcp_connection_provider.cpp +++ b/lib/netplay/tcp/tcp_connection_provider.cpp @@ -20,6 +20,7 @@ #include "tcp_connection_provider.h" #include "lib/netplay/tcp/netsocket.h" +#include "lib/netplay/tcp/tcp_connection_address.h" #include "lib/netplay/tcp/tcp_connection_poll_group.h" #include "lib/netplay/tcp/tcp_client_connection.h" #include "lib/netplay/tcp/tcp_listen_socket.h" @@ -40,9 +41,14 @@ void TCPConnectionProvider::shutdown() SOCKETshutdown(); } -net::result TCPConnectionProvider::resolveHost(const char* host, uint16_t port) +net::result> TCPConnectionProvider::resolveHost(const char* host, uint16_t port) { - return ConnectionAddress::parse(host, port); + auto resolved = tcp::resolveHost(host, port); + if (!resolved.has_value()) + { + return tl::make_unexpected(resolved.error()); + } + return std::make_unique(resolved.value()); } net::result TCPConnectionProvider::openListenSocket(uint16_t port) @@ -55,9 +61,15 @@ net::result TCPConnectionProvider::openListenSocket(uint16_t por return new TCPListenSocket(res.value()); } -net::result TCPConnectionProvider::openClientConnectionAny(const ConnectionAddress& addr, unsigned timeout) +net::result TCPConnectionProvider::openClientConnectionAny(const IConnectionAddress& addr, unsigned timeout) { - const auto* rawAddr = addr.asRawSocketAddress(); + const auto* tcpAddr = dynamic_cast(&addr); + ASSERT(tcpAddr != nullptr, "Expected TCPConnectionAddress instance"); + if (!tcpAddr) + { + throw std::runtime_error("Expected TCPConnectionAddress instance"); + } + const auto* rawAddr = tcpAddr->asRawSocketAddress(); auto res = socketOpenAny(rawAddr, timeout); if (!res.has_value()) { diff --git a/lib/netplay/tcp/tcp_connection_provider.h b/lib/netplay/tcp/tcp_connection_provider.h index 9a1336558b3..c34add18b1e 100644 --- a/lib/netplay/tcp/tcp_connection_provider.h +++ b/lib/netplay/tcp/tcp_connection_provider.h @@ -35,11 +35,11 @@ class TCPConnectionProvider final : public WzConnectionProvider virtual void initialize() override; virtual void shutdown() override; - virtual net::result resolveHost(const char* host, uint16_t port) override; + virtual net::result> resolveHost(const char* host, uint16_t port) override; virtual net::result openListenSocket(uint16_t port) override; - virtual net::result openClientConnectionAny(const ConnectionAddress& addr, unsigned timeout) override; + virtual net::result openClientConnectionAny(const IConnectionAddress& addr, unsigned timeout) override; virtual bool openClientConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) override; virtual IConnectionPollGroup* newConnectionPollGroup() override; diff --git a/lib/netplay/wz_connection_provider.h b/lib/netplay/wz_connection_provider.h index 1c9d51dbd8e..9a9d879ff66 100644 --- a/lib/netplay/wz_connection_provider.h +++ b/lib/netplay/wz_connection_provider.h @@ -20,6 +20,7 @@ #pragma once #include +#include #include "lib/netplay/connection_address.h" #include "lib/netplay/net_result.h" @@ -28,6 +29,7 @@ class IListenSocket; class IClientConnection; class IConnectionPollGroup; +struct IConnectionAddress; /// /// Abstraction layer to facilitate creating client/server connections and @@ -57,7 +59,7 @@ class WzConnectionProvider /// Resolve host + port combination and return an opaque `ConnectionAddress` handle /// representing the resolved network address. /// - virtual net::result resolveHost(const char* host, uint16_t port) = 0; + virtual net::result> resolveHost(const char* host, uint16_t port) = 0; /// /// Open a listening socket bound to a specified local port. /// @@ -68,7 +70,7 @@ class WzConnectionProvider /// /// Connection address to bind the client connection to. /// Timeout in milliseconds. - virtual net::result openClientConnectionAny(const ConnectionAddress& addr, unsigned timeout) = 0; + virtual net::result openClientConnectionAny(const IConnectionAddress& addr, unsigned timeout) = 0; /// /// Async variant of `openClientConnectionAny()`. /// diff --git a/po/POTFILES.in b/po/POTFILES.in index 1365f738c66..78dd4afab84 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -302,9 +302,8 @@ lib/ivis_opengl/png_util_spng.cpp lib/ivis_opengl/screen.cpp lib/ivis_opengl/tex.cpp lib/ivis_opengl/textdraw.cpp -lib/netplay/error_categories.cpp -lib/netplay/connection_address.cpp lib/netplay/connection_provider_registry.cpp +lib/netplay/error_categories.cpp lib/netplay/netjoin_stub.cpp lib/netplay/netlog.cpp lib/netplay/netpermissions.cpp From 1ae4f40930cc61124c58c7fcf685e63014a460e8 Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Sat, 16 Nov 2024 17:28:57 +0300 Subject: [PATCH 3/3] netplay: move code for async opening of client connections to the base `WzConnectionProvider` Provide the default implementation for `openClientConnectionAsync()` in `WzConnectionProvider`, which just spawns a new thread and piggibacks on the `resolveHost()` + `openClientConnectionAny()` combination. The `TCPConnectionProvider` is now simpler because it can use the default implementation from the base class. Signed-off-by: Pavel Solodovnikov --- lib/netplay/tcp/tcp_connection_provider.cpp | 74 ---------------- lib/netplay/tcp/tcp_connection_provider.h | 1 - lib/netplay/wz_connection_provider.cpp | 93 +++++++++++++++++++++ lib/netplay/wz_connection_provider.h | 7 +- src/screens/joiningscreen.cpp | 21 +++-- 5 files changed, 110 insertions(+), 86 deletions(-) create mode 100644 lib/netplay/wz_connection_provider.cpp diff --git a/lib/netplay/tcp/tcp_connection_provider.cpp b/lib/netplay/tcp/tcp_connection_provider.cpp index d801fc05a14..1bf1edddf4b 100644 --- a/lib/netplay/tcp/tcp_connection_provider.cpp +++ b/lib/netplay/tcp/tcp_connection_provider.cpp @@ -78,80 +78,6 @@ net::result TCPConnectionProvider::openClientConnectionAny(c return new TCPClientConnection(res.value()); } -namespace -{ - -OpenConnectionResult socketOpenTCPConnectionSync(const char* host, uint32_t port) -{ - const auto hostsResult = resolveHost(host, port); - SocketAddress* hosts = hostsResult.value_or(nullptr); - if (hosts == nullptr) - { - const auto hostsErr = hostsResult.error(); - const auto hostsErrMsg = hostsErr.message(); - return OpenConnectionResult(hostsErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, hostsErr.value(), hostsErrMsg.c_str())); - } - - auto sockResult = socketOpenAny(hosts, 15000); - Socket* client_transient_socket = sockResult.value_or(nullptr); - deleteSocketAddress(hosts); - hosts = nullptr; - - if (client_transient_socket == nullptr) - { - const auto errValue = sockResult.error(); - const auto errMsg = errValue.message(); - return OpenConnectionResult(errValue, astringf("Cannot connect to [%s]:%d, [%d]:%s", host, port, errValue.value(), errMsg.c_str())); - } - - return OpenConnectionResult(new TCPClientConnection(client_transient_socket)); -} - -struct OpenConnectionRequest -{ - std::string host; - uint32_t port = 0; - OpenConnectionToHostResultCallback callback; -}; - -static int openDirectTCPConnectionAsyncImpl(void* data) -{ - OpenConnectionRequest* pRequestInfo = (OpenConnectionRequest*)data; - if (!pRequestInfo) - { - return 1; - } - - pRequestInfo->callback(socketOpenTCPConnectionSync(pRequestInfo->host.c_str(), pRequestInfo->port)); - delete pRequestInfo; - return 0; -} - -} // anonymous namespace - -bool TCPConnectionProvider::openClientConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) -{ - // spawn background thread to handle this - auto pRequest = new OpenConnectionRequest(); - pRequest->host = host; - pRequest->port = port; - pRequest->callback = callback; - - WZ_THREAD* pOpenConnectionThread = wzThreadCreate(openDirectTCPConnectionAsyncImpl, pRequest); - if (pOpenConnectionThread == nullptr) - { - debug(LOG_ERROR, "Failed to create thread for opening connection"); - delete pRequest; - return false; - } - - wzThreadDetach(pOpenConnectionThread); - // the thread handles deleting pRequest - pOpenConnectionThread = nullptr; - - return true; -} - IConnectionPollGroup* TCPConnectionProvider::newConnectionPollGroup() { auto* sset = allocSocketSet(); diff --git a/lib/netplay/tcp/tcp_connection_provider.h b/lib/netplay/tcp/tcp_connection_provider.h index c34add18b1e..891bf469f9c 100644 --- a/lib/netplay/tcp/tcp_connection_provider.h +++ b/lib/netplay/tcp/tcp_connection_provider.h @@ -40,7 +40,6 @@ class TCPConnectionProvider final : public WzConnectionProvider virtual net::result openListenSocket(uint16_t port) override; virtual net::result openClientConnectionAny(const IConnectionAddress& addr, unsigned timeout) override; - virtual bool openClientConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) override; virtual IConnectionPollGroup* newConnectionPollGroup() override; }; diff --git a/lib/netplay/wz_connection_provider.cpp b/lib/netplay/wz_connection_provider.cpp new file mode 100644 index 00000000000..89f4b8556b2 --- /dev/null +++ b/lib/netplay/wz_connection_provider.cpp @@ -0,0 +1,93 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "wz_connection_provider.h" + +#include "lib/framework/wzapp.h" + +namespace +{ + +OpenConnectionResult openClientConnectionSyncImpl(const char* host, uint32_t port, std::chrono::milliseconds timeout, WzConnectionProvider* connProvider) +{ + auto addrResult = connProvider->resolveHost(host, port); + if (!addrResult.has_value()) + { + const auto hostsErr = addrResult.error(); + const auto hostsErrMsg = hostsErr.message(); + return OpenConnectionResult(hostsErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, hostsErr.value(), hostsErrMsg.c_str())); + } + auto connRes = connProvider->openClientConnectionAny(*addrResult.value(), timeout.count()); + if (!connRes.has_value()) + { + const auto connErr = connRes.error(); + const auto connErrMsg = connErr.message(); + return OpenConnectionResult(connErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, connErr.value(), connErrMsg.c_str())); + } + return OpenConnectionResult(connRes.value()); +} + +struct OpenConnectionRequest +{ + std::string host; + uint32_t port = 0; + std::chrono::milliseconds timeout{ 15000 }; + OpenConnectionToHostResultCallback callback; + WzConnectionProvider* connProvider; +}; + +int openDirectConnectionAsyncImpl(void* data) +{ + OpenConnectionRequest* pRequestInfo = (OpenConnectionRequest*)data; + if (!pRequestInfo) + { + return 1; + } + pRequestInfo->callback(openClientConnectionSyncImpl( + pRequestInfo->host.c_str(), + pRequestInfo->port, + pRequestInfo->timeout, + pRequestInfo->connProvider)); + delete pRequestInfo; + return 0; +} + +} // anonymous namespace + +bool WzConnectionProvider::openClientConnectionAsync(const std::string& host, uint32_t port, std::chrono::milliseconds timeout, OpenConnectionToHostResultCallback callback) +{ + // spawn background thread to handle this + auto pRequest = new OpenConnectionRequest(); + pRequest->host = host; + pRequest->port = port; + pRequest->timeout = timeout; + pRequest->callback = callback; + pRequest->connProvider = this; + WZ_THREAD* pOpenConnectionThread = wzThreadCreate(openDirectConnectionAsyncImpl, pRequest); + if (pOpenConnectionThread == nullptr) + { + debug(LOG_ERROR, "Failed to create thread for opening connection"); + delete pRequest; + return false; + } + wzThreadDetach(pOpenConnectionThread); + // the thread handles deleting pRequest + pOpenConnectionThread = nullptr; + return true; +} diff --git a/lib/netplay/wz_connection_provider.h b/lib/netplay/wz_connection_provider.h index 9a9d879ff66..91683e3081b 100644 --- a/lib/netplay/wz_connection_provider.h +++ b/lib/netplay/wz_connection_provider.h @@ -20,7 +20,9 @@ #pragma once #include +#include #include +#include #include "lib/netplay/connection_address.h" #include "lib/netplay/net_result.h" @@ -72,9 +74,10 @@ class WzConnectionProvider /// Timeout in milliseconds. virtual net::result openClientConnectionAny(const IConnectionAddress& addr, unsigned timeout) = 0; /// - /// Async variant of `openClientConnectionAny()`. + /// Async variant of `openClientConnectionAny()` with the default implementation, which + /// spawns a new thread and piggybacks on the `resolveHost()` and `openClientConnectionAny()` combination. /// - virtual bool openClientConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) = 0; + virtual bool openClientConnectionAsync(const std::string& host, uint32_t port, std::chrono::milliseconds timeout, OpenConnectionToHostResultCallback callback); /// /// Create a group for polling client connections. /// diff --git a/src/screens/joiningscreen.cpp b/src/screens/joiningscreen.cpp index 13f5ee5b7ce..bf77aafe2fb 100644 --- a/src/screens/joiningscreen.cpp +++ b/src/screens/joiningscreen.cpp @@ -1239,16 +1239,19 @@ void WzJoiningGameScreen_HandlerRoot::attemptToOpenConnection(size_t connectionI } auto weakSelf = std::weak_ptr(std::dynamic_pointer_cast(shared_from_this())); + constexpr std::chrono::milliseconds CLIENT_OPEN_ASYNC_TIMEOUT{ 15000 }; // Default timeout of 15s + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); - connProvider.openClientConnectionAsync(description.host, description.port, [weakSelf, connectionIdx](OpenConnectionResult&& result) { - auto strongSelf = weakSelf.lock(); - if (!strongSelf) - { - // background thread ultimately returned after the requester has gone away (join was cancelled?) - just return - return; - } - strongSelf->processOpenConnectionResultOnMainThread(connectionIdx, std::move(result)); - }); + connProvider.openClientConnectionAsync(description.host, description.port, CLIENT_OPEN_ASYNC_TIMEOUT, + [weakSelf, connectionIdx](OpenConnectionResult&& result) { + auto strongSelf = weakSelf.lock(); + if (!strongSelf) + { + // background thread ultimately returned after the requester has gone away (join was cancelled?) - just return + return; + } + strongSelf->processOpenConnectionResultOnMainThread(connectionIdx, std::move(result)); + }); break; } updateJoiningStatus(_("Establishing connection with host"));