diff --git a/lib/netplay/netplay.cpp b/lib/netplay/netplay.cpp index 2d4135b4271..cc8a495b44f 100644 --- a/lib/netplay/netplay.cpp +++ b/lib/netplay/netplay.cpp @@ -264,6 +264,8 @@ static char externalIPAddress[40]; /** * Used for connections with clients. */ +#define NET_PING_TMP_PING_CHALLENGE_SIZE 128 +static std::array, MAX_TMP_SOCKETS> tmp_challenges{}; static Socket *tmp_socket[MAX_TMP_SOCKETS] = { nullptr }; ///< Sockets used to talk to clients which have not yet been assigned a player number (host only). static SocketSet *tmp_socket_set = nullptr; @@ -3668,6 +3670,11 @@ static void NETallowJoining() debug(LOG_ERROR, "Cannot create socket set: %s", strSockError(getSockErr())); return; } + // FIXME: I guess initialization of allowjoining is here now... - FlexCoral + for (auto& challenge : tmp_challenges) + { + challenge.clear(); + } } // Find the first empty socket slot @@ -3762,6 +3769,13 @@ static void NETallowJoining() // Connection is successful. connectFailed = false; + + // Give client a challenge to solve before connecting + tmp_challenges[i].resize(NET_PING_TMP_PING_CHALLENGE_SIZE); + genSecRandomBytes(tmp_challenges[i].data(), tmp_challenges[i].size()); + NETbeginEncode(NETnetTmpQueue(i), NET_PING); + NETbytes(&(tmp_challenges[i])); + NETend(); } else { @@ -3799,6 +3813,7 @@ static void NETallowJoining() SocketSet_DelSocket(tmp_socket_set, tmp_socket[i]); socketClose(tmp_socket[i]); tmp_socket[i] = nullptr; + tmp_challenges[i].clear(); } } @@ -3844,14 +3859,42 @@ static void NETallowJoining() char ModList[modlist_string_size] = { '\0' }; char GamePassword[password_string_size] = { '\0' }; uint8_t playerType = 0; + EcKey::Key pkey; + EcKey identity; + EcKey::Sig challengeResponse; NETbeginDecode(NETnetTmpQueue(i), NET_JOIN); NETstring(name, sizeof(name)); NETstring(ModList, sizeof(ModList)); NETstring(GamePassword, sizeof(GamePassword)); NETuint8_t(&playerType); + NETbytes(&pkey); + NETbytes(&challengeResponse); NETend(); + identity.fromBytes(pkey, EcKey::Public); + // verify signature that player is joining with, reject him if he can not do that + if (!identity.verify(challengeResponse, tmp_challenges[i].data(), tmp_challenges[i].size())) + { + debug(LOG_ERROR, "freeing temp socket %p, couldn't create player!", static_cast(tmp_socket[i])); + + rejected = ERROR_WRONGDATA; + NETbeginEncode(NETnetTmpQueue(i), NET_REJECTED); + NETuint8_t(&rejected); + NETend(); + NETflush(); + NETpop(NETnetTmpQueue(i)); + + SocketSet_DelSocket(tmp_socket_set, tmp_socket[i]); + socketClose(tmp_socket[i]); + tmp_socket[i] = nullptr; + tmp_challenges[i].clear(); + sync_counter.cantjoin++; + return; + } + + tmp_challenges[i].clear(); + if ((playerType == NET_JOIN_SPECTATOR) || (int)NetPlay.playercount <= gamestruct.desc.dwMaxPlayers) { tmp = NET_CreatePlayer(name, false, (playerType == NET_JOIN_SPECTATOR)); @@ -3959,6 +4002,7 @@ static void NETallowJoining() snprintf(buf, sizeof(buf), "%s[%" PRIu8 "] %s has joined, IP is: %s", pPlayerType, index, name, NetPlay.players[index].IPtextAddress); debug(LOG_INFO, "%s", buf); NETlogEntry(buf, SYNC_FLAG, index); + wz_command_interface_output("WZEVENT: player join: %u %s %s %s\n", i, base64Encode(pkey).c_str(), identity.publicHashString().c_str(), NetPlay.players[i].IPtextAddress); debug(LOG_NET, "%s, %s, with index of %u has joined using socket %p", pPlayerType, name, (unsigned int)index, static_cast(connected_bsocket[index])); @@ -3967,6 +4011,8 @@ static void NETallowJoining() MultiPlayerJoin(index); + ingame.VerifiedIdentity[index] = true; + // Narrowcast to new player that everyone has joined. for (j = 0; j < MAX_CONNECTED_PLAYERS; ++j) { @@ -4402,7 +4448,7 @@ bool NETfindGame(uint32_t gameId, GAMESTRUCT& output) // //////////////////////////////////////////////////////////////////////// // //////////////////////////////////////////////////////////////////////// // Functions used to setup and join games. -bool NETjoinGame(const char *host, uint32_t port, const char *playername, bool asSpectator /*= false*/) +bool NETjoinGame(const char *host, uint32_t port, const char *playername, const EcKey& playerIdentity, bool asSpectator /*= false*/) { SocketAddress *hosts = nullptr; unsigned int i; @@ -4496,14 +4542,7 @@ bool NETjoinGame(const char *host, uint32_t port, const char *playername, bool a socketBeginCompression(bsocket); uint8_t playerType = (!asSpectator) ? NET_JOIN_PLAYER : NET_JOIN_SPECTATOR; - - // Send a join message to the host - NETbeginEncode(NETnetQueue(NET_HOST_ONLY), NET_JOIN); - NETstring(playername, 64); - NETstring(getModList().c_str(), modlist_string_size); - NETstring(NetPlay.gamePassword, sizeof(NetPlay.gamePassword)); - NETuint8_t(&playerType); - NETend(); + if (bsocket == nullptr) { return false; // Connection dropped while sending NET_JOIN. @@ -4592,6 +4631,27 @@ bool NETjoinGame(const char *host, uint32_t port, const char *playername, bool a NETclose(); return false; } + else if (type == NET_PING) + { + std::vector challenge(NET_PING_TMP_PING_CHALLENGE_SIZE, 0); + NETbeginDecode(NETnetQueue(NET_HOST_ONLY), NET_PING); + NETbytes(&challenge, NET_PING_TMP_PING_CHALLENGE_SIZE * 4); + NETend(); + NETpop(queue); + + EcKey::Sig challengeResponse = playerIdentity.sign(challenge.data(), challenge.size()); + EcKey::Key identity = playerIdentity.toBytes(EcKey::Public); + + NETbeginEncode(NETnetQueue(NET_HOST_ONLY), NET_JOIN); + NETstring(playername, 64); + NETstring(getModList().c_str(), modlist_string_size); + NETstring(NetPlay.gamePassword, sizeof(NetPlay.gamePassword)); + NETuint8_t(&playerType); + NETbytes(&identity); + NETbytes(&challengeResponse); + NETend(); + NETflush(); + } else { debug(LOG_ERROR, "Unexpected %s.", messageTypeToString(type)); diff --git a/lib/netplay/netplay.h b/lib/netplay/netplay.h index 46df7c3848b..d81068f1a0e 100644 --- a/lib/netplay/netplay.h +++ b/lib/netplay/netplay.h @@ -456,7 +456,7 @@ bool NEThaltJoining(); // stop new players joining this game bool NETenumerateGames(const std::function& handleEnumerateGameFunc); bool NETfindGames(std::vector& results, size_t startingIndex, size_t resultsLimit, bool onlyMatchingLocalVersion = false); bool NETfindGame(uint32_t gameId, GAMESTRUCT& output); -bool NETjoinGame(const char *host, uint32_t port, const char *playername, bool asSpectator = false); // join game given with playername +bool NETjoinGame(const char *host, uint32_t port, const char *playername, const EcKey& playerIdentity, bool asSpectator = false); // join game given with playername bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectatorHost, // host a game uint32_t gameType, uint32_t two, uint32_t three, uint32_t four, UDWORD plyrs); bool NETchangePlayerName(UDWORD player, char *newName);// change a players name. diff --git a/lib/netplay/nettypes.cpp b/lib/netplay/nettypes.cpp index de6ee74ef8f..cbe97754dd6 100644 --- a/lib/netplay/nettypes.cpp +++ b/lib/netplay/nettypes.cpp @@ -802,7 +802,7 @@ void NETbytes(std::vector *vec, unsigned maxLen) if (len > maxLen) { - debug(LOG_ERROR, "NETstring: %s packet, length %u truncated at %u", NETgetPacketDir() == PACKET_ENCODE ? "Encoding" : "Decoding", len, maxLen); + debug(LOG_ERROR, "NETbytes: %s packet, length %u truncated at %u", NETgetPacketDir() == PACKET_ENCODE ? "Encoding" : "Decoding", len, maxLen); } len = std::min(len, maxLen); // Truncate length if necessary. diff --git a/src/multiint.cpp b/src/multiint.cpp index 2666759541f..b97261a0b2c 100644 --- a/src/multiint.cpp +++ b/src/multiint.cpp @@ -1049,13 +1049,14 @@ static JoinGameResult joinGameInternalConnect(const char *host, uint32_t port, s { // oldUI may get captured for use in the password dialog, among other things. PLAYERSTATS playerStats; + loadMultiStats(sPlayer, &playerStats); if (ingame.localJoiningInProgress) { return JoinGameResult::FAILED; } - if (!NETjoinGame(host, port, (char *)sPlayer, asSpectator)) // join + if (!NETjoinGame(host, port, (char *)sPlayer, playerStats.identity, asSpectator)) // join { switch (getLobbyError()) { @@ -1084,7 +1085,6 @@ static JoinGameResult joinGameInternalConnect(const char *host, uint32_t port, s } ingame.localJoiningInProgress = true; - loadMultiStats(sPlayer, &playerStats); setMultiStats(selectedPlayer, playerStats, false); setMultiStats(selectedPlayer, playerStats, true);