Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send TCP packets on separate thread #7226

Merged
merged 15 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions protocol/test/protocol_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1916,24 +1916,24 @@ TEST_F(ProtocolTest, QueueTest) {
item = 1;
EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0);
item = 2;
EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0);
EXPECT_EQ(fifo_queue_enqueue_item_timeout(fifo_queue, &item, -1), 0);
EXPECT_EQ(fifo_queue_dequeue_item(fifo_queue, &item), 0);
EXPECT_EQ(item, 1);
item = 3;
EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0);
item = 4;
EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0);
item = 5;
EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0);
EXPECT_EQ(fifo_queue_enqueue_item_timeout(fifo_queue, &item, 50), 0);
item = 6;
EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0);
item = 7;
EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), -1);
EXPECT_EQ(fifo_queue_enqueue_item_timeout(fifo_queue, &item, 50), -1);
EXPECT_EQ(fifo_queue_dequeue_item(fifo_queue, &item), 0);
EXPECT_EQ(item, 2);
EXPECT_EQ(fifo_queue_dequeue_item_timeout(fifo_queue, &item, 100), 0);
EXPECT_EQ(item, 3);
EXPECT_EQ(fifo_queue_dequeue_item_timeout(fifo_queue, &item, 100), 0);
EXPECT_EQ(fifo_queue_dequeue_item_timeout(fifo_queue, &item, -1), 0);
EXPECT_EQ(item, 4);
EXPECT_EQ(fifo_queue_dequeue_item(fifo_queue, &item), 0);
EXPECT_EQ(item, 5);
Expand Down
141 changes: 120 additions & 21 deletions protocol/whist/network/tcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Includes
#include <whist/utils/clock.h>
#include <whist/network/throttle.h>
#include "whist/core/features.h"
#include "whist/utils/queue.h"

#if !OS_IS(OS_WIN32)
#include <fcntl.h>
Expand All @@ -30,6 +31,11 @@ Defines
// Currently set to the "large enough" 1GB
#define MAX_TCP_PAYLOAD_SIZE 1000000000

// How many packets to allow to be queued up on
// a single TCP sending thread before queueing
// up the next packet will block.
#define TCP_SEND_QUEUE_SIZE 16

