From 25008522b87f0d3783c866ac470923bfedd1d756 Mon Sep 17 00:00:00 2001 From: victoryang00 Date: Sun, 22 Dec 2024 04:06:53 -0800 Subject: [PATCH] finish nn context --- include/wamr_interp_frame.h | 1 + include/wamr_wasi_nn_context.h | 26 ++------- src/wamr_interp_frame.cpp | 83 ++++++++++++++++++--------- src/wamr_wasi_nn_context.cpp | 102 ++++++++++++++++++++++----------- 4 files changed, 130 insertions(+), 82 deletions(-) diff --git a/include/wamr_interp_frame.h b/include/wamr_interp_frame.h index 9c22364..79f3c8d 100644 --- a/include/wamr_interp_frame.h +++ b/include/wamr_interp_frame.h @@ -23,6 +23,7 @@ struct WAMRInterpFrame { /* Instruction pointer of the bytecode array. */ uint32 ip{}; int32 function_index{}; + std::string function_name{}; /* Operand stack top pointer of the current frame. The bottom of the stack is the next cell after the last local variable. */ uint32 sp{}; diff --git a/include/wamr_wasi_nn_context.h b/include/wamr_wasi_nn_context.h index 7fd34c8..8f3b770 100644 --- a/include/wamr_wasi_nn_context.h +++ b/include/wamr_wasi_nn_context.h @@ -25,31 +25,17 @@ /* Maximum number of graph execution context per WASM instance*/ #define MAX_GRAPH_EXEC_CONTEXTS_PER_INST 10 -struct WAMRWASINNGraph { - std::vector buffer; -}; - -struct WAMRWASINNInterpreter { - // std::unique_ptr interpreter; - std::vector interpreter; - uint32_t a; // placeholder -}; - struct WAMRWASINNModel { - // std::unique_ptr model; - std::vector model; - execution_target target; + std::string model_name; + std::vector input_tensor; + std::vector dims; }; struct WAMRWASINNContext { - bool is_initialized; - graph_encoding current_encoding; - std::vector graph; // TODO: support multiple graph - uint32_t current_models; + bool is_initialized = false; + graph_encoding current_encoding = graph_encoding::tensorflow; + uint32_t current_models = 0; WAMRWASINNModel models[MAX_GRAPHS_PER_INST]; - uint32_t current_interpreters; - WAMRWASINNInterpreter interpreters[MAX_GRAPH_EXEC_CONTEXTS_PER_INST]; - void dump_impl(WASINNContext *env); void restore_impl(WASINNContext *env); }; diff --git a/src/wamr_interp_frame.cpp b/src/wamr_interp_frame.cpp index 3a42b97..7a4c006 100644 --- a/src/wamr_interp_frame.cpp +++ b/src/wamr_interp_frame.cpp @@ -18,41 +18,69 @@ #include extern WAMRInstance *wamr; void WAMRInterpFrame::dump_impl(WASMInterpFrame *env) { - SPDLOG_ERROR("not impl"); - exit(-1); + if (env->function) { + wamr->set_func(env->function->u.func); + + if (env->ip) + ip = env->ip - env->function->u.func->code; // here we need to get the + // offset from the code start. +#if WASM_ENABLE_FAST_INTERP == 0 + if (env->sp) { + sp = reinterpret_cast(env->sp) - + ((uint8 *)wamr->get_exec_env()->wasm_stack.s.bottom); // offset to the wasm_stack_top + } +#endif + if (env->function->u.func->field_name) + function_name = env->function->u.func->field_name; + else + function_name = env->function->u.func_import->field_name; + } } std::vector> wasm_replay_csp_bytecode(WASMExecEnv *exec_env, WASMInterpFrame *frame, const uint8 *target_addr); void WAMRInterpFrame::restore_impl(WASMInterpFrame *env) { auto module_inst = (WASMModuleInstance *)wamr->get_exec_env()->module_inst; - if (0 <= function_index && function_index < module_inst->e->function_count) { - // LOGV(INFO) << fmt::format("function_index {} restored", function_index); + if (0 < function_index && function_index < module_inst->e->function_count) { + // LOGV(INFO) << fmt::format("function_index {} restored", + // function_index); env->function = &module_inst->e->functions[function_index]; + if (env->function->is_import_func) { + LOG_DEBUG("is_import_func"); + exit(-1); + } } else { - // LOGV(ERROR) << fmt::format("function_index {} invalid", function_index); - exit(-1); - } - - if (env->function->is_import_func) { - SPDLOG_ERROR("is_import_func"); - exit(-1); + auto target_module = wamr->get_module_instance()->e; + for (uint32 i = 0; i < target_module->function_count; i++) { + auto cur_func = &target_module->functions[i]; + if (!cur_func->is_import_func) { + if (!strcmp(cur_func->u.func->field_name, function_name.c_str())) { + function_index = i; + env->function = cur_func; + break; + } + } else { + if (!strcmp(cur_func->u.func_import->field_name, function_name.c_str())) { + function_index = i; + env->function = cur_func; + break; + } + } + } } - wamr->set_func(env->function->u.func); auto cur_func = env->function; WASMFunction *cur_wasm_func = cur_func->u.func; - SPDLOG_INFO("ip_offset {} sp_offset {}, code start {}", ip, sp, (void *)wasm_get_func_code(env->function)); + LOG_DEBUG("ip_offset %d sp_offset %d, code start %p", ip, sp, (void *)wasm_get_func_code(env->function)); env->ip = wasm_get_func_code(env->function) + ip; - + memcpy(env->lp, stack_frame.data(), stack_frame.size() * sizeof(uint32)); +#if WASM_ENABLE_FAST_INTERP == 0 env->sp_bottom = env->lp + cur_func->param_cell_num + cur_func->local_cell_num; env->sp = env->lp + sp; env->sp_boundary = env->sp_bottom + cur_wasm_func->max_stack_cell_num; - memcpy(env->lp, stack_frame.data(), stack_frame.size() * sizeof(uint32)); - // print_csps(csp); - SPDLOG_DEBUG("wasm_replay_csp_bytecode {} {} {}", (void *)wamr->get_exec_env(), (void *)env, (void *)env->ip); + LOG_DEBUG("wasm_replay_csp_bytecode %d %d %d", (void *)wamr->get_exec_env(), (void *)env, (void *)env->ip); env->csp_bottom = (WASMBranchBlock *)env->sp_boundary; if (env->function->u.func && !env->function->is_import_func && env->sp_bottom) { @@ -62,15 +90,16 @@ void WAMRInterpFrame::restore_impl(WASMInterpFrame *env) { int i = 0; for (auto &&csp_item : csp) { restore(csp_item.get(), env->csp_bottom + i); - SPDLOG_ERROR("csp_bottom {}", ((uint8 *)env->csp_bottom + i) - wamr->get_exec_env()->wasm_stack.s.bottom); + LOG_DEBUG("csp_bottom %d", ((uint8 *)env->csp_bottom + i) - wamr->get_exec_env()->wasm_stack.s.bottom); i++; } env->csp = env->csp_bottom + csp.size(); env->csp_boundary = env->csp_bottom + env->function->u.func->max_block_num; } - SPDLOG_INFO("func_idx {} ip {} sp {} stack bottom {}", function_index, (void *)env->ip, (void *)env->sp, - (void *)wamr->get_exec_env()->wasm_stack.s.bottom); +#endif + LOG_DEBUG("func_idx %d ip %p sp %p stack bottom %p", function_index, (void *)env->ip, (void *)env->sp, + (void *)wamr->get_exec_env()->wasm_stack.s.bottom); } #if WASM_ENABLE_AOT != 0 @@ -79,8 +108,8 @@ void WAMRInterpFrame::dump_impl(AOTFrame *env) { ip = env->ip_offset; sp = env->sp - env->lp; // offset to the wasm_stack_top - SPDLOG_INFO("function_index {} ip_offset {} lp {} sp {} sp_offset {}", env->func_index, ip, (void *)env->lp, - (void *)env->sp, sp); + LOG_DEBUG("function_index %d ip_offset %d lp %p sp %p sp_offset %lu", env->func_index, ip, (void *)env->lp, + (void *)env->sp, sp); stack_frame = std::vector(env->lp, env->sp); @@ -90,7 +119,7 @@ void WAMRInterpFrame::dump_impl(AOTFrame *env) { std::cout << std::endl; } void WAMRInterpFrame::restore_impl(AOTFrame *env) { - SPDLOG_ERROR("not impl"); + LOG_DEBUG("not impl"); exit(-1); } @@ -239,7 +268,7 @@ std::vector> wasm_replay_csp_bytecode(WASMExecE auto e = std::make_unique(); \ e->cell_num = cell_num; \ e->begin_addr = frame_ip - cur_func->u.func->code; \ - e->target_addr = (_target_addr)-cur_func->u.func->code; \ + e->target_addr = (_target_addr) - cur_func->u.func->code; \ e->frame_sp = reinterpret_cast(frame_sp - (param_cell_num)) - exec_env->wasm_stack.s.bottom; \ csp.emplace_back(std::move(e)); \ } @@ -840,15 +869,15 @@ std::vector> wasm_replay_csp_bytecode(WASMExecE case SIMD_v128_load64_splat: case SIMD_v128_store: /* memarg align */ - skip_leb_uint32(frame_ip,frame_ip_end); + skip_leb_uint32(frame_ip, frame_ip_end); /* memarg offset*/ - skip_leb_uint32(frame_ip,frame_ip_end); + skip_leb_uint32(frame_ip, frame_ip_end); break; case SIMD_v128_const: case SIMD_v8x16_shuffle: /* immByte[16] immLaneId[16] */ - CHECK_BUF1(frame_ip,frame_ip_end, 16); + CHECK_BUF1(frame_ip, frame_ip_end, 16); frame_ip += 16; break; diff --git a/src/wamr_wasi_nn_context.cpp b/src/wamr_wasi_nn_context.cpp index 227e316..0ac0976 100644 --- a/src/wamr_wasi_nn_context.cpp +++ b/src/wamr_wasi_nn_context.cpp @@ -13,7 +13,9 @@ #if WASM_ENABLE_WASI_NN != 0 #include "wamr_wasi_nn_context.h" #include "wamr.h" -#include +#include "wasi_nn.h" +#include +#include #define MAX_MODEL_SIZE 85000000 #define MAX_OUTPUT_TENSOR_SIZE 1000000 @@ -23,50 +25,80 @@ extern WAMRInstance *wamr; void WAMRWASINNContext::dump_impl(WASINNContext *env) { fprintf(stderr, "dump_impl\n"); // take the model + // get target } void WAMRWASINNContext::restore_impl(WASINNContext *env) { fprintf(stderr, "restore_impl\n"); - // replay the graph initialization - wasm_get_output(env->wasm_instance, env->wasm_instance->output_index, env->wasm_instance->output_size); - FILE *pFile = fopen(model_name, "r"); - if (pFile == NULL) - return invalid_argument; + // replay the graph initialization + if (!is_initialized) { + for (const auto &model : models) { + FILE *pFile = fopen(model.model_name.c_str(), "r"); + if (pFile == nullptr) + return; - uint8_t *buffer; - size_t result; + uint8_t *buffer; + size_t result; - // allocate memory to contain the whole file: - buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE); - if (buffer == NULL) { - fclose(pFile); - return missing_memory; - } + // allocate memory to contain the whole file: + buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE); + if (buffer == nullptr) { + fclose(pFile); + return; + } - result = fread(buffer, 1, MAX_MODEL_SIZE, pFile); - if (result <= 0) { - fclose(pFile); - free(buffer); - return missing_memory; - } + result = fread(buffer, 1, MAX_MODEL_SIZE, pFile); + if (result <= 0) { + fclose(pFile); + free(buffer); + return; + } - graph_builder_array arr; + graph_builder_array arr; - arr.size = 1; - arr.buf = (graph_builder *)malloc(sizeof(graph_builder)); - if (arr.buf == NULL) { - fclose(pFile); - free(buffer); - return missing_memory; - } + arr.size = 1; + arr.buf = (graph_builder *)malloc(sizeof(graph_builder)); + if (arr.buf == nullptr) { + fclose(pFile); + free(buffer); + return; + } + + arr.buf[0].size = result; + arr.buf[0].buf = buffer; + graph g; + error res = load(&arr, tensorflowlite, execution_target::gpu, &g); + if (res != error::success) { + res = load(&arr, tensorflowlite, execution_target::tpu, &g); + } + if (res != error::success) { + res = load(&arr, tensorflowlite, execution_target::cpu, &g); + } - arr.buf[0].size = result; - arr.buf[0].buf = buffer; + graph_execution_context ctx; + if (init_execution_context(g, &ctx) != success) { + return; + } + tensor_dimensions dims; + dims.size = INPUT_TENSOR_DIMS; + dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t)); + if (dims.buf == NULL) + return; - error res = load(&arr, tensorflowlite, target, g); - fclose(pFile); - free(buffer); - free(arr.buf); - return res; + tensor tensor; + tensor.dimensions = &dims; + for (int i = 0; i < tensor.dimensions->size; ++i) + tensor.dimensions->buf[i] = model.dims[i]; + tensor.type = fp32; + tensor.data = (uint8_t *)model.input_tensor.data(); + error err = set_input(ctx, 0, &tensor); + + free(dims.buf); + + fclose(pFile); + free(buffer); + free(arr.buf); + } + } } #endif \ No newline at end of file