Skip to content

Commit

Permalink
udp gateway done
Browse files Browse the repository at this point in the history
  • Loading branch information
victoryang00 committed Jan 15, 2024
1 parent e40e273 commit 40acdad
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 212 deletions.
21 changes: 7 additions & 14 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,27 @@
[submodule "lib/wasm-micro-runtime"]
path = lib/wasm-micro-runtime
url = https://github.com/Multi-V-VM/wasm-micro-runtime.git
branch = pthread
[submodule "bench/linpack"]
path = bench/linpack
url = https://github.com/victoryang00/linpack
url = https://github.com/Multi-V-VM/linpack
[submodule "bench/hdastar"]
path = bench/hdastar
url = https://github.com/victoryang00/HDAStar
url = https://github.com/Multi-V-VM/HDAStar
[submodule "lib/libcrafter"]
path = lib/libcrafter
url = https://github.com/victoryang00/libcrafter
url = https://github.com/Multi-V-VM/libcrafter
[submodule "bench/nas"]
path = bench/nas
url = https://github.com/victoryang00/NPB3.0-omp-C
url = https://github.com/Multi-V-VM/NPB3.0-omp-C
[submodule "bench/nginx-module"]
path = bench/nginx-module
url = https://github.com/nginx/unit-wasm
[submodule "opendal"]
path = opendal
url = https://github.com/apache/incubator-opendal
[submodule "bench/opendal"]
path = bench/opendal
url = https://github.com/apache/incubator-opendal
[submodule "bench/ftp"]
path = bench/ftp
url = https://github.com/victoryang00/FTP-implementation-over-TCP
url = https://github.com/Multi-V-VM/FTP-implementation-over-TCP
[submodule "bench/orb_slam3"]
path = bench/orb_slam3
url = https://github.com/victoryang00/orb-slam-expts/
url = https://github.com/Multi-V-VM/orb-slam-expts/
[submodule "bench/lammps"]
path = bench/lammps
url = https://github.com/victoryang00/lammps
url = https://github.com/Multi-V-VM/lammps
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ add_executable(MVVM_restore src/restore.cpp ${UNCOMMON_SHARED_SOURCE})
add_executable(MVVM_checkpoint src/checkpoint.cpp ${UNCOMMON_SHARED_SOURCE})

target_link_libraries(MVVM_export fmt::fmt -lm -ldl -lpthread ${BLAS_LIBRARIES})
target_link_libraries(MVVM_restore fmt::fmt cxxopts::cxxopts -lm -ldl -lpthread ${BLAS_LIBRARIES} MVVM_export vmlib)
target_link_libraries(MVVM_checkpoint fmt::fmt cxxopts::cxxopts -lm -ldl -lpthread ${BLAS_LIBRARIES} MVVM_export vmlib)
target_link_libraries(MVVM_restore fmt::fmt cxxopts::cxxopts -lm -ldl -lpthread ${BLAS_LIBRARIES} MVVM_export vmlib ${LLVM_LLDB_LIB})
target_link_libraries(MVVM_checkpoint fmt::fmt cxxopts::cxxopts -lm -ldl -lpthread ${BLAS_LIBRARIES} MVVM_export vmlib ${LLVM_LLDB_LIB})
if (MVVM_BUILD_MPI)
add_executable(MVVM_mpi_test test/mpi.cpp ${UNCOMMON_SHARED_SOURCE})
target_link_libraries(MVVM_mpi_test fmt::fmt cxxopts::cxxopts -lm -ldl -lpthread ${BLAS_LIBRARIES} MVVM_export vmlib)
target_link_libraries(MVVM_mpi_test fmt::fmt cxxopts::cxxopts -lm -ldl -lpthread ${BLAS_LIBRARIES} MVVM_export vmlib ${LLVM_LLDB_LIB})
endif()
add_definitions(-DCXXOPTS_NO_RTTI=1)
107 changes: 59 additions & 48 deletions gateway/main.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "crafter/Protocols/RawLayer.h"
#include "logging.h"
#include <chrono>
#include <crafter.h>
Expand Down Expand Up @@ -25,7 +26,7 @@ int fd;
std::vector<std::jthread> backend_thread;
std::vector<std::tuple<std::string, std::string, std::string>> forward_pair;
bool is_forward = false;

