Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

netplay.cpp: separate global socket_set into client_socket_set and server_socket_set #3780

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 40 additions & 32 deletions lib/netplay/netplay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ static Socket *tcp_socket = nullptr; ///< Socket used to talk to a

static Socket *bsocket = nullptr; ///< Socket used to talk to the host (clients only). If bsocket != NULL, then tcp_socket == NULL.
static Socket *connected_bsocket[MAX_CONNECTED_PLAYERS] = { nullptr }; ///< Sockets used to talk to clients (host only).
static SocketSet *socket_set = nullptr;
static SocketSet* client_socket_set = nullptr;
static SocketSet* server_socket_set = nullptr;

static struct UPNPUrls urls;
static struct IGDdatas data;
Expand Down Expand Up @@ -979,7 +980,7 @@ static void NETplayerCloseSocket(UDWORD index, bool quietSocketClose)
NETlogEntry("Player has left nicely.", SYNC_FLAG, index);

// Although we can get a error result from DelSocket, it don't really matter here.
SocketSet_DelSocket(*socket_set, connected_bsocket[index]);
SocketSet_DelSocket(*server_socket_set, connected_bsocket[index]);
socketClose(connected_bsocket[index]);
connected_bsocket[index] = nullptr;
}
Expand Down Expand Up @@ -1732,20 +1733,26 @@ int NETclose()
}
}

if (socket_set)
if (client_socket_set)
{
// checking to make sure tcp_socket is still valid
if (tcp_socket)
{
SocketSet_DelSocket(*socket_set, tcp_socket);
SocketSet_DelSocket(*client_socket_set, tcp_socket);
}
if (bsocket)
{
SocketSet_DelSocket(*socket_set, bsocket);
SocketSet_DelSocket(*client_socket_set, bsocket);
}
debug(LOG_NET, "Freeing socket_set %p", static_cast<void *>(socket_set));
deleteSocketSet(socket_set);
socket_set = nullptr;
debug(LOG_NET, "Freeing socket_set %p", static_cast<void *>(client_socket_set));
deleteSocketSet(client_socket_set);
client_socket_set = nullptr;
}
else if (server_socket_set)
{
debug(LOG_NET, "Freeing socket_set %p", static_cast<void*>(server_socket_set));
deleteSocketSet(server_socket_set);
server_socket_set = nullptr;
}
if (tcp_socket)
{
Expand Down Expand Up @@ -1919,7 +1926,7 @@ bool NETsend(NETQUEUE queue, NetMessage const *message)
debug(LOG_ERROR, "Failed to send message: %s", strSockError(getSockErr()));
debug(LOG_ERROR, "Host connection was broken, socket %p.", static_cast<void *>(bsocket));
NETlogEntry("write error--client disconnect.", SYNC_FLAG, player);
SocketSet_DelSocket(*socket_set, bsocket); // mark it invalid
SocketSet_DelSocket(*client_socket_set, bsocket); // mark it invalid
socketClose(bsocket);
bsocket = nullptr;
NetPlay.players[NetPlay.hostPlayer].heartbeat = false; // mark host as dead
Expand Down Expand Up @@ -2949,7 +2956,8 @@ bool NETrecvNet(NETQUEUE *queue, uint8_t *type)
NETcheckPlayers(); // make sure players are still alive & well
}

if (socket_set == nullptr || checkSockets(*socket_set, NET_READ_TIMEOUT) <= 0)
SocketSet* sset = NetPlay.isHost ? server_socket_set : client_socket_set;
if (sset == nullptr || checkSockets(*sset, NET_READ_TIMEOUT) <= 0)
{
goto checkMessages;
}
Expand All @@ -2970,7 +2978,7 @@ bool NETrecvNet(NETQUEUE *queue, uint8_t *type)
continue;
}

