diff --git a/lib/netplay/netplay.cpp b/lib/netplay/netplay.cpp index 3c52e69e2d7..d63699f85b9 100644 --- a/lib/netplay/netplay.cpp +++ b/lib/netplay/netplay.cpp @@ -251,7 +251,8 @@ static Socket *tcp_socket = nullptr; ///< Socket used to talk to a static Socket *bsocket = nullptr; ///< Socket used to talk to the host (clients only). If bsocket != NULL, then tcp_socket == NULL. static Socket *connected_bsocket[MAX_CONNECTED_PLAYERS] = { nullptr }; ///< Sockets used to talk to clients (host only). -static SocketSet *socket_set = nullptr; +static SocketSet* client_socket_set = nullptr; +static SocketSet* server_socket_set = nullptr; static struct UPNPUrls urls; static struct IGDdatas data; @@ -979,7 +980,7 @@ static void NETplayerCloseSocket(UDWORD index, bool quietSocketClose) NETlogEntry("Player has left nicely.", SYNC_FLAG, index); // Although we can get a error result from DelSocket, it don't really matter here. - SocketSet_DelSocket(*socket_set, connected_bsocket[index]); + SocketSet_DelSocket(*server_socket_set, connected_bsocket[index]); socketClose(connected_bsocket[index]); connected_bsocket[index] = nullptr; } @@ -1732,20 +1733,26 @@ int NETclose() } } - if (socket_set) + if (client_socket_set) { // checking to make sure tcp_socket is still valid if (tcp_socket) { - SocketSet_DelSocket(*socket_set, tcp_socket); + SocketSet_DelSocket(*client_socket_set, tcp_socket); } if (bsocket) { - SocketSet_DelSocket(*socket_set, bsocket); + SocketSet_DelSocket(*client_socket_set, bsocket); } - debug(LOG_NET, "Freeing socket_set %p", static_cast(socket_set)); - deleteSocketSet(socket_set); - socket_set = nullptr; + debug(LOG_NET, "Freeing socket_set %p", static_cast(client_socket_set)); + deleteSocketSet(client_socket_set); + client_socket_set = nullptr; + } + else if (server_socket_set) + { + debug(LOG_NET, "Freeing socket_set %p", static_cast(server_socket_set)); + deleteSocketSet(server_socket_set); + server_socket_set = nullptr; } if (tcp_socket) { @@ -1919,7 +1926,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message) debug(LOG_ERROR, "Failed to send message: %s", strSockError(getSockErr())); debug(LOG_ERROR, "Host connection was broken, socket %p.", static_cast(bsocket)); NETlogEntry("write error--client disconnect.", SYNC_FLAG, player); - SocketSet_DelSocket(*socket_set, bsocket); // mark it invalid + SocketSet_DelSocket(*client_socket_set, bsocket); // mark it invalid socketClose(bsocket); bsocket = nullptr; NetPlay.players[NetPlay.hostPlayer].heartbeat = false; // mark host as dead @@ -2949,7 +2956,8 @@ bool NETrecvNet(NETQUEUE *queue, uint8_t *type) NETcheckPlayers(); // make sure players are still alive & well } - if (socket_set == nullptr || checkSockets(*socket_set, NET_READ_TIMEOUT) <= 0) + SocketSet* sset = NetPlay.isHost ? server_socket_set : client_socket_set; + if (sset == nullptr || checkSockets(*sset, NET_READ_TIMEOUT) <= 0) { goto checkMessages; } @@ -2970,7 +2978,7 @@ bool NETrecvNet(NETQUEUE *queue, uint8_t *type) continue; } - dataLen = NET_fillBuffer(pSocket, socket_set, buffer, sizeof(buffer)); + dataLen = NET_fillBuffer(pSocket, sset, buffer, sizeof(buffer)); if (dataLen > 0) { // we received some data, add to buffer @@ -4398,7 +4406,7 @@ static void NETallowJoining() connected_bsocket[index] = tmp_socket[i]; NET_waitingForIndexChangeAckSince[index] = nullopt; tmp_socket[i] = nullptr; - SocketSet_AddSocket(*socket_set, connected_bsocket[index]); + SocketSet_AddSocket(*server_socket_set, connected_bsocket[index]); NETmoveQueue(NETnetTmpQueue(i), NETnetQueue(index)); // Copy player's IP address. @@ -4544,11 +4552,11 @@ bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectator } debug(LOG_NET, "New tcp_socket = %p", static_cast(tcp_socket)); // Host needs to create a socket set for MAX_PLAYERS - if (!socket_set) + if (!server_socket_set) { - socket_set = allocSocketSet(); + server_socket_set = allocSocketSet(); } - if (socket_set == nullptr) + if (server_socket_set == nullptr) { debug(LOG_ERROR, "Cannot create socket set: %s", strSockError(getSockErr())); return false; @@ -4672,9 +4680,9 @@ bool NETenumerateGames(const std::function& handl if (tcp_socket != nullptr) { debug(LOG_NET, "Deleting tcp_socket %p", static_cast(tcp_socket)); - if (socket_set) + if (client_socket_set) { - SocketSet_DelSocket(*socket_set, tcp_socket); + SocketSet_DelSocket(*client_socket_set, tcp_socket); } socketClose(tcp_socket); tcp_socket = nullptr; @@ -4693,23 +4701,23 @@ bool NETenumerateGames(const std::function& handl } debug(LOG_NET, "New tcp_socket = %p", static_cast(tcp_socket)); // client machines only need 1 socket set - socket_set = allocSocketSet(); - if (socket_set == nullptr) + client_socket_set = allocSocketSet(); + if (client_socket_set == nullptr) { debug(LOG_ERROR, "Cannot create socket set: %s", strSockError(getSockErr())); setLobbyError(ERROR_CONNECTION); return false; } - debug(LOG_NET, "Created socket_set %p", static_cast(socket_set)); + debug(LOG_NET, "Created socket_set %p", static_cast(client_socket_set)); - SocketSet_AddSocket(*socket_set, tcp_socket); + SocketSet_AddSocket(*client_socket_set, tcp_socket); debug(LOG_NET, "Sending list cmd"); if (writeAll(*tcp_socket, "list", sizeof("list")) == SOCKET_ERROR) { debug(LOG_NET, "Server socket encountered error: %s", strSockError(getSockErr())); - SocketSet_DelSocket(*socket_set, tcp_socket); // mark it invalid + SocketSet_DelSocket(*client_socket_set, tcp_socket); // mark it invalid socketClose(tcp_socket); tcp_socket = nullptr; @@ -4727,7 +4735,7 @@ bool NETenumerateGames(const std::function& handl return true; // continue enumerating })) { - SocketSet_DelSocket(*socket_set, tcp_socket); // mark it invalid + SocketSet_DelSocket(*client_socket_set, tcp_socket); // mark it invalid socketClose(tcp_socket); tcp_socket = nullptr; @@ -4789,7 +4797,7 @@ bool NETenumerateGames(const std::function& handl // if ignoring the first batch, treat this as a fatal error debug(LOG_NET, "Second readGameStructsList call failed"); - SocketSet_DelSocket(*socket_set, tcp_socket); // mark it invalid + SocketSet_DelSocket(*client_socket_set, tcp_socket); // mark it invalid socketClose(tcp_socket); tcp_socket = nullptr; @@ -4813,7 +4821,7 @@ bool NETenumerateGames(const std::function& handl } } - SocketSet_DelSocket(*socket_set, tcp_socket); // mark it invalid (we are done with it) + SocketSet_DelSocket(*client_socket_set, tcp_socket); // mark it invalid (we are done with it) socketClose(tcp_socket); tcp_socket = nullptr; @@ -4915,16 +4923,16 @@ bool NETjoinGame(const char *host, uint32_t port, const char *playername, const } // client machines only need 1 socket set - socket_set = allocSocketSet(); - if (socket_set == nullptr) + client_socket_set = allocSocketSet(); + if (client_socket_set == nullptr) { debug(LOG_ERROR, "Cannot create socket set: %s", strSockError(getSockErr())); return false; } - debug(LOG_NET, "Created socket_set %p", static_cast(socket_set)); + debug(LOG_NET, "Created socket_set %p", static_cast(client_socket_set)); // tcp_socket is used to talk to host machine - SocketSet_AddSocket(*socket_set, tcp_socket); + SocketSet_AddSocket(*client_socket_set, tcp_socket); // Send NETCODE_VERSION_MAJOR and NETCODE_VERSION_MINOR p_buffer = buffer; @@ -4948,11 +4956,11 @@ bool NETjoinGame(const char *host, uint32_t port, const char *playername, const { debug(LOG_ERROR, "Received error %d", result); - SocketSet_DelSocket(*socket_set, tcp_socket); + SocketSet_DelSocket(*client_socket_set, tcp_socket); socketClose(tcp_socket); tcp_socket = nullptr; - deleteSocketSet(socket_set); - socket_set = nullptr; + deleteSocketSet(client_socket_set); + client_socket_set = nullptr; setLobbyError((LOBBY_ERROR_TYPES)result); return false;