int id = 0;
// Function to recalculate the IP checksum
unsigned short in_cksum(unsigned short *buf, int len) {
unsigned long sum = 0;
Expand All @@ -39,6 +40,7 @@ unsigned short in_cksum(unsigned short *buf, int len) {
sum += (sum >> 16);
return (unsigned short)(~sum);
}
// https://github.com/pellegre/libcrafter-examples/blob/03832c5c6f68b55a714877bf53aaba2fc33c43ff/SimpleHijackConnection/main.cpp#L113
void forward(const unsigned char *buf, int len) {
int sock;
struct sockaddr_in dst {};
Expand Down Expand Up @@ -169,8 +171,11 @@ void packet_handler(u_char *user, const struct pcap_pkthdr *header, const u_char
ntohs(udphdr->uh_dport));
LOGV(INFO) << fmt::format("ID:{} TOS:0x{}, TTL:{} IpLen:{} DgLen:{}", ntohs(iphdr->ip_id), iphdr->ip_tos,
iphdr->ip_ttl, 4 * iphdr->ip_hl, ntohs(iphdr->ip_len));
id = ntohs(iphdr->ip_id) + 1;
LOGV(ERROR) << "id" << id;
LOGV(INFO) << fmt::format("+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+");
packets += 1;

break;

case IPPROTO_ICMP:
Expand Down Expand Up @@ -221,7 +226,6 @@ void keep_alive(const std::stop_token &stopToken, std::string source_ip, int sou
int dest_port) {
// Send keep alive message to socket
// Initialize Libcrafter
InitCrafter();

while (!stopToken.stop_requested()) {
// Create an IP layer
Expand All @@ -243,18 +247,16 @@ void keep_alive(const std::stop_token &stopToken, std::string source_ip, int sou
std::this_thread::sleep_for(std::chrono::seconds(1));
}
// Clean up Libcrafter
CleanCrafter();
}
void send_fin(std::string source_ip, int source_port, std::string dest_ip,
int dest_port, const char * payload) {
void send_fin(std::string source_ip, int source_port, std::string dest_ip, int dest_port, const char *payload) {
// Send keep alive message to socket
// Initialize Libcrafter
InitCrafter();
// Create an IP layer
IP ip_layer;
ip_layer.SetSourceIP(source_ip);
ip_layer.SetDestinationIP(dest_ip);

ip_layer.SetIdentification(21180);
// Create a UDP layer
UDP udp_layer;
udp_layer.SetSrcPort(source_port);
Expand All @@ -263,23 +265,21 @@ void send_fin(std::string source_ip, int source_port, std::string dest_ip,

// Craft the packet
Packet packet = ip_layer / udp_layer;

ip_layer.SetCheckSum(in_cksum((unsigned short *)packet.GetRawPtr(), packet.GetSize()));
// Send the packet
packet.Send();
packet.Send(MVVM_SOCK_INTERFACE);
// Clean up Libcrafter
CleanCrafter();
}
void send_fin_tcp(std::string source_ip, int source_port, std::string dest_ip,
int dest_port, const char * payload) {
void send_fin_tcp(std::string source_ip, int source_port, std::string dest_ip, int dest_port, const char *payload) {
// Send keep alive message to socket
// Initialize Libcrafter
InitCrafter();
// Create an IP layer
IP ip_layer;
ip_layer.SetSourceIP(source_ip);
ip_layer.SetDestinationIP(dest_ip);

// Create a UDP layer
// Create a TCP layer
TCP tcp_layer;
tcp_layer.SetSrcPort(source_port);
tcp_layer.SetDstPort(dest_port);
Expand All @@ -290,9 +290,8 @@ void send_fin_tcp(std::string source_ip, int source_port, std::string dest_ip,
Packet packet = ip_layer / tcp_layer;

// Send the packet
packet.Send();
packet.Send(MVVM_SOCK_INTERFACE);
// Clean up Libcrafter
CleanCrafter();
}
void sigterm_handler(int sig) {
struct pcap_stat stats {};
Expand All @@ -308,7 +307,11 @@ void sigterm_handler(int sig) {
LOGV(INFO) << "Bye";
exit(0);
}

// int main(){
// struct mvvm_op_data *op_data = (struct mvvm_op_data *)malloc(sizeof(struct mvvm_op_data));
// op_data->op = MVVM_SOCK_FIN;
// send_fin("172.17.0.2",42466,"172.17.0.3",12346,((char *)op_data));
// }
int main() {
struct sockaddr_in address {};
int opt = 1;
Expand All @@ -321,7 +324,7 @@ int main() {
char filter_exp[] = "net 172.17.0.0/24";
bpf_u_int32 netmask;

struct mvvm_op_data op_data {};
auto op_data = (struct mvvm_op_data *)malloc(sizeof(struct mvvm_op_data));

signal(SIGTERM, sigterm_handler);
signal(SIGQUIT, sigterm_handler);
Expand Down Expand Up @@ -383,61 +386,69 @@ int main() {
}
// offload info from client
if ((rc = recv(client_fd, buffer, sizeof(buffer), 0)) > 0) {
memcpy(&op_data, buffer, sizeof(op_data));
switch (op_data.op) {
memcpy(op_data, buffer, sizeof(*op_data));
switch (op_data->op) {
case MVVM_SOCK_SUSPEND:
// suspend
LOGV(ERROR) << "suspend"; // capture all the packets from dest to source

for (int idx = 0; idx < op_data.size; idx++) {
if (op_data.addr[idx][0].is_4) {
server_ip = fmt::format("{}.{}.{}.{}", op_data.addr[idx][0].ip4[0], op_data.addr[idx][0].ip4[1],
op_data.addr[idx][0].ip4[2], op_data.addr[idx][0].ip4[3]);
client_ip = fmt::format("{}.{}.{}.{}", op_data.addr[idx][1].ip4[0], op_data.addr[idx][1].ip4[1],
op_data.addr[idx][1].ip4[2], op_data.addr[idx][1].ip4[3]);
for (int idx = 0; idx < op_data->size; idx++) {
if (op_data->addr[idx][0].is_4) {
server_ip =
fmt::format("{}.{}.{}.{}", op_data->addr[idx][0].ip4[0], op_data->addr[idx][0].ip4[1],
op_data->addr[idx][0].ip4[2], op_data->addr[idx][0].ip4[3]);
client_ip =
fmt::format("{}.{}.{}.{}", op_data->addr[idx][1].ip4[0], op_data->addr[idx][1].ip4[1],
op_data->addr[idx][1].ip4[2], op_data->addr[idx][1].ip4[3]);
} else {
server_ip = fmt::format("{:04x}:{:04x}:{:04x}:{:04x}:{:04x}:{:04x}:{:04x}:{:04x}",
op_data.addr[idx][0].ip6[0], op_data.addr[idx][0].ip6[1],
op_data.addr[idx][0].ip6[2], op_data.addr[idx][0].ip6[3],
op_data.addr[idx][0].ip6[4], op_data.addr[idx][0].ip6[5],
op_data.addr[idx][0].ip6[6], op_data.addr[idx][0].ip6[7]);
op_data->addr[idx][0].ip6[0], op_data->addr[idx][0].ip6[1],
op_data->addr[idx][0].ip6[2], op_data->addr[idx][0].ip6[3],
op_data->addr[idx][0].ip6[4], op_data->addr[idx][0].ip6[5],
op_data->addr[idx][0].ip6[6], op_data->addr[idx][0].ip6[7]);
client_ip = fmt::format("{:04x}:{:04x}:{:04x}:{:04x}:{:04x}:{:04x}:{:04x}:{:04x}",
op_data.addr[idx][1].ip6[0], op_data.addr[idx][1].ip6[1],
op_data.addr[idx][1].ip6[2], op_data.addr[idx][1].ip6[3],
op_data.addr[idx][1].ip6[4], op_data.addr[idx][1].ip6[5],
op_data.addr[idx][1].ip6[6], op_data.addr[idx][1].ip6[7]);
op_data->addr[idx][1].ip6[0], op_data->addr[idx][1].ip6[1],
op_data->addr[idx][1].ip6[2], op_data->addr[idx][1].ip6[3],
op_data->addr[idx][1].ip6[4], op_data->addr[idx][1].ip6[5],
op_data->addr[idx][1].ip6[6], op_data->addr[idx][1].ip6[7]);
}
LOGV(INFO) << "server_ip:" << server_ip << " client_ip:" << client_ip;
if (op_data.is_tcp)
backend_thread.emplace_back(keep_alive, server_ip, op_data.addr[idx][0].port, client_ip,
op_data.addr[idx][1].port); // server to client? client to server?
LOGV(INFO) << "server_ip:" << server_ip << ":" << op_data->addr[idx][0].port
<< " client_ip:" << client_ip << ":" << op_data->addr[idx][1].port;
if (op_data->is_tcp)
backend_thread.emplace_back(keep_alive, server_ip, op_data->addr[idx][0].port, client_ip,
op_data->addr[idx][1].port); // server to client? client to server?
forward_pair.emplace_back(server_ip, client_ip, "");
// send the fin to server
op_data->op = MVVM_SOCK_FIN;
sleep(2);
if (!op_data->is_tcp) {
LOGV(INFO) << "send fin";
// send (client_fd, &op_data, sizeof(op_data), 0);
// sendto(, &op_data, sizeof(op_data), 0, (struct sockaddr *)&address, sizeof(address));
send_fin(client_ip, op_data->addr[idx][1].port, server_ip, op_data->addr[idx][0].port,
(char *)op_data);
} else {
// send(, &op_data, sizeof(op_data), 0); // if tcp, require continue the syn?
send_fin_tcp(client_ip, op_data->addr[idx][1].port, server_ip, op_data->addr[idx][0].port,
(char *)op_data);
}
}
// send fin
op_data.op = MVVM_SOCK_FIN;
if (!op_data.is_tcp) {
// send (client_fd, &op_data, sizeof(op_data), 0);
// sendto(, &op_data, sizeof(op_data), 0, (struct sockaddr *)&address, sizeof(address));
send_fin(client_ip, op_data.addr[0][1].port, server_ip, op_data.addr[0][0].port, (char*)&op_data);
} else {
// send(, &op_data, sizeof(op_data), 0); // if tcp, require continue the syn?
send_fin_tcp(client_ip, op_data.addr[0][1].port, server_ip, op_data.addr[0][0].port, (char*)&op_data);
}

break;
case MVVM_SOCK_RESUME:
// resume
LOGV(ERROR) << "resume";
auto tmp_tuple = forward_pair[forward_pair.size() - 1];
std::get<2>(tmp_tuple) =
fmt::format("{}.{}.{}.{}", op_data.addr[0][0].ip4[0], op_data.addr[0][0].ip4[1],
op_data.addr[0][0].ip4[2], op_data.addr[0][0].ip4[3]);
fmt::format("{}.{}.{}.{}", op_data->addr[0][0].ip4[0], op_data->addr[0][0].ip4[1],
op_data->addr[0][0].ip4[2], op_data->addr[0][0].ip4[3]);
LOGV(ERROR) << "forward_pair[forward_pair.size()]" << std::get<0>(forward_pair[forward_pair.size() - 1])
<< std::get<1>(forward_pair[forward_pair.size() - 1]);
is_forward = true;
// for udp forward from source to remote
// stop keep_alive
if (op_data.is_tcp) {
if (op_data->is_tcp) {
backend_thread.pop_back();
}
sleep(1);
Expand Down
2 changes: 2 additions & 0 deletions include/wamr.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class WAMRInstance {
int invoke_ftell(uint32 fd, uint32 offset, uint32 whench);
int invoke_preopen(uint32 fd, const std::string &path);
int invoke_sock_open(uint32_t poolfd, int af, int socktype, uint32_t *sockfd);
int invoke_recv(int sockfd, uint8 **buf, size_t len, int flags);
int invoke_recvfrom(int sockfd, uint8 **buf, size_t len, int flags, struct sockaddr *src_addr, socklen_t *addrlen);
~WAMRInstance();
};
#endif // MVVM_WAMR_H
1 change: 1 addition & 0 deletions include/wamr_wasi_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct SocketMetaData {
SocketAddrPool socketAddress{};
WasiSockOpenData socketOpenData{};
int replay_start_index{};
bool is_collection=false;
#if !defined(_WIN32)
WasiSockSendToData socketSentToData{}; //
std::vector<WasiSockRecvFromData> socketRecvFromDatas;
Expand Down
8 changes: 4 additions & 4 deletions src/checkpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ std::mutex as_mtx;
void serialize_to_file(WASMExecEnv *instance) {
// gateway
#if !defined(_WIN32)
if (!wamr->socket_fd_map_.empty()) {
if (!wamr->socket_fd_map_.empty() && gettid() == getpid()) {
// tell gateway to keep alive the server
struct sockaddr_in addr {};
int fd = 0;
ssize_t rc;
SocketAddrPool src_addr{};

for (auto [idx, socks] : enumerate(wamr->socket_fd_map_)) {
auto [tmp_fd, sock_data] = socks;
for (auto [tmp_fd, sock_data] : wamr->socket_fd_map_) {
int idx =wamr->op_data.size;
src_addr = sock_data.socketAddress;
auto tmp_ip4 =
fmt::format("{}.{}.{}.{}", src_addr.ip4[0], src_addr.ip4[1], src_addr.ip4[2], src_addr.ip4[3]);
Expand Down
Loading

0 comments on commit 40acdad

Please sign in to comment.