Skip to content

Commit

Permalink
fix multi-thread
Browse files Browse the repository at this point in the history
Signed-off-by: victoryang00 <[email protected]>
  • Loading branch information
victoryang00 committed Feb 1, 2024
1 parent d6aacc6 commit 79a52e5
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 82 deletions.
25 changes: 13 additions & 12 deletions include/wamr.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class WAMRInstance {
std::vector<const char *> arg_{};
std::vector<const char *> addr_{};
std::vector<const char *> ns_pool_{};
std::vector<WAMRExecEnv*> execEnv{};
std::map<int, std::tuple<std::string, std::vector<std::tuple<int, int, fd_op>>>> fd_map_{};
// add offset to pair->tuple, 3rd param 'int'
std::map<int, int> new_sock_map_{};
Expand All @@ -51,19 +52,21 @@ class WAMRInstance {
// lwcp is LightWeight CheckPoint
std::map<ssize_t, int> lwcp_list;
size_t ready = 0;
std::mutex as_mtx;
std::mutex as_mtx{};
std::vector<struct sync_op_t> sync_ops;
bool should_snapshot{};
WASMMemoryInstance **tmp_buf;
uint32 tmp_buf_size;
void replay_sync_ops(bool, wasm_exec_env_t);
void register_tid_map();
uint32 tmp_buf_size{};
std::vector<struct sync_op_t>::iterator sync_iter;
std::mutex sync_op_mutex;
std::condition_variable sync_op_cv;
std::map<ssize_t, ssize_t> tid_map;
std::map<ssize_t, ssize_t> child_tid_map;
std::map<ssize_t, size_t> tid_start_arg_map;
uint32 id{};
size_t cur_thread;
std::chrono::time_point<std::chrono::high_resolution_clock> time;
std::vector<long long> latencies;
bool is_jit{};
bool is_aot{};
char error_buf[128]{};
Expand All @@ -76,31 +79,29 @@ class WAMRInstance {
explicit WAMRInstance(const char *wasm_path, bool is_jit);

void instantiate();
void recover(std::vector<std::unique_ptr<WAMRExecEnv>> *execEnv);
void recover(std::vector<std::unique_ptr<WAMRExecEnv>> *);
bool load_wasm_binary(const char *wasm_path, char **buffer_ptr);
bool get_int3_addr();
bool replace_int3_with_nop();
bool replace_mfence_with_nop();
bool replace_nop_with_int3();
std::chrono::time_point<std::chrono::high_resolution_clock> time;
std::vector<long long> latencies;
void replay_sync_ops(bool, wasm_exec_env_t);
void register_tid_map();
WASMFunction *get_func();
void set_func(WASMFunction *);
#if WASM_ENABLE_AOT != 0
std::vector<uint32> get_args();
AOTFunctionInstance *get_func(int index);
[[nodiscard]] AOTModule *get_module() const;
#endif
WASMExecEnv *get_exec_env();
WASMModuleInstance *get_module_instance() const;

#if WASM_ENABLE_AOT != 0
AOTModule *get_module() const;
#endif
[[nodiscard]] WASMModuleInstance *get_module_instance() const;

void set_wasi_args(WAMRWASIContext &addrs);
void set_wasi_args(const std::vector<std::string> &dir_list, const std::vector<std::string> &map_dir_list,
const std::vector<std::string> &env_list, const std::vector<std::string> &arg_list,
const std::vector<std::string> &addr_list, const std::vector<std::string> &ns_lookup_pool);
void spawn_child(WASMExecEnv *main_env);

int invoke_main();
void invoke_init_c();
Expand Down
2 changes: 2 additions & 0 deletions include/wamr_export.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ int get_sock_fd(int);
void insert_sync_op(wasm_exec_env_t exec_env, const uint32 *mutex, enum sync_op locking);
void restart_execution(uint32 targs);
void insert_tid_start_arg(ssize_t, size_t);
void change_thread_id_to_child(ssize_t, ssize_t);
void insert_parent_child(ssize_t, ssize_t);
extern int pthread_create_wrapper(wasm_exec_env_t exec_env, uint32 *thread, const void *attr, uint32 elem_index,
uint32 arg);
extern int32 pthread_mutex_lock_wrapper(wasm_exec_env_t, uint32 *);
Expand Down
18 changes: 2 additions & 16 deletions include/wamr_memory_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,15 @@ struct WAMRMemoryInstance {
void dump_impl(WASMMemoryInstance *env) {
module_type = env->module_type;
ref_count = env->ref_count;
LOGV(ERROR)<< "ref_count:" << ref_count;
LOGV(ERROR) << "ref_count:" << ref_count;
num_bytes_per_page = env->num_bytes_per_page;
cur_page_count = env->cur_page_count;
max_page_count = env->max_page_count;
memory_data.resize(env->memory_data_size);
memcpy(memory_data.data(), env->memory_data, env->memory_data_size);
heap_data = std::vector<uint8>(env->heap_data, env->heap_data_end);
};
void restore_impl(WASMMemoryInstance *env) {
env->module_type = module_type;
env->ref_count = ref_count+1;
LOGV(ERROR)<< "ref_count:" << env->ref_count;
env->num_bytes_per_page = num_bytes_per_page;
env->cur_page_count = cur_page_count;
env->max_page_count = max_page_count;
env->memory_data_size = memory_data.size();
env->memory_data = (uint8 *)malloc(env->memory_data_size);
memcpy(env->memory_data, memory_data.data(), env->memory_data_size);
env->memory_data_end = env->memory_data + memory_data.size();
env->heap_data = (uint8 *)malloc(heap_data.size());
memcpy(env->heap_data, heap_data.data(), heap_data.size());
env->heap_data_end = env->heap_data + heap_data.size();
};
void restore_impl(WASMMemoryInstance *env);
};

template <SerializerTrait<WASMMemoryInstance *> T> void dump(T t, WASMMemoryInstance *env) { t->dump_impl(env); }
Expand Down
1 change: 1 addition & 0 deletions include/wamr_wasi_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct WAMRWASIContext {
std::vector<std::string> addr_pool;
std::vector<std::string> ns_lookup_list;
std::map<int, int> tid_start_arg_map;
std::map<int, int> child_tid_map;
uint32_t exit_code;

void dump_impl(WASIArguments *env);
Expand Down
10 changes: 6 additions & 4 deletions src/restore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,26 @@ int main(int argc, char **argv) {
auto a = struct_pack::deserialize<std::vector<std::unique_ptr<WAMRExecEnv>>>(*reader).value();
// is server for all and the is server?
#if !defined(_WIN32)
if (!a[a.size()-1]->module_inst.wasi_ctx.socket_fd_map.empty()) { // new ip, old ip // only if tcp requires keepalive
if (!a[a.size() - 1]
->module_inst.wasi_ctx.socket_fd_map.empty()) { // new ip, old ip // only if tcp requires keepalive
// tell gateway to stop keep alive the server
struct sockaddr_in addr {};
int fd = 0;
bool is_tcp_server;
SocketAddrPool src_addr = wamr->local_addr;
LOGV(INFO) << "new ip is "
<< fmt::format("{}.{}.{}.{}:{}", src_addr.ip4[0], src_addr.ip4[1], src_addr.ip4[2], src_addr.ip4[3], src_addr.port);
<< fmt::format("{}.{}.{}.{}:{}", src_addr.ip4[0], src_addr.ip4[1], src_addr.ip4[2], src_addr.ip4[3],
src_addr.port);
// got from wamr
for (auto &[fd, socketMetaData] : a[a.size()-1]->module_inst.wasi_ctx.socket_fd_map) {
for (auto &[fd, socketMetaData] : a[a.size() - 1]->module_inst.wasi_ctx.socket_fd_map) {
wamr->op_data.is_tcp |= socketMetaData.type;
is_tcp_server |= socketMetaData.is_server;
}
is_tcp_server &= wamr->op_data.is_tcp;

wamr->op_data.op = is_tcp_server ? MVVM_SOCK_RESUME_TCP_SERVER : MVVM_SOCK_RESUME;
wamr->op_data.addr[0][0] = src_addr;

// Create a socket
if ((fd = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
LOGV(ERROR) << "socket error";
Expand Down
107 changes: 61 additions & 46 deletions src/wamr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,25 +643,26 @@ void WAMRInstance::replay_sync_ops(bool main, wasm_exec_env_t exec_env) {
#endif
WAMRExecEnv *child_env;
// will call pthread create wrapper if needed?
void WAMRInstance::recover(std::vector<std::unique_ptr<WAMRExecEnv>> *execEnv) {
void WAMRInstance::recover(std::vector<std::unique_ptr<WAMRExecEnv>> *e_) {
execEnv.reserve(e_->size());
std::transform(e_->begin(), e_->end(), std::back_inserter(execEnv),
[](const std::unique_ptr<WAMRExecEnv> &uniquePtr) { return uniquePtr ? uniquePtr.get() : nullptr; });
// got this done tommorrow.
// order threads by id (descending)
std::sort(execEnv->begin(), execEnv->end(),
[](const std::unique_ptr<WAMRExecEnv> &a, const std::unique_ptr<WAMRExecEnv> &b) {
return a->frames.back()->function_index > b->frames.back()->function_index;
;
});
std::sort(execEnv.begin(), execEnv.end(), [](const auto &a, const auto &b) {
return a->frames.back()->function_index > b->frames.back()->function_index;
});

argptr = (ThreadArgs **)malloc(sizeof(void *) * execEnv->size());
uint32 id = 0;
set_wasi_args(execEnv->back()->module_inst.wasi_ctx);
argptr = (ThreadArgs **)malloc(sizeof(void *) * execEnv.size());
set_wasi_args(execEnv.back()->module_inst.wasi_ctx);

instantiate();
auto mi = module_inst;

get_int3_addr();
replace_int3_with_nop();

restore(execEnv->back().get(), cur_env);
restore(execEnv.back(), cur_env);
auto main_env = cur_env;
auto main_saved_call_chain = main_env->restore_call_chain;
cur_thread = main_env->cur_count;
Expand All @@ -675,43 +676,8 @@ void WAMRInstance::recover(std::vector<std::unique_ptr<WAMRExecEnv>> *execEnv) {

invoke_init_c();
invoke_preopen(1, "/dev/stdout");

#if !defined(_WIN32)
for (auto [idx, exec_] : *execEnv | enumerate) {
if (idx + 1 == execEnv->size()) {
// the last one should be the main thread
break;
}
child_env = exec_.get();

// requires to record the args and callback for the pthread.
auto thread_arg = ThreadArgs{main_env};
main_env->restore_call_chain = NULL;

argptr[id] = &thread_arg;

// restart thread execution
fprintf(stderr, "pthread_create_wrapper, func %d\n", child_env->cur_count);
// module_inst = wasm_runtime_instantiate(module, stack_size, heap_size, error_buf, sizeof(error_buf));
exec_env->is_restore = false;
auto s = exec_env->restore_call_chain;
exec_env->restore_call_chain = NULL;
invoke_init_c();
invoke_preopen(1, "/dev/stdout");
exec_env->is_restore = true;
exec_env->restore_call_chain = s;
if (tid_start_arg_map.find(child_env->cur_count) != tid_start_arg_map.end())
thread_spawn_wrapper(main_env, tid_start_arg_map[child_env->cur_count]);
else
pthread_create_wrapper(main_env, nullptr, nullptr, id, id); // tid_map

thread_init.acquire();
id++;
}
module_inst = mi;
fprintf(stderr, "child spawned %p\n\n", main_env);
spawn_child(main_env);
// restart main thread execution
#endif
if (!is_aot) {
wasm_interp_call_func_bytecode(get_module_instance(), get_exec_env(), get_exec_env()->cur_frame->function,
get_exec_env()->cur_frame->prev_frame);
Expand All @@ -723,6 +689,7 @@ void WAMRInstance::recover(std::vector<std::unique_ptr<WAMRExecEnv>> *execEnv) {
// invoke_init_c();
// invoke_preopen(1, "/dev/stdout");
fprintf(stderr, "wakeup.release\n");
sleep(1);
wakeup.release(100);

cur_env->is_restore = true;
Expand All @@ -740,6 +707,52 @@ void WAMRInstance::recover(std::vector<std::unique_ptr<WAMRExecEnv>> *execEnv) {
}
}

void WAMRInstance::spawn_child(WASMExecEnv *cur_env) {
#if !defined(_WIN32)
this->as_mtx.lock();
for (auto [idx, exec_] : execEnv | enumerate) {
if (idx + 1 == execEnv.size()) {
// the last one should be the main thread
break;
}
// std::memcpy(child_env, exec_.get(), sizeof(WASMExecEnv));
child_env = exec_;
// requires to record the args and callback for the pthread.
auto thread_arg = ThreadArgs{cur_env};
//cur_env->restore_call_chain = NULL;
// cur_env->is_restore = false;

argptr[id] = &thread_arg;
auto parent = child_tid_map[child_env->cur_count];
if (parent == cur_env->cur_count) {
LOGV(ERROR) << parent << " " << child_env->cur_count;
// restart thread execution
fprintf(stderr, "pthread_create_wrapper, func %d\n", child_env->cur_count);
// module_inst = wasm_runtime_instantiate(module, stack_size, heap_size, error_buf, sizeof(error_buf));
exec_env->is_restore = false;
auto s = exec_env->restore_call_chain;
exec_env->restore_call_chain = NULL;
invoke_init_c();
invoke_preopen(1, "/dev/stdout");
exec_env->is_restore = true;
exec_env->restore_call_chain = s;
if (tid_start_arg_map.find(child_env->cur_count) != tid_start_arg_map.end()) {
// find the parent env

// main thread
thread_spawn_wrapper(cur_env, tid_start_arg_map[child_env->cur_count]);

} else {
pthread_create_wrapper(cur_env, nullptr, nullptr, id, id); // tid_map
}
fprintf(stderr, "child spawned %p\n\n", cur_env);
this->as_mtx.unlock();

thread_init.acquire();
}
}
#endif
}
WASMFunction *WAMRInstance::get_func() { return static_cast<WASMFunction *>(func); }
void WAMRInstance::set_func(WASMFunction *f) { func = static_cast<WASMFunction *>(f); }
void WAMRInstance::set_wasi_args(const std::vector<std::string> &dir_list, const std::vector<std::string> &map_dir_list,
Expand Down Expand Up @@ -767,7 +780,9 @@ void WAMRInstance::set_wasi_args(WAMRWASIContext &context) {
extern WAMRInstance *wamr;
extern "C" { // stop name mangling so it can be linked externally
void wamr_wait(wasm_exec_env_t exec_env) {

fprintf(stderr, "child getting ready to wait\n");
wamr->spawn_child(exec_env);
// register thread id mapping
wamr->register_tid_map();
thread_init.release(1);
Expand Down
12 changes: 12 additions & 0 deletions src/wamr_export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,18 @@ void insert_sync_op(wasm_exec_env_t exec_env, const uint32 *mutex, enum sync_op
void insert_tid_start_arg(ssize_t tid, size_t start_arg){
wamr->tid_start_arg_map[tid] = start_arg;
};
void change_thread_id_to_child(ssize_t tid, ssize_t child_tid){
for (auto &[k, v] : wamr->child_tid_map) {
if (k == child_tid) {
wamr->child_tid_map[tid] = v;
wamr->child_tid_map.erase(k);
break;
}
}
};
void insert_parent_child(ssize_t tid, ssize_t child_tid){
wamr->child_tid_map[child_tid] = tid;
};
void lightweight_checkpoint(WASMExecEnv *exec_env) {
int fid = -1;
if (((AOTFrame *)exec_env->cur_frame)) {
Expand Down
22 changes: 22 additions & 0 deletions src/wamr_mmeory_instance.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "wamr.h"
#include "wamr_memory_instance.h"
extern WAMRInstance *wamr;
void WAMRMemoryInstance::restore_impl(WASMMemoryInstance *env) {
env->module_type = module_type;
env->ref_count = ref_count + 1;
LOGV(ERROR) << "ref_count:" << env->ref_count;
env->num_bytes_per_page = num_bytes_per_page;
env->cur_page_count = cur_page_count;
env->max_page_count = max_page_count;
env->memory_data_size = memory_data.size();
if (env->ref_count > 0) // shared memory
env->memory_data =
(uint8 *)mmap(NULL, wamr->heap_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0);
else
env->memory_data = (uint8 *)malloc(env->memory_data_size);
memcpy(env->memory_data, memory_data.data(), env->memory_data_size);
env->memory_data_end = env->memory_data + memory_data.size();
env->heap_data = (uint8 *)malloc(heap_data.size());
memcpy(env->heap_data, heap_data.data(), heap_data.size());
env->heap_data_end = env->heap_data + heap_data.size();
};
8 changes: 8 additions & 0 deletions src/wamr_wasi_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ void WAMRWASIContext::dump_impl(WASIArguments *env) {
for (auto &[k, v] : wamr->tid_start_arg_map) {
tid_start_arg_map[k] = v;
}
for (auto &[k, v] : wamr->child_tid_map) {
child_tid_map[k] = v;
LOGV(ERROR) << "child_tid_map: " << k << " " << v;
}
// only one thread has fd_map
if (wamr->should_snapshot)
for (auto [fd, res] : wamr->fd_map_) {
Expand Down Expand Up @@ -129,6 +133,10 @@ void WAMRWASIContext::restore_impl(WASIArguments *env) {
for (auto &[k, v] : tid_start_arg_map) {
wamr->tid_start_arg_map[k] = v;
}
for (auto &[k, v] : child_tid_map) {
wamr->child_tid_map[k] = v;
LOGV(ERROR) << "child_tid_map: " << k << " " << v;
}
for (auto [fd, res] : this->fd_map) {
// differ from path from file
auto path = std::get<0>(res);
Expand Down
4 changes: 2 additions & 2 deletions test/multi-thread.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pthread_mutex_t g_count_lock;
static void *thread(void *arg) {
for (int i = 0; i < NUM_ITER; i++) {
__atomic_fetch_add(&g_count, 1, __ATOMIC_SEQ_CST);
if (i % 100 == 0)
// if (i % 100 == 0)
printf("print!!!%d\n", i);
}
printf("Value of g_count is %d\n", g_count);
Expand Down Expand Up @@ -44,6 +44,6 @@ int main(int argc, char **argv) {
// if (g_count != MAX_NUM_THREADS * NUM_ITER) {
// __builtin_trap();
// }
__wasilibc_nocwd_openat_nomode(1,"/dev/stdout",0);
// __wasilibc_nocwd_openat_nomode(1,"/dev/stdout",0);
exit(0);
}
2 changes: 1 addition & 1 deletion test/mutex.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@ int main(int argc, char **argv) {
if (g_count != MAX_NUM_THREADS * NUM_ITER) {
__builtin_trap();
}
//__wasilibc_nocwd_openat_nomode(1,"/dev/stdout",0);
// __wasilibc_nocwd_openat_nomode(1,"/dev/stdout",0);
return -1;
}

0 comments on commit 79a52e5

Please sign in to comment.