typedef enum {
TCP_PING,
TCP_PONG,
Expand Down Expand Up @@ -88,8 +94,20 @@ typedef struct {
// Only recvp every RECV_INTERVAL_MS, to keep CPU usage low.
// This is because a recvp takes ~8ms sometimes
WhistTimer last_recvp;

// TCP send is not atomic, so we have to hold packets in a queue and send on a separate thread
WhistThread send_thread;
QueueContext* send_queue;
WhistSemaphore send_semaphore;
bool run_sender;
} TCPContext;

// Struct for holding packets on queue
typedef struct TCPQueueItem {
TCPNetworkPacket* packet;
int packet_size;
} TCPQueueItem;

// Time between consecutive pings
#define TCP_PING_INTERVAL_SEC 2.0
// Time before a ping to be considered "lost", and reconnection starts
Expand Down Expand Up @@ -160,6 +178,19 @@ int create_tcp_client_context(TCPContext* context, char* destination, int port,
*/
int tcp_send_constructed_packet(TCPContext* context, TCPPacket* packet);

/**
* @brief Multithreaded function to asynchronously
* send all TCP packets for one socket context
* on the same thread.
* This prevents garbled TCP messages from
* being sent since large TCP sends are not atomic.
*
* @param opaque Pointer to associated socket context
*
* @returns 0 on exit
*/
int multithreaded_tcp_send(void* opaque);

/**
* @brief Returns the size, in bytes, of the relevant part of
* the TCPPacket, that must be sent over the network
Expand Down Expand Up @@ -492,6 +523,13 @@ static void tcp_destroy_socket_context(void* raw_context) {
FATAL_ASSERT(raw_context != NULL);
TCPContext* context = raw_context;

// Destroy TCP send queue resources
context->run_sender = false;

// Any pending TCP packets will be dropped
whist_wait_thread(context->send_thread, NULL);
fifo_queue_destroy(context->send_queue);

closesocket(context->socket);
closesocket(context->listen_socket);
whist_destroy_mutex(context->mutex);
Expand Down Expand Up @@ -560,6 +598,9 @@ bool create_tcp_socket_context(SocketContext* network_context, char* destination
context->last_pong_id = -1;
start_timer(&context->last_ping_timer);
context->connection_lost = false;
context->send_queue = NULL;
context->send_semaphore = NULL;
context->send_thread = NULL;
start_timer(&context->last_recvp);

int ret;
Expand All @@ -578,6 +619,22 @@ bool create_tcp_socket_context(SocketContext* network_context, char* destination
return false;
}

// Set up TCP send queue
context->run_sender = true;
if ((context->send_queue = fifo_queue_create(sizeof(TCPQueueItem), TCP_SEND_QUEUE_SIZE)) ==
NULL ||
(context->send_semaphore = whist_create_semaphore(0)) == NULL ||
(context->send_thread = whist_create_thread(multithreaded_tcp_send,
"multithreaded_tcp_send", context)) == NULL) {
// If any of the created resources are NULL, there was a failure and we need to clean up and
// return false
if (context->send_queue) fifo_queue_destroy(context->send_queue);
if (context->send_semaphore) whist_destroy_semaphore(context->send_semaphore);
free(context);
network_context->context = NULL;
return false;
}

// Restore the original timeout
set_timeout(context->socket, context->timeout);

Expand Down Expand Up @@ -763,33 +820,75 @@ int tcp_send_constructed_packet(TCPContext* context, TCPPacket* packet) {
memcpy(network_packet->payload, packet, packet_size);
}

int tcp_packet_size = get_tcp_network_packet_size(network_packet);
// Add TCPNetworkPacket to the queue to be sent on the TCP send thread
TCPQueueItem queue_item;
queue_item.packet = network_packet;
queue_item.packet_size = packet_size;
if (fifo_queue_enqueue_item_timeout(context->send_queue, &queue_item, -1) < 0) return -1;
whist_post_semaphore(context->send_semaphore);
return 0;
}

// For now, the TCP network throttler is NULL, so this is a no-op.
network_throttler_wait_byte_allocation(context->network_throttler, tcp_packet_size);
int multithreaded_tcp_send(void* opaque) {
TCPQueueItem queue_item;
TCPNetworkPacket* network_packet = NULL;
TCPContext* context = (TCPContext*)opaque;
while (true) {
whist_wait_semaphore(context->send_semaphore);
// Check to see if the sender thread needs to stop running
if (!context->run_sender) break;
// If connection is lost, then wait for up to TCP_PING_MAX_RECONNECTION_TIME_SEC
// before continuing.
if (context->connection_lost) {
// Need to re-increment semaphore because wait_semaphore at the top of the loop
// will have decremented semaphore for a packet we are not sending yet.
whist_post_semaphore(context->send_semaphore);
sardination marked this conversation as resolved.
Show resolved Hide resolved
// If the wait for another packet times out, then we return to the top of the loop
if (!whist_wait_timeout_semaphore(context->send_semaphore,
TCP_PING_MAX_RECONNECTION_TIME_SEC * 1000))
continue;
}

// This is useful enough to print, even outside of LOG_NETWORKING GUARDS
LOG_INFO("Sending a WhistPacket of size %d (Total %d bytes), over TCP", packet_size,
tcp_packet_size);
// If there is no item to be dequeued, continue
if (fifo_queue_dequeue_item(context->send_queue, &queue_item) < 0) continue;

// Send the packet
bool failed = false;
int ret = send(context->socket, (const char*)network_packet, tcp_packet_size, 0);
if (ret < 0) {
int error = get_last_network_error();
if (error == WHIST_ECONNRESET) {
LOG_WARNING("TCP Connection reset by peer");
context->connection_lost = true;
} else {
LOG_WARNING("Unexpected TCP Packet Error: %d", error);
network_packet = queue_item.packet;

int tcp_packet_size = get_tcp_network_packet_size(network_packet);

// For now, the TCP network throttler is NULL, so this is a no-op.
network_throttler_wait_byte_allocation(context->network_throttler, tcp_packet_size);

// This is useful enough to print, even outside of LOG_NETWORKING GUARDS
LOG_INFO("Sending a WhistPacket of size %d (Total %d bytes), over TCP",
queue_item.packet_size, tcp_packet_size);

// Send the packet. If a partial packet is sent, keep sending until full packet has been
// sent.
int total_sent = 0;
while (total_sent < tcp_packet_size) {
int ret = send(context->socket, (const char*)(network_packet + total_sent),
tcp_packet_size, 0);
if (ret < 0) {
int error = get_last_network_error();
if (error == WHIST_ECONNRESET) {
LOG_WARNING("TCP Connection reset by peer");
context->connection_lost = true;
} else {
LOG_WARNING("Unexpected TCP Packet Error: %d", error);
}
// Don't attempt to send the rest of the packet if there was a failure
break;
} else {
total_sent += ret;
}
}
failed = true;
}

// Free the encrypted allocation
deallocate_region(network_packet);
// Free the encrypted allocation
deallocate_region(network_packet);
}

return failed ? -1 : 0;
return 0;
}

int get_tcp_packet_size(TCPPacket* tcp_packet) {
Expand Down
2 changes: 1 addition & 1 deletion protocol/whist/network/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ bool create_tcp_socket_context(SocketContext* context, char* destination, int po
char* binary_aes_private_key);

/**
* @brief Creates a tcp listen socket, that can be used in SocketContext
* @brief Creates a tcp listen socket, that can be used in SocketContext
*
* @param sock The socket that will be initialized
* @param port The port to listen on
Expand Down
93 changes: 79 additions & 14 deletions protocol/whist/utils/queue.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ typedef struct QueueContext {
int num_items;
int max_items;
WhistMutex mutex;
WhistCondition cond;
WhistCondition avail_items_cond;
WhistCondition avail_space_cond;
void *data;
bool destroying;
} QueueContext;

static void increment_idx(QueueContext *context, int *idx) {
Expand All @@ -29,6 +31,15 @@ static void dequeue_item(QueueContext *context, void *item) {
void *source_item = (uint8_t *)context->data + (context->item_size * context->read_idx);
memcpy(item, source_item, context->item_size);
increment_idx(context, &context->read_idx);
whist_broadcast_cond(context->avail_space_cond);
}

static void enqueue_item(QueueContext *context, const void *item) {
context->num_items++;
void *target_item = (uint8_t *)context->data + (context->item_size * context->write_idx);
memcpy(target_item, item, context->item_size);
increment_idx(context, &context->write_idx);
whist_broadcast_cond(context->avail_items_cond);
}

QueueContext *fifo_queue_create(size_t item_size, int max_items) {
Expand All @@ -49,14 +60,21 @@ QueueContext *fifo_queue_create(size_t item_size, int max_items) {
return NULL;
}

context->cond = whist_create_cond();
if (context->cond == NULL) {
context->avail_items_cond = whist_create_cond();
if (context->avail_items_cond == NULL) {
fifo_queue_destroy(context);
return NULL;
}

context->avail_space_cond = whist_create_cond();
if (context->avail_space_cond == NULL) {
fifo_queue_destroy(context);
return NULL;
}

context->item_size = item_size;
context->max_items = max_items;
context->destroying = false;
return context;
}

Expand All @@ -69,11 +87,39 @@ int fifo_queue_enqueue_item(QueueContext *context, const void *item) {
whist_unlock_mutex(context->mutex);
return -1;
}
context->num_items++;
void *target_item = (uint8_t *)context->data + (context->item_size * context->write_idx);
memcpy(target_item, item, context->item_size);
increment_idx(context, &context->write_idx);
whist_broadcast_cond(context->cond);
enqueue_item(context, item);
whist_unlock_mutex(context->mutex);
return 0;
}

int fifo_queue_enqueue_item_timeout(QueueContext *context, const void *item, int timeout_ms) {
if (context == NULL) {
return -1;
}
WhistTimer timer;
start_timer(&timer);
int current_timeout_ms = timeout_ms;
whist_lock_mutex(context->mutex);
while (context->num_items >= context->max_items) {
if (context->destroying) {
whist_unlock_mutex(context->mutex);
return -1;
}
if (timeout_ms >= 0) {
bool res =
whist_timedwait_cond(context->avail_space_cond, context->mutex, current_timeout_ms);
if (res == false) { // In case of a timeout simply exit
whist_unlock_mutex(context->mutex);
return -1;
}
int elapsed_ms = (int)(get_timer(&timer) * MS_IN_SECOND);
current_timeout_ms = max(timeout_ms - elapsed_ms, 0);
} else {
// Negative timeout_ms indicates block until available, not timeout
whist_wait_cond(context->avail_space_cond, context->mutex);
}
}
enqueue_item(context, item);
whist_unlock_mutex(context->mutex);
return 0;
}
Expand Down Expand Up @@ -101,13 +147,23 @@ int fifo_queue_dequeue_item_timeout(QueueContext *context, void *item, int timeo
int current_timeout_ms = timeout_ms;
whist_lock_mutex(context->mutex);
while (context->num_items <= 0) {
bool res = whist_timedwait_cond(context->cond, context->mutex, current_timeout_ms);
if (res == false) { // In case of a timeout simply exit
if (context->destroying) {
whist_unlock_mutex(context->mutex);
return -1;
}
int elapsed_ms = (int)(get_timer(&timer) * MS_IN_SECOND);
current_timeout_ms = max(timeout_ms - elapsed_ms, 0);
if (timeout_ms >= 0) {
bool res =
whist_timedwait_cond(context->avail_items_cond, context->mutex, current_timeout_ms);
if (res == false) { // In case of a timeout simply exit
whist_unlock_mutex(context->mutex);
return -1;
}
int elapsed_ms = (int)(get_timer(&timer) * MS_IN_SECOND);
current_timeout_ms = max(timeout_ms - elapsed_ms, 0);
} else {
// Negative timeout_ms indicates block until available, not timeout
whist_wait_cond(context->avail_items_cond, context->mutex);
}
}
dequeue_item(context, item);
whist_unlock_mutex(context->mutex);
Expand All @@ -118,14 +174,23 @@ void fifo_queue_destroy(QueueContext *context) {
if (context == NULL) {
return;
}

// Make sure that all blocking calls release
context->destroying = true;
whist_broadcast_cond(context->avail_items_cond);
whist_broadcast_cond(context->avail_space_cond);

if (context->data != NULL) {
free(context->data);
}
if (context->mutex != NULL) {
whist_destroy_mutex(context->mutex);
}
if (context->cond != NULL) {
whist_destroy_cond(context->cond);
if (context->avail_items_cond != NULL) {
whist_destroy_cond(context->avail_items_cond);
}
if (context->avail_space_cond != NULL) {
whist_destroy_cond(context->avail_space_cond);
}
free(context);
}
Loading