Skip to content

Commit

Permalink
use futex for IPC notification
Browse files Browse the repository at this point in the history
  • Loading branch information
deanlee committed Jul 4, 2024
1 parent 74074d6 commit b597955
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 67 deletions.
1 change: 1 addition & 0 deletions SConscript
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ msgq_objects = env.SharedObject([
'msgq/impl_zmq.cc',
'msgq/impl_msgq.cc',
'msgq/impl_fake.cc',
'msgq/futex.cc',
'msgq/msgq.cc',
])
msgq = env.Library('msgq', msgq_objects)
Expand Down
62 changes: 62 additions & 0 deletions msgq/futex.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "msgq/futex.h"

#include <fcntl.h>
#include <limits.h>
#include <linux/futex.h>
#include <stdio.h>
#include <sys/mman.h>
#include <syscall.h>
#include <unistd.h>

#include <cassert>
#include <stdexcept>

Futex::Futex(const std::string &path) {
auto fd = open(path.c_str(), O_RDWR | O_CREAT, 0664);
if (fd < 0) {
throw std::runtime_error("Failed to open file: " + path);
}

if (ftruncate(fd, sizeof(uint32_t)) < 0) {
close(fd);
throw std::runtime_error("Failed to truncate file: " + path);
}

int *mem = (int *)mmap(NULL, sizeof(uint32_t), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
close(fd);
if (mem == MAP_FAILED) {
throw std::runtime_error("Failed to mmap file: " + path);
}

futex = reinterpret_cast<std::atomic<uint32_t> *>(mem);
}

Futex::~Futex() {
munmap(futex, sizeof(uint32_t));
}

void Futex::broadcast() {
// Increment the futex value to signal waiting threads
futex->fetch_add(1, std::memory_order_relaxed);

// Wake up all threads waiting on the futex
syscall(SYS_futex, futex, FUTEX_WAKE, INT_MAX, NULL, NULL, 0);
}

bool Futex::wait(uint32_t expected, int timeout_ms) {
if (futex->load(std::memory_order_relaxed) != expected) {
return true; // Already not equal, no need to wait
}

if (timeout_ms <= 0) {
return false; // Timeout immediately
}

// Perform the futex wait syscall
struct timespec ts;
ts.tv_sec = timeout_ms / 1000;
ts.tv_nsec = (timeout_ms % 1000) * 1000 * 1000;
syscall(SYS_futex, futex, FUTEX_WAIT, expected, &ts, nullptr, 0);

return futex->load(std::memory_order_relaxed) != expected;
}
18 changes: 18 additions & 0 deletions msgq/futex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include <cstdint>
#include <atomic>
#include <string>


class Futex {
public:
Futex(const std::string &path);
~Futex();
void broadcast();
bool wait(uint32_t expected, int timeout_ms);
inline uint32_t value() const { return futex->load(); }

private:
std::atomic<uint32_t> *futex = nullptr;
};
11 changes: 2 additions & 9 deletions msgq/impl_msgq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ Message * MSGQSubSocket::receive(bool non_blocking){
}

msgq_msg_t msg;

MSGQMessage *r = NULL;

int rc = msgq_msg_recv(&msg, q);
Expand All @@ -93,21 +92,15 @@ Message * MSGQSubSocket::receive(bool non_blocking){
items[0].q = q;

int t = (timeout != -1) ? timeout : 100;

int n = msgq_poll(items, 1, t);
rc = msgq_msg_recv(&msg, q);

// The poll indicated a message was ready, but the receive failed. Try again
if (n == 1 && rc == 0){
continue;
if (msgq_poll(items, 1, t) > 0) {
rc = msgq_msg_recv(&msg, q);
}

if (timeout != -1){
break;
}
}


if (!non_blocking){
std::signal(SIGINT, prev_handler_sigint);
std::signal(SIGTERM, prev_handler_sigterm);
Expand Down
78 changes: 20 additions & 58 deletions msgq/msgq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,21 @@
#include <cerrno>
#include <cmath>
#include <cstring>
#include <cstdint>
#include <chrono>
#include <algorithm>
#include <cstdlib>
#include <csignal>
#include <random>
#include <string>
#include <limits>

#include <poll.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/syscall.h>
#include <fcntl.h>
#include <unistd.h>

#include <stdio.h>

#include "msgq/futex.h"
#include "msgq/msgq.h"

void sigusr2_handler(int signal) {
assert(signal == SIGUSR2);
}
Futex g_futex("/dev/shm/msgq_futex");

uint64_t msgq_get_uid(void){
std::random_device rd("/dev/urandom");
Expand Down Expand Up @@ -85,7 +75,6 @@ void msgq_wait_for_subscriber(msgq_queue_t *q){

int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){
assert(size < 0xFFFFFFFF); // Buffer must be smaller than 2^32 bytes
std::signal(SIGUSR2, sigusr2_handler);

std::string full_path = "/dev/shm/";
const char* prefix = std::getenv("OPENPILOT_PREFIX");
Expand Down Expand Up @@ -142,7 +131,6 @@ void msgq_close_queue(msgq_queue_t *q){
}
}


void msgq_init_publisher(msgq_queue_t * q) {
//std::cout << "Starting publisher" << std::endl;
uint64_t uid = msgq_get_uid();
Expand All @@ -158,15 +146,6 @@ void msgq_init_publisher(msgq_queue_t * q) {
q->write_uid_local = uid;
}

static void thread_signal(uint32_t tid) {
#ifndef SYS_tkill
// TODO: this won't work for multithreaded programs
kill(tid, SIGUSR2);
#else
syscall(SYS_tkill, tid, SIGUSR2);
#endif
}

void msgq_init_subscriber(msgq_queue_t * q) {
assert(q != NULL);
assert(q->num_readers != NULL);
Expand All @@ -185,14 +164,11 @@ void msgq_init_subscriber(msgq_queue_t * q) {

for (size_t i = 0; i < NUM_READERS; i++){
*q->read_valids[i] = false;

uint64_t old_uid = *q->read_uids[i];
*q->read_uids[i] = 0;

// Wake up reader in case they are in a poll
thread_signal(old_uid & 0xFFFFFFFF);
}

// Notify readers
g_futex.broadcast();
continue;
}

Expand Down Expand Up @@ -293,10 +269,7 @@ int msgq_msg_send(msgq_msg_t * msg, msgq_queue_t *q){
PACK64(*q->write_pointer, write_cycles, new_ptr);

// Notify readers
for (uint64_t i = 0; i < num_readers; i++){
uint64_t reader_uid = *q->read_uids[i];
thread_signal(reader_uid & 0xFFFFFFFF);
}
g_futex.broadcast();

return msg->size;
}
Expand Down Expand Up @@ -414,42 +387,31 @@ int msgq_msg_recv(msgq_msg_t * msg, msgq_queue_t * q){
goto start;
}


return msg->size;
}



int msgq_poll(msgq_pollitem_t * items, size_t nitems, int timeout){
int msgq_poll(msgq_pollitem_t * items, size_t nitems, int timeout) {
int num = 0;
int timeout_ms = (timeout == -1) ? 100 : timeout;
uint32_t current_futex_value = 0;

// Check if messages ready
for (size_t i = 0; i < nitems; i++) {
items[i].revents = msgq_msg_ready(items[i].q);
if (items[i].revents) num++;
}

int ms = (timeout == -1) ? 100 : timeout;
struct timespec ts;
ts.tv_sec = ms / 1000;
ts.tv_nsec = (ms % 1000) * 1000 * 1000;


auto start_time = std::chrono::high_resolution_clock::now();
while (num == 0) {
int ret;

ret = nanosleep(&ts, &ts);
if (g_futex.wait(current_futex_value, timeout_ms)) {
current_futex_value = g_futex.value();

// Check if messages ready
for (size_t i = 0; i < nitems; i++) {
if (items[i].revents == 0 && msgq_msg_ready(items[i].q)){
num += 1;
items[i].revents = 1;
// Check if messages ready
for (size_t i = 0; i < nitems; i++) {
items[i].revents = msgq_msg_ready(items[i].q);
if (items[i].revents) ++num;
}
}

// exit if we had a timeout and the sleep finished
if (timeout != -1 && ret == 0){
// Update the remaining timeout
auto current_time = std::chrono::high_resolution_clock::now();
timeout_ms -= std::chrono::duration_cast<std::chrono::milliseconds>(current_time - start_time).count();
start_time = current_time;
if (timeout_ms <= 0) {
break;
}
}
Expand Down

0 comments on commit b597955

Please sign in to comment.