dataLen = NET_fillBuffer(pSocket, socket_set, buffer, sizeof(buffer));
dataLen = NET_fillBuffer(pSocket, sset, buffer, sizeof(buffer));
if (dataLen > 0)
{
// we received some data, add to buffer
Expand Down Expand Up @@ -4398,7 +4406,7 @@ static void NETallowJoining()
connected_bsocket[index] = tmp_socket[i];
NET_waitingForIndexChangeAckSince[index] = nullopt;
tmp_socket[i] = nullptr;
SocketSet_AddSocket(*socket_set, connected_bsocket[index]);
SocketSet_AddSocket(*server_socket_set, connected_bsocket[index]);
NETmoveQueue(NETnetTmpQueue(i), NETnetQueue(index));

// Copy player's IP address.
Expand Down Expand Up @@ -4544,11 +4552,11 @@ bool NEThostGame(const char *SessionName, const char *PlayerName, bool spectator
}
debug(LOG_NET, "New tcp_socket = %p", static_cast<void *>(tcp_socket));
// Host needs to create a socket set for MAX_PLAYERS
if (!socket_set)
if (!server_socket_set)
{
socket_set = allocSocketSet();
server_socket_set = allocSocketSet();
}
if (socket_set == nullptr)
if (server_socket_set == nullptr)
{
debug(LOG_ERROR, "Cannot create socket set: %s", strSockError(getSockErr()));
return false;
Expand Down Expand Up @@ -4672,9 +4680,9 @@ bool NETenumerateGames(const std::function<bool (const GAMESTRUCT& game)>& handl
if (tcp_socket != nullptr)
{
debug(LOG_NET, "Deleting tcp_socket %p", static_cast<void *>(tcp_socket));
if (socket_set)
if (client_socket_set)
{
SocketSet_DelSocket(*socket_set, tcp_socket);
SocketSet_DelSocket(*client_socket_set, tcp_socket);
}
socketClose(tcp_socket);
tcp_socket = nullptr;
Expand All @@ -4693,23 +4701,23 @@ bool NETenumerateGames(const std::function<bool (const GAMESTRUCT& game)>& handl
}
debug(LOG_NET, "New tcp_socket = %p", static_cast<void *>(tcp_socket));
// client machines only need 1 socket set
socket_set = allocSocketSet();
if (socket_set == nullptr)
client_socket_set = allocSocketSet();
if (client_socket_set == nullptr)
{
debug(LOG_ERROR, "Cannot create socket set: %s", strSockError(getSockErr()));
setLobbyError(ERROR_CONNECTION);
return false;
}
debug(LOG_NET, "Created socket_set %p", static_cast<void *>(socket_set));
debug(LOG_NET, "Created socket_set %p", static_cast<void *>(client_socket_set));

SocketSet_AddSocket(*socket_set, tcp_socket);
SocketSet_AddSocket(*client_socket_set, tcp_socket);

debug(LOG_NET, "Sending list cmd");

if (writeAll(*tcp_socket, "list", sizeof("list")) == SOCKET_ERROR)
{
debug(LOG_NET, "Server socket encountered error: %s", strSockError(getSockErr()));
SocketSet_DelSocket(*socket_set, tcp_socket); // mark it invalid
SocketSet_DelSocket(*client_socket_set, tcp_socket); // mark it invalid
socketClose(tcp_socket);
tcp_socket = nullptr;

Expand All @@ -4727,7 +4735,7 @@ bool NETenumerateGames(const std::function<bool (const GAMESTRUCT& game)>& handl
return true; // continue enumerating
}))
{
SocketSet_DelSocket(*socket_set, tcp_socket); // mark it invalid
SocketSet_DelSocket(*client_socket_set, tcp_socket); // mark it invalid
socketClose(tcp_socket);
tcp_socket = nullptr;

Expand Down Expand Up @@ -4789,7 +4797,7 @@ bool NETenumerateGames(const std::function<bool (const GAMESTRUCT& game)>& handl
// if ignoring the first batch, treat this as a fatal error
debug(LOG_NET, "Second readGameStructsList call failed");

SocketSet_DelSocket(*socket_set, tcp_socket); // mark it invalid
SocketSet_DelSocket(*client_socket_set, tcp_socket); // mark it invalid
socketClose(tcp_socket);
tcp_socket = nullptr;

Expand All @@ -4813,7 +4821,7 @@ bool NETenumerateGames(const std::function<bool (const GAMESTRUCT& game)>& handl
}
}

SocketSet_DelSocket(*socket_set, tcp_socket); // mark it invalid (we are done with it)
SocketSet_DelSocket(*client_socket_set, tcp_socket); // mark it invalid (we are done with it)
socketClose(tcp_socket);
tcp_socket = nullptr;

Expand Down Expand Up @@ -4915,16 +4923,16 @@ bool NETjoinGame(const char *host, uint32_t port, const char *playername, const
}

// client machines only need 1 socket set
socket_set = allocSocketSet();
if (socket_set == nullptr)
client_socket_set = allocSocketSet();
if (client_socket_set == nullptr)
{
debug(LOG_ERROR, "Cannot create socket set: %s", strSockError(getSockErr()));
return false;
}
debug(LOG_NET, "Created socket_set %p", static_cast<void *>(socket_set));
debug(LOG_NET, "Created socket_set %p", static_cast<void *>(client_socket_set));

// tcp_socket is used to talk to host machine
SocketSet_AddSocket(*socket_set, tcp_socket);
SocketSet_AddSocket(*client_socket_set, tcp_socket);

// Send NETCODE_VERSION_MAJOR and NETCODE_VERSION_MINOR
p_buffer = buffer;
Expand All @@ -4948,11 +4956,11 @@ bool NETjoinGame(const char *host, uint32_t port, const char *playername, const
{
debug(LOG_ERROR, "Received error %d", result);

SocketSet_DelSocket(*socket_set, tcp_socket);
SocketSet_DelSocket(*client_socket_set, tcp_socket);
socketClose(tcp_socket);
tcp_socket = nullptr;
deleteSocketSet(socket_set);
socket_set = nullptr;
deleteSocketSet(client_socket_set);
client_socket_set = nullptr;

setLobbyError((LOBBY_ERROR_TYPES)result);
return false;
Expand Down
Loading