From 53534676037fbf76cd84b5b155df6fbdf2d777d4 Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Fri, 25 Oct 2024 14:55:33 +0300 Subject: [PATCH] netplay: introduce abstractions for client/server-side sockets and connection providers This patch introduces multiple high-level abstractions over raw low-level sockets, which are necessary for supporting network backends other than default legacy `TCP_DIRECT` implementation: 1. `WzConnectionProvider` - abstracts the way WZ establishes server-side and client-side connections. This thing effectively provides usable listen sockets and client connections to work with, hence the name. 2. `IListenSocket` - abstraction over listen sockets. 3. `IClientConnection` - abstraction over client-side sockets (and also server-side connections to the game clients). 4. `IConnectionPollGroup` - generalization of socket sets for polling multiple connections in one go. 5. `ConnectionProviderRegistry` - trivial singleton class providing storage for connection providers. 6. `ConnectionAddress` - opaque connection address object, aimed to replace direct uses of `addrinfo` and provide a bit more abstract way to represent connection credentials. Still looks like a crutch right now, but it's better than nothing, nonetheless. The existing implementation in `netplay/netsocket.h(.cpp)` has been moved to the `tcp` subfolder and wrapped entirely into the `tcp` namespace. The patch provides `TCP*`-prefixed implementations of the base interfaces mentioned above, which are implemented in terms of the old `netsocket` code. There's now a `ConnectionProviderType::TCP_DIRECT` enumeration descriptor for accessing the default connection provider. All uses in the high-level code (`netplay.cpp`, `joiningscreen.cpp`) are amended appropriately to use the all-new high-level abstractions instead of old low-level tcp-specific `Socket` and `SocketSet`. NOTE: there are still a few functions from the `tcp::` namespace used directly in the Discord RPC integration code, but these shouldn't pose any problem to either extract these into a more generic abstraction layer or to be rewritten not to use these functions at all, because they don't actually use any low-level stuff that's hard to refactor. Signed-off-by: Pavel Solodovnikov --- lib/netplay/CMakeLists.txt | 12 +- lib/netplay/byteorder_funcs_wrapper.cpp | 50 ++++ lib/netplay/byteorder_funcs_wrapper.h | 34 +++ lib/netplay/client_connection.h | 120 +++++++++ lib/netplay/connection_address.cpp | 65 +++++ lib/netplay/connection_address.h | 72 +++++ lib/netplay/connection_poll_group.h | 42 +++ lib/netplay/connection_provider_registry.cpp | 38 +++ lib/netplay/connection_provider_registry.h | 75 ++++++ lib/netplay/listen_socket.h | 45 ++++ lib/netplay/net_result.h | 32 +++ lib/netplay/netplay.cpp | 247 ++++++++++-------- lib/netplay/netplay.h | 8 +- lib/netplay/open_connection_result.h | 56 ++++ lib/netplay/sync_debug.cpp | 6 +- lib/netplay/{ => tcp}/netsocket.cpp | 79 +----- lib/netplay/{ => tcp}/netsocket.h | 49 +--- lib/netplay/tcp/tcp_client_connection.cpp | 83 ++++++ lib/netplay/tcp/tcp_client_connection.h | 52 ++++ lib/netplay/tcp/tcp_connection_poll_group.cpp | 60 +++++ lib/netplay/tcp/tcp_connection_poll_group.h | 45 ++++ lib/netplay/tcp/tcp_connection_provider.cpp | 153 +++++++++++ lib/netplay/tcp/tcp_connection_provider.h | 48 ++++ lib/netplay/tcp/tcp_listen_socket.cpp | 70 +++++ lib/netplay/tcp/tcp_listen_socket.h | 46 ++++ lib/netplay/wz_connection_provider.h | 80 ++++++ po/POTFILES.in | 8 +- src/integrations/wzdiscordrpc.cpp | 10 +- src/screens/joiningscreen.cpp | 55 ++-- 29 files changed, 1482 insertions(+), 258 deletions(-) create mode 100644 lib/netplay/byteorder_funcs_wrapper.cpp create mode 100644 lib/netplay/byteorder_funcs_wrapper.h create mode 100644 lib/netplay/client_connection.h create mode 100644 lib/netplay/connection_address.cpp create mode 100644 lib/netplay/connection_address.h create mode 100644 lib/netplay/connection_poll_group.h create mode 100644 lib/netplay/connection_provider_registry.cpp create mode 100644 lib/netplay/connection_provider_registry.h create mode 100644 lib/netplay/listen_socket.h create mode 100644 lib/netplay/net_result.h create mode 100644 lib/netplay/open_connection_result.h rename lib/netplay/{ => tcp}/netsocket.cpp (96%) rename lib/netplay/{ => tcp}/netsocket.h (81%) create mode 100644 lib/netplay/tcp/tcp_client_connection.cpp create mode 100644 lib/netplay/tcp/tcp_client_connection.h create mode 100644 lib/netplay/tcp/tcp_connection_poll_group.cpp create mode 100644 lib/netplay/tcp/tcp_connection_poll_group.h create mode 100644 lib/netplay/tcp/tcp_connection_provider.cpp create mode 100644 lib/netplay/tcp/tcp_connection_provider.h create mode 100644 lib/netplay/tcp/tcp_listen_socket.cpp create mode 100644 lib/netplay/tcp/tcp_listen_socket.h create mode 100644 lib/netplay/wz_connection_provider.h diff --git a/lib/netplay/CMakeLists.txt b/lib/netplay/CMakeLists.txt index 12cb14cb648..65962e17e91 100644 --- a/lib/netplay/CMakeLists.txt +++ b/lib/netplay/CMakeLists.txt @@ -26,8 +26,16 @@ add_dependencies(autorevision_netcodeversion autorevision) # Ensure ordering and ############################ # netplay library -file(GLOB HEADERS "*.h") -file(GLOB SRC "*.cpp") +file(GLOB_RECURSE HEADERS "*.h") +file(GLOB_RECURSE SRC "*.cpp") + +if(MSVC AND CMAKE_VERSION VERSION_GREATER 3.7) + # Automatic detection of source groups via `source_group(TREE )` syntax + # has been introduced in CMake 3.8. + # Please consult https://cmake.org/cmake/help/latest/command/source_group.html for additional info. + source_group(TREE "${CMAKE_CURRENT_LIST_DIR}" PREFIX "Sources" FILES ${SRC}) + source_group(TREE "${CMAKE_CURRENT_LIST_DIR}" PREFIX "Headers" FILES ${HEADERS}) +endif() find_package (Threads REQUIRED) find_package (ZLIB REQUIRED) diff --git a/lib/netplay/byteorder_funcs_wrapper.cpp b/lib/netplay/byteorder_funcs_wrapper.cpp new file mode 100644 index 00000000000..79366a8949b --- /dev/null +++ b/lib/netplay/byteorder_funcs_wrapper.cpp @@ -0,0 +1,50 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 1999-2004 Eidos Interactive + Copyright (C) 2005-2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/byteorder_funcs_wrapper.h" + +#include "lib/framework/wzglobal.h" + +// bring in the original `htonl`/`htons`/`ntohs`/`htohl` functions +#if defined WZ_OS_WIN +# include +#else // *NIX / *BSD variants +# include +#endif + +uint32_t wz_htonl(uint32_t hostlong) +{ + return htonl(hostlong); +} + +uint16_t wz_htons(uint16_t hostshort) +{ + return htons(hostshort); +} + +uint32_t wz_ntohl(uint32_t netlong) +{ + return ntohl(netlong); +} + +uint16_t wz_ntohs(uint16_t netshort) +{ + return ntohs(netshort); +} diff --git a/lib/netplay/byteorder_funcs_wrapper.h b/lib/netplay/byteorder_funcs_wrapper.h new file mode 100644 index 00000000000..c1348fd6505 --- /dev/null +++ b/lib/netplay/byteorder_funcs_wrapper.h @@ -0,0 +1,34 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 1999-2004 Eidos Interactive + Copyright (C) 2005-2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +/// +/// byteorder functions wrappers for WZ just to avoid polluting all places, +/// where these functions are needed, with conditional includes of +/// and winsock headers. +/// + +uint32_t wz_htonl(uint32_t hostlong); +uint16_t wz_htons(uint16_t hostshort); +uint32_t wz_ntohl(uint32_t netlong); +uint16_t wz_ntohs(uint16_t netshort); diff --git a/lib/netplay/client_connection.h b/lib/netplay/client_connection.h new file mode 100644 index 00000000000..986fdc04b4a --- /dev/null +++ b/lib/netplay/client_connection.h @@ -0,0 +1,120 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include + +#include "lib/framework/types.h" // bring in `ssize_t` for MSVC +#include "lib/netplay/net_result.h" + +/// +/// Basic abstraction over client connection sockets. +/// +/// These are capable of reading (`readAll` and `readNoInt`) and +/// writing data (via `writeAll()` + `flush()` combination). +/// +/// The internal implementation may also implement advanced compression mechanisms +/// on top of these connections by providing non-trivial `enableCompression()` overload. +/// +/// In this case, `writeAll()` should somehow accumulate the data into a write queue, +/// compressing the outcoming data on-the-fly; and `flush()` should empty the write queue +/// and actually post a message to the transmission queue, which, in turn, will be emptied +/// by the internal connection interface in a timely manner, when there are enough messages +/// to be sent over the network. +/// +class IClientConnection +{ +public: + + virtual ~IClientConnection() = default; + + /// + /// Read exactly `size` bytes into `buf` buffer. + /// Supports setting a timeout value in milliseconds. + /// + /// Destination buffer to read the data into. + /// The size of data to be read in bytes. + /// Timeout value in milliseconds. + /// On success, returns the number of bytes read; + /// On failure, returns an `std::error_code` (having `GenericSystemErrorCategory` error category) + /// describing the actual error. + virtual net::result readAll(void* buf, size_t size, unsigned timeout) = 0; + /// + /// Reads at most `max_size` bytes into `buf` buffer. + /// Raw count of bytes (after compression) is returned in `rawByteCount`. + /// + /// Destination buffer to read the data into. + /// The maximum number of bytes to read from the client socket. + /// Output parameter: Raw count of bytes (after compression). + /// On success, returns the number of bytes read; + /// On failure, returns an `std::error_code` (having `GenericSystemErrorCategory` error category) + /// describing the actual error. + virtual net::result readNoInt(void* buf, size_t max_size, size_t* rawByteCount) = 0; + /// + /// Nonblocking write of `size` bytes to the socket. The data will be written to a + /// separate write queue in asynchronous manner, possibly by a separate thread. + /// Raw count of bytes (after compression) will be returned in `rawByteCount`, which + /// will often be 0 until the socket is flushed. + /// + /// The reason for this method to be async is that in some cases we want + /// client connections to have compression mechanism enabled. This naturally + /// introduces the 2-phase write process, which involves a write queue (accumulating + /// the data for compression on-the-fly) and a submission (transmission) + /// queue (for transmitting of compressed and assembled messages), + /// which is managed by the network backend implementation. + /// + /// Source buffer to read the data from. + /// The number of bytes to write to the socket. + /// Output parameter: raw count of bytes (after compression) written. + /// The total number of bytes written. + virtual net::result writeAll(const void* buf, size_t size, size_t* rawByteCount) = 0; + /// + /// This method indicates whether the socket has some data ready to be read (i.e. + /// whether the next `readAll/readNoInt` operation will execute without blocking or not). + /// + virtual bool readReady() const = 0; + /// + /// Actually sends the data written with `writeAll()`. Only useful with sockets + /// which have compression enabled. + /// Note that flushing too often makes compression less effective. + /// Raw count of bytes (after compression) is returned in `rawByteCount`. + /// + /// Raw count of bytes (after compression) as written + /// to the submission queue by the flush operation. + virtual void flush(size_t* rawByteCount) = 0; + /// + /// Enables compression for the current socket. + /// + /// This makes all subsequent write operations asynchronous, plus + /// the written data will need to be flushed explicitly at some point. + /// + virtual void enableCompression() = 0; + /// + /// Enables or disables the use of Nagle algorithm for the socket. + /// + /// For direct TCP connections this is equivalent to setting `TCP_NODELAY` to the + /// appropriate value (i.e.: + /// `enable == true` <=> `TCP_NODELAY == false`; + /// `enable == false` <=> `TCP_NODELAY == true`). + /// + virtual void useNagleAlgorithm(bool enable) = 0; + /// + /// Returns textual representation of the socket's connection address. + /// + virtual std::string textAddress() const = 0; +}; diff --git a/lib/netplay/connection_address.cpp b/lib/netplay/connection_address.cpp new file mode 100644 index 00000000000..e4c277b8c50 --- /dev/null +++ b/lib/netplay/connection_address.cpp @@ -0,0 +1,65 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/connection_address.h" +#include "lib/netplay/tcp/netsocket.h" // for `resolveHost` + +#include "lib/framework/frame.h" // for `ASSERT` + +struct ConnectionAddress::Impl final +{ + explicit Impl(SocketAddress* addr) + : mAddr_(addr) + {} + + ~Impl() + { + ASSERT(mAddr_ != nullptr, "Invalid addrinfo stored in the connection address"); + freeaddrinfo(mAddr_); + } + + SocketAddress* mAddr_; +}; + +ConnectionAddress::ConnectionAddress() = default; +ConnectionAddress::ConnectionAddress(ConnectionAddress&&) = default; +ConnectionAddress::~ConnectionAddress() = default; + +const SocketAddress* ConnectionAddress::asRawSocketAddress() const +{ + return mPimpl_->mAddr_; +} + + +net::result ConnectionAddress::parse(const char* hostname, uint16_t port) +{ + ConnectionAddress res; + const auto addr = tcp::resolveHost(hostname, port); + if (!addr.has_value()) + { + return tl::make_unexpected(addr.error()); + } + res.mPimpl_ = std::make_unique(addr.value()); + return net::result{std::move(res)}; +} + +net::result ConnectionAddress::parse(const std::string& hostname, uint16_t port) +{ + return parse(hostname.c_str(), port); +} diff --git a/lib/netplay/connection_address.h b/lib/netplay/connection_address.h new file mode 100644 index 00000000000..604d9dc12da --- /dev/null +++ b/lib/netplay/connection_address.h @@ -0,0 +1,72 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include +#include + +#include "lib/netplay/net_result.h" + +#if defined WZ_OS_UNIX +# include +#elif defined WZ_OS_WIN +# include +#endif + +typedef struct addrinfo SocketAddress; + +/// +/// Opaque class representing abstract connection address to use with various +/// network backend implementations. The internal representation is made +/// hidden on purpose since we don't want to actually leak internal data layout +/// to clients. +/// +/// Instead, we would like to introduce "conversion routines" yielding +/// various representations for convenient consumption with various network +/// backends. +/// +/// NOTE: this class may or may not represent a chain of resolved network addresses +/// instead of just a single one, much like a `addrinfo` structure. +/// +/// Currently, only knows how to convert itself to `addrinfo` struct, +/// which is used with the `TCP_DIRECT` network backend. +/// +/// New conversion routines should be introduced for other network backends, +/// if deemed necessary. +/// +class ConnectionAddress +{ +public: + + ConnectionAddress(); + ConnectionAddress(ConnectionAddress&&); + ConnectionAddress(const ConnectionAddress&) = delete; + ~ConnectionAddress(); + + static net::result parse(const char* hostname, uint16_t port); + static net::result parse(const std::string& hostname, uint16_t port); + + // NOTE: The lifetime of the returned `addrinfo` struct is bounded by the parent object's lifetime! + const SocketAddress* asRawSocketAddress() const; + +private: + + struct Impl; + std::unique_ptr mPimpl_; +}; diff --git a/lib/netplay/connection_poll_group.h b/lib/netplay/connection_poll_group.h new file mode 100644 index 00000000000..a2b66bcb801 --- /dev/null +++ b/lib/netplay/connection_poll_group.h @@ -0,0 +1,42 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +class IClientConnection; + +/// +/// Abstract representation of a poll group comprised of several client connections. +/// +class IConnectionPollGroup +{ +public: + + virtual ~IConnectionPollGroup() = default; + + /// + /// Polls the sockets in the poll group for updates. + /// + /// Timeout value after which the internal implementation should abandon + /// polling the client connections and return. + /// On success, returns the number of connection descriptors in the poll group. + /// On failure, `0` can returned if the timeout expired before any connection descriptors + /// became ready, or `-1` if there was an error during the internal poll operation. + virtual int checkSockets(unsigned timeout) = 0; + + virtual void add(IClientConnection* conn) = 0; + virtual void remove(IClientConnection* conn) = 0; +}; diff --git a/lib/netplay/connection_provider_registry.cpp b/lib/netplay/connection_provider_registry.cpp new file mode 100644 index 00000000000..9c436a555eb --- /dev/null +++ b/lib/netplay/connection_provider_registry.cpp @@ -0,0 +1,38 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include + +#include "lib/netplay/connection_provider_registry.h" + +ConnectionProviderRegistry& ConnectionProviderRegistry::Instance() +{ + static ConnectionProviderRegistry instance; + return instance; +} + +WzConnectionProvider& ConnectionProviderRegistry::Get(ConnectionProviderType pt) +{ + const auto it = registeredProviders_.find(pt); + if (it == registeredProviders_.end()) + { + throw std::runtime_error("Attempt to get nonexistent connection provider"); + } + return *it->second; +} diff --git a/lib/netplay/connection_provider_registry.h b/lib/netplay/connection_provider_registry.h new file mode 100644 index 00000000000..7b9f9c30af5 --- /dev/null +++ b/lib/netplay/connection_provider_registry.h @@ -0,0 +1,75 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include + +#include "lib/netplay/wz_connection_provider.h" +#include "lib/netplay/tcp/tcp_connection_provider.h" + +/// +/// Available types of connection providers (i.e. network backend implementations). +/// +enum class ConnectionProviderType +{ + TCP_DIRECT +}; + +template +struct ProviderHelperTraits; + +template <> +struct ProviderHelperTraits +{ + using ConcreteType = tcp::TCPConnectionProvider; +}; + +/// +/// Global singleton registry containing available network connection providers. +/// +class ConnectionProviderRegistry +{ +public: + + static ConnectionProviderRegistry& Instance(); + + WzConnectionProvider& Get(ConnectionProviderType pt); + + template + void Register() + { + using ProviderConcreteType = typename ProviderHelperTraits::ConcreteType; + // No-op in case this provider has been already registered. + registeredProviders_.emplace(PT, std::make_unique()); + } + + template + void Deregister() + { + registeredProviders_.erase(PT); + } + +private: + + ConnectionProviderRegistry() = default; + + std::unordered_map> registeredProviders_; +}; diff --git a/lib/netplay/listen_socket.h b/lib/netplay/listen_socket.h new file mode 100644 index 00000000000..04cfdd9fc9e --- /dev/null +++ b/lib/netplay/listen_socket.h @@ -0,0 +1,45 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include + +class IClientConnection; + +/// +/// Server-side listen socket abstraction. +/// +class IListenSocket +{ +public: + + virtual ~IListenSocket() = default; + + enum class IPVersions : uint8_t + { + IPV4 = 0b00000001, + IPV6 = 0b00000010 + }; + using IPVersionsMask = std::underlying_type_t; + + /// + /// Accept an incoming client connection on the current server-side listen socket. + /// + virtual IClientConnection* accept() = 0; + virtual IPVersionsMask supportedIpVersions() const = 0; +}; diff --git a/lib/netplay/net_result.h b/lib/netplay/net_result.h new file mode 100644 index 00000000000..be558bcc517 --- /dev/null +++ b/lib/netplay/net_result.h @@ -0,0 +1,32 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include + +namespace net +{ + +template +using result = ::tl::expected; + +} // namespace net diff --git a/lib/netplay/netplay.cpp b/lib/netplay/netplay.cpp index fa2cd4a5432..8020dda8651 100644 --- a/lib/netplay/netplay.cpp +++ b/lib/netplay/netplay.cpp @@ -47,7 +47,11 @@ #include "netplay.h" #include "netlog.h" #include "netreplay.h" -#include "netsocket.h" +#include "lib/netplay/byteorder_funcs_wrapper.h" +#include "lib/netplay/client_connection.h" +#include "lib/netplay/listen_socket.h" +#include "lib/netplay/connection_poll_group.h" +#include "lib/netplay/connection_provider_registry.h" #include "netpermissions.h" #include "sync_debug.h" #include "port_mapping_manager.h" @@ -70,6 +74,12 @@ # include "lib/framework/cocoa_wrapper.h" #endif +#ifndef WZ_OS_WIN +static const int SOCKET_ERROR = -1; +#else +# include // SOCKET_ERROR +#endif + // WARNING !!! This is initialised via configuration.c !!! char masterserver_name[255] = {'\0'}; static unsigned int masterserver_port = 0, gameserver_port = 0; @@ -160,8 +170,8 @@ class LobbyServerConnectionHandler Connected }; LobbyConnectionState currentState = LobbyConnectionState::Disconnected; - Socket *rs_socket = nullptr; - SocketSet* waitingForConnectionFinalize = nullptr; + IClientConnection* rs_socket = nullptr; + IConnectionPollGroup* waitingForConnectionFinalize = nullptr; uint32_t lastConnectionTime = 0; uint32_t lastServerUpdate = 0; bool queuedServerUpdate = false; @@ -231,15 +241,15 @@ bool netPlayersUpdated; // Server-side socket (host-only) which is used to listen for client connections. // There's also `rs_socket` held by `LobbyServerConnectionHandler`, which is used to communicate with the lobby server. -static Socket* server_listen_socket = nullptr; +static IListenSocket* server_listen_socket = nullptr; -static Socket *bsocket = nullptr; ///< Socket used to talk to the host (clients only). If bsocket != NULL, then client_transient_socket == NULL. -static Socket *connected_bsocket[MAX_CONNECTED_PLAYERS] = { nullptr }; ///< Sockets used to talk to clients (host only). +static IClientConnection* bsocket = nullptr; ///< Socket used to talk to the host (clients only). If bsocket != NULL, then client_transient_socket == NULL. +static IClientConnection* connected_bsocket[MAX_CONNECTED_PLAYERS] = { nullptr }; ///< Sockets used to talk to clients (host only). // Client-side socket set. Contains of only 1 socket at most: `bsocket` (which is a stable client connection to the host). -static SocketSet* client_socket_set = nullptr; +static IConnectionPollGroup* client_socket_set = nullptr; // Server-side socket set. Contains up to `MAX_CONNECTED_PLAYERS` sockets: // `connected_bsocket[i]` - sockets used to communicate with clients during a game session. -static SocketSet* server_socket_set = nullptr; +static IConnectionPollGroup* server_socket_set = nullptr; /** * Used for connections with clients. @@ -301,13 +311,13 @@ struct TmpSocketInfo } }; -static Socket *tmp_socket[MAX_TMP_SOCKETS] = { nullptr }; ///< Sockets used to talk to clients which have not yet been assigned a player number (host only). +static IClientConnection* tmp_socket[MAX_TMP_SOCKETS] = { nullptr }; ///< Sockets used to talk to clients which have not yet been assigned a player number (host only). static std::array tmp_connectState; static bool bAsyncJoinApprovalEnabled = false; static std::unordered_map tmp_pendingIPs; static lru11::Cache tmp_badIPs(512, 64); -static SocketSet *tmp_socket_set = nullptr; +static IConnectionPollGroup* tmp_socket_set = nullptr; static int32_t NetGameFlags[4] = { 0, 0, 0, 0 }; char iptoconnect[PATH_MAX] = "\0"; // holds IP/hostname from command line bool cliConnectToIpAsSpectator = false; // for cli option @@ -534,17 +544,17 @@ bool NETsetAsyncJoinApprovalResult(const std::string& uniqueJoinID, AsyncJoinApp // *********** Socket with buffer that read NETMSGs ****************** -static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *bufstart, int bufsize) +static size_t NET_fillBuffer(IClientConnection** pSocket, IConnectionPollGroup* pSocketSet, uint8_t *bufstart, int bufsize) { - Socket *socket = *pSocket; + IClientConnection* socket = *pSocket; - if (!socketReadReady(*socket)) + if (!socket->readReady()) { return 0; } size_t rawBytes; - const auto readResult = readNoInt(*socket, bufstart, bufsize, &rawBytes); + const auto readResult = socket->readNoInt(bufstart, bufsize, &rawBytes); if (readResult.has_value()) { @@ -572,7 +582,7 @@ static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *b // an error occurred, or the remote host has closed the connection. if (pSocketSet != nullptr) { - SocketSet_DelSocket(*pSocketSet, socket); + pSocketSet->remove(socket); } if (bsocket == socket) { @@ -586,7 +596,7 @@ static size_t NET_fillBuffer(Socket **pSocket, SocketSet *pSocketSet, uint8_t *b NETclose(); return 0; } - socketClose(socket); + delete socket; *pSocket = nullptr; } @@ -966,8 +976,8 @@ static void NETplayerCloseSocket(UDWORD index, bool quietSocketClose) NETlogEntry("Player has left nicely.", SYNC_FLAG, index); // Although we can get a error result from DelSocket, it don't really matter here. - SocketSet_DelSocket(*server_socket_set, connected_bsocket[index]); - socketClose(connected_bsocket[index]); + server_socket_set->remove(connected_bsocket[index]); + delete connected_bsocket[index]; connected_bsocket[index] = nullptr; } else @@ -1199,7 +1209,7 @@ static constexpr size_t GAMESTRUCTmessageBufSize() * * @see GAMESTRUCT,NETrecvGAMESTRUCT */ -static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourgamestruct) +static net::result NETsendGAMESTRUCT(IClientConnection* sock, const GAMESTRUCT *ourgamestruct) { // A buffer that's guaranteed to have the correct size (i.e. it // circumvents struct padding, which could pose a problem). Initialise @@ -1210,13 +1220,13 @@ static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourga unsigned int i; auto push32 = [&](uint32_t value) { - uint32_t swapped = htonl(value); + uint32_t swapped = wz_htonl(value); memcpy(buffer, &swapped, sizeof(swapped)); buffer += sizeof(swapped); }; auto push16 = [&](uint16_t value) { - uint16_t swapped = htons(value); + uint16_t swapped = wz_htons(value); memcpy(buffer, &swapped, sizeof(swapped)); buffer += sizeof(swapped); }; @@ -1305,7 +1315,7 @@ static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourga debug(LOG_NET, "sending GAMESTRUCT, size: %u", (unsigned int)sizeof(buf)); // Send over the GAMESTRUCT - const auto writeResult = writeAll(*sock, buf, sizeof(buf)); + const auto writeResult = sock->writeAll(buf, sizeof(buf), nullptr); if (!writeResult.has_value()) { const auto writeErrMsg = writeResult.error().message(); @@ -1325,7 +1335,7 @@ static net::result NETsendGAMESTRUCT(Socket *sock, const GAMESTRUCT *ourga * * @see GAMESTRUCT,NETsendGAMESTRUCT */ -static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) +static bool NETrecvGAMESTRUCT(IClientConnection& sock, GAMESTRUCT *ourgamestruct) { // A buffer that's guaranteed to have the correct size (i.e. it // circumvents struct padding, which could pose a problem). @@ -1336,7 +1346,7 @@ static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) auto pop32 = [&]() -> uint32_t { uint32_t value = 0; memcpy(&value, buffer, sizeof(value)); - value = ntohl(value); + value = wz_ntohl(value); buffer += sizeof(value); return value; }; @@ -1344,13 +1354,13 @@ static bool NETrecvGAMESTRUCT(Socket& sock, GAMESTRUCT *ourgamestruct) auto pop16 = [&]() -> uint16_t { uint16_t value = 0; memcpy(&value, buffer, sizeof(value)); - value = ntohs(value); + value = wz_ntohs(value); buffer += sizeof(value); return value; }; // Read a GAMESTRUCT from the connection - auto readResult = readAll(sock, buf, sizeof(buf), NET_TIMEOUT_DELAY); + auto readResult = sock.readAll(buf, sizeof(buf), NET_TIMEOUT_DELAY); if (!readResult.has_value()) { if (readResult.error() == std::errc::timed_out || readResult.error() == std::errc::connection_reset) @@ -1510,7 +1520,8 @@ int NETinit(bool bFirstCall) NETlogEntry("NETinit!", SYNC_FLAG, selectedPlayer); NET_InitPlayers(true, true); - SOCKETinit(); + ConnectionProviderRegistry::Instance().Register(); + ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT).initialize(); if (bFirstCall) { @@ -1558,7 +1569,8 @@ int NETshutdown() } NetPlay.MOTD = nullptr; NETdeleteQueue(); - SOCKETshutdown(); + ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT).shutdown(); + ConnectionProviderRegistry::Instance().Deregister(); // Reset net usage statistics. nStats = nZeroStats; @@ -1589,7 +1601,7 @@ int NETclose() if (connected_bsocket[i]) { debug(LOG_NET, "Closing connected_bsocket[%u], %p", i, static_cast(connected_bsocket[i])); - socketClose(connected_bsocket[i]); + delete connected_bsocket[i]; connected_bsocket[i] = nullptr; } NET_DestroyPlayer(i, true); @@ -1598,7 +1610,7 @@ int NETclose() if (tmp_socket_set) { debug(LOG_NET, "Freeing tmp_socket_set %p", static_cast(tmp_socket_set)); - deleteSocketSet(tmp_socket_set); + delete tmp_socket_set; tmp_socket_set = nullptr; } @@ -1608,7 +1620,7 @@ int NETclose() { // FIXME: need SocketSet_DelSocket() as well, socket_set or tmp_socket_set? debug(LOG_NET, "Closing tmp_socket[%d] %p", i, static_cast(tmp_socket[i])); - socketClose(tmp_socket[i]); + delete tmp_socket[i]; tmp_socket[i] = nullptr; } } @@ -1617,28 +1629,28 @@ int NETclose() { if (bsocket) { - SocketSet_DelSocket(*client_socket_set, bsocket); + client_socket_set->remove(bsocket); } debug(LOG_NET, "Freeing socket_set %p", static_cast(client_socket_set)); - deleteSocketSet(client_socket_set); + delete client_socket_set; client_socket_set = nullptr; } else if (server_socket_set) { debug(LOG_NET, "Freeing socket_set %p", static_cast(server_socket_set)); - deleteSocketSet(server_socket_set); + delete server_socket_set; server_socket_set = nullptr; } if (server_listen_socket) { debug(LOG_NET, "Closing server_listen_socket %p", static_cast(server_listen_socket)); - socketClose(server_listen_socket); + delete server_listen_socket; server_listen_socket = nullptr; } if (bsocket) { debug(LOG_NET, "Closing bsocket %p", static_cast(bsocket)); - socketClose(bsocket); + delete bsocket; bsocket = nullptr; } @@ -1718,7 +1730,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) return true; } - Socket **sockets = connected_bsocket; + IClientConnection** sockets = connected_bsocket; bool isTmpQueue = false; switch (queue.queueType) { @@ -1755,7 +1767,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) } ssize_t rawLen = message->rawLen(); size_t compressedRawLen; - const auto writeResult = writeAll(*sockets[player], rawData, rawLen, &compressedRawLen); + const auto writeResult = sockets[player]->writeAll(rawData, rawLen, &compressedRawLen); const auto res = writeResult.value_or(SOCKET_ERROR); delete[] rawData; // Done with the data. @@ -1788,7 +1800,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) uint8_t *rawData = message->rawDataDup(); ssize_t rawLen = message->rawLen(); size_t compressedRawLen; - const auto writeResult = writeAll(*bsocket, rawData, rawLen, &compressedRawLen); + const auto writeResult = bsocket->writeAll(rawData, rawLen, &compressedRawLen); const auto res = writeResult.value_or(SOCKET_ERROR); delete[] rawData; // Done with the data. @@ -1805,8 +1817,8 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) debug(LOG_ERROR, "Failed to send message: %s", writeErrMsg.c_str()); debug(LOG_ERROR, "Host connection was broken, socket %p.", static_cast(bsocket)); NETlogEntry("write error--client disconnect.", SYNC_FLAG, player); - SocketSet_DelSocket(*client_socket_set, bsocket); // mark it invalid - socketClose(bsocket); + client_socket_set->remove(bsocket); // mark it invalid + delete bsocket; bsocket = nullptr; NetPlay.players[NetPlay.hostPlayer].heartbeat = false; // mark host as dead //Game is pretty much over --should just end everything when HOST dies. @@ -1847,7 +1859,7 @@ void NETflush() // We are the host, send directly to player. if (connected_bsocket[player] != nullptr) { - socketFlush(*connected_bsocket[player], player, &compressedRawLen); + connected_bsocket[player]->flush(&compressedRawLen); nStats.rawBytes.sent += compressedRawLen; } } @@ -1856,7 +1868,7 @@ void NETflush() // We are the host, send directly to player. if (tmp_socket[player] != nullptr) { - socketFlush(*tmp_socket[player], std::numeric_limits::max(), &compressedRawLen); + tmp_socket[player]->flush(&compressedRawLen); nStats.rawBytes.sent += compressedRawLen; } } @@ -1865,7 +1877,7 @@ void NETflush() { if (bsocket != nullptr) { - socketFlush(*bsocket, NetPlay.hostPlayer, &compressedRawLen); + bsocket->flush(&compressedRawLen); nStats.rawBytes.sent += compressedRawLen; } } @@ -2826,15 +2838,15 @@ bool NETrecvNet(NETQUEUE *queue, uint8_t *type) NETcheckPlayers(); // make sure players are still alive & well } - SocketSet* sset = NetPlay.isHost ? server_socket_set : client_socket_set; - if (sset == nullptr || checkSockets(*sset, NET_READ_TIMEOUT) <= 0) + IConnectionPollGroup* pollGroup = NetPlay.isHost ? server_socket_set : client_socket_set; + if (pollGroup == nullptr || pollGroup->checkSockets(NET_READ_TIMEOUT) <= 0) { goto checkMessages; } for (current = 0; current < MAX_CONNECTED_PLAYERS; ++current) { - Socket **pSocket = NetPlay.isHost ? &connected_bsocket[current] : &bsocket; + IClientConnection** pSocket = NetPlay.isHost ? &connected_bsocket[current] : &bsocket; uint8_t buffer[NET_BUFFER_SIZE]; size_t dataLen; @@ -2848,7 +2860,7 @@ bool NETrecvNet(NETQUEUE *queue, uint8_t *type) continue; } - dataLen = NET_fillBuffer(pSocket, sset, buffer, sizeof(buffer)); + dataLen = NET_fillBuffer(pSocket, pollGroup, buffer, sizeof(buffer)); if (dataLen > 0) { // we received some data, add to buffer @@ -3228,7 +3240,7 @@ unsigned NETgetDownloadProgress(unsigned player) return static_cast(progress); } -static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) +static ssize_t readLobbyResponse(IClientConnection& sock, unsigned int timeout) { uint32_t lobbyStatusCode; uint32_t MOTDLength; @@ -3236,7 +3248,7 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) ssize_t received = 0; // Get status and message length - auto readResult = readAll(sock, &buffer, sizeof(buffer), timeout); + auto readResult = sock.readAll(&buffer, sizeof(buffer), timeout); if (!readResult.has_value()) { goto error; @@ -3251,7 +3263,7 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) free(NetPlay.MOTD); } NetPlay.MOTD = (char *)malloc(MOTDLength + 1); - readResult = readAll(sock, NetPlay.MOTD, MOTDLength, timeout); + readResult = sock.readAll(NetPlay.MOTD, MOTDLength, timeout); if (!readResult.has_value()) { goto error; @@ -3306,11 +3318,11 @@ static ssize_t readLobbyResponse(Socket& sock, unsigned int timeout) return SOCKET_ERROR; } -bool readGameStructsList(Socket& sock, unsigned int timeout, const std::function& handleEnumerateGameFunc) +bool readGameStructsList(IClientConnection& sock, unsigned int timeout, const std::function& handleEnumerateGameFunc) { unsigned int gamecount = 0; uint32_t gamesavailable = 0; - const auto readResult = readAll(sock, &gamesavailable, sizeof(gamesavailable), NET_TIMEOUT_DELAY); + const auto readResult = sock.readAll(&gamesavailable, sizeof(gamesavailable), NET_TIMEOUT_DELAY); if (readResult.has_value()) { @@ -3346,7 +3358,8 @@ bool readGameStructsList(Socket& sock, unsigned int timeout, const std::function if (tmpGame.desc.host[0] == '\0') { memset(tmpGame.desc.host, 0, sizeof(tmpGame.desc.host)); - strncpy(tmpGame.desc.host, getSocketTextAddress(sock), sizeof(tmpGame.desc.host) - 1); + const auto textAddr = sock.textAddress(); + strncpy(tmpGame.desc.host, textAddr.data(), sizeof(tmpGame.desc.host) - 1); } uint32_t Vmgr = (tmpGame.future4 & 0xFFFF0000) >> 16; @@ -3398,12 +3411,13 @@ bool LobbyServerConnectionHandler::connect() return false; // already connecting or connected } + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + bool bProcessingConnectOrDisconnectThisCall = true; uint32_t gameId = 0; - const auto hostsResult = resolveHost(masterserver_name, masterserver_port); - const auto hosts = hostsResult.value_or(nullptr); + const auto hostsResult = connProvider.resolveHost(masterserver_name, masterserver_port); - if (hosts == nullptr) + if (!hostsResult.has_value()) { const auto hostsErrMsg = hostsResult.error().message(); debug(LOG_ERROR, "Cannot resolve masterserver \"%s\": %s", masterserver_name, hostsErrMsg.c_str()); @@ -3417,16 +3431,17 @@ bool LobbyServerConnectionHandler::connect() return bProcessingConnectOrDisconnectThisCall; } + const auto& hosts = hostsResult.value(); + // Close an existing socket. if (rs_socket != nullptr) { - socketClose(rs_socket); + delete rs_socket; rs_socket = nullptr; } // try each address from resolveHost until we successfully connect. - auto sockResult = socketOpenAny(hosts, 1500); - deleteSocketAddress(hosts); + auto sockResult = connProvider.openClientConnectionAny(hosts, 1500); rs_socket = sockResult.value_or(nullptr); @@ -3446,10 +3461,10 @@ bool LobbyServerConnectionHandler::connect() } // Get a game ID - auto gameIdResult = writeAll(*rs_socket, "gaId", sizeof("gaId")); + auto gameIdResult = rs_socket->writeAll("gaId", sizeof("gaId"), nullptr); if (gameIdResult.has_value()) { - gameIdResult = readAll(*rs_socket, &gameId, sizeof(gameId), 10000); + gameIdResult = rs_socket->readAll(&gameId, sizeof(gameId), 10000); } if (!gameIdResult.has_value()) { @@ -3475,7 +3490,7 @@ bool LobbyServerConnectionHandler::connect() wz_command_interface_output("WZEVENT: lobbyid: %" PRIu32 "\n", gamestruct.gameId); // Register our game with the server - const auto writeAddGameRes = writeAll(*rs_socket, "addg", sizeof("addg")); + const auto writeAddGameRes = rs_socket->writeAll("addg", sizeof("addg"), nullptr); auto sendGamestructRes = ignoreExpectedResultValue(writeAddGameRes); if (sendGamestructRes.has_value()) @@ -3495,8 +3510,8 @@ bool LobbyServerConnectionHandler::connect() queuedServerUpdate = false; lastConnectionTime = realTime; - waitingForConnectionFinalize = allocSocketSet(); - SocketSet_AddSocket(*waitingForConnectionFinalize, rs_socket); + waitingForConnectionFinalize = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT).newConnectionPollGroup(); + waitingForConnectionFinalize->add(rs_socket); currentState = LobbyConnectionState::Connecting_WaitingForResponse; return bProcessingConnectOrDisconnectThisCall; @@ -3512,7 +3527,7 @@ bool LobbyServerConnectionHandler::disconnect() if (rs_socket != nullptr) { // we don't need this anymore, so clean up - socketClose(rs_socket); + delete rs_socket; rs_socket = nullptr; server_not_there = true; } @@ -3576,7 +3591,7 @@ void LobbyServerConnectionHandler::sendUpdateNow() void LobbyServerConnectionHandler::sendKeepAlive() { ASSERT_OR_RETURN(, rs_socket != nullptr, "Null socket"); - if (!writeAll(*rs_socket, "keep", sizeof("keep")).has_value()) + if (!rs_socket->writeAll("keep", sizeof("keep"), nullptr).has_value()) { // The socket has been invalidated, so get rid of it. (using them now may cause SIGPIPE). disconnect(); @@ -3599,21 +3614,21 @@ void LobbyServerConnectionHandler::run() bool exceededTimeout = (realTime - lastConnectionTime >= 10000); // We use readLobbyResponse to display error messages and handle state changes if there's no response // So if exceededTimeout, just call it with a low timeout - int checkSocketRet = checkSockets(*waitingForConnectionFinalize, NET_READ_TIMEOUT); + int checkSocketRet = waitingForConnectionFinalize->checkSockets(NET_READ_TIMEOUT); if (checkSocketRet == SOCKET_ERROR) { debug(LOG_ERROR, "Lost connection to lobby server"); disconnect(); break; } - if (exceededTimeout || (checkSocketRet > 0 && socketReadReady(*rs_socket))) + if (exceededTimeout || (checkSocketRet > 0 && rs_socket->readReady())) { if (readLobbyResponse(*rs_socket, NET_TIMEOUT_DELAY) == SOCKET_ERROR) { disconnect(); break; } - deleteSocketSet(waitingForConnectionFinalize); + delete waitingForConnectionFinalize; waitingForConnectionFinalize = nullptr; currentState = LobbyConnectionState::Connected; } @@ -3748,9 +3763,9 @@ static bool quickRejectConnection(const std::string& ip) static void NETcloseTempSocket(unsigned int i) { - std::string rIP = getSocketTextAddress(*tmp_socket[i]); - SocketSet_DelSocket(*tmp_socket_set, tmp_socket[i]); - socketClose(tmp_socket[i]); + std::string rIP = tmp_socket[i]->textAddress(); + tmp_socket_set->remove(tmp_socket[i]); + delete tmp_socket[i]; tmp_socket[i] = nullptr; tmp_connectState[i].reset(); auto it = tmp_pendingIPs.find(rIP); @@ -3769,14 +3784,14 @@ static void NETcloseTempSocket(unsigned int i) static void NEThostPromoteTempSocketToPermanentPlayerConnection(unsigned int tempSocketIdx, uint8_t index) { - std::string rIP = getSocketTextAddress(*tmp_socket[tempSocketIdx]); + std::string rIP = tmp_socket[tempSocketIdx]->textAddress(); debug(LOG_NET, "freeing temp socket %p (%d), creating permanent socket.", static_cast(tmp_socket[tempSocketIdx]), __LINE__); - SocketSet_DelSocket(*tmp_socket_set, tmp_socket[tempSocketIdx]); + tmp_socket_set->remove(tmp_socket[tempSocketIdx]); connected_bsocket[index] = tmp_socket[tempSocketIdx]; tmp_socket[tempSocketIdx] = nullptr; NET_waitingForIndexChangeAckSince[index] = nullopt; - SocketSet_AddSocket(*server_socket_set, connected_bsocket[index]); + server_socket_set->add(connected_bsocket[index]); NETmoveQueue(NETnetTmpQueue(tempSocketIdx), NETnetQueue(index)); // Copy player's IP address @@ -3820,12 +3835,13 @@ static void NETallowJoining() ActivitySink::ListeningInterfaces listeningInterfaces; if (server_listen_socket != nullptr) { - listeningInterfaces.IPv4 = socketHasIPv4(*server_listen_socket); + const auto supportedProtocols = server_listen_socket->supportedIpVersions(); + listeningInterfaces.IPv4 = supportedProtocols & static_cast(IListenSocket::IPVersions::IPV4); if (listeningInterfaces.IPv4) { listeningInterfaces.ipv4_port = NETgetGameserverPort(); } - listeningInterfaces.IPv6 = socketHasIPv6(*server_listen_socket); + listeningInterfaces.IPv6 = supportedProtocols & static_cast(IListenSocket::IPVersions::IPV6); if (listeningInterfaces.IPv6) { listeningInterfaces.ipv6_port = NETgetGameserverPort(); @@ -3849,7 +3865,7 @@ static void NETallowJoining() } ASSERT(tmp_socket_set != nullptr, "Null tmp_socket_set"); - if (checkSockets(*tmp_socket_set, NET_READ_TIMEOUT) > 0) + if (tmp_socket_set->checkSockets(NET_READ_TIMEOUT) > 0) { for (i = 0; i < MAX_TMP_SOCKETS; ++i) { @@ -3858,7 +3874,7 @@ static void NETallowJoining() continue; } - if (!socketReadReady(*tmp_socket[i])) + if (!tmp_socket[i]->readReady()) { continue; } @@ -3867,7 +3883,7 @@ static void NETallowJoining() { char *p_buffer = tmp_connectState[i].buffer; - const auto sizeReadResult = readNoInt(*tmp_socket[i], p_buffer + tmp_connectState[i].usedBuffer, 8 - tmp_connectState[i].usedBuffer); + const auto sizeReadResult = tmp_socket[i]->readNoInt(p_buffer + tmp_connectState[i].usedBuffer, 8 - tmp_connectState[i].usedBuffer, nullptr); if (sizeReadResult.has_value()) { tmp_connectState[i].usedBuffer += sizeReadResult.value(); @@ -3901,7 +3917,7 @@ static void NETallowJoining() char buf[(sizeof(char) * 4) + sizeof(uint32_t) + sizeof(uint32_t)] = { 0 }; char *pLobbyRespBuffer = buf; auto push32 = [&pLobbyRespBuffer](uint32_t value) { - uint32_t swapped = htonl(value); + uint32_t swapped = wz_htonl(value); memcpy(pLobbyRespBuffer, &swapped, sizeof(swapped)); pLobbyRespBuffer += sizeof(swapped); }; @@ -3917,15 +3933,15 @@ static void NETallowJoining() // Copy gameId (as 32bit large big endian number) push32(gamestruct.gameId); - writeAll(*tmp_socket[i], buf, sizeof(buf)); + tmp_socket[i]->writeAll(buf, sizeof(buf), nullptr); connectFailed = true; } else if (NETisCorrectVersion(major, minor)) { - result = htonl(ERROR_NOERROR); + result = wz_htonl(ERROR_NOERROR); memcpy(&tmp_connectState[i].buffer, &result, sizeof(result)); - writeAll(*tmp_socket[i], &tmp_connectState[i].buffer, sizeof(result)); - socketBeginCompression(*tmp_socket[i]); + tmp_socket[i]->writeAll(&tmp_connectState[i].buffer, sizeof(result), nullptr); + tmp_socket[i]->enableCompression(); // Connection is successful. connectFailed = false; @@ -3940,9 +3956,9 @@ static void NETallowJoining() else { debug(LOG_INFO, "Received an invalid version \"%" PRIu32 ".%" PRIu32 "\".", major, minor); - result = htonl(ERROR_WRONGVERSION); + result = wz_htonl(ERROR_WRONGVERSION); memcpy(&tmp_connectState[i].buffer, &result, sizeof(result)); - writeAll(*tmp_socket[i], &tmp_connectState[i].buffer, sizeof(result)); + tmp_socket[i]->writeAll(&tmp_connectState[i].buffer, sizeof(result), nullptr); NETlogEntry("Invalid game version", SYNC_FLAG, i); NETaddSessionBanBadIP(tmp_connectState[i].ip); connectFailed = true; @@ -3982,7 +3998,7 @@ static void NETallowJoining() else if (tmp_connectState[i].connectState == TmpSocketInfo::TmpConnectState::PendingJoinRequest) { uint8_t buffer[NET_BUFFER_SIZE]; - const auto readResult = readNoInt(*tmp_socket[i], buffer, sizeof(buffer)); + const auto readResult = tmp_socket[i]->readNoInt(buffer, sizeof(buffer), nullptr); uint8_t rejected = 0; if (!readResult.has_value()) @@ -4377,7 +4393,7 @@ static void NETallowJoining() NETpop(tmpQueue); } - std::string rIP = getSocketTextAddress(*tmp_socket[i]); + std::string rIP = tmp_socket[i]->textAddress(); NETaddSessionBanBadIP(rIP); NETcloseTempSocket(i); @@ -4482,14 +4498,16 @@ bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectator return true; } + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + // Start listening for client connections on `gameserver_port`. // These will initially be assigned to `tmp_socket[i]` until accepted in the game session, // in which case `tmp_socket[i]` will be assigned to `connected_bsocket[i]` and `tmp_socket[i]` // will become nullptr. - net::result serverListenResult = {}; + net::result serverListenResult = {}; if (!server_listen_socket) { - serverListenResult = socketListen(gameserver_port); + serverListenResult = connProvider.openListenSocket(gameserver_port); server_listen_socket = serverListenResult.value_or(nullptr); } if (server_listen_socket == nullptr) @@ -4502,7 +4520,7 @@ bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectator // Host needs to create a socket set for MAX_PLAYERS if (!server_socket_set) { - server_socket_set = allocSocketSet(); + server_socket_set = connProvider.newConnectionPollGroup(); } // allocate socket storage for all possible players for (unsigned i = 0; i < MAX_CONNECTED_PLAYERS; ++i) @@ -4607,19 +4625,17 @@ bool NETenumerateGames(const std::function& handl debug(LOG_ERROR, "Likely missing NETinit(true) - this won't return any results"); return false; } - const auto hostsResult = resolveHost(masterserver_name, masterserver_port); - SocketAddress* hosts = hostsResult.value_or(nullptr); - if (!hosts) + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + const auto hostsResult = connProvider.resolveHost(masterserver_name, masterserver_port); + if (!hostsResult.has_value()) { const auto hostsErrMsg = hostsResult.error().message(); debug(LOG_ERROR, "Cannot resolve hostname \"%s\": %s", masterserver_name, hostsErrMsg.c_str()); setLobbyError(ERROR_CONNECTION); return false; } - - auto sockResult = socketOpenAny(hosts, 15000); - deleteSocketAddress(hosts); - hosts = nullptr; + const auto& hosts = hostsResult.value(); + auto sockResult = connProvider.openClientConnectionAny(hosts, 15000); if (!sockResult.has_value()) { const auto sockErrMsg = sockResult.error().message(); @@ -4627,18 +4643,18 @@ bool NETenumerateGames(const std::function& handl setLobbyError(ERROR_CONNECTION); return false; } - Socket* sock = sockResult.value(); + IClientConnection* sock = sockResult.value(); debug(LOG_NET, "New socket = %p", static_cast(sock)); debug(LOG_NET, "Sending list cmd"); - const auto writeResult = writeAll(*sock, "list", sizeof("list")); + const auto writeResult = sock->writeAll("list", sizeof("list"), nullptr); if (!writeResult.has_value()) { const auto writeErrMsg = writeResult.error().message(); debug(LOG_NET, "Server socket encountered error: %s", writeErrMsg.c_str()); // mark it invalid - socketClose(sock); + delete sock; // when we fail to receive a game count, bail out setLobbyError(ERROR_CONNECTION); @@ -4655,7 +4671,7 @@ bool NETenumerateGames(const std::function& handl })) { // mark it invalid - socketClose(sock); + delete sock; setLobbyError(ERROR_CONNECTION); return false; @@ -4665,7 +4681,7 @@ bool NETenumerateGames(const std::function& handl if (readLobbyResponse(*sock, NET_TIMEOUT_DELAY) == SOCKET_ERROR) { // mark it invalid - socketClose(sock); + delete sock; addConsoleMessage(_("Failed to get a lobby response!"), DEFAULT_JUSTIFY, NOTIFY_MESSAGE); // treat as fatal error @@ -4679,7 +4695,7 @@ bool NETenumerateGames(const std::function& handl // Hence as long as we don't treat "0" as signifying any change in behavior, this should be safe + backwards-compatible #define IGNORE_FIRST_BATCH 1 uint32_t responseParameters = 0; - const auto readResult = readAll(*sock, &responseParameters, sizeof(responseParameters), NET_TIMEOUT_DELAY); + const auto readResult = sock->readAll(&responseParameters, sizeof(responseParameters), NET_TIMEOUT_DELAY); if (readResult.has_value()) { responseParameters = ntohl(responseParameters); @@ -4717,7 +4733,7 @@ bool NETenumerateGames(const std::function& handl debug(LOG_NET, "Second readGameStructsList call failed"); // mark it invalid - socketClose(sock); + delete sock; // when we fail to receive a game count, bail out setLobbyError(ERROR_CONNECTION); @@ -4740,7 +4756,7 @@ bool NETenumerateGames(const std::function& handl } // mark it invalid (we are done with it) - socketClose(sock); + delete sock; return true; } @@ -4796,7 +4812,7 @@ bool NETfindGame(uint32_t gameId, GAMESTRUCT& output) } // "consumes" the sockets and related info -bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char *playername, NETQUEUE joiningQUEUEInfo, Socket **client_joining_socket, SocketSet **client_joining_socket_set) +bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char *playername, NETQUEUE joiningQUEUEInfo, IClientConnection** client_joining_socket, IConnectionPollGroup** client_joining_socket_set) { if (hostPlayer >= MAX_CONNECTED_PLAYERS) { @@ -5101,7 +5117,8 @@ void NETacceptIncomingConnections() { // initialize temporary server socket set // FIXME: why is this not done in NETinit()?? - Per - tmp_socket_set = allocSocketSet(); + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + tmp_socket_set = connProvider.newConnectionPollGroup(); // FIXME: I guess initialization of allowjoining is here now... - FlexCoral for (auto& tmpState : tmp_connectState) { @@ -5127,12 +5144,12 @@ void NETacceptIncomingConnections() } // See if there's an incoming connection - tmp_socket[i] = socketAccept(server_listen_socket); + tmp_socket[i] = server_listen_socket->accept(); if (!tmp_socket[i]) { return; } - const std::string rIP = getSocketTextAddress(*tmp_socket[i]); + const std::string rIP = tmp_socket[i]->textAddress(); if (quickRejectConnection(rIP)) { debug(LOG_NET, "freeing temp socket %p (%d)", static_cast(tmp_socket[i]), __LINE__); @@ -5141,7 +5158,7 @@ void NETacceptIncomingConnections() } NETinitQueue(NETnetTmpQueue(i)); - SocketSet_AddSocket(*tmp_socket_set, tmp_socket[i]); + tmp_socket_set->add(tmp_socket[i]); tmp_pendingIPs[rIP]++; @@ -5155,7 +5172,7 @@ void NETacceptIncomingConnections() if (bEnableTCPNoDelay) { - // Enable TCP_NODELAY - socketSetTCPNoDelay(*tmp_socket[i], true); + // Disable use of Nagle Algorithm for the TCP socket (i.e. enable TCP_NODELAY option in case of TCP connection) + tmp_socket[i]->useNagleAlgorithm(false); } } diff --git a/lib/netplay/netplay.h b/lib/netplay/netplay.h index e16e5d21b34..75db62d6947 100644 --- a/lib/netplay/netplay.h +++ b/lib/netplay/netplay.h @@ -443,9 +443,11 @@ bool NEThaltJoining(); // stop new players joining this game bool NETenumerateGames(const std::function& handleEnumerateGameFunc); bool NETfindGames(std::vector& results, size_t startingIndex, size_t resultsLimit, bool onlyMatchingLocalVersion = false); bool NETfindGame(uint32_t gameId, GAMESTRUCT& output); -struct Socket; -struct SocketSet; -bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char *playername, NETQUEUE joiningQUEUEInfo, Socket **client_joining_socket, SocketSet **client_joining_socket_set); + +class IClientConnection; +class IConnectionPollGroup; + +bool NETpromoteJoinAttemptToEstablishedConnectionToHost(uint32_t hostPlayer, uint8_t index, const char* playername, NETQUEUE joiningQUEUEInfo, IClientConnection** client_joining_socket, IConnectionPollGroup** client_joining_socket_set); bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectatorHost, // host a game uint32_t gameType, uint32_t two, uint32_t three, uint32_t four, UDWORD plyrs); bool NETchangePlayerName(UDWORD player, char *newName);// change a players name. diff --git a/lib/netplay/open_connection_result.h b/lib/netplay/open_connection_result.h new file mode 100644 index 00000000000..b9b1c9aac7d --- /dev/null +++ b/lib/netplay/open_connection_result.h @@ -0,0 +1,56 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include +#include +#include + +#include "lib/netplay/client_connection.h" + +// Higher-level functions for opening a connection / socket +struct OpenConnectionResult +{ +public: + OpenConnectionResult(std::error_code ec, std::string errorString) + : errorCode(ec) + , errorString(std::move(errorString)) + { } + + OpenConnectionResult(IClientConnection* open_socket) + : open_socket(open_socket) + { } + +public: + bool hasError() const { return static_cast(errorCode); } +public: + OpenConnectionResult(const OpenConnectionResult& other) = delete; // non construction-copyable + OpenConnectionResult& operator=(const OpenConnectionResult&) = delete; // non copyable + OpenConnectionResult(OpenConnectionResult&&) = default; + OpenConnectionResult& operator=(OpenConnectionResult&&) = default; +public: + + std::unique_ptr open_socket; + std::error_code errorCode; + std::string errorString; +}; + +typedef std::function OpenConnectionToHostResultCallback; diff --git a/lib/netplay/sync_debug.cpp b/lib/netplay/sync_debug.cpp index 536db78c080..e23499a86db 100644 --- a/lib/netplay/sync_debug.cpp +++ b/lib/netplay/sync_debug.cpp @@ -27,9 +27,9 @@ #include "lib/framework/debug.h" #include "lib/framework/physfs_ext.h" #include "lib/gamelib/gtime.h" +#include "lib/netplay/byteorder_funcs_wrapper.h" #include "nettypes.h" #include "netplay.h" -#include "netsocket.h" // solely to bring in `htonl` function #include @@ -76,7 +76,7 @@ struct SyncDebugValueChange : public SyncDebugEntry variableName = vn; newValue = nv; id = i; - uint32_t valueBytes = htonl(newValue); + uint32_t valueBytes = wz_htonl(newValue); crc = wz::crc_update(crc, function, strlen(function) + 1); crc = wz::crc_update(crc, variableName, strlen(variableName) + 1); crc = wz::crc_update(crc, &valueBytes, 4); @@ -105,7 +105,7 @@ struct SyncDebugIntList : public SyncDebugEntry numInts = std::min(num, ARRAY_SIZE(valueBytes)); for (unsigned n = 0; n < numInts; ++n) { - valueBytes[n] = htonl(ints[n]); + valueBytes[n] = wz_htonl(ints[n]); } crc = wz::crc_update(crc, valueBytes, 4 * numInts); } diff --git a/lib/netplay/netsocket.cpp b/lib/netplay/tcp/netsocket.cpp similarity index 96% rename from lib/netplay/netsocket.cpp rename to lib/netplay/tcp/netsocket.cpp index 47a45a053a0..51798b11afd 100644 --- a/lib/netplay/netsocket.cpp +++ b/lib/netplay/tcp/netsocket.cpp @@ -26,7 +26,7 @@ #include "lib/framework/frame.h" #include "lib/framework/wzapp.h" #include "netsocket.h" -#include "error_categories.h" +#include "lib/netplay/error_categories.h" #include #include @@ -47,6 +47,9 @@ // Already included Winsock2.h which defines TCP_NODELAY #endif +namespace tcp +{ + enum { SOCK_CONNECTION, @@ -127,6 +130,8 @@ static void setSockErr(int error) #endif } +} // namespace tcp + #if defined(WZ_OS_WIN) typedef int (WINAPI *GETADDRINFO_DLL_FUNC)(const char *node, const char *service, const struct addrinfo *hints, @@ -216,6 +221,9 @@ static void freeaddrinfo(struct addrinfo *res) } #endif +namespace tcp +{ + static int addressToText(const struct sockaddr *addr, char *buf, size_t size) { auto handleIpv4 = [&](uint32_t addr) { @@ -1749,71 +1757,4 @@ void SOCKETshutdown() #endif } -OpenConnectionResult socketOpenTCPConnectionSync(const char *host, uint32_t port) -{ - const auto hostsResult = resolveHost(host, port); - SocketAddress* hosts = hostsResult.value_or(nullptr); - if (hosts == nullptr) - { - const auto hostsErr = hostsResult.error(); - const auto hostsErrMsg = hostsErr.message(); - return OpenConnectionResult(hostsErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, hostsErr.value(), hostsErrMsg.c_str())); - } - - auto sockResult = socketOpenAny(hosts, 15000); - Socket* client_transient_socket = sockResult.value_or(nullptr); - deleteSocketAddress(hosts); - hosts = nullptr; - - if (client_transient_socket == nullptr) - { - const auto errValue = sockResult.error(); - const auto errMsg = errValue.message(); - return OpenConnectionResult(errValue, astringf("Cannot connect to [%s]:%d, [%d]:%s", host, port, errValue.value(), errMsg.c_str())); - } - - return OpenConnectionResult(client_transient_socket); -} - -struct OpenConnectionRequest -{ - std::string host; - uint32_t port = 0; - OpenConnectionToHostResultCallback callback; -}; - -static int openDirectTCPConnectionAsyncImpl(void* data) -{ - OpenConnectionRequest* pRequestInfo = (OpenConnectionRequest*)data; - if (!pRequestInfo) - { - return 1; - } - - pRequestInfo->callback(socketOpenTCPConnectionSync(pRequestInfo->host.c_str(), pRequestInfo->port)); - delete pRequestInfo; - return 0; -} - -bool socketOpenTCPConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) -{ - // spawn background thread to handle this - auto pRequest = new OpenConnectionRequest(); - pRequest->host = host; - pRequest->port = port; - pRequest->callback = callback; - - WZ_THREAD * pOpenConnectionThread = wzThreadCreate(openDirectTCPConnectionAsyncImpl, pRequest); - if (pOpenConnectionThread == nullptr) - { - debug(LOG_ERROR, "Failed to create thread for opening connection"); - delete pRequest; - return false; - } - - wzThreadDetach(pOpenConnectionThread); - // the thread handles deleting pRequest - pOpenConnectionThread = nullptr; - - return true; -} +} // namespace tcp diff --git a/lib/netplay/netsocket.h b/lib/netplay/tcp/netsocket.h similarity index 81% rename from lib/netplay/netsocket.h rename to lib/netplay/tcp/netsocket.h index 32fff7019d8..b46a9ce89ce 100644 --- a/lib/netplay/netsocket.h +++ b/lib/netplay/tcp/netsocket.h @@ -21,18 +21,14 @@ #ifndef _net_socket_h #define _net_socket_h +#include "lib/framework/wzglobal.h" #include "lib/framework/types.h" #include #include #include +#include -#include - -namespace net -{ - template - using result = ::tl::expected; -} // namespace net +#include "lib/netplay/net_result.h" #if defined(WZ_OS_UNIX) # include @@ -81,14 +77,22 @@ static const SOCKET INVALID_SOCKET = -1; # define MSG_NOSIGNAL 0 #endif +namespace tcp +{ + struct Socket; struct SocketSet; + +} // namespace tcp + typedef struct addrinfo SocketAddress; #ifndef WZ_OS_WIN static const int SOCKET_ERROR = -1; #endif +namespace tcp +{ // Init/shutdown. void SOCKETinit(); @@ -134,35 +138,6 @@ WZ_DECL_NONNULL(2) void SocketSet_AddSocket(SocketSet& set, Socket *socket); // WZ_DECL_NONNULL(2) void SocketSet_DelSocket(SocketSet& set, Socket *socket); ///< Removes a Socket from a SocketSet. int checkSockets(const SocketSet& set, unsigned int timeout); ///< Checks which Sockets are ready for reading. Returns the number of ready Sockets, or returns SOCKET_ERROR on error. -// Higher-level functions for opening a connection / socket -struct OpenConnectionResult -{ -public: - OpenConnectionResult(std::error_code ec, std::string errorString) - : errorCode(ec) - , errorString(errorString) - { } - - OpenConnectionResult(Socket* open_socket) - : open_socket(open_socket) - { } -public: - bool hasError() const { return static_cast(errorCode); } -public: - OpenConnectionResult( const OpenConnectionResult& other ) = delete; // non construction-copyable - OpenConnectionResult& operator=( const OpenConnectionResult& ) = delete; // non copyable - OpenConnectionResult(OpenConnectionResult&&) = default; - OpenConnectionResult& operator=(OpenConnectionResult&&) = default; -public: - struct SocketDeleter { - void operator()(Socket* b) { if (b) { socketClose(b); } } - }; - std::unique_ptr open_socket; - std::error_code errorCode; - std::string errorString; -}; -typedef std::function OpenConnectionToHostResultCallback; -bool socketOpenTCPConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback); -OpenConnectionResult socketOpenTCPConnectionSync(const char *host, uint32_t port); +} // namespace tcp #endif //_net_socket_h diff --git a/lib/netplay/tcp/tcp_client_connection.cpp b/lib/netplay/tcp/tcp_client_connection.cpp new file mode 100644 index 00000000000..73d9d78c768 --- /dev/null +++ b/lib/netplay/tcp/tcp_client_connection.cpp @@ -0,0 +1,83 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/netsocket.h" +#include "lib/framework/wzapp.h" +#include "lib/framework/debug.h" +#include "lib/framework/string_ext.h" + +namespace tcp +{ + +TCPClientConnection::TCPClientConnection(Socket* rawSocket) + : socket_(rawSocket) +{ + ASSERT(socket_ != nullptr, "Null socket passed to TCPClientConnection ctor"); +} + +TCPClientConnection::~TCPClientConnection() +{ + if (socket_) + { + socketClose(socket_); + } +} + +net::result TCPClientConnection::readAll(void* buf, size_t size, unsigned timeout) +{ + return tcp::readAll(*socket_, buf, size, timeout); +} + +net::result TCPClientConnection::readNoInt(void* buf, size_t maxSize, size_t* rawByteCount) +{ + return tcp::readNoInt(*socket_, buf, maxSize, rawByteCount); +} + +net::result TCPClientConnection::writeAll(const void* buf, size_t size, size_t* rawByteCount) +{ + return tcp::writeAll(*socket_, buf, size, rawByteCount); +} + +bool TCPClientConnection::readReady() const +{ + return socketReadReady(*socket_); +} + +void TCPClientConnection::flush(size_t* rawByteCount) +{ + socketFlush(*socket_, std::numeric_limits::max()/*unused*/, rawByteCount); +} + +void TCPClientConnection::enableCompression() +{ + socketBeginCompression(*socket_); +} + +void TCPClientConnection::useNagleAlgorithm(bool enable) +{ + socketSetTCPNoDelay(*socket_, !enable); +} + +std::string TCPClientConnection::textAddress() const +{ + return getSocketTextAddress(*socket_); +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_client_connection.h b/lib/netplay/tcp/tcp_client_connection.h new file mode 100644 index 00000000000..d8f45716a9d --- /dev/null +++ b/lib/netplay/tcp/tcp_client_connection.h @@ -0,0 +1,52 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include "lib/netplay/client_connection.h" + +namespace tcp +{ + +struct Socket; + +class TCPClientConnection : public IClientConnection +{ +public: + + TCPClientConnection(Socket* rawSocket); + virtual ~TCPClientConnection() override; + + virtual net::result readAll(void* buf, size_t size, unsigned timeout) override; + virtual net::result readNoInt(void* buf, size_t maxSize, size_t* rawByteCount) override; + virtual net::result writeAll(const void* buf, size_t size, size_t* rawByteCount) override; + virtual bool readReady() const override; + virtual void flush(size_t* rawByteCount) override; + virtual void enableCompression() override; + virtual void useNagleAlgorithm(bool enable) override; + virtual std::string textAddress() const override; + +private: + + Socket* socket_; + + friend class TCPConnectionPollGroup; +}; + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_poll_group.cpp b/lib/netplay/tcp/tcp_connection_poll_group.cpp new file mode 100644 index 00000000000..7aab0fcbe71 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_poll_group.cpp @@ -0,0 +1,60 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_connection_poll_group.h" +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/netsocket.h" +#include "lib/framework/wzapp.h" +#include "lib/framework/debug.h" + +namespace tcp +{ + +TCPConnectionPollGroup::TCPConnectionPollGroup(SocketSet* sset) + : sset_(sset) +{} + +TCPConnectionPollGroup::~TCPConnectionPollGroup() +{ + if (sset_) + { + deleteSocketSet(sset_); + } +} + +int TCPConnectionPollGroup::checkSockets(unsigned timeout) +{ + return tcp::checkSockets(*sset_, timeout); +} + +void TCPConnectionPollGroup::add(IClientConnection* conn) +{ + auto* tcpConn = dynamic_cast(conn); + ASSERT_OR_RETURN(, tcpConn != nullptr, "Expected to have TCPClientConnection instance"); + SocketSet_AddSocket(*sset_, tcpConn->socket_); +} + +void TCPConnectionPollGroup::remove(IClientConnection* conn) +{ + auto tcpConn = dynamic_cast(conn); + ASSERT_OR_RETURN(, tcpConn != nullptr, "Expected to have TCPClientConnection instance"); + SocketSet_DelSocket(*sset_, tcpConn->socket_); +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_poll_group.h b/lib/netplay/tcp/tcp_connection_poll_group.h new file mode 100644 index 00000000000..7146f0997dd --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_poll_group.h @@ -0,0 +1,45 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include "lib/netplay/connection_poll_group.h" + +namespace tcp +{ + +struct SocketSet; + +class TCPConnectionPollGroup : public IConnectionPollGroup +{ +public: + + TCPConnectionPollGroup(SocketSet* sset); + virtual ~TCPConnectionPollGroup() override; + + virtual int checkSockets(unsigned timeout) override; + virtual void add(IClientConnection* conn) override; + virtual void remove(IClientConnection* conn) override; + +private: + + SocketSet* sset_; +}; + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_provider.cpp b/lib/netplay/tcp/tcp_connection_provider.cpp new file mode 100644 index 00000000000..8a3182a02e7 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_provider.cpp @@ -0,0 +1,153 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "tcp_connection_provider.h" + +#include "lib/netplay/tcp/netsocket.h" +#include "lib/netplay/tcp/tcp_connection_poll_group.h" +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/tcp_listen_socket.h" + +#include "lib/netplay/open_connection_result.h" +#include "lib/framework/wzapp.h" + +namespace tcp +{ + +void TCPConnectionProvider::initialize() +{ + SOCKETinit(); +} + +void TCPConnectionProvider::shutdown() +{ + SOCKETshutdown(); +} + +net::result TCPConnectionProvider::resolveHost(const char* host, uint16_t port) +{ + return ConnectionAddress::parse(host, port); +} + +net::result TCPConnectionProvider::openListenSocket(uint16_t port) +{ + auto res = tcp::socketListen(port); + if (!res.has_value()) + { + return tl::make_unexpected(res.error()); + } + return new TCPListenSocket(res.value()); +} + +net::result TCPConnectionProvider::openClientConnectionAny(const ConnectionAddress& addr, unsigned timeout) +{ + const auto* rawAddr = addr.asRawSocketAddress(); + auto res = socketOpenAny(rawAddr, timeout); + if (!res.has_value()) + { + return tl::make_unexpected(res.error()); + } + return new TCPClientConnection(res.value()); +} + +namespace +{ + +OpenConnectionResult socketOpenTCPConnectionSync(const char* host, uint32_t port) +{ + const auto hostsResult = resolveHost(host, port); + SocketAddress* hosts = hostsResult.value_or(nullptr); + if (hosts == nullptr) + { + const auto hostsErr = hostsResult.error(); + const auto hostsErrMsg = hostsErr.message(); + return OpenConnectionResult(hostsErr, astringf("Cannot resolve host \"%s\": [%d]: %s", host, hostsErr.value(), hostsErrMsg.c_str())); + } + + auto sockResult = socketOpenAny(hosts, 15000); + Socket* client_transient_socket = sockResult.value_or(nullptr); + deleteSocketAddress(hosts); + hosts = nullptr; + + if (client_transient_socket == nullptr) + { + const auto errValue = sockResult.error(); + const auto errMsg = errValue.message(); + return OpenConnectionResult(errValue, astringf("Cannot connect to [%s]:%d, [%d]:%s", host, port, errValue.value(), errMsg.c_str())); + } + + return OpenConnectionResult(new TCPClientConnection(client_transient_socket)); +} + +struct OpenConnectionRequest +{ + std::string host; + uint32_t port = 0; + OpenConnectionToHostResultCallback callback; +}; + +static int openDirectTCPConnectionAsyncImpl(void* data) +{ + OpenConnectionRequest* pRequestInfo = (OpenConnectionRequest*)data; + if (!pRequestInfo) + { + return 1; + } + + pRequestInfo->callback(socketOpenTCPConnectionSync(pRequestInfo->host.c_str(), pRequestInfo->port)); + delete pRequestInfo; + return 0; +} + +} // anonymous namespace + +bool TCPConnectionProvider::openClientConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) +{ + // spawn background thread to handle this + auto pRequest = new OpenConnectionRequest(); + pRequest->host = host; + pRequest->port = port; + pRequest->callback = callback; + + WZ_THREAD* pOpenConnectionThread = wzThreadCreate(openDirectTCPConnectionAsyncImpl, pRequest); + if (pOpenConnectionThread == nullptr) + { + debug(LOG_ERROR, "Failed to create thread for opening connection"); + delete pRequest; + return false; + } + + wzThreadDetach(pOpenConnectionThread); + // the thread handles deleting pRequest + pOpenConnectionThread = nullptr; + + return true; +} + +IConnectionPollGroup* TCPConnectionProvider::newConnectionPollGroup() +{ + auto* sset = allocSocketSet(); + if (!sset) + { + return nullptr; + } + return new TCPConnectionPollGroup(sset); +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_connection_provider.h b/lib/netplay/tcp/tcp_connection_provider.h new file mode 100644 index 00000000000..9a1336558b3 --- /dev/null +++ b/lib/netplay/tcp/tcp_connection_provider.h @@ -0,0 +1,48 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include +#include +#include + +#include "lib/netplay/wz_connection_provider.h" + +namespace tcp +{ + +class TCPConnectionProvider final : public WzConnectionProvider +{ +public: + + virtual void initialize() override; + virtual void shutdown() override; + + virtual net::result resolveHost(const char* host, uint16_t port) override; + + virtual net::result openListenSocket(uint16_t port) override; + + virtual net::result openClientConnectionAny(const ConnectionAddress& addr, unsigned timeout) override; + virtual bool openClientConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) override; + + virtual IConnectionPollGroup* newConnectionPollGroup() override; +}; + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_listen_socket.cpp b/lib/netplay/tcp/tcp_listen_socket.cpp new file mode 100644 index 00000000000..d7ea294b4fa --- /dev/null +++ b/lib/netplay/tcp/tcp_listen_socket.cpp @@ -0,0 +1,70 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#include "lib/netplay/tcp/tcp_listen_socket.h" +#include "lib/netplay/tcp/tcp_client_connection.h" +#include "lib/netplay/tcp/netsocket.h" +#include "lib/framework/wzapp.h" +#include "lib/framework/debug.h" + +namespace tcp +{ + +TCPListenSocket::TCPListenSocket(Socket* rawSocket) + : listenSocket_(rawSocket) +{} + +TCPListenSocket::~TCPListenSocket() +{ + if (listenSocket_) + { + socketClose(listenSocket_); + } +} + +IClientConnection* TCPListenSocket::accept() +{ + ASSERT(listenSocket_ != nullptr, "Internal socket handle shouldn't be null!"); + if (!listenSocket_) + { + return nullptr; + } + auto* s = socketAccept(listenSocket_); + if (!s) + { + return nullptr; + } + return new TCPClientConnection(s); +} + +IListenSocket::IPVersionsMask TCPListenSocket::supportedIpVersions() const +{ + IPVersionsMask resMask = 0; + if (socketHasIPv4(*listenSocket_)) + { + resMask |= static_cast(IPVersions::IPV4); + } + if (socketHasIPv6(*listenSocket_)) + { + resMask |= static_cast(IPVersions::IPV6); + } + return resMask; +} + +} // namespace tcp diff --git a/lib/netplay/tcp/tcp_listen_socket.h b/lib/netplay/tcp/tcp_listen_socket.h new file mode 100644 index 00000000000..94602b5d031 --- /dev/null +++ b/lib/netplay/tcp/tcp_listen_socket.h @@ -0,0 +1,46 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include "lib/netplay/listen_socket.h" + +namespace tcp +{ + +struct Socket; + +class TCPListenSocket : public IListenSocket +{ +public: + + TCPListenSocket(tcp::Socket* rawSocket); + virtual ~TCPListenSocket() override; + + virtual IClientConnection* accept() override; + virtual IPVersionsMask supportedIpVersions() const override; + +private: + + tcp::Socket* listenSocket_; +}; + +} // namespace tcp diff --git a/lib/netplay/wz_connection_provider.h b/lib/netplay/wz_connection_provider.h new file mode 100644 index 00000000000..1c9d51dbd8e --- /dev/null +++ b/lib/netplay/wz_connection_provider.h @@ -0,0 +1,80 @@ +/* + This file is part of Warzone 2100. + Copyright (C) 2024 Warzone 2100 Project + + Warzone 2100 is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + Warzone 2100 is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Warzone 2100; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#pragma once + +#include + +#include "lib/netplay/connection_address.h" +#include "lib/netplay/net_result.h" +#include "lib/netplay/open_connection_result.h" + +class IListenSocket; +class IClientConnection; +class IConnectionPollGroup; + +/// +/// Abstraction layer to facilitate creating client/server connections and +/// provide host resolution routines for a given network backend. +/// +/// A typical implementation of this interface should at least provide the following +/// things: +/// +/// 1. Initialization/teardown routines (setup some common state, like write/submission +/// queues or service threads, plus initialization of low-level backend code, e.g. +/// calls to init/deinit functions from a 3rd-party library). +/// 2. Host resolution. +/// 3. Opening server-side listen sockets. +/// 4. Opening client-side connections (sync and async). +/// 5. Creating connection poll groups. +/// +class WzConnectionProvider +{ +public: + + virtual ~WzConnectionProvider() = default; + + virtual void initialize() = 0; + virtual void shutdown() = 0; + + /// + /// Resolve host + port combination and return an opaque `ConnectionAddress` handle + /// representing the resolved network address. + /// + virtual net::result resolveHost(const char* host, uint16_t port) = 0; + /// + /// Open a listening socket bound to a specified local port. + /// + virtual net::result openListenSocket(uint16_t port) = 0; + /// + /// Synchronously open a client connection bound to one of the addresses + /// represented by `addr` (the first one that succeeds). + /// + /// Connection address to bind the client connection to. + /// Timeout in milliseconds. + virtual net::result openClientConnectionAny(const ConnectionAddress& addr, unsigned timeout) = 0; + /// + /// Async variant of `openClientConnectionAny()`. + /// + virtual bool openClientConnectionAsync(const std::string& host, uint32_t port, OpenConnectionToHostResultCallback callback) = 0; + /// + /// Create a group for polling client connections. + /// + virtual IConnectionPollGroup* newConnectionPollGroup() = 0; +}; diff --git a/po/POTFILES.in b/po/POTFILES.in index 37815360597..8006cb1b029 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -302,14 +302,20 @@ lib/ivis_opengl/png_util_spng.cpp lib/ivis_opengl/screen.cpp lib/ivis_opengl/tex.cpp lib/ivis_opengl/textdraw.cpp +lib/netplay/connection_address.cpp +lib/netplay/connection_provider_registry.cpp lib/netplay/netjoin_stub.cpp lib/netplay/netlog.cpp lib/netplay/netpermissions.cpp lib/netplay/netplay.cpp lib/netplay/netqueue.cpp lib/netplay/netreplay.cpp -lib/netplay/netsocket.cpp lib/netplay/nettypes.cpp +lib/netplay/tcp/netsocket.cpp +lib/netplay/tcp/tcp_client_connection.cpp +lib/netplay/tcp/tcp_connection_poll_group.cpp +lib/netplay/tcp/tcp_connection_provider.cpp +lib/netplay/tcp/tcp_listen_socket.cpp lib/netplay/port_mapping_manager.cpp lib/netplay/port_mapping_manager_impl_libplum.cpp lib/netplay/port_mapping_manager_impl_miniupnpc.cpp diff --git a/src/integrations/wzdiscordrpc.cpp b/src/integrations/wzdiscordrpc.cpp index 8cfd245eb18..eaa897d5b16 100644 --- a/src/integrations/wzdiscordrpc.cpp +++ b/src/integrations/wzdiscordrpc.cpp @@ -22,7 +22,7 @@ #include "lib/framework/crc.h" #include "lib/gamelib/gtime.h" #include "lib/netplay/netplay.h" -#include "lib/netplay/netsocket.h" +#include "lib/netplay/tcp/netsocket.h" #include "../activity.h" #include "../frontend.h" #include "../multiint.h" @@ -395,10 +395,10 @@ static void joinGameFromSecret_v1(const std::string joinSecretStr) return; } std::vector ipNetworkBinaryFormat = EmbeddedJSONSignature::b64Decode(b64UrlSafeTob64(connectionDetails[0].toUtf8())); - std::string ipAddressStr = ipv6_NetBinary_To_AddressString(ipNetworkBinaryFormat); + std::string ipAddressStr = tcp::ipv6_NetBinary_To_AddressString(ipNetworkBinaryFormat); if (ipAddressStr.empty()) { - ipAddressStr = ipv4_NetBinary_To_AddressString(ipNetworkBinaryFormat); + ipAddressStr = tcp::ipv4_NetBinary_To_AddressString(ipNetworkBinaryFormat); } if (ipAddressStr.empty()) { @@ -766,7 +766,7 @@ void DiscordRPCActivitySink::setJoinInformation(const ActivitySink::MultiplayerG std::string joinSecretDetails; if (!pExternalIPv4Address->empty()) { - auto ipv4AddressBinaryForm = ipv4_AddressString_To_NetBinary(*pExternalIPv4Address); + auto ipv4AddressBinaryForm = tcp::ipv4_AddressString_To_NetBinary(*pExternalIPv4Address); if (!ipv4AddressBinaryForm.empty()) { joinSecretDetails += b64Tob64UrlSafe(EmbeddedJSONSignature::b64Encode(ipv4AddressBinaryForm)); @@ -775,7 +775,7 @@ void DiscordRPCActivitySink::setJoinInformation(const ActivitySink::MultiplayerG } if (!ipv6Address.empty()) { - auto ipv6AddressBinaryForm = ipv6_AddressString_To_NetBinary(ipv6Address); + auto ipv6AddressBinaryForm = tcp::ipv6_AddressString_To_NetBinary(ipv6Address); if (!ipv6AddressBinaryForm.empty()) { if (!joinSecretDetails.empty()) diff --git a/src/screens/joiningscreen.cpp b/src/screens/joiningscreen.cpp index d249fe76c90..5bf493ca84f 100644 --- a/src/screens/joiningscreen.cpp +++ b/src/screens/joiningscreen.cpp @@ -28,8 +28,12 @@ #include "lib/widget/scrollablelist.h" #include "lib/ivis_opengl/pieblitfunc.h" #include "lib/ivis_opengl/piepalette.h" +#include "lib/netplay/byteorder_funcs_wrapper.h" #include "lib/netplay/netplay.h" -#include "lib/netplay/netsocket.h" +#include "lib/netplay/client_connection.h" +#include "lib/netplay/connection_poll_group.h" +#include "lib/netplay/open_connection_result.h" +#include "lib/netplay/connection_provider_registry.h" #include "../hci.h" #include "../activity.h" @@ -782,8 +786,8 @@ class WzJoiningGameScreen_HandlerRoot : public W_CLICKFORM // state when handling initial connection join uint32_t startTime = 0; - Socket* client_transient_socket = nullptr; - SocketSet* tmp_joining_socket_set = nullptr; + IClientConnection* client_transient_socket = nullptr; + IConnectionPollGroup* tmp_joining_socket_set = nullptr; NETQUEUE tmpJoiningQUEUE = {}; NetQueuePair *tmpJoiningQueuePair = nullptr; char initialAckBuffer[10] = {'\0'}; @@ -1096,22 +1100,22 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect if (NETgetEnableTCPNoDelay()) { - // Enable TCP_NODELAY - socketSetTCPNoDelay(*client_transient_socket, true); + // Disable use of Nagle Algorithm for the TCP socket (i.e. enable TCP_NODELAY option in case of TCP transport) + client_transient_socket->useNagleAlgorithm(false); } // Send initial connection data: NETCODE_VERSION_MAJOR and NETCODE_VERSION_MINOR char buffer[sizeof(int32_t) * 2] = { 0 }; char *p_buffer = buffer; auto pushu32 = [&](uint32_t value) { - uint32_t swapped = htonl(value); + uint32_t swapped = wz_htonl(value); memcpy(p_buffer, &swapped, sizeof(swapped)); p_buffer += sizeof(swapped); }; pushu32(NETGetMajorVersion()); pushu32(NETGetMinorVersion()); - const auto writeResult = writeAll(*client_transient_socket, buffer, sizeof(buffer)); + const auto writeResult = client_transient_socket->writeAll(buffer, sizeof(buffer), nullptr); if (!writeResult.has_value()) { const auto writeErrMsg = writeResult.error().message(); @@ -1120,7 +1124,8 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect return; } - tmp_joining_socket_set = allocSocketSet(); + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + tmp_joining_socket_set = connProvider.newConnectionPollGroup(); if (tmp_joining_socket_set == nullptr) { debug(LOG_ERROR, "Cannot create socket set - out of memory?"); @@ -1130,7 +1135,7 @@ void WzJoiningGameScreen_HandlerRoot::processOpenConnectionResult(size_t connect debug(LOG_NET, "Created socket_set %p", static_cast(tmp_joining_socket_set)); // `client_transient_socket` is used to talk to host machine - SocketSet_AddSocket(*tmp_joining_socket_set, client_transient_socket); + tmp_joining_socket_set->add(client_transient_socket); // Create temporary NETQUEUE auto NETnetJoinTmpQueue = [&]() @@ -1167,7 +1172,9 @@ void WzJoiningGameScreen_HandlerRoot::attemptToOpenConnection(size_t connectionI description.port = NETgetGameserverPort(); // use default configured port } auto weakSelf = std::weak_ptr(std::dynamic_pointer_cast(shared_from_this())); - socketOpenTCPConnectionAsync(description.host, description.port, [weakSelf, connectionIdx](OpenConnectionResult&& result) { + + auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT); + connProvider.openClientConnectionAsync(description.host, description.port, [weakSelf, connectionIdx](OpenConnectionResult&& result) { auto strongSelf = weakSelf.lock(); if (!strongSelf) { @@ -1188,7 +1195,7 @@ bool WzJoiningGameScreen_HandlerRoot::joiningSocketNETsend() uint8_t *rawData = message->rawDataDup(); ssize_t rawLen = message->rawLen(); size_t compressedRawLen = 0; - const auto writeResult = writeAll(*client_transient_socket, rawData, rawLen, &compressedRawLen); + const auto writeResult = client_transient_socket->writeAll(rawData, rawLen, &compressedRawLen); delete[] rawData; // Done with the data. queue->popMessageForNet(); if (writeResult.has_value()) @@ -1203,7 +1210,7 @@ bool WzJoiningGameScreen_HandlerRoot::joiningSocketNETsend() debug(LOG_ERROR, "Failed to send message (type: %" PRIu8 ", rawLen: %zu, compressedRawLen: %zu) to host: %s", message->type, message->rawLen(), compressedRawLen, writeErrMsg.c_str()); return false; } - socketFlush(*client_transient_socket, NET_HOST_ONLY); // Make sure the message was completely sent. + client_transient_socket->flush(nullptr); // Make sure the message was completely sent. ASSERT(queue->numMessagesForNet() == 0, "Queue not empty (%u messages remaining).", queue->numMessagesForNet()); return true; } @@ -1214,14 +1221,14 @@ void WzJoiningGameScreen_HandlerRoot::closeConnectionAttempt() { if (tmp_joining_socket_set) { - SocketSet_DelSocket(*tmp_joining_socket_set, client_transient_socket); + tmp_joining_socket_set->remove(client_transient_socket); } - socketClose(client_transient_socket); + delete client_transient_socket; client_transient_socket = nullptr; } if (tmp_joining_socket_set) { - deleteSocketSet(tmp_joining_socket_set); + delete tmp_joining_socket_set; tmp_joining_socket_set = nullptr; } if (tmpJoiningQueuePair) @@ -1281,15 +1288,17 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() if (currentJoiningState == JoiningState::AwaitingInitialNetcodeHandshakeAck) { // read in data, if we have it - if (checkSockets(*tmp_joining_socket_set, NET_READ_TIMEOUT) > 0) + if (tmp_joining_socket_set->checkSockets(NET_READ_TIMEOUT) > 0) { - if (!socketReadReady(*client_transient_socket)) + if (!client_transient_socket->readReady()) { return; // wait for next check } char *p_buffer = initialAckBuffer; - const auto readResult = readNoInt(*client_transient_socket, p_buffer + usedInitialAckBuffer, expectedInitialAckSize - usedInitialAckBuffer); + const auto readResult = client_transient_socket->readNoInt(p_buffer + usedInitialAckBuffer, + expectedInitialAckSize - usedInitialAckBuffer, + nullptr); if (readResult.has_value()) { usedInitialAckBuffer += static_cast(readResult.value()); @@ -1299,7 +1308,7 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() { uint32_t result = ERROR_CONNECTION; memcpy(&result, initialAckBuffer, sizeof(result)); - result = ntohl(result); + result = wz_ntohl(result); if (result != ERROR_NOERROR) { debug(LOG_ERROR, "Received error %d", result); @@ -1321,7 +1330,7 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() } // transition to net message mode (enable compression, wait for messages) - socketBeginCompression(*client_transient_socket); + client_transient_socket->enableCompression(); currentJoiningState = JoiningState::ProcessingJoinMessages; // permit fall-through to currentJoiningState == JoiningState::ProcessingJoinMessage case below } @@ -1332,15 +1341,15 @@ void WzJoiningGameScreen_HandlerRoot::processJoining() if (currentJoiningState == JoiningState::ProcessingJoinMessages) { // read in data, if we have it - if (checkSockets(*tmp_joining_socket_set, NET_READ_TIMEOUT) > 0) + if (tmp_joining_socket_set->checkSockets(NET_READ_TIMEOUT) > 0) { - if (!socketReadReady(*client_transient_socket)) + if (!client_transient_socket->readReady()) { return; // wait for next check } uint8_t readBuffer[NET_BUFFER_SIZE]; - const auto readResult = readNoInt(*client_transient_socket, readBuffer, sizeof(readBuffer)); + const auto readResult = client_transient_socket->readNoInt(readBuffer, sizeof(readBuffer), nullptr); if (!readResult.has_value()) { // disconnect or programmer error